比Stable Diffusion便宜118倍,1890美元訓出11.6億參數高質量文生圖模型

近日,來自加州大學爾灣分校等機構的研究人員,利用延遲掩蔽、MoE、分層擴展等策略,將擴散模型的訓練成本降到了1890美元。

訓練一個擴散模型要多少錢?

之前最便宜的方法(Wuerstchen)用了28400美元,而像Stable Diffusion這樣的模型還要再貴一個數量級。

大模型時代,一般人根本玩不起。想要各種文生小姐姐,還得靠廠商們負重前行

為了降低這龐大的開銷,研究者們嘗試了各種方案。

比如,原始的擴散模型從噪聲到圖像大約需要1000步,目前已經被減少到20步左右,甚至更少。

當擴散模型中的基礎模塊逐漸由Unet(CNN)替換為DiT(Transformer)之後,一些根據Transformer特性來做的優化也跟了上來。

比如量化,比如跳過Attention中的一些冗餘計算,比如pipeline。

而近日,來自加州大學爾灣分校等機構的研究人員,把「慳錢」這個目標直接向前推進了一大步:

論文地址:https://arxiv.org/abs/2407.15811論文地址:https://arxiv.org/abs/2407.15811

——從頭開始訓練一個11.6億參數的擴散模型,只需要1890美元!

對比SOTA有了一個數量級的提升,讓普通人也看到了能摸一摸預訓練的希望。

更重要的是,降低成本的技術並沒有影響模型的性能,11.6億個參數給出了下面這樣非常不錯的效果。

除了觀感,模型的數據指標也很優秀,比如下表給出的FID分數,非常接近Stable Diffusion 1.5和DALL·E 2。

相比之下,Wuerstchen的降成本方案則導致自己的考試分數不甚理想。

慳錢的秘訣

抱著「Stretching Each Dollar」的目標,研究人員從擴散模型的基礎模塊DiT入手。 

首先,序列長度是Transformer計算成本的大敵,需要除掉。

對於圖像來說,就需要在不影響性能的情況下,儘量減少參加計算的patch數量(同時也減少了內存開銷)。

減少圖像切塊數可以有兩種方式,一是增大每塊的尺寸,二是幹掉一部分patch(mask)。

因為前者會顯著降低模型性能,所以我們考慮進行mask的方式。

最樸素的mask(Naive token masking)類似於卷積UNet中隨機裁剪的訓練,但允許對圖像的非連續區域進行訓練。

而之前最先進的方法(MaskDiT),在輸出之前增加了一個恢復重建的結構,通過額外的損失函數來訓練,希望通過學習彌補丟掉的信息。

這兩種mask都為了降低計算成本,在一開始就丟棄了大部分patch,信息的損失顯著降低了Transformer的整體性能,即使MaskDiT試圖彌補,也只是獲得了不太多的改進。

——丟掉信息不可取,那麼怎樣才能減小輸入又不丟信息呢?

延遲掩蔽

本文提出了一種延遲掩蔽策略(deferred masking strategy),在mask之前使用混合器(patch-mixer)進行預處理,把被丟棄patch的信息嵌入到倖存的patch中,從而顯著減少高mask帶來的性能下降。

在本架構中,patch-mixer是通過注意力層和前饋層的組合來實現的,使用二進製掩碼進行mask,整個模型的損失函數為:

與MaskDiT相比,這裏不需要額外的損失函數,整體設計和訓練更加簡單。

而混合器本身是個非常輕量的結構,符合慳錢的標準。

微調

由於非常高的掩蔽比(masking ratio)會顯著降低擴散模型學習圖像中全局結構的能力,並引入訓練到測試的分佈偏移,所以作者在預訓練(mask)後進行了小幅度的微調(unmask)。

另外,微調還可以減輕由於使用mask而產生的任何不良生成偽影。

MoE和分層擴展

MoE能夠增加模型的參數和表達能力,而不會顯著增加訓練成本。

作者使用基於專家選擇路由的簡化MoE層,每個專家確定路由到它的token,而不需要任何額外的輔助損失函數來平衡專家之間的負載。

此外,作者還考慮了分層縮放方法,線性增加Transformer塊的寬度(即注意力層和前饋層中的隱藏層尺寸)。

