英偉達聯手MIT清北發佈SANA 1.5,線性擴散Transformer再刷文生圖新SOTA

SANA 1.5是一種高效可擴展的線性擴散Transformer,針對文本生成圖像任務進行了三項創新:高效的模型增長策略、深度剪枝和推理時擴展策略。這些創新不僅大幅降低了訓練和推理成本,還在生成質量上達到了最先進的水平。

近年來,文本生成圖像的技術不斷突破,但隨著模型規模的擴大,計算成本也隨之急劇上升。

為此,英偉達聯合MIT、清華、北大等機構的研究人員提出了一種高效可擴展的線性擴散Transformer——SANA,在大幅降低計算需求的情況下,還能保持有競爭力的性能。

SANA1.5在此基礎上,聚焦了兩個關鍵問題:

論文鏈接:https://arxiv.org/pdf/2501.18427論文鏈接:https://arxiv.org/pdf/2501.18427

SANA 1.5:高效模型擴展三大創新

SANA 1.5在SANA 1.0(已被ICLR 2025接收)的基礎上,有三項關鍵創新。

首先,研究者提出了一種高效的模型增長策略,使得SANA可以從1.6B(20層)擴展到4.8B(60層)參數,同時顯著減少計算資源消耗,並結合了一種節省內存的8位優化器。

與傳統的從頭開始訓練大模型不同,通過有策略地初始化額外模塊,可以讓大模型保留小模型的先驗知識。與從頭訓練相比,這種方法能減少60%的訓練時間。

其二,引入了模型深度剪枝技術,實現了高效的模型壓縮。通過識別並保留關鍵的塊,實現高效的模型壓縮,然後通過微調快速恢復模型質量,實現靈活的模型配置。

其三,研究者提出了一種推理期間擴展策略,引入了重覆采樣策略,使得SANA在推理時通過計算而非參數擴展,使小模型也能達到大模型的生成質量。

通過生成多個樣本,並利用基於視覺語言模型(VLM)的選擇機制,將GenEval分數從0.72提升至0.80。

與從頭開始訓練大模型不同,研究者首先將一個包含N個Transformer層的基礎模型擴展到N+M層(在實驗中,N=20,M=40),同時保留其學到的知識。

在推理階段,採用兩種互補的方法,實現高效部署:

  • 模型深度剪枝機制:識別並保留關鍵的Transformer塊,從而在小的微調成本下,實現靈活的模型配置。
  • 推理時擴展策略:借助重覆采樣和VLM引導選擇,在計算資源和模型容量之間權衡。

同時,內存高效CAME-8bit優化器讓單個消費級GPU上微調十億級別的模型成為可能。

下圖展示了這些組件如何在不同的計算資源預算下協同工作,實現高效擴展。

模型增長

研究者提出一種高效的模型增長策略,目的是對預訓練的DiT模型進行擴展,把它從𝑁層增加到𝑁+𝑀層,同時保留模型已經學到的知識。

研究過程中,探索了三種初始化策略,最終選定部分保留初始化方法。這是因為該方法既簡單又穩定。

在這個策略里,預訓練的N層繼續發揮特徵提取的作用,而新增加的M層一開始是隨機初始化,從恒等映射起步,慢慢學習優化特徵表示。

實驗結果顯示,與循環擴展和塊擴展策略相比,這種部分保留初始化方法在訓練時的動態表現最為穩定。

模型剪枝

本文提出了一種模型深度剪枝方法,能高效地將大模型壓縮成各種較小的配置,同時保持模型質量。

受Minitron啟發,通過輸入輸出相似性模式分析塊的重要性:

這裏的

表示第i個transformer的第t個token。

模型的頭部和尾部塊的重要性較高,而中間層的輸入和輸出特徵相似性較高,表明這些層主要用於逐步優化生成的結果。根據排序後的塊重要性,對transformer塊進行剪枝。

剪枝會逐步削弱高頻細節,因為,在剪枝後進一步微調模型,以彌補信息損失。

