Meta探索大模型記憶層,擴展至1280億個參數,優於MoE

機器之心報導

編輯:小舟、陳陳

預訓練語言模型通常在其參數中編碼大量信息,並且隨著規模的增加,它們可以更準確地回憶和使用這些信息。對於主要將信息編碼為線性矩陣變換權重的密集深度神經網絡來說,參數大小的擴展直接與計算和能量需求的增加相關。語言模型需要學習的一個重要信息子集是簡單關聯。雖然前饋網絡原則上(給定足夠的規模)可以學習任何函數,但使用聯想記憶(associative memory)會更高效。

記憶層(memory layers)使用可訓練的鍵值查找機制向模型添加額外的參數,而不會增加 FLOP。從概念上講,稀疏激活的記憶層補充了計算量大的密集前饋層,提供了廉價地存儲和檢索信息的專用容量。

最近,Meta 的一項新研究使記憶層超越了概念驗證,證明了它們在大型語言模型(LLM)擴展中的實用性。

  • 論文標題:Memory Layers at Scale

  • 論文地址:https://arxiv.org/pdf/2412.09764

  • 項目地址:https://github.com/facebookresearch/memory

在下遊任務中,通過改進的記憶層增強的語言模型的性能優於計算預算兩倍以上的密集模型,以及在計算和參數相當的專家混合(MoE)模型。

這項工作表明,當記憶層得到充分改進和擴展時,可以用於增強密集神經網絡,從而帶來巨大的性能提升。通過用記憶層替換一個或多個 transformer 層的前饋網絡(FFN)來實現這一點(保持其他層不變)。這些優勢在各種基本模型大小(從 1.34 億到 80 億參數)和內存容量(最多 1280 億參數)中都是一致的。這意味著存儲容量實現了兩個數量級的飛躍。

記憶增強架構

可訓練的記憶層類似於注意力機制。給定一個查詢

。輸出是值的軟組合,根據 q 和相應鍵之間的相似性進行加權。

,以及值

,一組鍵

在使用時,記憶層與注意力層之間存在兩個區別。

  • 首先,記憶層中的鍵和值是可訓練參數,而不是激活參數;

  • 其次,記憶層在鍵和值的數量方面通常具有更大的規模,因此稀疏查詢和更新是必需的。

該研究將鍵-值對的數量擴展到數百萬。在這種情況下,只有 top-k 最相似的鍵和相應的值被輸出。一個簡單的記憶層可以用下面的等式來描述:

其中,I 是一組指標,

,輸出

擴展記憶層

擴展記憶層時面臨的一個瓶頸是「查詢 – 鍵」檢索機制。簡單的最近鄰搜索需要比較每一對查詢 – 鍵,這對於大型記憶來說很快就變得不可行。雖然可以使用近似向量相似性技術,但當鍵正在不斷訓練並需要重新索引時,將它們整合起來是一個挑戰。相反,本文採用了可訓練的「product-quantized」鍵。

並行記憶。記憶層是記憶密集型的,主要是由於可訓練參數和相關優化器狀態的數量龐大導致的。該研究在多個 GPU 上並行化嵌入查找和聚合,記憶值在嵌入維度上進行分片。在每個步驟中,索引都從進程組中收集,每個 worker 進行查找,然後將嵌入的部分聚合到分片中。此後,每個 worker 收集與其自身索引部分相對應的部分嵌入。該過程如圖 2 所示。

共享記憶。深度網絡在不同層上以不同的抽像級別對信息進行編碼。向多個層添加記憶可能有助於模型以更通用的方式使用其記憶。與以前的工作相比,該研究在所有記憶層中使用共享記憶參數池,從而保持參數數量相同並最大化參數共享。

該研究通過引入具有 silu 非線性的輸入相關門控來提高記憶層的訓練性能。等式 (1) 中的輸出變為:

其中 silu (x) = x sigmoid (x),⊙是元素的乘法(參見圖 3)。其中 silu (x) = x sigmoid (x),⊙是元素的乘法(參見圖 3)。

實驗及結果

首先,該研究固定記憶大小,並與密集基線以及參數大致匹配的 MOE 和 PEER 模型進行比較。

從表 1 中我們可以看出,Memory 模型比密集基線模型有了大幅改進,在 QA 任務上的表現通常與密集參數數量為其兩倍的模型相當。

Memory+ (有 3 個記憶層)比 Memory 有了進一步的改進,其性能通常介於計算能力高出其 2 到 4 倍的密集模型之間。

對於相同數量的參數,PEER 架構的表現與 Memory 模型相似,但落後於 Memory+。MOE 模型的表現遠不及 Memory 變體。

圖 4 顯示了不同大小的 Memory、MOE 和密集模型在 QA 任務上的擴展性能。圖 4 顯示了不同大小的 Memory、MOE 和密集模型在 QA 任務上的擴展性能。
圖 1 表明 Memory+ 模型的實際 QA 性能隨著記憶大小的增加而不斷的增加。圖 1 表明 Memory+ 模型的實際 QA 性能隨著記憶大小的增加而不斷的增加。

在 6400 萬個鍵(1280 億個記憶參數)下,1.3B Memory 模型的性能接近 Llama2 7B 模型,後者使用了 10 倍以上的 FLOPs(見表 2)。

最後,本文在 8B 基礎模型和 4096^2 個記憶值的基礎上 (64B 記憶參數)擴展了 Memory+ 模型,表 2 報告了結果,發現記憶增強模型的表現明顯優於密集基線。