多虧Transformer,Mamba更強了!僅用1%計算量達新SOTA

明敏 發自 凹非寺

量子位 | 公眾號 QbitAI

Attention is all you need.

至少在矩陣這兒是。

Mamba架構最新進展:僅需1%計算量,新模型性能達SOTA

能做到這一點,還多虧了Transformer。

通過將Transformer模型中的知識有效遷移到Mamba等替代架構中,模型能在保持較低計算成本的同時,性能更好。

這就是由Mamba主創之一Albert Gu領銜的最新成果。

值得一提的是,這種方法還適用於Mamba以外的非Transformer架構。

從Transformer到SSMs

Transformer由於依賴二次自注意力機制,所需計算量很大。

二次自注意力機制能讓模型在處理序列數據時有效捕捉序列內部的長距離依賴關係,但是由於二次時間複雜度(如果輸入規模翻倍,模型計算所需時間增加4倍),導致處理長序列的計算成本很高。

為瞭解決這個問題,學界提出了很多新架構,比如Mamba、RWKV等,它們的微調和推理成本更低。

考慮到Transformer模型預訓練已經投入了大量計算資源,研究人員想到,為什麼不能在此基礎上進行提升?

所以在本項研究中,他們提出了一種蒸餾方法MOHAWK,利用Transformer預訓練模型來訓練SSMs模型。

其核心在於注意力機制、線性注意力、Mamba的結構化掩碼注意力SMA等,都是跨輸入長度維度的序列轉換。因此它們都有各自的矩陣混合器,比如softmax。

通過將注意力和SSMs視為通過應用不同類別的矩陣來混合不同token嵌入的序列變換,序列模型架構可以分解為獨立序列混合和通道混合塊。

比如Transformer由注意力(序列混合器)和MLP(通道混合器)塊組成,使用這種分解可以蒸餾模型的每個元素。

具體蒸餾分為三個階段

第一階段:矩陣對齊(Matrix Orientation)。對齊序列變換矩陣本身。

第二階段:隱藏狀態對齊(Hidden-State Alignment)。對齊網絡每個單獨層的隱藏狀態表示,且不犧牲預先學習的表示。

第三階段:權重轉移和知識蒸餾(Weight-Transfer and Knowledge Distillation)。通過一個端到端訓練階段,將權重轉移,最終使用只有一小部分訓練數據來蒸餾網絡的最終輸出。

利用這個方法來實際修改一個模型,比如Phi-Mamba。

它結合了Mamba-2和Phi-1.5。

通過MOHAWK方法,該模型從預訓練的Transformer模型中學習,同時作為狀態空間模型,它在處理長序列上比傳統Transformer架構更高效。

該模型僅使用3B token進行蒸餾,數據量為從頭訓練模型的1%,但是性能達到開源非Transformer架構中的SOTA。

實驗發現,隱藏狀態對齊更好,可以提高後續階段的性能。

研究團隊也發佈了混合Phi-Mamba-1.5B,通過5B token蒸餾,模型與類似混合模型表現相當,但是注意力層只用了4層

值得一提的是,這種蒸餾方法不止適用於Mamba。

該研究由CUM助理教授、Cartesia AI聯合創始人及首席科學家Albert Gu領銜。

去年,他和FlashAttention作者Tri Dao一起提出了Mamba,成為第一個真正實現匹配Transformer性能的線性時間序列模型。

論文地址:

https://arxiv.org/abs/2408.10189