使用與大模型相同的訓練損失來監督剪枝後的模型。剪枝模型的適配過程非常簡單,僅需100步微調,剪枝後的1.6B參數模型就能達到與完整的4.8B參數模型相近的質量,並且優於SANA 1.0的1.6B模型。

推理時擴展

SANA 1.5經過充分訓練,在高效擴展的基礎上,生成能力有了顯著提升。受LLM推理時擴展的啟發,研究者也想通過這種方式,讓SANA 1.5表現得更好。

對SANA和很多擴散模型來說,增加去噪步數是一種常見的推理時擴展方法。但實際上,這個方法不太理想。一方面,新增的去噪步驟沒辦法修正之前出現的錯誤;另一方面,生成質量很快就會達到瓶頸。

相較而言,增加采樣次數是更有潛力的方向。

研究者用視覺語言模型(VLM)來判斷生成圖像和文本提示是否匹配。他們以NVILA-2B為基礎模型,專門製作了一個數據集對其進行微調。

微調後的VLM能自動比較並評價生成的圖像,經過多輪篩選,選出排名top-N的候選圖像。這不僅確保了評選結果的可靠性,還能有效過濾與文本提示不匹配的圖像。

模型增長、模型深度剪枝和推理擴展,構成了一個高效的模型擴展框架。三種方法協同配合,證明了精心設計的優化策略,遠比單純增加參數更有效。

  • 模型增長策略探索了更大的優化空間,挖掘出更優質的特徵表示。
  • 模型深度剪枝精準識別並保留了關鍵特徵,從而實現高效部署。
  • 推理時間擴展表明,當模型容量有限時,借助額外的推理時間和計算資源,能讓模型達到與大模型相似甚至更好的效果。

為了實現大模型的高效訓練與微調,研究者對CAME進行擴展,引入按塊8位量化,從而實現CAME-8bit優化器。

CAME-8bit相比AdamW-32bit減少了約8倍的內存使用,同時保持訓練的穩定性。

該優化器不僅在預訓練階段效果顯著,在單GPU微調場景中更是意義非凡。用RTX 4090這樣的消費級GPU,就能輕鬆微調SANA 4.8B。

研究揭示了高效擴展不僅僅依賴於增加模型容量。通過充分利用小模型的知識,並設計模型的增長-剪枝,更高的生成質量並不一定需要更大的模型。

SANA 1.5 評估結果

實驗表明,SANA 1.5的訓練收斂速度比傳統方法(擴大規模並從頭開始訓練)快2.5倍。

訓練擴展策略將GenEval分數從0.66提升至0.72,並通過推理擴展將其進一步提高至0.80,在GenEval基準測試中達到了最先進的性能。

模型增長

將SANA-4.8B與當前最先進的文本生成圖像方法進行了比較,結果如表所示。

從SANA-1.6B到4.8B的擴展帶來了顯著的改進:GenEval得分提升0.06(從0.66增加到0.72),FID降低0.34(從5.76降至5.42),DPG得分提升0.2(從84.8增加到85.0)。

和當前最先進的方法相比,SANA-4.8B模型的參數數量少很多,卻能達到和大模型一樣甚至更好的效果。

SANA-4.8B的GenEval得分為0.72,接近Playground v3的0.76。

在運行速度上,SANA-4.8B的延遲比FLUX-dev(23.0秒)低5.5倍;吞吐量為0.26樣本/秒,是FLUX-dev(0.04樣本/秒)的6.5倍,這使得SANA-4.8B在實際應用中更具優勢。

模型剪枝

為了和SANA 1.0(1.6B)公平比較,此次訓練的SANA 1.5(4.8B)模型,沒有用高質量數據做監督微調。

所有結果都是針對512×512尺寸的圖像評估得出的。經過修剪和微調的模型,僅用較低的計算成本,得分就達到了0.672,超過了從頭訓練模型的0.664。

推理時擴展

將推理擴展應用於SANA 1.5(4.8B)模型,並在GenEval基準上與其他大型圖像生成模型進行了比較。

