一行代碼Post-Train任意長序列!360智腦開源360-LLaMA-Factory
AIxiv專欄是機器之心發佈學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯繫報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
項目核心開發者 Haosheng Zou 本科畢業於清華大學電子系,博士畢業於清華大學計算機系朱軍教授組,目前在 360 智腦從事長文本和強化學習等後訓練工作。開發者 Xiaowei Lv 目前在人民大學信息學院研二在讀。Fenrui Xiao、Junchen Liu、Qi An 和 Xiaodong Sun 等在開發測試中亦有貢獻。
大模型長序列的處理能力已越來越重要,像複雜長文本任務、多幀影片理解任務、以及 OpenAI 近期發佈的 o1、o3 系列模型的高計算量模式,需要處理的輸入 + 輸出總 token 數從幾萬量級上升到了幾百萬量級。面對模型日益增長的長序列需求,在預訓練(Pre-Training)和後訓練(Post-Training)階段,所用的平台框架都需要支持更長序列數據的訓練。不同於預訓練階段基於 Megatron-LM 定製開發的常見選擇,後訓練階段因後訓練算法的多樣性(比如僅 DPO 就有幾十個變種)和訓練需求的靈活性,至今沒有一個框架同時在並行策略、後訓練算法、GPU 顯存優化和簡單易用這 4 個方面上全部做到兼容并包。
在所有開源的後訓練框架中,LLaMA-Factory 是用戶最多的框架之一(GitHub star 數已 37k 多),保持長期迭代更新,支持豐富的模型和後訓練算法,有各種 GPU 顯存優化技巧和簡單易用的方式。然而,LLaMA-Factory 在長序列後訓練上支持仍有所欠缺,尚不支持長序列的關鍵技術 —— 序列並行。
項目主頁:https://github.com/Qihoo360/360-LLaMA-Factory
最近,360 智腦基於 LLaMA-Factory 開源了 360-LLaMA-Factory,加入了序列並行功能,一行代碼即可支持任意長序列的後訓練(Post-Training)—— 僅需額外指定序列並行一個參數:
sequence_parallel_size: 16
按需增加序列並行的 GPU 卡數,即可在任意長度的序列上 SFT 或 DPO。
360-LLaMA-Factory 的實現經過了嚴格的正確性驗證,已在主倉 Pull Request 中審核過。正式合併進 LLaMA-Factory 主倉之前,可先使用 360-LLaMA-Factory。
1、項目背景與項目簡介
360 智腦早在 2023 年就開始了長文本大模型的研發,到目前為止已經成功應用於開源並更新了兩個版本的 360Zhinao-7B-Chat-360k 模型,以及近日發佈的長思維鏈推理模型 360gpt2-o1。在 360-LLaMA-Factory 中,我們將 360 智腦內部長序列後訓練能力系統性地整合進了 LLaMA-Factory 中,用戶僅需額外添加一行代碼,即可進行理論上任意長度的長序列後訓練(增加序列並行的 GPU 卡數即可):
sequence_parallel_size: 16
在原先使用 LLaMA-Factory 的基礎上,只需額外增加一個參數
通過這種方式,360-LLaMA-Factory 將 LLaMA-Factory 的序列並行也做到了簡單易用和兼容并包,和 LLaMA-Factory 的其他功能完全兼容。
粗粒度地測試 8 卡 80G 的全參數後訓練(不考慮除了 zero3-offload 和 gradient checkpointing 外的任何優化技巧),360-LLaMA-Factory 至少可以訓到 SFT 210k (7B) / 128k (72B) 和 DPO 84k (7B) / 46k (72B)。若加上注掉 logits = logits.float () 和 DPO 預計算等技巧,2 卡序列並行即可解決諸多常見的訓練需求。360-LLaMA-Factory 讓序列並行也真正成為了簡單好用、效果也好的後訓練工具。
作為開源社區的一份子,360-LLaMA-Factory 離不開 LLaMA-Factory、ring-flash-attention 和 EasyContext 等開源項目的開創性工作,我們的底層開發部分依賴了這些工作,但也有我們自己在具體實現方式上的不同和見解。我們相信我們的代碼實現已做到儘可能好的模塊化和儘可能少的原始代碼修改,且嚴格檢查過正確性,因此也已向 LLaMA-Factory 主倉提交了 Pull Request,初步審核通過。我們樂於同開源社區共建完善這項工作。
2、長序列及其後訓練
2.1 長序列大模型的訓練:預訓練 vs 後訓練
隨著大模型訓練數據長度的增長,預訓練和後訓練平台框架都需要支持長序列數據訓練。
-
預訓練階段,英偉達的 Megatron-LM 憑藉豐富高效的並行策略與出色的 GPU 顯存優化,成為主流框架,基於它的定製開發往往是最通用的解法, Megatron-LM 本身已實現了序列並行(Megatron-LM 稱之為 context parallelism,其他工作一般稱為 sequence parallelism)。
-
後訓練階段情況相對複雜。後訓練算法多樣,如 DPO 就有諸多變種,且訓練需求靈活多變,不同場景對算法、資源、並行性等要求各異。因此,至今沒有一個框架能在並行策略、後訓練算法、GPU 顯存優化和易用性這四個關鍵方面做到近乎完美的兼容。雖有框架在部分方面表現尚可,但總體仍存在短板,這也限制了模型在長序列數據後訓練上的進一步發展。
2.2 長序列的通解 —— 序列並行及其難點
長序列後訓練面臨的關鍵瓶頸是:序列長度增延長,激活顯存會大幅上升。雖然有 unsloth、liger kernel、LoRA 等多種降低顯存佔用的技巧,但均未從根本上解決序列長度增加的本質問題,其效果存在明確上限。
序列並行(sequence parallelism)被認為是解決長序列訓練問題的通解,它通過把一條長序列切分到不同的顯卡上進行計算,從而避免了每張顯卡處理過長的序列,從根本上解決了 「每張顯卡處理的序列長度增加」 的問題。然而,序列並行的實現難度較大,需要在切分後的序列之間進行通信計算 attention,需要侵入修改原始的 attention 函數。在開源的 Megatron-LM 中,序列並行也是所有並行策略中最後才添加的,LLaMA-Factory 之前還沒有支持序列並行。
2.3 序列並行後訓練的相關工作
我們調研了其他一些支持序列並行的開源框架,有些實現上有錯或小 bug、導致支持的後訓練算法不全;有些更新維護不及時、訓練較新的模型不方便、顯示進度條等易用性不足。有的與 LLaMA-Factory 相比繼承依賴更少,支持功能較少但更乾淨、更適合定製開發,有不同的使用場景。此外,各家的序列並行具體實現也不盡相同。詳見下面的表 1 和 GitHub README,有未調研到的也請包涵並聯繫 360-LLaMA-Factory。
表 1:一些支持序列並行的後訓練框架對比
3、360-LLaMA-Factory 框架解析
360-LLaMA-Factory 系統性地為 LLaMA-Factory 增加了序列並行的支持。以下將簡要介紹 360-LLaMA-Factory 框架中的模塊化修改和執行流程。
3.1 360-LLaMA-Factory 的框架和模塊化封裝
360-LLaMA-Factory 將序列並行的代碼做到了儘可能好的模塊化和儘可能少的原始代碼修改。
我們認為序列並行本質上應認為是對模型的修改,因此在 model_args 中增加了參數並抽像為 apply_sequence_parallel 修改模型的函數。
# src/llamafactory/model/loader.py
sequence_parallel_group = apply_sequence_parallel(model_args) # 序列並行monkey patch,改動attention計算
…
model.sequence_parallel_group = sequence_parallel_group # 維護模型的序列並行組,不開則為None
相應地,數據處理部分也要相應地修改,我們將 zigzag ring attention 所需的數據處理抽像成了一個 decorator,裝飾原來的數據處理函數。背後,這會將先 shuffle、packing、預處理好的數據進一步做好序列並行的準備:先將每行 pad 或截斷到指定的訓練長度,再按 zigzag 切分並按順序寫入數據集,最後在訓練時用 SequentialSampler 讀取訓練數據。
# src/llamafactory/data/loader.py
@sequence_parallel_decorator
def get_dataset(…)
loss 計算則需要在 Trainer 中做序列並行組內的 reduce 彙總和計算。
# src/llamafactory/train/sft/trainer.py
dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(label_num, op=dist.ReduceOp.SUM, group=sp_group)
loss /= label_num
# src/llamafactory/train/dpo/trainer.py
dist.all_reduce(policy_chosen_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(policy_rejected_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(reference_chosen_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(reference_rejected_logps, op=dist.ReduceOp.SUM, group=sp_group)
3.2 360-LLaMA-Factory 的 SFT 和 DPOTrainer
除了統一的模塊化抽像,序列並行也需要對 360-LLaMA-Factory 的 Trainer 稍做定製化的修改,以適配各底層庫。針對最普遍的後訓練需求 SFT 和 DPO(及其變種),我們對 360-LLaMA-Factory 中的 SFT 和 DPOTrainer 做了儘可能少且清晰的修改。
其中,dummy_forward 是因為我們發現基於目前的底層序列並行實現,在第一次 forward 時 DPO loss 不等於 log (sigmoid (0)),但學習率設為 0 時之後的 DPO loss 全都等於。因此,訓練最開始時先做且僅做一次假前傳,不對正式訓練循環造成任何影響。
從 SFT 和 DPO 的序列並行對比圖中,可以清晰地看出 360-LLaMA-Factory 序列並行帶來的改動。
圖 3:360-LLaMA-Factory SFT 序列並行對比
圖 4:360-LLaMA-Factory DPO 序列並行對比
4、360-LLaMA-Factory 效果驗證
內部 360-LLaMA-Factory 的早期版本已訓練了開源的 360Zhinao2-7B-Chat-360k。
為驗證本次開源的 360-LLaMA-Factory 的正確性,我們用總量為 30 條的小數據集,驗證了序列並行開與不開的對比情況下,訓練曲線的差別,以此來確保 360-LLaMA-Factory 所有實現的正確性。從下圖可見,序列並行對訓練曲線的影響幾乎可以忽略不計,DPO 稍有一定數值誤差,但我們也仔細檢查了該誤差與 DeepSpeed Ulysses 的誤差範圍一致,很可能部分是並行計算本身的隨機性導致的,亦可參考 ring-flash-attention 的詳細說明。
圖 5:360-LLaMA-Factory SFT 和 DPO 序列並行開關對比
為便於對比效果,我們基於第三方全尺寸開源模型粗粒度壓測了最大訓練長度,如下表 2、表 3 所示,可見 8 卡 80G 的序列並行上限已可滿足幾十至幾百 k 超長序列的需求:
表 2:第三方開源模型多尺寸 SFT 長度壓測
表 3:第三方開源模型多尺寸 DPO 長度壓測
5、總結
360 智腦開源了 360-LLaMA-Factory,支持了序列並行,僅需額外 1 個參數控制。基於 LLaMA-Factory 和 ring-flash-attention 開發,360-LLaMA-Factory 的實現模塊化、效果正確且在長序列上有效。
歡迎開發者們使用和開發。在本倉庫(https://github.com/Qihoo360/360-LLaMA-Factory)下提交序列並行相關的 issue 或 PR 即可。
也歡迎研究者們,尤其是依賴長序列大模型的研究者們,在研究中使用我們的代碼,可以這樣引用我們的工作:
@software{360-llama-factory,
author = {Haosheng Zou, Xiaowei Lv, Shousheng Jia and Xiangzheng Zhang},
title = {360-LLaMA-Factory},
url = {https://github.com/Qihoo360/360-LLaMA-Factory},
year = {2024}
}
建議同時引用 LLaMA-Factory 和 ring-flash-attention 相關工作。