英偉達又賺到了!FlashAttention3來了:H100利用率飆升至75%

機器之心報導

編輯:陳陳、小舟

740 TFLOPS!迄今最強 FlashAttention 來了。

隨著大型語言模型(LLM)加速落地,擴展模型上下文窗口變得越來越重要。然而,Transformer 架構的核心 —— 注意力層的時間複雜度和空間複雜度與輸入序列長度的平方成正比。這使得擴展模型上下文窗口存在挑戰。

2022 年,一種快速、內存高效的注意力算法 ——FlashAttention 問世,該算法無需任何近似即可加速注意力並減少內存佔用。

FlashAttention 對注意力計算進行重新排序的算法,並利用 tiling 和重計算來顯著加快計算速度,將內存使用量從序列長度的二次減少到線性。

2023 年,研究團隊宣佈推出 FlashAttention-2,在算法、並行化和工作分區等方面有了顯著改進。

現在,來自 Meta、英偉達、Together AI 等機構的研究者宣佈推出 FlashAttention-3,它採用了加速 Hopper GPU 注意力的三種主要技術:

  • 通過 warp-specialization 重疊整體計算和數據移動;

  • 交錯分塊 matmul 和 softmax 運算;

  • 利用硬件支持 FP8 低精度的不連貫處理。

FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍,高達 740 TFLOPS,即 H100 理論最大 FLOPS 利用率為 75%。使用 FP8,FlashAttention-3 的速度更是接近 1.2 PFLOPS。

FlashAttention-3 的改進將帶來:

  • 更高效的 GPU 利用率:H100 理論最大 FLOPS 利用率為 75%,而之前僅為 35%。這使得 LLM 的訓練和運行速度比以前的版本快得多。

  • 較低精度下更好的性能:FlashAttention-3 可以在保持精度的同時使用較低精度的數字 (FP8)。這可以實現更快的處理速度並可能降低內存使用量,從而為運行大規模人工智能操作的客戶節省成本並提高效率。

  • 能夠在 LLM 中使用更長的上下文:通過加速注意力機制,FlashAttention-3 使 AI 模型能夠更有效地處理更長的文本片段。這使得應用程序能夠理解並生成更長、更複雜的內容而不會減慢速度。

論文標題:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

論文地址:https://tridao.me/publications/flash3/flash3.pdf

論文作者之一 、FlashAttention1-3 版本的參與者 Tri Dao 表示:FlashAttention 被廣泛用於加速 Transformers,已經使注意力速度提高了 4-8 倍,但尚未利用現代 GPU。因而他們發佈了 FlashAttention-3:在 FP16 上速度提高了 1.5-2 倍,在 H100 上高達 740 TFLOPS(75% 實用性),FP8 接近 1.2 PFLOPS!

Hopper GPU 硬件特性:WGMMA、TMA、FP8

雖然 FlashAttention-2 在 Ampere (A100) GPU 上可以實現 70% 的理論最大 FLOPS,但它尚未利用 Hopper GPU 上的新功能來最大限度地提高性能。接下來文章描述了一些新的 Hopper 特定功能,以及它們為何如此重要。

首先是 WGMMA(Warpgroup Matrix Multiply-Accumulate),該功能利用了 Hopper 架構上新的張量內核,比 Ampere 架構具有更高的吞吐量。

然後是 TMA(Tensor Memory Accelerator),這是一個特殊的硬件單元,可以加速全局內存和共享內存之間的數據傳輸,用於處理所有索引計算和邊界外預測。這樣一來寄存器就釋放了,寄存器是增加 tile 大小和效率的寶貴資源。

低精度 FP8,讓 Tensor Core 吞吐量翻了一倍。

FlashAttention-3 充分利用了 Hopper 架構的所有這些新功能。

異步:GEMM 和 Softmax 重疊

注意力機制主要有兩個操作,GEMM 和 softmax。為什麼要將它們重疊?

問題在於在現代加速器上,非矩陣乘法(matmul)運算比矩陣乘法運算慢。特殊函數如指數運算(如 softmax 函數)的吞吐量甚至低於浮點乘加操作;這些運算是由多功能單元處理的,這是一個與浮點乘加或矩陣乘加不同的單元。

理想情況下,研究者希望矩陣乘法和 softmax 能夠並行操作。當 Tensor Cores 忙於矩陣乘法時,多功能單元應當在計算指數運算! 

Inter-warpgroup 重疊

重疊 GEMM 和 softmax 最簡單的方法是什麼都不做,warp 調度程序會免費完成部分重疊。下圖說明了 pingpong 調度,其中相同的顏色表示相同的迭代。

Intra-warpgroup 重疊

即使在一個 warpgroup 中,研究者也可以在運行該 warpgroup 的 GEMM 時運行 softmax 的某些部分。如圖所示,相同的顏色表示相同的迭代。

這種 pipeline 流程可以將 FP16 注意力前向傳播的吞吐量從大約 620 TFLOPS 提高到 640-660 TFLOPS,但代價是更高的寄存器壓力,因而需要更多的寄存器來同時保存 GEMM 的累加器以及 Softmax 的輸入 / 輸出。

低精度:使用非相干處理減少量化誤差

激活 LLM 可能存在一些極端值,導致量化困難,從而產生較大的量化誤差。本文採用非相干處理(incoherent processing),該技術通過將查詢和鍵與一個隨機正交矩陣相乘來「分散(spread out)」極端值,從而減少量化誤差。特別地,該研究使用了 Hadamard 變換,它可以在每個注意力頭中以 O (d log d) 的時間複雜度完成,而不是 O (d^2),其中 d 是頭部維度。

研究者發現非相干處理可以將量化誤差減少很多,具體的數值誤差比較見下表。

實驗

文中展示了 FlashAttention-3 的一些結果,並將其與 FlashAttention-2 以及 Triton 和 cuDNN 中的實現進行了比較(兩者都已經使用了 Hopper GPU 的新硬件功能)。

在 FP16 精度下,FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍。

對於 FP8,FlashAttention-3 接近 1.2 PFLOPS。