LIama 3+Mamba強強聯手!蒸餾到線性RNN,推理速度提升1.6倍

基爾西 發自 凹非寺

量子位 | 公眾號 QbitAI

把Llama 3蒸餾到Mamba,推理速度最高可提升1.6倍!

而且性能不減,甚至表現比原始模型還要優異。

這是來自Together AI的新作,通過蒸餾將Transformer和Mamba模型結合到了一起,同時還為混合模型涉及了推理加速算法

提出Mamba架構的大神、FlashAttention作者Tri Dao,也參與了這一項目。

Together AI創始人兼CEO表示,Transformer和Mamba的混合,是未來大模型的一大發展方向。

將Transformer蒸餾進Mamba

在蒸餾正式開始之前,需要先進行從Transformer到線性RNN的初始化。

作者觀察到,Transformer的注意力機制與RNN的計算之間存在一定的相似性。

因此可以將Transformer的注意力線性化,從而建立二者的聯繫。

利用這種對應關係,可以將預訓練的Transformer模型的參數複製到Mamba模型中。

在完成參數初始化後,作者採用了一個三階段的蒸餾流程進一步提升Mamba模型的性能,使其更好地學習Transformer的知識。

第一階段是基於偽標籤的蒸餾——使用預訓練的Transformer教師模型在無標籤數據上生成偽標籤,然後讓Mamba學生模型在這些偽標籤上訓練。

這一過程的損失函數結合了KL散度損失和交叉熵損失,分別用於模仿教師模型輸出分佈以及偽標籤的擬合。

第二階段是在指令數據集上進行的監督微調,使用帶標籤的指令數據集(如OpenHermes 2.5)進行訓練。

最後一個階段,是用人類反饋數據,通過基於獎勵模型進行優化。

作者收集了人類對模型輸出的反饋數據,然後據此構建一個獎勵模型並使用 RL 算法(如 PPO)來優化模型在該獎勵模型下的表現。

在8塊80G A100 GPU上,每個混合模型的整個蒸餾過程,只需不到五天的時間。

通過以上的蒸餾過程,作者得到了Transformer-Mamba混合模型,之後又提出了Speculative Decoding(推測解碼)算法來加速推理過程。

混合模型推理加速算法

推測解碼算法的基本思想是使用一個輕量級的Draft模型來預測多個token,然後再用驗證模型(Verifier)來驗證這些預測。

這樣可以顯著提高解碼的並行性,加速生成過程。

Draft模型通常是一個小的Transformer,根據當前的上下文預測出接下來的K個token。

對於預測出的K個token,Transformer層可以直接並行地處理這K個token,計算它們的隱狀態;

Mamba層則需要按照順序依次處理每個token,首先計算當前token的隱狀態,並將其與之前的隱狀態進行比較。

  • 如果當前token是正確的,則將其添加到已接受的序列中,並更新最新的隱狀態(但不保存中間狀態)。

  • 如果當前token是錯誤的,則停止處理後續token,並將最新的隱狀態回退到上一個已接受的token處。

如果序列中的所有K個token都被接受,則將它們添加到輸出序列中,並繼續預測下一組token。

如果有token被拒絕,則從第一個被拒絕的token處截斷預測序列,並返回初始步驟從該位置開始重新預測。

Llama 3推理速度提升1.6倍

測試結果表明,混合模型在單論(AlpacaEval)和多輪(MT-Bench)聊天對話任務上與Llama-3相當甚至更優。

並且還對不同混合比例的模型表現進行了測試,發現其中按照1:1比例混合的模型表現最佳。

在零樣本的通用 NLP 任務評測中,混合模型的平均成績優於同等規模的RNN模型。

在少樣本的OpenLLM Leaderboard榜單上,混合模型的表現與最好的開源RNN模型相當,並在GSM8K和CRUX任務上超過了對應的Instruct模型。

除了模型性能,作者也對推測解碼算法帶來的加速效果進行了測試。

首先測試的是純Mamba模型,結果在2.8B和7B的模型上,相比原來的解碼方式,推理速度提升了1.7-2.6倍。

進一步地,作者在蒸餾的Zephyr和Llama混合模型上進行了測試,結果Zephyr混合模型的推理速度提升了1.8倍以上,Llama混合模型也有1.6倍左右的加速。

論文地址:

https://www.together.ai/blog/the-mamba-in-the-llama-distilling-and-accelerating-hybrid-models