字節Seed團隊PHD-Transformer突破預訓練長度擴展!破解KV緩存膨脹難題

機器之心報導

編輯:杜偉

最近,DeepSeek-R1 和 OpenAI o1/03 等推理大模型在後訓練階段探索了長度擴展(length scaling),通過強化學習(比如 PPO、GPRO)訓練模型生成很長的推理鏈(CoT),並在奧數等高難度推理任務上取得了顯著的效果提升。

受此啟發,研究人員開始探索預訓練階段的長度擴展,已有方法包括在序列中插入文本、插入潛在向量(如 Coconut)、複用中間層隱藏狀態(如 CoTFormer)以及將中間隱藏狀態映射為概念(如 COCOMix)。不過,這些方法普遍存在問題,比如需要更大的 KV 緩存導致推理慢 / 佔內存多。

本文中,來自 ByteDance Seed 團隊的研究者提出了更簡單的方法:直接重覆輸入 tokens(1/2/3/4 次),不做中間層處理。他們觀察到了訓練損失和模型性能隨重覆倍數擴展的趨勢,如下圖 1a 和 1b 所示。但是,直接重覆 tokens 也帶來了新問題,包括 KV 緩存規模線性增加,內存壓力大;預填充時間超線性增加;解碼延遲變長。這些都是實現預訓練長度擴展需要重點解決的挑戰。

  • 論文標題:Efficient Pretraining Length Scaling

  • arXiv 地址:https://arxiv.org/pdf/2504.14992

研究者提出了一種推理友好的新穎長度擴展方法,核心是 PHD-Transformer(Parallel Hidden Decoding Transformer),它保持了與原始 transformer 相同的 KV 緩存大小,同時實現有效的長度擴展。PHD-Transformer 通過創新的 KV 緩存管理策略實現了這些能力。

具體來講,研究者將第一個 token 表示原始 token,將重覆的 token 表示為解碼 token。同時僅保留從原始 token 生成的 KV 緩存來用於長距離依賴建模,並在隱藏解碼 token 用於下一個 token 預測之後丟棄它們的 KV 緩存。因此,PHD-Transformer 提供了與原始 transformer 相同的 KV 緩存,同時相較於簡單的 token 重覆實現了顯著的推理加速(如圖 1d 所示)。

另外,為了更好地保留隱藏解碼 token 的 KV 緩存的性能優勢,研究者引入了一種滑動窗口注意力 ——PHD-SWA,保持了這些 token 的局部滑動窗口緩存,在實現顯著性能提升的同時,僅需要

的額外 KV 緩存內存。

研究者還注意到,在 PHD-SWA 中,隱藏解碼 token 的 KV 緩存表現出了順序依賴關係,這導致預填充時間呈線性增長。為瞭解決這個問題,研究者提出了逐塊滑動窗口注意力 —— PHD-CSWA,從而限制了每個塊內的順序依賴關係。

因此,得益於只有最後一個塊的預填充時間呈線性增長,PHD-CSWA 顯著縮短了預填充時間(如圖 1c 所示)。

方法概覽

PHD 的架構下圖 2 所示,與原始 Transformer 相比,PHD 保留了相同的模型架構,僅在輸入序列和注意力矩陣的設計上有所不同。具體而言,他們僅允許原始 token

生成 KV 緩存,並且可以被所有 token 全局關注;同時隱藏狀態的 KV 緩存在並行隱藏解碼後會被立即丟棄。注意力矩陣的策略具體如下: 

研究者在推理過程中實現了與原始 Transformer 相同的 KV 緩存大小和內存訪問模式。雖然需要 K 次 FLOP,但這些計算可以並行處理,從而在內存受限的推理場景中最大限度地降低延遲開銷。該架構的核心優勢在於原始 token 和隱藏解碼 token 之間的解耦。在預填充期間,只有原始 token 需要計算。

這種設計確保預填充時間與原始 Transformer 相同,並且無論擴展因子 K 如何變化,預填充時間都保持不變。而對於損失計算,研究者僅使用 token 的最終副本進行下一個 token 的預測。總之,使用 token 的第一個副本進行 KV 緩存生成,使用 token 的最後一個副本進行下一個 token 的預測。

內核設計

M^ij_mn 的簡單實現會導致注意力層計算量增加 K^2 倍,FFN 層計算量也增加 K 倍。然而,由於注意力是稀疏計算的,

