微調大模型,AMD MI300X就夠了!跟著這篇博客微調Llama 3.1 405B,效果媲美H100

機器之心報導

機器之心編輯部

隨著 AI 模型的參數量越來越大,對算力的需求也水漲船高。

比如最近,Llama-3.1 登上了最強開源大模型的寶座,但超大杯 405B 版本的內存就高達 900 多 GB,這對算力構成了更加苛刻的挑戰。

如何降低算力的使用成本和使用門檻,已經成為許多公司尋求突破的關鍵。Felafax 就是其中的一家創業公司,致力於簡化 AI 訓練集群的搭建流程。

Nikhil Sonti 和 Nikhin Sonti 創立了 Felafax,他們的口號是在構建開源 AI 平台,為下一代 AI 硬件服務,將機器學習的訓練成本降低 30%。

與英偉達相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的性價比,按每美元計算,其性能表現更為出色。

最近,Felafax 的聯合創始人 Nikhil Sonti 發佈了一篇博客,詳細分享了如何通過 8 張 AMD MI300X GPU 和 JAX 微調 LLaMA 3.1 405B 模型的方法,所有代碼現已開源。

Github 鏈接:https://github.com/felafax/felafax

機器之心對博客內容進行了不改變原意的編譯、整理,以下是博客內容:

JAX 尤其適合非英偉達硬件

JAX 是一個強大的機器學習庫,結合了類似 NumPy 的 API、自動微分功能以及 Google 的 XLA 編譯器。它在模型並行化方面提供了優秀的 API,因此非常適合像 LLaMA 3.1 405B 這樣的超大模型訓練。

在使用 AMD 硬件時,JAX 有幾個明顯的優勢:

  • 多硬件並行支持:JAX 採用 XLA(加速線性代數)編譯器,將計算編譯為硬件無關的中間表示(HLO),這意味著同樣的 JAX 代碼無需修改便可高效運行在不同硬件後端,包括 AMD GPU。

  • 獨立於底層硬件:XLA 編譯器的優化策略是通用的,不針對某個特定的硬件平台。這使得任何支持 XLA 的硬件設備(如 CPU、GPU、TPU)都能受益於這些優化,獲得更好的性能表現。

  • 極高的適應性:從 NVIDIA 轉移到 AMD(或其他硬件)時,JAX 只需做極少的代碼改動。而相較之下,PyTorch 與英偉達的 CUDA 生態系統緊密耦合,遷移過程相對複雜。

因此,JAX 成為了我們在非英偉達硬件上的最佳選擇。

拉取 Docker 鏡像:

docker pull rocm/jax:latest

啟動 Docker 容器:

# Pull the Docker Image:docker pull rocm/jax:latest 
# Start the Docker Container:docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest
# Verify the Installation: python3 -c 'import jax; print(jax.devices())'

驗證安裝

python3 -c 'import jax; print (jax.devices ())'

訓練使用了一個配備了 8 張 AMD MI300x GPU 的 AMD 節點。每張 MI300x 擁有 192GB 的 HBM3 內存,性能表現與最新的英偉達 H100 GPU 相比非常出色。

與英偉達 H100 的比較,來源:TensorWave

與英偉達 H100 的比較,來源:TensorWave

訓練 LLaMA 405B:性能與可擴展性

使用 JAX,可以成功地在 AMD GPU 上訓練 LLaMA 405B 模型。我們使用 LoRA 微調,將所有模型權重和 LoRA 參數都設為 bfloat16,LoRA rank 設為 8,LoRA alpha 設為 16:

  • 模型大小:LLaMA 模型的權重佔用了約 800GB 的顯存。

  • LoRA 權重 + 優化器狀態:大約佔用了 400GB 的顯存。

  • 顯存總使用量:佔總顯存的 77%,約 1200GB。

  • 限制:由於 405B 模型的規模過大,batch 大小和序列長度的空間有限,使用的 batch size 為 16,序列長度為 64。

  • JIT 編譯:由於空間限制,無法運行 JIT 編譯版本;它可能需要比急切模式稍多的空間。

  • 訓練速度:使用 JAX 急切模式,約為 35 tokens / 秒。

  • 內存效率:穩定在約 70% 左右。

  • 擴展性:在 8 張 GPU 上,使用 JAX 的擴展性接近線性。

