擴散LLM推理用上類GRPO強化學習!優於單獨SFT,UCLA、Meta新框架d1開源

機器之心報導

編輯:陳陳、杜偉

大語言模型的推理能力,不再是 AR(自回歸)的專屬。擴散模型現在也能「動腦子」,新框架 d1 讓它們學會瞭解數學、懂邏輯、會思考。

當前,強化學習(RL)方法在最近模型的推理任務上取得了顯著的改進,比如 DeepSeek-R1、Kimi K1.5,顯示了將 RL 直接用於基礎模型可以取得媲美 OpenAI o1 的性能。

不過,基於 RL 的後訓練進展主要受限於自回歸的大語言模型(LLM),它們通過從左到右的序列推理來運行。

與此同時,離散擴散大語言模型(dLLM)成為有潛力的語言建模的非自回歸替代。不像以因果方式逐 token 生成文本的自回歸模型那樣,dLLM 通過迭代去噪過程生成文本,在多步驟操作中優化序列的同時並通過雙向注意力利用過去和未來的上下文。其中,LLaDA 等開放的掩碼 dLLM 實現了媲美同尺寸自回歸模型的性能,而 Mercury 等閉源 dLLM 進一步展現了出色的推理延遲。

然而,頂級的開源 dLLM 並沒有使用 RL 後訓練,使得這一有潛力的研究方向還有很大的挖掘空間。這一範式轉變引出了重要的問題:RL 後訓練如何在非自回歸上下文中高效地實現?

RL 算法適應掩碼 dLLM 面臨一些獨特的挑戰,原因在於自回歸模型採用的已有方法(如 PPO、GRPO)通過計算生成序列的對數概率來估計和優化策略分佈,導致無法直接應用於 dLLM。雖然這種計算在自回歸模型中通過序列因式分解很容易實現,但 dLLM 由於它們的迭代、非序列生成過程而缺乏這種自然分解。

為瞭解決這些問題,來自 UCLA 和 Meta AI 的研究者提出了一個兩階段後訓練框架 d1,從而可以在掩碼 dLLM 中進行推理。在第一階段,模型在高質量推理軌跡中進行監督微調;在第二即 RL 階段,研究者引入了用於掩碼 dLLM 的新穎策略梯度方法 diffu-GRPO,它利用提出的高效一步(one-step)對數概率估計在 GRPO 的基礎上創建。

研究者表示,他們的估計器利用了隨機提示詞掩碼,作為策略優化的一種正則化,使得可以擴展 per batch 的梯度更新數量並減少 RL 訓練所需的在線生成數量。這將極大地降低計算時間。

  • 論文標題:d1: Scaling Reasoning in Diffusion Large Language Models via Reinforcement Learning

  • 論文地址:https://arxiv.org/pdf/2504.12216

  • 項目主頁:https://dllm-reasoning.github.io/

  • GitHub 地址:https://github.com/dllm-reasoning/d1

在實驗部分,研究者使用 LLaDA-8B-Instruct 作為基礎模型實例化 d1。他們將 d1-LLaDA 的性能與基礎 LLaDA 模型以及僅使用 SFT 和僅使用 diffu-GRPO 訓練的 LLaDA 模型進行比較。結果表明,d1 在四個數學和邏輯推理基準測試中始終優於基礎模型,如下圖 1 所示。d1-LLaDA 同樣優於僅使用 SFT 方法和僅使用 diffu-GRPO 方法的模型。

方法概覽

d1 是一個兩階段框架,通過依次結合監督微調(SFT)和在線強化學習(RL)來增強預訓練掩碼 dLLMs 的推理性能。

其中,在線強化學習(特別是 GRPO 算法)已被證明能有效提升離線訓練語言模型的性能。然而,GRPO 的學習策略並不能直接泛化到 dLLMs。

GRPO 的目標函數(如公式 3 所示)需要同時計算當前策略 π_θ 和舊策略 π_θold 在以下兩個層面的(對數)似然比:

token 層面(用於優勢權重計算);

序列層面(用於反向 KL 散度項)。

核心問題在於:研究者需要高效計算 dLLMs 生成內容的逐 token 對數概率和序列對數概率。

自回歸(AR)模型,如 Transformer,直接對每個 token 的對數概率進行建模,並且可以通過鏈式法則使用一次前向傳遞輕鬆計算出序列級別的對數概率

。同樣,KL 項可以分解為

與 AR 模型不同,dLLMs 不遵循序列對數概率的順序分解。同時,每個 token 的對數概率計算成本也很高,因為解碼過程中需要多次調用掩碼預測器 f_θ。基於此,該研究提出了一個高效的對數概率估計器。

對於序列對數概率,該研究使用均場近似方法,將其分解為獨立的每個 token 對數概率的乘積。

對於每個 token 的對數概率,該研究引入了一種估計方法,該方法僅調用一次 f_θ。

基於新引入的對數概率估計器,該研究將 GRPO 擴展到掩碼 dLLMs,推導出 diffu-GRPO 的損失函數。

算法如下圖所示。

實驗結果

表 1 報告了基線模型 LLaDA-8B-Instruct 與採用不同後訓練優化方案的模型,在四項任務上的零樣本性能對比。

圖 3 繪製了有效 token 的平均數量:圖 3 繪製了有效 token 的平均數量:

基於實驗,該研究得出以下主要發現:

diffu-GRPO 在所有 12 種設置中都一致優於基礎的 LLaDA 和 SFT(監督式微調)。diffu-GRPO 和 SFT 都相較於 LLaDA-8B-Instruct 基線有所提升,但 diffu-GRPO 顯示出更持續且幅度更大的增益。具體來說,diffu-GRPO 在所有 12 種設置中都優於 LLaDA-8B-Instruct 和 SFT,而 SFT 僅在其中的 7 種設置中優於 LLaDA-8B-Instruct,這表明 diffu-GRPO 相比於單獨的 SFT 實現了更強的整體性能提升

LLaDA+diffu-GRPO 在所有設置中都優於基礎的 LLaDA-8B-Instruct 模型,而 d1-LLaDA 在每種情況下都超過了 LLaDA+SFT。這表明,無論初始化是來自預訓練模型還是經過 SFT 調整的檢查點,diffu-GRPO 都能提供可靠的性能提升

d1 訓練方案實現了最顯著的性能提升。通過先進行監督微調(SFT)、再結合 diffu-GRPO 訓練所形成的 d1-LLaDA 模型,產生了超越單一方法的疊加增益。這種組合式方法在 12 個實驗設置中有 11 項優於純 diffu-GRPO 方案,表明兩個訓練階段存在協同效應。

定性結果表明,在 SFT 和 d1-LLaDA 生成中出現了頓悟時刻。儘管與 LLaDA-8B-Instruct 相比,生成序列長度為 128 和 256 的性能隨著 SFT、diffu-GRPO 和 d1 有所提高,但從質的方面看,在生成的推理軌跡中並未觀察到顯著差異。然而當序列長度達到 512 時,該研究開始觀察到 SFT 和 d1-LLaDA 模型展現出兩種關鍵能力:自我修正機制和回溯行為。