由於視覺模型中的更深層傾向於學習更複雜的特徵,因此在更深層中使用更多的參數將帶來更好的性能。

實驗設置

作者使用兩種DiT的變體:DiT-Tiny/2和DiT-Xl/2,patch大小為2。

使用具有餘弦學習率衰減和高權重衰減的AdamW優化器訓練所有模型。

模型前端使用Stable-Diffusion-XL模型中的四通道變分自動編碼器(VAE)來提取圖像特徵,另外還測試了最新的16通道VAE在大規模訓練(慳錢版)中的性能。

作者使用EDM框架作為所有擴散模型的統一訓練設置,使用FID以及CLIP分數來衡量圖像生成模型的性能。

文本編碼器選擇了最常用的CLIP模型,儘管T5-xxl這種較大的模型在文本合成等具有挑戰性的任務上表現更好,但為了慳錢的目標,這裏沒有採用。

訓練數據集

使用三個真實圖像數據集(Conceptual Captions、Segment Anything、TextCaps),包含2200萬個圖像文本對。

由於SA1B不提供真實的字幕,這裏使用LLaVA模型生成的合成字幕。作者還在大規模訓練中添加了兩個包含1500萬個圖像文本對的合成圖像數據集:JourneyDB和DiffusionDB。

對於小規模消融,研究人員通過從較大的COYO-700M數據集中對10個CIFAR-10類的圖像進行二次采樣,構建了一個名為cifar-captions的文本到圖像數據集。

評估

使用DiT-Tiny/2模型和cifar-captions數據集(256×256解像度)進行所有評估實驗。

對每個模型進行60K優化步驟的訓練,並使用AdamW優化器和指數移動平均值(最後10K步平滑係數為0.995)。

延遲掩蔽

實驗的基線選擇我們上面提到的Naive masking,而本文的延遲掩蔽則加入一個輕量的patch-mixer,參數量小於主幹網絡的10%。

一般來說,丟掉的patch越多(高masking ratio),模型的性能會越差,比如MaskDiT在超過50%後表現大幅下降。

這裏的對比實驗採用預設的超參數(學習率1.6×10e-4、0.01的權重衰減和餘弦學習率)來訓練兩個模型。

上圖的結果顯示了延遲屏蔽方法在FID、Clip-FID和Clip score三個指標上都獲得了提升。

並且,與基線的性能差距隨著掩蔽率的增加而擴大。在掩蔽率為75%的情況下,樸素掩蔽會將FID分數降低至 16.5,而本文的方法則達到5.03,更接近於無掩蔽時的FID分數(3.79)。

超參數

沿著訓練LLM的一般思路,這裏比較兩個任務的超參數選擇。

首先,在前饋層中,SwiGLU激活函數優於GELU。其次,較高的權重衰減會帶來更好的圖像生成性能。

另外,與LLM訓練不同的是,當對AdamW二階矩 (β) 使用更高的運行平均係數時,本文的擴散模型可以達到更好的性能。

最後,作者發現使用少量的訓練步驟,而將學習率增加到最大可能值(直到訓練不穩定)也顯著提高了圖像生成性能。

混合器的設計

大力出奇蹟一般都是對的,作者也觀察到使用更大的patch-mixer後,模型性能得到持續改善。

然而,本著慳錢的目的,這裏還是選擇使用小型的混合器。

作者將噪聲分佈修改為 (−0.6, 1.2),這改善了字幕和生成圖像之間的對齊。

如下圖所示,在75% masking ratio下,作者還研究了採用不同patch大小所帶來的影響。

當連續區域變多(patch變大)時,模型的性能會下降,因此保留隨機屏蔽每個patch的原始策略。

分層縮放

這個實驗訓練了DiT-Tiny架構的兩種變體,一種具有恒定寬度,另一種採用分層縮放的結構。

兩種方法都使用Naive masking,並調整Transformer的尺寸,保證兩種情況下的模型算力相同,同時執行相同的訓練步驟和訓練時間。

由上表結果可知發現,在所有三個性能指標上,分層縮放方法都優於基線的恒定寬度方法,這表明分層縮放方法更適合DiT的掩蔽訓練。

參考資料: 

https://arxiv.org/abs/2407.15811 

本文來自微信公眾號「新智元」,作者:新智元,36氪經授權發佈。