谷歌提出遞歸混合模型(MoR):通過參數共享和自適應計算提升Transformer效率
Transformer架構從根本上改變了人工智能的格局。從支撐大型語言模型(LLMs)的對話能力到實現實時語言翻譯,Transformer已成為現代自然語言處理(NLP)應用的核心。然而,其成功背后伴隨著巨大的代價:訓練和部署往往需要超大規模數據中心的計算和內存資源。
這種計算需求給人工智能發展帶來了巨大的經濟壓力。雖然科技巨頭能夠承擔大規模GPU集群的成本,但小型組織和獨立研究人員往往難以跟上步伐。因此,提高Transformer模型的效率已成為關鍵目標——不僅是為了提升AI能力,也是為了降低與訓練和部署這些模型相關的財務和資源成本。
為應對這一挑戰,已經出現了幾種創新方法:
- 遞歸Transformer:巧妙地在多個層中重用同一組權重的模型,大幅減少參數數量
- 參數共享技術:如層綁定(layer tying)等在層間共享權重以縮小模型規模的方法
- 提前退出機制:允許模型對簡單輸入跳過部分層,節省寶貴的計算資源
雖然這些方法各自解決了效率的特定方面,但它們通常要么專注于減少參數,要么專注于節省計算資源,很少能同時兼顧兩者。而谷歌的遞歸混合模型(MoR) 正是在這一背景下應運而生。
MoR 通過將參數共享和自適應計算統一到單一框架中,實現了范式轉變。它不僅僅是另一個增量改進——更是對如何讓Transformer既高效又強大的根本性重新思考。
理解遞歸混合模型
MoR示意圖中不同顏色表示對每個token應用的遞歸次數不同。
MoR 的核心基于一個看似簡單卻意義深遠的見解:不同的token需要不同程度的“思考”來有效處理。就像人類會在復雜概念上投入更多腦力,而對簡單詞匯則一帶而過一樣,MoR 允許Transformer按token自適應分配計算資源。
該架構通過三個相互關聯的機制實現這一目標:
- 參數高效的遞歸塊:有效重用權重
- 動態路由系統:基于token復雜度分配計算資源
- 優化的鍵值緩存:減少內存開銷
參數共享策略
基礎:參數共享的工作原理
傳統Transformer為每個層使用獨特的參數,使得模型深度與參數數量呈線性關系。MoR 通過引入遞歸塊打破了這一范式——遞歸塊是可在多個步驟中重用的共享參數集。
這就像用一套單獨的刀叉勺吃飯,而不是用一把勺子完成整個用餐過程。傳統Transformer依賴不同的工具(層)執行每個操作,而遞歸塊則在多個步驟中有效地重用同一工具。
四種策略:對比分析
MoR 探索了四種不同的參數共享策略,每種策略都有其獨特特點:
1. 循環策略(Cycle Strategy)
在循環方法中,權重在層間循環共享,形成重復模式。雖然簡單直接,但這種方法可能會造成信息瓶頸,因為早期層需要同時處理初始處理和復雜推理任務。
這里的信息瓶頸源于在所有遞歸步驟中循環重用固定、相同的參數集。這迫使相同的共享層處理從初始特征提取到遞歸后期復雜迭代推理的廣泛計算需求。這種僵化、無差別的重用限制了模型學習不同處理深度所需的多樣化或專門化特征的能力。因此,共享參數塊成為瓶頸,阻礙了信息在網絡中高效且專門化的轉換。這正是“中間循環”(Middle-Cycle)等策略被開發的原因——通過允許第一層和最后一層使用獨特參數,在效率和表征靈活性之間取得平衡。
2. 序列策略(Sequence Strategy)
序列策略在進入下一層之前連續重用同一層。這允許在每種層類型中實現更深的專門化,但可能限制模型在不同處理類型之間切換的能力。
3. 中間循環策略(Middle-Cycle Strategy)——最優選擇
這正是MoR 的優勢所在。中間循環策略為關鍵的輸入處理層和輸出生成層保留獨特權重,而在中間層之間循環共享權重。這種策略通過在最重要的位置保持完整能力,同時在最有益的地方實現參數共享,平衡了效率和性能。
4. 中間序列策略(Middle-Sequence Strategy)
與中間循環策略類似,但中間層采用序列共享方式。雖然可行,但其性能不如中間循環策略。
為什么中間循環策略占優
實證結果一致表明中間循環策略是最優的。這反映了對Transformer信息處理方式的深刻理解:
- 輸入層需要獨特參數來處理多樣化的原始輸入表征
- 中間層可以共享參數,因為它們執行相似的抽象推理任務
- 輸出層需要獨特參數將抽象表征映射回具體輸出
這種架構選擇減少了約三分之二的參數,同時保持了具有競爭力的性能——這是效率方面的顯著成就。
路由機制:自適應token級計算
固定計算的問題
傳統Transformer對每個token施加相同的計算量,無論其復雜度如何。處理“the”和處理“量子糾纏”所獲得的計算預算相同。這種統一方法本質上是浪費的。
將全部處理能力用于像“.”、“The”或“—”這樣的token可能被認為是不必要的——就像上面提到的類比一樣。
MoR 通過復雜的路由機制解決了這一問題,該機制根據token的復雜度動態為單個token分配不同的遞歸深度。該系統包括兩種主要路由策略:
專家選擇路由(Expert-Choice Routing):層級token選擇
專家選擇路由將每個遞歸深度視為一個“專家”,由專家選擇最具挑戰性的token進行處理。其工作原理如下:
- token評分:在每個遞歸步驟r,輕量級路由器檢查隱藏狀態并計算每個token的復雜度分數
- Top-k選擇:只有評分最高的k個token進入下一個遞歸深度
- 層級過濾:token必須在深度r被選中,才有資格進入深度r+1
這形成了一個自然的層級結構:
- 所有token都經過第一次遞歸(基礎處理)
- 只有復雜token進入第二次遞歸(中間推理)
- 最具挑戰性的token進入最深層遞歸(高級推理)
這里,信息泄露是專家選擇路由中的一個問題,即選擇top-k專家的過程無意中使用了未來token的信息,違反了因果關系。這可能會影響模型在推理時的性能,解決方法是使用輔助機制或正則化來確保因果處理。
token選擇路由(Token-Choice Routing):基于承諾的處理
token選擇路由采用不同的方法,從一開始就為每個token分配完整的遞歸路徑:
- 初始評估:在第一次遞歸步驟中,路由器分析每個token并將其分配給特定專家(遞歸深度)
- 完整路徑執行:token隨后沿著其分配的路徑完成所有指定的遞歸步驟
- 無信息泄露:與專家選擇路由不同,未來token不會影響早期路由決策,保持因果一致性
這種方法避免了專家選擇路由中可能出現的因果違規,但需要仔細的負載均衡,以防止某些專家負擔過重而其他專家閑置。
輔助組件:使路由工作
兩種路由策略都需要復雜的輔助機制:
對于專家選擇路由:
- 輔助損失:通過懲罰不良路由決策,確保路由器做出可靠選擇
- 線性路由器架構:輕量級、高效的路由器,無需大量計算開銷即可確定token復雜度
對于token選擇路由:
- 平衡損失:通過鼓勵token在專家間均勻分布,防止負載不平衡
- Z損失:通過防止路由權重變得過于極端來穩定訓練
內存優化:KV緩存策略
KV緩存策略:矩陣中的每個方塊表示一個token(行)是否關注另一個token的緩存鍵(列)。在“遞歸級KV緩存”(上)中,僅緩存每個遞歸步驟中當前選中(未被丟棄)的token的鍵(藍色),且注意力僅限于這些條目。在“遞歸KV共享”(下)中,所有先前token的鍵在第一次遞歸步驟中被緩存(紫色),并在后續遞歸步驟中共享以進行注意力計算。
KV緩存的挑戰
鍵值(KV)緩存對Transformer效率至關重要,它存儲計算出的注意力鍵和值以避免重復計算。然而,傳統緩存方法可能消耗大量內存,尤其是對于長序列。MoR 引入了兩種創新緩存策略,與其遞歸架構無縫協作。
遞歸級緩存(Recursion-wise Caching):選擇性內存管理
在遞歸級緩存中,MoR 為每個遞歸步驟維護單獨的KV緩存,但巧妙之處在于:只有在每個深度實際被處理的token,其KV對才會在該級別被緩存。
這創建了一個自然的層級內存結構:
- 深度1:所有token都有緩存的KV對(完整序列)
- 深度2:只有復雜token有緩存的KV對(縮減序列)
- 深度3:只有最復雜的token有緩存的KV對(最小序列)
內存節省非常顯著。對于具有N_r次遞歸的模型,總KV內存使用量降至約為普通Transformer的??(N_r + 1)/(2N_r)?
?。對于3次遞歸模型,這意味著KV內存使用量減少67%。
遞歸KV共享(Recursive KV Sharing):最大化內存效率
遞歸KV共享通過在所有后續遞歸深度中重用第一次遞歸步驟的KV對,進一步優化了內存效率。這種方法利用了所有token都經過第一個遞歸塊這一事實。
其工作原理如下:
- 第一次遞歸:所有token生成并緩存其KV對
- 后續遞歸:重用這些緩存的KV對,僅更新查詢向量
- 注意力計算:后期遞歸可以使用共享的KV對關注完整上下文
這種策略提供了最大的內存效率,將KV內存使用量減少到原始需求的??1/N_r?
?。對于3次遞歸模型,這意味著內存減少83%。
其代價是性能略有下降,因為鍵和值表征不會隨遞歸深度演變。然而,實證結果表明這種下降很小,使其成為內存受限環境的理想選擇。
架構整合:各部分如何協同工作
統一框架
MoR 的成功不在于任何單一創新,而在于它如何將多種效率技術整合到一個連貫的架構中。讓我們追蹤token序列如何流經MoR 模型:
步驟1:初始處理
所有token通過獨特的第一層進入,接受標準Transformer處理。這確保了強大的輸入表征,不受參數共享的限制。
步驟2:路由決策
路由器分析每個token的隱藏狀態并做出路由決策:
- 專家選擇:計算復雜度分數并選擇top-k token進入下一次遞歸
- token選擇:為每個token分配特定的遞歸深度路徑
步驟3:遞歸處理
token沿著其分配的遞歸路徑流動:
- 簡單token(如“the”、“and”)可能在1-2次遞歸后完成處理
- 復雜token(如技術術語、命名實體)接受3-4次遞歸
- 每次遞歸應用相同的共享參數塊,但處理的token集逐漸篩選
步驟4:內存管理
KV緩存策略啟動:
- 遞歸級:為每個活躍token集維護單獨的緩存
- 遞歸共享:在所有遞歸深度中重用初始KV對
步驟5:輸出生成
處理后的token通過獨特的最后一層進行輸出生成,確保盡管中間層有參數共享,仍能獲得高質量結果。
計算效率提升
這種整合方法在多個維度上實現了效率提升:
1. 參數效率:通過遞歸塊減少約67%的獨特參數
與傳統Transformer架構相比,MoR模型在不犧牲性能的情況下顯著減少了參數數量。標準的Vanilla Transformer有3.15億參數,而具有3次遞歸的MoR模型(Expert Cache M-Cyc 3)僅使用1.18億獨特參數——減少了67%。盡管如此,MoR在驗證困惑度和少樣本準確率方面始終達到或超過基線。例如,在相同訓練計算量下,具有2次遞歸的MoR(M-Cyc 2)實現了更高的少樣本準確率(43.1%),超過3.15億參數的Vanilla基線(42.3%),同時使用的參數幾乎少了50%。在更大規模(3.6億+)下,MoR繼續與Vanilla模型匹敵或超越,同時保持約三分之一的獨特參數。
2. 計算效率:通過選擇性token處理減少FLOPs
3. 內存效率:大幅降低KV緩存需求
4. 吞吐量提升:連續的深度批處理消除了GPU空閑時間
實際意義和現實應用
部署場景
MoR 的效率提升使其在多種部署場景中特別有價值:
邊緣計算
憑借參數減少和內存效率,MoR 模型可以在資源受限的設備上運行,而傳統大型模型在這些設備上是不可能運行的。
云成本優化
更小的內存占用和更高的吞吐量直接轉化為AI服務提供商更低的云計算成本。
實時應用
吞吐量的提升使實時處理成為可能,如實時翻譯、交互式聊天機器人和動態內容生成。
訓練效率
MoR 的優勢不僅限于推理,還延伸到訓練:
- 減少內存需求:訓練期間降低峰值內存使用
- 改進FSDP效率:參數共享與全分片數據并行(FSDP)訓練配合極佳
- 更快收斂:自適應計算可帶來更高效的學習
技術挑戰與解決方案
專家選擇路由中的因果性
專家選擇路由面臨一個根本挑戰:訓練時,路由器在做選擇決策時可以看到未來的token,但推理時卻不能。這種“信息泄露”可能導致推理性能不佳。
MoR 嘗試通過以下方式解決:
- 輔助損失:訓練路由器僅基于可用信息做出良好決策
- 輔助路由器:專門為推理時行為訓練的獨立路由器
token選擇路由中的負載均衡
token選擇路由可能遭受負載不平衡問題,即某些專家(深度遞歸)接收大量token,而其他專家接收很少。這導致資源利用效率低下。
解決方案包括:
- 平衡損失:明確懲罰不均衡的token分布
- 無損失方法:使用路由器偏置鼓勵均衡分配
- 動態容量:根據實際負載調整專家容量
KV緩存一致性
當token在遞歸處理中提前退出時,其KV對可能在后續token的注意力計算中缺失。MoR 通過以下方式解決這一問題:
- 結構化緩存:在遞歸深度間維持一致的緩存結構
- 選擇性注意力:將注意力限制在可用的緩存對
- 共享上下文:確保所有token都能訪問必要的上下文信息
展望未來:未來意義
縮放定律與MoR
MoR 表現出與傳統Transformer不同的縮放行為。研究表明,MoR 從增大模型規模中獲益比從增加訓練數據中更多,這為遞歸架構提出了新的計算最優縮放策略。
這對我們進行模型開發的方式具有深遠意義:
- 資源分配:傾向于更大的模型和更短的訓練周期
- 架構設計:優先考慮參數效率而非原始參數數量
- 訓練策略:關注模型容量而非數據量
整合機會
MoR 的模塊化設計使其與其他效率技術兼容:
- 量化:遞歸塊可以被量化以進一步節省內存
- 剪枝:結構化剪枝可以進一步降低計算需求
在這種情況下,剪枝可能不是減少計算的有效方法。由于遞歸Transformer已經通過多次重用同一層來節省內存,引入剪枝可能會干擾這一機制。在某種意義上,剪枝和遞歸可能相互沖突。我非常好奇看到這方面的實證結果。
- 蒸餾:知識蒸餾可以從更大的模型轉移能力
多模態擴展
遞歸架構本質上與模態無關。未來的工作可以將MoR 擴展到:
- 視覺Transformer:對不同圖像區域的自適應處理
- 音頻處理:對不同音頻片段的可變計算
- 多模態模型:處理文本、圖像和音頻的統一架構
結論
遞歸混合模型(MoR) 提供的不僅僅是另一種效率機制——它代表了如何構建強大且經濟可持續的AI系統的戰略性重新思考。通過在單一框架中結合參數共享、自適應計算和內存優化,MoR 為在不犧牲性能的情況下降低計算成本開辟了道路。
其意義可能超出學術研究。MoR 的效率提升可能有助于降低部署大型語言模型的財務和基礎設施要求。這使得小型組織、研究實驗室和初創公司更有可能使用原本無法獲得的先進AI能力。
隨著AI發展的繼續推進,MoR 表明創新并不總是需要擴大規模——有時,它來自優化資源的使用方式。從這個意義上說,MoR 支持向經濟上可擴展的AI轉變,在這種轉變中,性能和資源效率是一致的。
參考文獻
Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation
本文轉載自????AIGC深一度??
