RAG新突破:塊狀注意力機制實現超低延遲檢索增強

AIxiv專欄是機器之心發佈學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯繫報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

在工業場景中,往往會利用檢索技術來為大語言模型添加一些來自外部數據庫的知識文檔,從而增強大語言模型的回覆可信度。一般來說,RAG 被公認是最有效的為 LLM 注入特定領域知識的方式。

然而,RAG 也有其不足之處。通常來說,在實際應用中,為確保能召回包含正確知識的文檔,對於每個用戶的查詢,會檢索多個文檔(一般在 5 到 30 個之間),並把這些文檔整合到輸入提示中供大語言模型處理。這樣一來,輸入提示的序列長度增加,使得推理效率大幅降低。具體來講,以首次生成標記的時間(湯臣FT)來衡量,RAG 大語言模型的推理延遲比非 RAG 大語言模型高很多。

由於數據庫中同一文檔經常會被不同 query 召回,大家很自然的會想到:是否能夠把已經算好的文檔表示(KV states)存在緩存中,以供二次使用?很遺憾, 由於自回歸注意力機制的限制,大語言模型中每個文檔的 KV States 都與上下文相關,所以遇到新的 query 時,模型必須重新編碼 KV states 才能確保準確預測。

最近,論文《Block-Attention for Efficient RAG》為檢索增強 (RAG) 場景實現了一種塊狀注意力機制,Block-Attention,通過分塊獨立編碼檢索到的文檔,使得模型無需重覆編碼計算已經在其他 query 中已經見過的文檔,從而實現線上推理效率的有效提升。在實驗中,該方法能夠讓使用 RAG 技術的模型與不使用 RAG 的模型有幾乎一樣的響應速度。同時,該方法甚至還能略微提升在 RAG 場景下的模型準確率。

  • 論文標題:Block-Attention for Efficient RAG

  • 論文地址:https://arxiv.org/pdf/2409.15355

如下圖所示,該方法把整個輸入序列分成若幹個 block,每個 block 獨立計算其 KV States,只有最後一個 block 能夠關注其他 blocks(在 RAG 場景中,最後一個 block 即用戶的輸入)。在 RAG 場景中,block-attention 讓模型不再需要重覆計算已經在其他 query 中見過的文檔。

Block-Attention 的實現並不複雜:1)獨立編碼除最後一個 block 以外的所有 blocks;2)為每個 blocks 重新計算位置編碼;3)將所有 blocks 拚接在一起,並計算最後一個 block 的 KV State。然而直接把模型不加任何修改的從 self-attention 切換到 block-attention 會導致大語言模型懵圈,畢竟模型在訓練階段從來沒見過 block-attention 方式編碼的輸入。一個量化的對比是,直接切換為 block-attention 會讓 Llama3-8B 在四個 RAG 數據集上的平均準確率由 67.9% 下降至 48.0%。

為了讓模型適應 block-attention,作者們對模型進行了進一步微調,作者們發現在 100-1000 步微調之後,模型就能快速適應 block-attention,在四個 RAG 數據集上的平均準確率恢復至 68.4%。另外,block-attention 方式的模型在 KV cache 技術的幫助下,能達到與無 RAG 模型相似的效率。在用戶輸入長度為 50 而 prompt 總長度為 32K 的極端情況下,block-attention model 的首字延時(Time To First Token, 湯臣FT)和首字浮點運算數(FLOPs To Frist Token, (FLOPs-TFT)分別能降低至 self-attention model 的 1.3% 和 0.2%,與無 RAG 模型的效率基本持平。

推理流程

關於 block-attention 的實現和詳細推導,讀者們請移步原文,這裏主要介紹 block-attention 模型的推理流程。如下圖所示,首先從緩存中查詢並提取前 K 個 block 的 KV states。然後,根據每個 block 在輸入序列中的位置,作者們對每個 block 的位置編碼進行了重新計算。具體的操作過程詳見論文的公式 3。最後,根據前 k-1 個 KV States 計算最後一個數據塊的鍵值狀態以及模型的輸出。

實驗結果

在實驗中,作者們主要想探究兩個問題的答案:1)在 RAG 場景中,block-attention 模型能否達到與自 self-attention 相同的準確率?2)block-attention 對效率的提升有多大?

對於問題一,上圖給出了答案。作者們根據實驗結果給出了三個結論:

1. 直接從 self-attention 切換到 block-attention 是不可取的,因為這會導致準確率急劇下降。例如,對於 Llama3-8B 和 Mistral-7B 模型,去除微調過程會導致在所有四個基準上平均絕對性能下降 21.99%。

2. 然而,如果作者們在微調階段使用塊注意力機制,那麼得到的模型與自注意力模型的性能幾乎相同,甚至在某些數據集上略好。例如,Mistral-7B-block-ft 在四個基準上的性能優於自回歸方式訓練的模型,平均準確率由 59.6% 上升至 62.3%。

3. 位置重新編碼操作對於 block-attention 模型至關重要。去除它會導致性能顯著下降 —— 在四個數據集上準確率平均下降 4%。

對於效率的提升,作者們也通過另一組實驗進行了驗證。他們將用戶的問題長度固定在 50 個 token,然後逐漸增加被召回文檔的數量,讓輸入序列總長度從 50 一直增加到 32K。模型在不同 prompt 長度下的首字延時(Time To First Token, 湯臣FT)和首字浮點運算數(FLOPs To Frist Token, (FLOPs-TFT)如下圖所示。顯然,加速效果令人滿意:當輸入序列的長度為 512 時,使用 block-attention 可以將 湯臣FT 減少 48%,將 FLOPs-TFT 減少 90.1%。隨著總長度的增加,block-attention 模型的 湯臣FT 和 FLOPs-湯臣F 保持基本不變的趨勢。當總長度達到 32K 時,加速效果可以達到驚人的 98.7%,FLOPs-TFT 的消耗甚至減少了 99.8%。作者們將此實驗結果總結為:「文本越長,block-attention 越重要」。

作者們最後還指出,block-attention 在很多場景中都有著重要作用,並不局限於 RAG。由於一些保密原因,作者們暫時無法透露在其他工業應用中是如何使用它的。作者們期待社區的研究人員能夠進一步探索 block-attention 的潛力,並將其應用於合適的場景。