由於硬件和顯存的限制,我們無法運行 JIT 編譯版本的 405B 模型,整個訓練過程是在 JAX 的急切模式下執行的,因此還有很大的進步空間。 

下圖中顯示了在一次微調訓練步驟中,8 張 GPU 的顯存利用率和 rocm-smi 輸出:

GPU 利用率:

顯存利用率:

rocm-smi 輸出:

訓練設置 

將 LLaMA 3.1 從 PyTorch 移植到 JAX 

此前,Nikhil Sonti 分享過如何將 LLaMA 3.1 從 PyTorch 移植到 JAX。他指出,目前 90% 的大型語言模型(LLM)都運行在 NVIDIA GPU 上,但實際上還有一些同樣強大且性價比更高的替代方案。例如,在 Google TPU 上訓練和部署 Llama 3.1 的成本比 NVIDIA GPU 低約 30%。

然而,支持非 NVIDIA 硬件的開發工具較為匱乏。Sonti 最初嘗試使用 PyTorch XLA 在 TPU 上訓練 Llama 3.1,但過程並不順利。XLA 與 PyTorch 的集成不夠完善,缺少一些關鍵的庫(如 bitsandbytes 無法正常運行),同時還遇到了一些難以解決的 HuggingFace 錯誤。

為此,他決定調整策略,將 Llama 3.1 從 PyTorch 移植到 JAX,成功解決了這些問題。Sonti 還錄製了詳細的教程影片,並開源了所有代碼:

  • 方法演示:https://dub.sh/felafax-demo

  • 代碼倉庫:https://github.com/felafax/felafax

加載模型,並把模型參數分片

處理像 LLaMA 405B 這樣的超大模型,需要在多個設備之間高效地進行參數分片。以下是如何通過 JAX 實現這一點的。

在 JAX 中進行參數分片

為了將巨大的 LLaMA 405B 模型高效地分佈到 8 張 AMD GPU 上,需要使用 JAX 的設備網格(device mesh)功能。

部署代碼:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69

JAX 的設備網格可以幫助我們把可用的設備組織成一個網格,讓我們可以指定如何把模型的參數和計算分配到不同的 GPU 上。

在本文的設置中,需要創建一個形狀為(1, 8, 1)的網格,並將軸分別命名為數據並行(dp)、全分片數據並行(fsdp)和模型並行(mp)。然後,為模型的每個張量定義特定的分片規則,指定這些維度如何沿著這些網格軸進行分片。

DEVICES = jax.devices () DEVICE_COUNT = len (DEVICES) DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1)) MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))

可視化分片

可以使用以下代碼來可視化分片結果,從而方便地驗證分片規則是否按預期應用。

jax.debug.visualize_array_sharding 

分片規則

模型不同組件的分片規則如下所示:

  • 參數如何分片:

參數要在 8 個 GPU 之間分配。例如,LM head(lm_head/kernel)張量有兩個軸,按照 PS (“fsdp”, “mp”) 進行分片。在本例中是 8 和 1,因此可以看到該張量在第一個軸上沿著 8 個 GPU 被拆分。

  • Non-Replicated 參數:

沒有任何分片規範的參數會在所有設備上進行複製。例如,層歸一化(attention_norm/kernel 和 ffn_norm/kernel)沒有設置分片規範,是 PS (None)。

應用分片函數

在加載模型時,使用以下分片函數逐步對模型權重進行分片:

def make_shard_and_gather_fns (partition_specs):    def make_shard_fn (partition_spec):        out_sharding = NamedSharding (mesh, partition_spec)        def shard_fn (tensor):            return jax.device_put (tensor, out_sharding).block_until_ready ()        return shard_fn

    shard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)    return shard_fns

