從線性注意力視角揭秘視覺Mamba,清華、阿里合作提出全新MILA模型

AIxiv專欄是機器之心發佈學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯繫報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

論文第一作者為清華大學自動化系博士生韓東辰,指導老師為黃高副教授。他的主要研究方向包括高效模型架構設計、多模態大模型等。

Mamba 是一種具有線性計算複雜度的狀態空間模型,它能夠以線性計算複雜度實現對輸入序列的有效建模,在近幾個月受到了廣泛的關注。

本文給出了一個十分有趣的發現:強大的 Mamba 模型與通常被認為性能不佳的線性注意力有著內在的相似性:本文用統一的公式表述了 Mamba 中的核心模塊狀態空間模型(SSM)和線性注意力,揭示了二者之間的密切聯繫,並探究了是哪些特殊的屬性和設計導致了 Mamba 的成功。

實驗結果表明,等效遺忘門和宏觀結構設計是 Mamba 成功的關鍵因素。本文通過分析自然地提出了一個新的模型結構:Mamba-Inspired Linear Attention(MILA),它同時繼承了 Mamba 和線性注意力的優點,在各種視覺任務中表現出超越現有的視覺 Mamba 模型的精度,同時保持了線性注意力優越的並行計算與高推理速度。

  • 論文鏈接:https://arxiv.org/abs/2405.16605

  • 代碼鏈接:https://github.com/LeapLabTHU/MLLA

  • 影片講解:https://www.bilibili.com/video/BV1NYzAYxEbZ

最近,以 Mamba 為例的狀態空間模型引起了廣泛的研究興趣。不同於 Transformer 的平方複雜度,Mamba 模型能夠以線性複雜度實現有效的序列建模,在長文本、高解像度圖像、影片等長序列建模和生成領域表現出很大的潛力。

然而,Mamba 並不是第一個實現線性複雜度全局建模的模型。早期的線性注意力使用線性歸一化代替 Softmax 注意力中的 Softmax 操作,將計算順序從 (QK) V 更改為 Q (KV) ,從而將計算複雜度降低為線性。然而,之前的許多工作表明線性注意的表達能力不足,難以取得令人滿意的效果。

令人驚訝的是,本文發現高性能的 Mamba 和表達能力不足的線性注意力的公式之間存在深層次的關聯。因此,一個引人思考的研究問題是:是什麼因素導致了 Mamba 的成功和它相較於線性注意力的顯著優勢?

從這個問題出發,本文在以下幾個方面進行了探索:

1. 揭示了 Mamba 與 Linear Attention Transformer 之間的關係:Mamba 和 Linear Attention Transformer 可以使用統一的公式表示。進一步地,Mamba 可以視為具有若干特殊設計的線性注意力,其特殊設計為:輸入門 (input gate)、遺忘門 (forget gate)、快捷連接 (shortcut)、無注意力的歸一化、single-head 和更先進的宏觀架構。

2. 實驗證明,遺忘門和宏觀架構很大程度上是 Mamba 性能成功的關鍵。然而,遺忘門會導致循環計算,可能並不適合視覺模型。本文發現,適當的位置編碼能夠在視覺任務中替代遺忘門的作用,同時保持並行計算和快速的推理。

3. 提出了一系列名為 MILA 的 Linear Attention Transformer 模型,它引入了 Mamba 的設計思想,並且比原始 Mamba 模型更適合視覺任務。

一、線性注意力與狀態空間模型回顧

本文首先簡略回顧線性注意力和狀態空間模型的數學表達。本部分公式較多,詳細推導請參考論文或影片講解。

1. 線性注意力

對於輸入序列,單頭線性注意力可以表達為:

可以看到,線性注意力通過先計算 K 和 V 的乘積,將計算複雜度降低到。上式中,每個 Q 擁有全局感受野,可以與所有的 K、V 進行信息交互。實際應用中,線性注意力也可以應用在自回歸的模型中,限制每個 token 只能與之前的 token 進行信息交互:

這種因果的線性注意力範式可以進一步寫成循環形式:

2. 狀態空間模型

對於實數序列輸入,Mamba 所採用的狀態空間模型可以表達為:

為了方便後續推導,此處對上式進行了 3 處數學表達上的等價變形,具體請參考原論文。等價變形後得到的公式為:

對於向量序列輸入,Mamba 會在每個維度分別應用上式的實數輸入 SSM,從而得到下面狀態空間模型:

值得注意的是,上式嚴格等價於 Mamba 所進行的 SSM 操作,這裏僅僅進行了數學表達形式上的等價變換。

二、Mamba 與線性注意力關係解析

對於輸入序列,Mamba與線性注意力的公式之間有許多相似之處。為了便於比較,本文將二者使用相同的公式進行表達:

以下是上述兩個公式的示意圖:

圖 1:Mamba 與線性注意力操作示意圖圖 1:Mamba 與線性注意力操作示意圖

