矩陣乘法可以算得更快了!港中文10頁論文證明:能源、時間均可節省
金磊 發自 凹非寺
量子位 | 公眾號 QbitAI
天下苦大模型矩陣乘法久矣。
畢竟不論是訓練還是推理過程,矩陣乘法作為最主要的計算操作之一,往往都需要消耗大量的算力。
那麼就沒有一種更「快、好、省」的方法來搞這事情嗎?
有的,香港中文大學最新一篇僅10頁的論文,便提出了一種新算法:
-
能源可節省:5%-10%
-
時間可節省:5%

論文作者之一的Dmitry Rybin表示:
這項研究對數據分析、芯片設計、無線通信和LLM訓練都有著深遠的影響!

這麼算矩陣乘法,更快!
矩陣乘法是計算機科學和數值線性代數中的核心問題之一。
自從Strassen和Winograd的開創性工作以來,研究者們一直在探索如何減少矩陣乘法所需的計算量。
儘管這類運算在統計、數據分析、深度學習和無線通信等領域有著廣泛應用,例如協方差矩陣的計算和線性回歸中的關鍵步驟,但對於具有特殊結構的矩陣乘法(如計算矩陣與其轉置的乘積XXt)的研究相對較少。
從理論角度看,計算XXt與一般矩陣乘法具有相同的漸近複雜度,因此只能通過常數因子優化來提升速度。
因此,這篇論文《XXt Can Be Faster》提出了一種名為RXTX的新算法,通過結合機器學習搜索方法和組合優化技術,顯著提升了XXt的計算效率。

我們先來瞭解一下RXTX。
整體來看,這個基於4×4分塊矩陣的遞歸乘法,通過機器學習搜索與組合優化相結合的方法發現。
算法主要包含以下關鍵步驟:
-
分塊與遞歸調用
:將矩陣X劃分為16個4×4子塊,通過8次遞歸調用處理子問題,並計算26個一般矩陣乘積m1至m26。

-
對稱乘積計算
:直接計算8個子塊的對稱乘積s1至m8。
-
結果組合
:通過線性組合上述乘積結果,得到最終的XXt矩陣各分塊元素C11至C44。


與此前最先進的算法(基 Strassen的遞歸分治)相比,RXTX的遞歸關係式為 R(n)=8R(n/4) + 26M(n/4),而原算法為 S(n) = 4S(n/2) + 2M(n/2)。
這一設計使得RXTX的漸近乘法常數為 26/41≈0.6341,比原算法的2/3≈0.6667降低了約5%。
接下來,我們來看下乘法次數與運算總量分析。
通過論文中的定理1的推導,RXTX的乘法次數表達式為:

實驗數據表明,當n為4的冪次時,RXTX的乘法次數比原算法低5%,且隨著n增大,這一優勢持續保持:


通過優化加法步驟(利用公共子表達式減少加法次數),RXTX的總運算量表達式為:

而原算法的總運算量包含對數項,導致其增長更快。
實驗顯示,當n≥256時,RXTX的總運算量優於原算法;當n≥1024時,顯著優於樸素算法:


在6144×6144矩陣的測試中,RXTX的平均運行時間為2.524秒,比BLAS的預設實現快9%,且在99%的測試中表現更優:

儘管運行時間受硬件和內存管理影響,但理論分析表明,當n≥256時,RXTX即可展現速度優勢。
值得一提的是,RXTX的發現得益於機器學習與組合優化的結合,具體流程如下:
-
RL代理生成候選乘積:通過強化學習策略生成大量可能的秩-1雙線性乘積。
-
MILP枚舉與篩選:
-
MILP-A:枚舉候選乘積與目標表達式(XXt的各分塊)之間的線性關係。
-
MILP-B:選擇最小的乘積子集,確保所有目標表達式可通過線性組合表示。
-
大鄰域搜索迭代:通過迭代優化,逐步減少冗餘乘積,提升算法效率。
這一方法借鑒了AlphaTensor的思路,但通過限制候選空間為二維張量,顯著降低了計算複雜度,使得MILP求解器(如 Gurobi)能夠高效處理。
論文地址:
https://arxiv.org/abs/2505.09814
參考鏈接:
[1]https://x.com/DmitryRybin1/status/1923349883945181392
[2]https://x.com/vikhyatk/status/1923541713618129273