# Create shard functions based on partitioning rulesshard_fns = make_shard_and_gather_fns (partitioning_rules)

這使得我們能夠將每個參數放置在指定的設備上,並按照設定的分片進行處理。

分片訓練 Batch

最初,訓練 Batch 是正常創建的,但在輸入模型之前,需要按照下面的代碼在 GPU 上進行分片:

train_batch = jax.device_put ( train_batch, NamedSharding (self.mesh, PS ("dp", "fsdp")))

在這裏,我們指定訓練 Batch 應該在 “dp” 和 “fsdp” 軸上進行分片,在本例中分別對應於被分成 1 和 8 份,如果把結果可視化出來,如下所示:

分片前:

在調用  jax.device_put 之後:

加入 LoRA

LoRA 通過將權重更新分解為低秩矩陣,減少了可訓練參數的數量,這對於微調大型模型特別有效。以下是在 AMD GPU 上微調 Llama 3.1-405 的 LoRA 的要點:

  • 將 LoRA 參數(lora_a 和 lora_b)與主模型參數分開。

  • 使用 jax.lax.stop_gradient (kernel) 來防止對主模型權重的更新。

  • 使用 lax.dot_general 進行快速、精確控制的矩陣運算。

  • LoRA 輸出在添加到主輸出之前會被縮放為 (self.lora_alpha/self.lora_rank)。

LoRADense 層

在此設定一個自定義的 LoRADense 層,該層集成了 LoRA 參數:

class LoRADense (nn.Module):    features: int    lora_rank: int = 8    lora_alpha: float = 16.0@nn.compactdef __call__(self, inputs: Any) -> Any:# Original kernel parameter (frozen)        kernel = self.param ('kernel', ...)        y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)# LoRA parameters (trainable)        lora_a = self.variable ('lora_params', 'lora_a', ..., ...)        lora_b = self.variable ('lora_params', 'lora_b', ..., ...)# Compute LoRA output        lora_output = lax.dot_general (inputs, lora_a.value, ...)        lora_output = lax.dot_general (lora_output, lora_b.value, ...)# Combine original output with LoRA modifications        y += (self.lora_alpha/self.lora_rank) * lora_output

        return y.astype (self.dtype)

分片 LoRA 參數

為了高效地在設備之間分配 LoRA 參數,我們也通過 JAX 設定了分片規則,這確保了 LoRA 參數與主模型參數的分片一致,優化了內存使用和計算效率。

LoRA A matrices (lora_a)

LoRA A 矩陣(lora_a)

  • 分片規則:PS (“fsdp”, “mp”)

  • 可視化結果:如下圖所示,lora_a 參數被分片為 (8, 1),這意味著第一個軸在 8 個設備上進行分片(”fsdp” 軸),而第二個軸未進行分片。

LoRA B 矩陣(lora_b)

  • 分片規則:PS (“mp”, “fsdp”)

  • 可視化結果:如下圖所示,lora_b 參數被分片為 (1, 8),這意味著第二個軸在 8 個設備上進行分片(fsdp 軸),而第一個軸未進行分片。

這種分片策略優化了參數的分配,減少了通信開銷,並在訓練過程中增強了並行性。它確保每個設備僅持有一部分 LoRA 參數,使得大模型如 LLaMA 405B 的高效擴展成為可能。

僅更新 LoRA 參數 

為了優化訓練,在微調 LLaMA 405B 模型,只計算 LoRA 參數的梯度,保持主模型參數不變。這個方法減少了內存使用,並加速了訓練,因為只更新較少的參數。可以移步 GitHub 倉庫,查看實現細節。

在訓練過程中,每一步都涉及將一批輸入數據通過模型進行處理。由於只有 LoRA 參數是可訓練的,因此模型的預測和計算的損失僅依賴於這些參數,然後對 LoRA 參數進行反向傳播。只更新這些參數簡化了訓練過程,使得在多個 GPU 上高效微調像 LLaMA 405B 這樣的大型模型成為可能。

更多研究細節,請參考原博客。

© THE END