從公式和示意圖可以看到,Mamba 的 SSM 操作與線性注意力有深刻的聯繫。具體來說,SSM 中的 C 類似於線性注意力中的 Q,B 類似於 K^T ,x 類似於 V ,h 類似於 S。因此,Mamba 和線性注意力有著非常密切的關係,Mamba 可以被認為是一種特殊的線性注意力。此外,基於公式和示意圖中還可以發現二者的幾個不同點:

(1) 在 Mamba 中,

會與

逐位相乘。由於

是每一位嚴格大於零的向量,因此可將其視為一個等效的輸入門,可以控制

輸入SSM的比例。

(2) 在 Mamba 中,有額外的

逐位相乘。在Mamba的實現中,

每一位都是0到1之間的實數,因此

實際控制對於之前的狀態空間

的衰減程度,因此可將其理解為等效的遺忘門。

(3) Mamba 中,有一個額外的可學習的 shortcut,

(4) 線性注意力中,有一個保證注意力之和為 1 的歸一化分母

,Mamba 中沒有這樣的歸一化。

除此之外,該圖和公式中的線性注意力都是單頭設計,因為僅有一組 Q 和 K。所以可以認為 Mamba 等效於單頭線性注意力,而沒有採用多頭設計(即多組 Q 和 K)。進一步,除了核心操作不同之外,Mamba 和傳統的線性注意力模型在宏觀結構上也有區別。二者的宏觀結構如下圖,Mamba 採用比較符合的結構,包含線性層、卷積、SSM 等。

圖 2:線性注意力模型、Mamba 和 MILA 的宏觀模型架構圖 2:線性注意力模型、Mamba 和 MILA 的宏觀模型架構

總而言之,Mamba 可以視為具有 6 種特殊設計的線性注意力模型,其特殊設計為:輸入門、遺忘門、shortcut、無注意力歸一化、單頭設計、更先進的宏觀結構。

三、實驗

Mamba 被視為 Transformer 的一種有力挑戰者,而線性注意力通常性能不佳。在之前的分析中,本文發現這兩種性能差距很大的模型具有深刻的相似性,並指出了他們之間的 6 個不同設計。接下來,本文通過實驗來驗證究竟是哪些設計導致了二者之間如此大的性能差距。

1. 核心驗證實驗

本文使用線性注意力作為 baseline 模型,在其基礎上引入每一個不同設計,並在 ImageNet 上實驗驗證模型性能的變化。結果如下圖所示:

圖 3:每個不同設計的影響圖 3:每個不同設計的影響

可以看到,Mamba 的等效遺忘門和宏觀設計對於模型性能最為關鍵,而其他設計影響不大或者不如線性注意力。同時,本文發現,由於遺忘門必須採用循環計算,引入遺忘門使得模型推理速度明顯下降。遺忘門帶來的循環計算對於語言模型等自回歸模型是合適的,因為模型在推理時本來就需要不斷自回歸循環計算。然而,這種模式對於圖像等非因果並不自然,因為它不僅限制了模型的感受野,還極大降低了模型的推理速度。本文發現,在視覺任務中,適當的位置編碼能夠引入類似遺忘門的位置信息,同時保持全局感受野、並行計算和更快的推理速度。

圖 4:在視覺模型中用位置編碼代替遺忘門圖 4:在視覺模型中用位置編碼代替遺忘門

2. MILA 模型

基於以上分析和驗證,本文將 Mamba 和線性注意力的優秀設計結合起來,將 Mamba 的兩項核心設計的精髓引入線性注意力,構建了 Mamba-Inspired Linear Attention (MILA) 模型。MILA 能夠以線性複雜度實現全局建模,同時享有並行計算和更快的推理速度,在多種視覺任務上都取得了優於各類視覺 Mamba 模型的效果。以下是一些實驗結果:

圖 5:ImageNet 分類實驗圖 5:ImageNet 分類實驗
圖 6:模型推理速度和性能的 Trade-off

圖 6:模型推理速度和性能的 Trade-off

圖 7:高解像度下遊任務 —— 物體檢測圖 7:高解像度下遊任務 —— 物體檢測

四、總結

(1) Mamba 可以視為具有若干特殊設計的線性注意力,其特殊設計為:輸入門 (input gate)、遺忘門 (forget gate)、快捷連接 (shortcut)、無注意力的歸一化、單頭設計 (single-head) 和更先進的宏觀架構。

(2) 實驗證明,遺忘門和宏觀架構很大程度上是 Mamba 性能成功的關鍵。然而,遺忘門會導致循環計算,可能並不適合視覺模型。本文發現,適當的位置編碼在視覺任務中替代遺忘門的作用,同時保持並行計算和快速的推理。

(3) 本文提出了一系列名為 MILA 的 Linear Attention Transformer 模型,它繼承了 Mamba 的核心優點,並且比原始 Mamba 模型更適合視覺任務。