的注意力可以大幅降低。因此,研究者將原始 token 和隱藏解碼 token 分成兩組,並將它們連接在一起。

下圖 3 展示了 K = 3 的示例,可以得到一個包含 t 個原始 token 的序列和一個包含 2t 個隱藏解碼序列的序列。通過重新排列 token 的位置,研究者將掩碼注意力的位置保留在一個連續塊中,從而優化了注意力計算,將注意力計算複雜度降低到

PHD-SWA 和 PHD-CSWA

與簡單的 token 重覆相比,PHD-Transformer 在保持原始 KV 緩存大小的同時實現了長度擴展。然而通過經驗觀察到,為隱藏解碼 token 保留一些 KV 緩存可以帶來顯著的性能提升。因此,為了在保持效率的同時獲得這些優勢,研究者引入了 PHD-SWA,將滑動窗口注意力限制在 W 個先前的隱藏解碼 token 上。

如下圖 4 所示,PHD-SWA 的注意力模式將對原始 token 的全局訪問與對 W 個最近隱藏解碼 token 的局部訪問相結合。這種改進的注意力機制實現了顯著的性能提升,同時僅需要

的額外 KV 緩存內存。

雖然 PHD-SWA 滑動窗口方法提升了模型性能,但由於隱藏解碼 token 的 KV 緩存中存在順序依賴關係,它會產生 K 倍的預填充開銷。為瞭解決這個問題,研究者引入了 PHD-CSWA,它可以在獨立的塊內處理注意力。 

如下圖 4 所示,PHD-CSWA 將滑動窗口注意力限制在單個塊內運行。這種架構創新將額外的預填充開銷減少到最終塊內的 K 次重覆,而不是整個序列重覆,這使得額外的計算成本幾乎可以忽略不計,同時保留了局部注意力模式的優勢。

實驗結果

在實驗中,研究者使用 OLMo2 作為代碼庫,並在 ARC、HellaSwag、PIQA、Winogrande、MMLU 和 CommonsenseQA 等公開基準測試集上進行了評估。

訓練細節:研究者使用 1.2B 參數規模的模型,它是一個 16 層的密集模型。每個 token 的隱藏層維數設置為 2048,FFN 層的隱藏層大小設置為 16384。同時使用組查詢注意力 (Group-Query Attention,GQA),它包含 32 個查詢頭和 8 個鍵 / 值頭,每個頭的隱藏層維數設置為 64。研究者使用 500B 個 token 訓練該模型。

對於本文提出的 PHD 系列設置,研究者預訓練了以下兩種 PHD-CSWA 變體:

  • PHD-CSWA-2-16-32,其中訓練 token 重覆兩次。保留一個包含 16 個 token 的局部窗口,並將塊大小設置為 32 個 token。 

  • PHD-CSWA-3-16-32,其中訓練 token 重覆三次。局部窗口大小和塊大小與 PHD-CSWA-2-16-32 的設置相同。

PHD-CSWA 在各個基準測試中均實現了持續的性能提升。下圖 5 中展示了訓練曲線,下表 1 中展示了主要結果。本文提出的 PHD-CSWA-2-16-32 在這些基準測試中平均實現了 1.5% 的準確率提升,訓練損失降低了 0.025;而 PHD-CSWA-3-16-32 在這些基準測試中平均實現了 2.0% 的準確率提升,訓練損失降低了 0.034。

研究者還分析了 PHD 和 PHD-SWA 的擴展性能,以分析擴展解碼計算的性能。 訓練細節:使用相同的 550M 模型配置,將窗口大小 W 設置為 16,並在 {2, 3, 5} 範圍內改變擴展因子 K。對於局部窗口大小,研究者在所有實驗中都將窗口大小設置為 16。

PHD-SWA 的性能在增加擴展因子時有效擴展。如下圖 8 所示,使用固定窗口大小時,損失曲線和下遊性能會隨著 token 重覆次數而有效擴展。通過將擴展因子設置為 5,可以實現接近 0.06 的損失降低,同時顯著提升下遊性能。

下表 2 中的定量結果表明,當擴展至 K = 5 時,所有基準測試的平均準確率提高了 1.8%,這證實了本文的方法在更激進的擴展方面仍然有效。

更多實驗結果請參閱原論文。