剛剛,Thinking Machines Lab首次發長文,揭開LLM推理不確定性真相
就在今天,由 OpenAI 前 CTO Mira Murati 成立于今年 2 月的人工智能初創公司 Thinking Machines Lab,發了第一篇文章 ——《克服 LLM 推理中的不確定性》(Defeating Nondeterminism in LLM Inference)。
這篇博客屬于 Thinking Machines Lab 新提出的博客欄目 Connectionism,意為「連接主義」。該公司表示:「我們相信,分享才能讓科學更好地發展。Connectionism 將涵蓋與我們的研究一樣廣泛的主題:從核函數數值計算到提示工程。Connectionism 這一名稱可以追溯到 AI 的早期年代。它曾是 20 世紀 80 年代的一個研究分支,專注于神經網絡及其與生物大腦的相似性。」
此外,Thinking Machines Lab 聯合創始人、著名技術博主翁荔(Lilian Weng)還在轉推中透露了一個消息,Connection Machine,即「連接機」,難道他們的產品要來了?
真是讓人期待呢。
地址:https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/
博客主要作者為 Horace He,這位 PyTorch 核心開發者于今年 3 月從 Meta 離職,加入了 Thinking Machines。
接下來看博客原文內容。
可復現性(reproducibility)是科學進步的基石。然而,從大語言模型中獲得可復現的結果卻非常困難。
例如,你可能會發現:即使是向 ChatGPT 提出同一個問題多次,也可能得到不同的回答。這本身并不令人意外,因為語言模型生成結果的過程涉及采樣 —— 這個過程會將模型的輸出轉換為一個概率分布,并以概率方式選擇一個 token。
更令人驚訝的是,即使我們將溫度參數調到 0(理論上使采樣過程變為確定性),大語言模型的 API 在實際中仍然不是確定性的。研究者已經對此有諸多討論。
即使是在你自己的硬件上,使用開源推理庫(如 vLLM 或 SGLang)運行推理,采樣過程依然不是確定性的。
為什么大語言模型的推理引擎不是確定性的呢?
一個常見的假設是:浮點運算的非結合性(non-associativity)與并發執行的某種組合會導致不確定性,這取決于哪個并發核心首先完成。我們將這種解釋稱為「LLM 推理不確定性的『并發 + 浮點』假設」。例如,一篇最近的 arXiv 論文(arXiv:2506.09501)寫道:
GPU 中的浮點運算具有非結合性(non-associativity),意味著 (a+b)+c≠a+(b+c),這是由于精度有限和舍入誤差所致。這一特性會直接影響 transformer 架構中注意力得分和 logit 的計算,因為在多線程中進行的并行操作,其執行順序不同會導致結果差異。
雖然這個假設并不完全錯誤,但它并沒有揭示事情的全貌。
例如,即使在 GPU 上,對相同的數據反復進行相同的矩陣乘法運算,每次的結果也都是每一位都相同的。我們確實在使用浮點數,GPU 也確實具有高度并發性。
那為什么在這個測試中卻看不到不確定性呢?
要理解大語言模型推理不確定性的真正原因,我們必須更深入地探究。
不幸的是,甚至連「LLM 推理是確定性」的這一說法的定義都很難明確。或許令人困惑的是,以下這些看似矛盾的說法實際上同時都是真實的:
- GPU 上的一些核(kernel)是不確定性的。
- 然而,語言模型在前向傳播過程中使用的所有核都是確定性的。
- 此外,像 vLLM 這樣的 LLM 推理服務器的前向傳播過程,也可以被認為是確定性的。
- 盡管如此,從使用推理服務器的任何用戶的角度來看,結果卻是不確定性的。
在這篇文章中,我們將解釋為什么「并發 + 浮點」假設沒有達到目的,揭露 LLM 推理不確定性背后的真正罪魁禍首,并解釋如何克服不確定性并在 LLM 推理中獲得真正可重復的結果。
原罪:浮點數的非結合性
在討論不確定性之前,有必要先解釋一下為什么存在數值差異。畢竟,我們通常將機器學習模型視為遵循交換律或結合律等結構性規則的數學函數。我們的機器學習庫難道不應該提供數學上正確的結果嗎?
罪魁禍首是浮點非結合性(floating-point non-associativity)。也就是說,對于浮點數 a、b、c,有:
諷刺的是,正是打破結合律讓浮點數變得有用。
浮點數之所以有用,是因為它們允許動態的精度。為了便于解釋,我們將使用十進制(而不是二進制),其中浮點數的格式為:尾數 * 10^ 指數。這里還將使用 3 位數字作為尾數,1 位數字作為指數。(注:在計算機科學中,尾數(mantissa,或有效數)是浮點數中用來表示精度的部分,它決定了數字的有效數字位數和精度。)
例如,對于值 3450,我們可以將其精確表示為 3.45*10^3。我們也可以將更小的值(例如 0.486)表示為 4.86*10^-1。這樣,浮點數既可以表示非常小的值,也可以表示非常大的值。在科學領域,我們可以說浮點數使我們能夠保持有效數的個數恒定。
如果兩個浮點數的指數相同,它們的加法運算看起來與整數加法類似。例如:
但是,如果兩個浮點數的指數不同,例如 1230 和 23.4,又會發生什么情況呢?理論上,它們的和應該是 1253.4。然而,由于浮點數運算只能保留 3 位有效數字,因此結果會被舍入為 1.25×103(或 1250)。
表示 1230 需要 3 位有效數字,表示 23.4 也需要 3 位有效數字。但是,這兩個數相加的結果(1253.4)卻需要 5 位有效數字才能精確表示。因此,我們的浮點數格式必須舍棄最后兩位(34)。某種程度上,這相當于我們在相加之前,將原來的 23.4 四舍五入為 20.0。
然而,這樣做會導致信息丟失。請注意,只要我們對兩個不同階位(即不同指數)的浮點數進行加法運算,就會發生這種情況。而實際應用中,我們經常需要對不同指數的浮點數進行加法運算。事實上,如果我們能夠保證所有浮點數的指數都相同,那么我們完全可以只使用整數!
換句話說,每次以不同順序相加浮點數時,結果都有可能完全不同。舉個極端的例子,對于某個數組,根據加法順序的不同,其求和結果可能出現 102 種不同的結果。
雖然這是導致輸出結果不一致的根本原因,但它并不能直接解釋不確定性行為的來源。它也無法幫助我們理解為什么浮點數的加法順序會改變、這種情況在什么時候發生、以及我們如何避免它。
答案藏在核函數(kernel)的實現方式中。
為什么核函數計算中數字加法順序并非總是固定的?
如前所述,解釋核函數計算中數字加法順序不一致的一個常見原因是「并發性 + 浮點運算」假設。
該假設認為,如果并發線程的執行順序是不可預測的,并且累加操作的順序依賴于并發線程的執行順序(例如原子加法 /atomic adds),那么最終的累加結果也會變得不可預測。
然而,令人困惑的是,盡管這種現象會導致核函數計算結果的不確定性,但并發機制(以及原子加法)實際上與大型語言模型推理中的不確定性無關!
為了解釋真正的罪魁禍首是什么,我們首先需要了解為什么現代 GPU 核函數很少需要使用原子加法。
什么時候需要使用原子加法操作?
GPU 通常會同時在多個核心(即流處理器)上并行運行程序。由于這些核心之間沒有內置同步機制,因此如果它們需要相互通信,就會很麻煩。例如,如果所有核心都需要對同一個元素進行累加,就可以使用原子加法(有時也稱為 fetch-and-add)。原子加法是不確定性的,結果的累加順序完全取決于哪個核心先完成計算。
具體來說,假設你要使用 100 個核心對一個包含 100 個元素的向量進行求和(例如 torch.sum ())。雖然可以并行加載所有 100 個元素,但最終我們必須將結果匯總為一個值。一種實現方法是使用某種原子加法操作,硬件保證所有加法操作都會執行,但并不保證執行順序。
原子加法操作可以確保每個核心的計算結果都能最終反映在總和中。但是,它并不能保證這些結果的累加順序。累加順序完全取決于哪個核心先完成計算,這是一種不確定性行為。
因此,多次執行相同的并行程序可能會產生不同的結果。這通常就是人們所說的不確定性,即,使用完全相同的輸入數據執行兩次相同的程序,但最終結果卻可能不同。這被稱為運行間不確定性(run-to-run nondeterminism),例如,運行兩次完全相同的 Python 腳本,即使依賴庫版本完全相同,結果也可能不同。
雖然并發的原子加法操作會使核函數的執行結果變得不可預測,但對于大多數核函數來說,原子加法并非必需。
事實上,在 LLM 的典型前向傳播過程中,通常根本不需要使用原子加法。這可能令人感到意外,因為并行化計算中的歸約操作通常可以從原子加法中獲益。但實際上,原子加法在大多數情況下并非必需,主要原因有兩點。
1. 通常情況下,批處理維度上的并行性已經足夠,因此我們無需在歸約維度上進行并行化。
2. 隨著時間的推移,大多數神經網絡庫都采用了各種策略,以在不犧牲性能的情況下實現結果的可預測性。
由于上述兩個因素,對于絕大多數神經網絡操作來說,不使用原子加法幾乎不會帶來性能損失。
當然,仍然有少數常見操作在不使用原子加法時會遭遇顯著的性能下降。例如,PyTorch 中的 scatter_add(即 a [b] += c)。不過,在大語言模型中唯一常用且依賴原子加法的操作,是 FlashAttention 的反向傳播(backward)。
然而,LLM 的前向傳播過程中并不涉及任何需要原子加法的操作。因此,LLM 的前向過程本質上是運行間確定的(即每次運行結果一致)。
維基百科上寫道:一個確定性算法是在給定特定輸入的情況下,始終產生相同輸出的算法。而在這里,只要輸入完全相同(即推理服務器處理的請求完全一致),前向傳播就總是會生成完全相同的輸出。
然而,前向傳播本身是確定性的并不意味著整個系統也是確定性的。比如,如果某個請求的輸出依賴于并行用戶的請求(例如 batch-norm 這樣的操作),那么由于每個請求都無法預知其他并發請求的內容,從單個請求的視角來看,整個 LLM 推理過程就會是不確定性的。
事實證明,我們的請求輸出確實依賴于其他并發用戶的請求。但這并不是因為跨 batch 泄露了信息,而是因為我們的前向傳播過程缺乏批次不變性(batch invariance),這導致同一個請求的輸出會受到前向傳播中 batch size(batch size)變化的影響。
批次不變性與確定性
為了說明什么是批次不變性,我們可以簡化問題,只關注矩陣乘法(matmul)。你可以假設所有的 matmul 實現都是運行間確定的,也就是說,同樣的輸入,每次運行都會得到相同的結果。
但它們并不是批次不變的。換句話說,當 batch size 發生變化時,batch 中的每個元素可能會得到不同的計算結果。
從數學角度來看,這是一種相當反常的性質。理論上,矩陣乘法在 batch 維度上應當是獨立的,batch 中其他元素的存在與否,或 batch 的大小,都不應影響某個具體元素的計算結果。
然而,我們通過實驗證據可以發現,現實情況并非如此。
請注意,這里的確定性是指每次運行結果都相同。如果你多次運行該腳本,它會始終返回相同的結果。
但是,如果將非批處理不變的核函數用作更大推理系統的一部分,則整個系統可能變得不確定性。當你向推理端點發送請求時,從用戶角度來看,服務器的負載情況是不可預測的。負載決定了核函數的 batch size,從而影響每個請求的最終結果。
如果你把某種核函數不具備不變性的屬性(例如:batch size)與該屬性本身的不確定性(例如:服務器負載情況)組合在一起,就會得到一個不確定性的系統。
換句話說,幾乎所有大語言模型推理端點之所以是不確定的,主要原因就是負載(以及由此決定的 batch size)本身具有不確定性!這種不確定性并非僅限于 GPU,使用 CPU 或 TPU 運行的 LLM 推理端點也會存在同樣的問題。因此,如果我們想避免推理服務器中的不確定性,就必須確保核函數對 batch size 具有不變性。
為了理解如何實現這一點,我們首先需要了解為什么核函數默認情況下并不具備批處理不變性。
我們如何使核具有批次不變性?
為了確保 Transformer 模型的實現與 batch size 無關,我們必須確保模型中的每個核心模塊都與 batch size 無關。幸運的是,我們可以假設每個逐點運算(pointwise operation)都與 batch size 無關。因此,我們只需要擔心涉及的 3 個操作:RMSNorm、矩陣乘法和注意力。
巧合的是,這些操作的難度正好是依次遞增的。要想在保持合理性能的同時實現批次不變性,每一種操作都需要一些額外的考量。我們先從 RMSNorm 開始談起。
RMSNorm
RMSNorm 實現方式:
批次不變性的要求是,無論核函數的 batch size 如何,每個元素的歸約順序都必須保持不變。需要注意的是,這并不意味著我們必須始終使用相同的歸約策略。例如,即使我們改變了要進行歸約的元素數量,只要歸約順序不變,我們的算法仍然可以滿足批處理不變性的要求。
因此,只有當 batch size 影響到歸約策略時,我們才會打破批次不變性。
讓我們來看一下 RMSNorm 的標準并行化策略。一般來說,并行算法都會從盡量減少核心之間的通信中獲益。在這里,為了方便討論,你可以假設我們所說的核心(cores)就是指 SM(Streaming Multiprocessors,流處理多處理器)。更具體地說,這里重要的性質是:核函數啟動的線程塊(threadblocks)數量多于 SM 的數量。
基于這一點,一種可行的策略就是:將每個 batch 元素分配給一個核心,就像上圖展示的那樣。
當我們增加 batch size 時,并不會影響歸約策略;如果 batch size = 200 已經能為核函數提供足夠的并行性,那么 batch size = 2000 顯然也同樣能夠提供足夠的并行性。
另一方面,減小 batch size 也會帶來一些挑戰。由于我們為每個批次元素分配一個核心,減小 batch size 會導致核心數量大于批次元素數量,從而造成部分核心閑置。遇到這種情況,優秀的核函數工程師會采用前面提到的解決方案之一(原子加法或分段求和),從而保持良好的并行性,進而提升性能。然而,這會改變求和策略,導致該核函數不再具備 batch size 不變的特性。
最簡單的解決方案就是直接忽略這些情況。這并不是完全不合理的,因為當 batch size 很小時,核函數通常本來就能很快執行,因此即使出現一些減速,也不會造成災難性的影響。
如果我們必須優化這種場景,一種方法是:始終使用一種在極小 batch size 下也能提供足夠并行度的歸約策略。這樣的策略會在 batch size 較大時導致過度并行,從而無法達到峰值性能,但它可以讓我們在整個 batch size 范圍內都獲得尚可(雖然不是最佳)的性能表現。
批次不變矩陣乘法
從本質上講,你可以把矩陣乘法看作是一次逐點運算后接一次歸約。那么,如果我們通過將輸出劃分為小塊來并行化矩陣乘法,就能得到一種類似的數據并行核函數策略,使得每一次歸約都在單個核心內完成。
與 RMSNorm 類似,矩陣乘法的批次維度(M 和 N)也可能變得過小,迫使我們必須沿歸約維度(K)進行拆分。盡管有兩個批次維度,矩陣乘法仍然需要每個核心有更多的工作量才能有效利用張量核心。例如,對于一個 [1024, K] x [K, 1024] 的矩陣乘法和一個標準的 [128, 128] 二維 tile 大小,數據并行策略最多只能將其分配到 64 個核心上,這不足以使 GPU 達到飽和。
在矩陣乘法中沿歸約維度進行拆分被稱為 Split-K 矩陣乘法。與 RMSNorm 的情況一樣,使用這種策略會破壞批次不變性。
矩陣乘法還有一個額外的復雜性,即張量核心指令。對于歸約操作,我們可以一次只處理一行;但高效的矩陣乘法核函數必須一次性操作一整個 tile。
每條張量核心指令(例如 wgmma.mma_async.sync.aligned.m64n128k16)在內部可能有不同的歸約順序。選擇不同張量核心指令的一個原因可能是 batch size 非常小。例如,如果我們使用的張量核心 PTX 指令操作的是一個長度為 256 的 tile,但 batch size 只有 32,那我們幾乎浪費了所有的計算資源!當 batch size 為 1 時,最快的核函數通常根本不使用張量核心。
因此,確保矩陣乘法批次不變性的最簡單方法是:編譯一個固定的核函數配置,并將其用于所有形狀的計算。盡管這會損失一些性能,但在 LLM 推理場景下,這種損失通常不是災難性的。特別是,Split-K 策略在 M 和 N 維度都很小時才最被需要,而幸運的是,在我們的應用場景中,N 維度(即模型維度)通常都相當大!
批次不變性注意力機制
在實現了矩陣乘法的批次不變性之后,注意力機制又引入了兩個額外的難題 —— 這也很貼切,因為它正好包含兩次矩陣乘法。
1. 與 RMSNorm 和矩陣乘法僅在特征維度上進行歸約不同,注意力機制現在需要在特征維度和序列維度上都進行歸約。
2. 因此,注意力機制必須處理各種影響序列處理方式的推理優化(例如分塊預填充、前綴緩存等)。
因此,為了在 LLM 推理中實現確定性,我們的數值計算必須對兩個因素保持不變:一是單次處理的請求數量,二是每個請求在推理引擎中的切分方式。
我們首先來了解一下注意力機制的標準并行策略,該策略最初由 FlashAttention-2 提出。與 RMSNorm 和矩陣乘法類似,其默認策略是數據并行策略。由于歸約是沿著鍵 / 值(K/V)張量進行的,因此數據并行策略只能沿著查詢(Q)張量進行并行化。
例如,根據推理引擎的選擇,一個序列可能被分成幾個部分處理(如在分塊預填充中),也可能一次性處理完畢(如果預填充未被分割)。為了實現批次不變性,對于一個給定的 token,其歸約順序必須獨立于其所在序列中同時被處理的其他 token 的數量。
如果你將 KV 緩存中的 K/V 值與當前正在處理的 token 的 K/V 值分開進行歸約(就像在 vLLM 的 Triton 注意力核函數中那樣),這個目標就無法實現。例如,在處理序列中的第 1000 個查詢 token 時,無論 KV 緩存中有 0 個 token(預填充階段)還是 999 個 token(解碼階段),其歸約順序都必須完全相同。
為解決此問題,我們可以在注意力核函數運行前就更新 KV 緩存和頁表,從而確保無論處理多少個 token,我們的鍵和值始終具有一致的內存布局。
加上這一額外處理(以及前文提到的所有措施,如使用一致的 tile 大小),我們便能實現一個批次不變性的注意力機制!
然而,這里存在一個重要問題。與矩陣乘法不同,LLM 推理中的注意力計算形狀通常確實需要一個拆分 - 歸約核函數(split-reduction kernel),這類核函數常被稱為 Split-KV 或 FlashDecoding。這是因為如果我們不沿著歸約維度進行并行,就只能沿著批次維度、頭維度和查詢長度維度進行并行。
在注意力的解碼階段,查詢長度非常小(通常為 1),因此除非 batch size 非常大,否則我們往往無法使 GPU 達到飽和狀態。不幸的是,這種情況不像在 RMSNorm 和矩陣乘法中那樣容易被忽略。例如,如果你的 KV 緩存非常長,即使只處理一個請求,注意力核函數的計算也可能耗時很長。
此外,常用于注意力的拆分 - 歸約策略也給批次不變性帶來了挑戰。例如,FlashInfer 的平衡調度算法會選擇能夠使 GPU 所有核心飽和的最大拆分大小,這使得其歸約策略并非批次不變的。然而,與 RMSNorm / 矩陣乘法不同,無論 batch size 如何,僅僅選擇一個固定的拆分數量是不夠的。
相反,為了實現批次不變性,我們必須采用固定拆分大小策略。換言之,我們固定的不是拆分的數量,而是每個拆分塊的大小,這樣最終會得到一個可變的拆分數量。通過這種方式,我們可以保證無論正在處理多少個 token,我們總是執行完全相同的歸約順序。
實現
我們基于 vLLM,通過利用其 FlexAttention 后端和 torch.Library,提供了一個確定性推理的演示。通過 torch.Library,我們能夠以一種非侵入式的方式替換掉大部分相關的 PyTorch 算子。
你可以在 thinking-machines-lab/batch-invariant-ops 找到「批次不變性」核函數庫,以及在「確定性」模式下運行的 vLLM 示例。
地址:https://github.com/thinking-machines-lab/batch_invariant_ops
實驗
完成結果的不確定性程度如何?
我們使用 Qwen3-235B-A22B-Instruct-2507 模型,在溫度為 0 的設置下,使用提示詞「Tell me about Richard Feynman」(非思考模式)采樣了 1000 次完成結果,每次生成 1000 個 token。
令人驚訝的是,我們得到了 80 個不同的完成結果,其中最常見的一個出現了 78 次。
通過觀察這些結果的差異,我們發現它們在前 102 個 token 上實際上是完全相同的!
首次出現差異是在第 103 個 token。所有的結果都生成了「Feynman was born on May 11, 1918, in」這個序列。然而,接下來,其中 992 次結果生成了「Queens, New York」,而另外 8 次則生成了「New York City」。
然而,當我們啟用批次不變性核函數后,全部 1000 次結果都變得完全相同。這正是我們期望采樣器應有的表現,但若不使用我們的批次不變性核函數,就無法實現確定性結果。
性能
目前,我們還沒有投入精力優化批次不變性核函數的性能。不過,我們還是進行了一些實驗來驗證其性能是否仍在可用范圍內。
我們搭建了一個配備單塊 GPU 的 API 服務器,運行 Qwen-3-8B 模型,并請求生成 1000 個序列,輸出長度控制在 90 到 110 個 token 之間。
性能下降的主要原因在于 vLLM 中的 FlexAttention 集成尚未經過深度優化。盡管如此,我們看到其性能并未出現災難性下降。
真正的在策略強化學習
正如研究人員所指出的,訓練和推理之間的數值差異會隱式地將我們的在策略強化學習(on-policy RL)轉變為離策略強化學習(off-policy RL)。
當然,如果我們甚至無法從兩次相同的推理請求中獲得每一位都相同的結果,那么在訓練和推理之間獲得每一位都相同的結果也是不可能的。因此,確定性推理使我們能夠修改訓練堆棧,從而在采樣和訓練之間獲得每一位都相同的結果,最終實現真正的在策略強化學習。
我們在 Bigmath 上,使用 RLVR 設置進行了實驗,其中強化學習策略由 Qwen 2.5-VL instruct 8B 模型初始化,最大 rollout 長度為 4096。
如果我們不使用離策略校正(即重要度加權)進行訓練,我們的獎勵會在訓練中途崩潰;而添加離策略校正項則可以使訓練順利進行。但是,如果我們在采樣器和訓練器之間實現了每一位都相同的結果,我們就完全處于在策略狀態(即 KL 散度為 0),同樣可以順利地進行訓練。
我們還可以繪制采樣器和訓練器之間對數概率的 KL 散度,其中所有 3 次運行都表現出顯著不同的行為。在使用重要度加權運行時,KL 散度保持在 0.001 左右,并伴有偶爾的峰值。然而,在不使用重要度加權的情況下運行,最終會導致 KL 散度在大約與獎勵崩潰同一時間出現峰值。當然,在運行「真正的在策略強化學習」時,我們的 KL 散度始終保持為 0,這表明訓練策略和采樣策略之間不存在任何差異。
總結
現代軟件系統往往由多層抽象構成。在機器學習中,當我們遇到不確定性和一些微妙的數值差異時,人們往往會傾向于視而不見。
畢竟,我們的系統本來就是「概率性的」,再多一點不確定性又有何妨?單元測試掛掉時,把 atol/rtol 調大點有什么問題?訓練器和采樣器之間的對數概率差異,應該不是真正的 bug 吧?
我們拒絕這種消極心態。只要稍微多做一些努力,我們就能理解不確定性的根源,甚至真正解決它們!
我們希望這篇博文能為社區提供一套可靠的思路,幫助大家在推理系統中應對不確定性,并激勵更多人深入理解自己的系統。