通過從2048張生成的圖像中選擇樣本,經過推理擴展的模型在整體準確率上比單張圖像生成提高了8%,在「顏色」「位置」和「歸屬」子任務上提升明顯。

不僅如此,借助推理時擴展,SANA 1.5(4.8B)模型的整體準確率比Playground v3 (24B)高4%。

結果表明,即使模型容量有限,提高推理效率,也能提升模型生成圖像的質量和準確性。

SANA:超高效文生圖

在這裏介紹一下之前的SANA工作。

SANA是一個超高效的文本生成圖像框架,能生成高達4096×4096解像度的圖像,不僅畫質清晰,還能讓圖像和輸入文本精準匹配,而且生成速度超快,在筆記本電腦的GPU上就能運行。

SANA為何如此強大?這得益於它的創新設計:

  • 深度壓縮自動編碼器:傳統自動編碼器壓縮圖像的能力有限,一般只能壓縮8倍。而SANA的自動編碼器能達到32倍壓縮,大大減少了潛在tokens數量,計算效率也就更高了。
  • 線性DiT:SANA用線性注意力替換了DiT中的標準注意力。在處理高解像度圖像時,速度更快,還不會降低圖像質量。
  • 僅解碼文本編碼器:SANA不用T5做文本編碼器了,而是採用現代化的小型僅解碼大模型。同時,通過上下文學習,設計出更貼合實際需求的指令,讓生成的圖像和輸入文本對應得更好。
  • 高效訓練與采樣:SANA提出了Flow-DPM-Solver方法,減少了采樣步驟。再配合高效的字幕標註與選取,讓模型更快收斂。

經過這些優化,SANA-0.6B表現十分出色。

它生成圖像的質量和像Flux-12B這樣的現代大型擴散模型差不多,但模型體積縮小了20倍,數據處理能力卻提升了100倍以上。

SANA-0.6B運行要求不高,在只有16GB顯存的筆記本GPU上就能運行,生成一張1024×1024解像度的圖像,用時不到1秒。

這意味著,創作者們用普通的筆記本電腦,就能輕鬆製作高質量圖像,大大降低了內容創作的成本。

研究者提出新的深度壓縮自動編碼器,將壓縮比例提升到32倍,和壓縮比例為8倍的自動編碼器相比,F32自動編碼器生成的潛在tokens減少了16倍。

這一改進對於高效訓練和超高解像度圖像生成,至關重要。

研究者提出一種全新的線性DiT,用線性注意力替代傳統的二次複雜度注意力,將計算複雜度從原本的O(N²) 降低至O(N)。另一方面,在MLP層引入3×3深度可分卷積,增強潛在tokens的局部信息。

在生成效果上,線性注意力與傳統注意力相當,在生成4K圖像時,推理延遲降低了1.7倍。Mix-FFN結構讓模型無需位置編碼,也能生成高質量圖像,這讓它成為首個無需位置嵌入的DiT變體。

在文本編碼器的選擇上,研究者選用了僅解碼的小型大語言模型Gemma,以此提升對提示詞的理解與推理能力。相較於CLIP和T5,Gemma在文本理解和指令執行方面表現更為出色。

為充分發揮Gemma的優勢,研究者優化訓練穩定性,設計複雜人類指令,借助Gemma的上下文學習能力,進一步提高了圖像與文本的匹配質量。

研究者提出一種自動標註與訓練策略,借助多個視覺語言模型(VLM)生成多樣化的重新描述文本。然後,運用基於CLIPScore的策略,篩選出CLIPScore較高的描述,以此增強模型的收斂性和對齊效果。

在推理環節,相較於Flow-Euler-Solver,Flow-DPM-Solver將推理步驟從28-50步縮減至14-20步,不僅提升了速度,生成效果也更為出色。

參考資料:

https://huggingface.co/papers/2501.18427

https://x.com/xieenze_jr/status/1885510823767875799

https://nvlabs.github.io/SANA/

本文來自微信公眾號「新智元」,作者:英智 好睏,36氪經授權發佈。