改變LoRA的初始化方式,北大新方法PiSSA顯著提升微調(diào)效果
隨著大模型的參數(shù)量日益增長(zhǎng),微調(diào)整個(gè)模型的開(kāi)銷逐漸變得難以接受。
為此,北京大學(xué)的研究團(tuán)隊(duì)提出了一種名為 PiSSA 的參數(shù)高效微調(diào)方法,在主流數(shù)據(jù)集上都超過(guò)了目前廣泛使用的 LoRA 的微調(diào)效果。
- 論文: PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models
- 論文鏈接: https://arxiv.org/pdf/2404.02948.pdf
- 代碼鏈接: https://github.com/GraphPKU/PiSSA
如圖 1 所示,PiSSA (圖 1c) 在模型架構(gòu)上和 LoRA [1] 完全一致 (圖 1b),只是初始化 Adapter 的方式不同。LoRA 使用高斯噪聲初始化 A,使用 0 初始化 B。而 PiSSA 使用主奇異值和奇異向量 (Principal Singular values and Singular vectors) 來(lái)初始化 Adapter 來(lái)初始化 A 和 B。
圖 1)從左到右依次為全參數(shù)微調(diào)、LoRA、以及 PiSSA。藍(lán)色代表凍結(jié)的參數(shù),橘黃色代表可訓(xùn)練參數(shù)及它們的初始化方式。相比全參數(shù)微調(diào),LoRA 和 PiSSA 都大幅節(jié)省了可訓(xùn)練參數(shù)量。對(duì)于相同輸入,這三種方法的初始輸出完全相等。然而,PiSSA 凍結(jié)模型的次要成分,直接微調(diào)主成分(前 r 個(gè)奇異值和奇異向量);而 LoRA 可看作凍結(jié)模型的主要部分,而去微調(diào) noise 部分。
在不同的任務(wù)上對(duì)比 PiSSA、LoRA 的微調(diào)效果
研究團(tuán)隊(duì)使用 llama 2-7B、Mistral-7B 以及 Gemma-7B 作為基礎(chǔ)模型,通過(guò)微調(diào)提升它們的數(shù)學(xué)、代碼和對(duì)話能力。其中包括:在 MetaMathQA 上訓(xùn)練,在 GSM8K 和 MATH 數(shù)據(jù)集上驗(yàn)證模型的數(shù)學(xué)能力;在 CodeFeedBack 上訓(xùn)練,在 HumanEval 和 MBPP 數(shù)據(jù)集上驗(yàn)證模型的代碼能力;在 WizardLM-Evol-Instruct 上訓(xùn)練,在 MT-Bench 上驗(yàn)證模型的對(duì)話能力。從下表的實(shí)驗(yàn)結(jié)果可以看出,使用相同規(guī)模的可訓(xùn)練參數(shù),PiSSA 的微調(diào)效果顯著超越了 LoRA,甚至超越了全參數(shù)微調(diào)。
對(duì)比 PiSSA、LoRA 在不同的可訓(xùn)練參數(shù)量下微調(diào)的效果
研究團(tuán)隊(duì)在數(shù)學(xué)任務(wù)上對(duì)模型的可訓(xùn)練參數(shù)量和效果之間的關(guān)系進(jìn)行消融實(shí)驗(yàn)。從圖 2.1 發(fā)現(xiàn)在訓(xùn)練初期,PiSSA 的訓(xùn)練 loss 下降特別快,而 LoRA 存在不下降,甚至略有上升的階段。此外,PiSSA 的訓(xùn)練 loss 全程低于 LoRA,說(shuō)明對(duì)訓(xùn)練集擬合得更好;從圖 2.2、2.3、2.4 可以看出在每種 setting 下,PiSSA 的 loss 始終比 LoRA 低,準(zhǔn)確率始終比 LoRA 高,PiSSA 能夠使用更少的可訓(xùn)練參數(shù)追趕上全參數(shù)微調(diào)的效果。
圖 2.1) 當(dāng)秩為 1 時(shí) PiSSA、LoRA 在訓(xùn)練過(guò)程中的 loss。每幅圖的右上角是前 100 步迭代放大的曲線。其中 PiSSA 用橙色線表示,LoRA 用藍(lán)色線表示,全參數(shù)微調(diào)用綠線展示了最終的 loss 作為參考。秩為 [2,4,8,16,32,64,128] 時(shí)的現(xiàn)象與此一致,詳見(jiàn)文章附錄。
圖 2.2)使用秩為 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 的最終 training loss。
圖 2.3)使用秩為 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 微調(diào)的模型在 GSM8K 上的準(zhǔn)確率。
圖 2.4)使用秩為 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 微調(diào)的模型在 MATH 上的準(zhǔn)確率。
PiSSA 方法詳解
受到 Intrinsic SAID [2]“預(yù)訓(xùn)練大模型參數(shù)具有低秩性” 的啟發(fā),PiSSA 對(duì)預(yù)訓(xùn)練模型的參數(shù)矩陣
進(jìn)行奇異值分解,其中前 r 個(gè)奇異值和奇異向量用來(lái)初始化適配器 (adapter) 的兩個(gè)矩陣
和
,
;剩余的奇異值和奇異向量用來(lái)構(gòu)造殘差矩陣
,使得
。因此,適配器中的參數(shù)包含了模型的核心參數(shù),而殘差矩陣中的參數(shù)是修正參數(shù)。通過(guò)微調(diào)參數(shù)量較小的核心適配器 A、B,凍結(jié)參數(shù)量較大的殘差矩陣
,就達(dá)成了用很少的參數(shù)近似全參數(shù)微調(diào)的效果。
盡管同樣受到 Intrinsic SAID [1] 啟發(fā),PiSSA 和 LoRA 背后的原理卻截然不同。
LoRA 認(rèn)為大模型微調(diào)前后矩陣的變化 △W 具有很低的本征秩 r,因此通過(guò)
和
相乘得到的低秩矩陣來(lái)模擬模型的變化 △W。初始階段,LoRA 使用高斯噪聲初始化 A,使用 0 初始化 B,因此
,以此保證模型初始能力沒(méi)有變化,并微調(diào) A 和 B 實(shí)現(xiàn)對(duì) W 進(jìn)行更新。與此相比,PiSSA 不關(guān)心 △W,而是認(rèn)為 W 具有很低的本征秩 r。因此直接對(duì) W 進(jìn)行奇異值分解,分解成主成分 A、B,以及殘差項(xiàng)
,使得
。假設(shè) W 的奇異值分解為
,A、B 使用 SVD 分解后奇異值最大的 r 個(gè)奇異值、奇異向量進(jìn)行初始化:
殘差矩陣使用其余的奇異值、奇異向量進(jìn)行初始化:
PiSSA 直接對(duì) W 的低秩主成分 A、B 進(jìn)行微調(diào),凍結(jié)次要的修正項(xiàng)。相比 LoRA 用高斯噪聲以及 0 初始化適配器參數(shù)、凍結(jié)核心模型參數(shù),PiSSA 收斂更快、效果更好。
PiSSA 的發(fā)音類似 “披薩”(pizza)--- 如果把整個(gè)大模型類比為一個(gè)完整的披薩,PiSSA 切掉其中一角,而且是餡料最豐富的一角(主奇異值、奇異向量),重新烘焙(在下游任務(wù)上微調(diào))成喜歡的口味。
由于 PiSSA 采用了和 LoRA 完全相同的架構(gòu),其可以作為 LoRA 的一種可選初始化方式,在 peft 包中很方便的進(jìn)行修改和調(diào)用 (如以下代碼所示)。相同的架構(gòu)也使得 PiSSA 繼承了大多數(shù) LoRA 的優(yōu)點(diǎn),如:對(duì)殘差模型使用 4bit 量化 [3],減小訓(xùn)練開(kāi)銷;微調(diào)完成后適配器能合并進(jìn)殘差模型,不改變推理過(guò)程的模型架構(gòu);無(wú)需分享完整模型參數(shù),只需要分享參數(shù)量很少的 PiSSA 模塊,使用者直接加載 PiSSA 模塊就能自動(dòng)進(jìn)行奇異值分解以及賦值;一個(gè)模型可以同時(shí)使用多個(gè) PiSSA 模塊等等。一些對(duì) LoRA 方法的改進(jìn),也能與 PiSSA 進(jìn)行結(jié)合:比如不固定每層的秩,通過(guò)學(xué)習(xí)找到最佳的秩 [4];用 PiSSA 指導(dǎo)的更新 [5],從而突破秩的限制等等。
# 在 peft 包中 LoRA 的初始化方式后面增加了一種 PiSSA 初始化選項(xiàng):
if use_lora:
nn.init.normal_(self.lora_A.weight, std=1 /self.r)
nn.init.zeros_(self.lora_B.weight)
elif use_pissa:
Ur, Sr, Vr = svd_lowrank (self.base_layer.weight, self.r, niter=4)
# 注意:由于 self.base_layer.weight 的維度是 (out_channel,in_channel, 所以 AB 的順序相比圖示顛倒了一下)
self.lora_A.weight = torch.diag (torch.sqrt (Sr)) @ Vh.t ()
self.lora_B.weight = Ur @ torch.diag (torch.sqrt (Sr))
self.base_layer.weight = self.base_layer.weight - self.lora_B.weight @ self.lora_A.weight
對(duì)比高中低奇異值微調(diào)效果實(shí)驗(yàn)
為了驗(yàn)證使用不同大小奇異值、奇異向量初始化適配器對(duì)模型的影響,研究人員分別使用高、中、低奇異值初始化 LLaMA 2-7B、Mistral-7B-v0.1、Gemma-7B 的適配器,然后在 MetaMathQA 數(shù)據(jù)集上進(jìn)行微調(diào),實(shí)驗(yàn)結(jié)果展示在圖 3 中。從圖中可以看出,使用主要奇異值初始化的方法訓(xùn)練損失最小,在 GSM8K 和 MATH 驗(yàn)證集上的準(zhǔn)確率更高。這一現(xiàn)象驗(yàn)證了微調(diào)主要奇異值、奇異向量的有效性。
圖 3)從左到右依次為訓(xùn)練 loss、在 GSM8K 上的準(zhǔn)確率、在 MATH 上的準(zhǔn)確率。其中藍(lán)色表示最大奇異值、橙色表示中等奇異值、綠色表示最小奇異值。
快速奇異值分解
PiSSA 繼承了 LoRA 的優(yōu)點(diǎn),使用起來(lái)方便,效果超越 LoRA。代價(jià)是在初始化階段,需要對(duì)模型進(jìn)行奇異值分解。雖然僅需要在初始化時(shí)分解一次,但是仍然可能需要幾分鐘甚至幾十分鐘的開(kāi)銷。因此,研究人員使用一種快速奇異值分解 [6] 方法替代標(biāo)準(zhǔn)的 SVD 分解,通過(guò)下表的實(shí)驗(yàn)可以看出,僅需幾秒鐘的時(shí)間,就能逼近標(biāo)準(zhǔn) SVD 分解的訓(xùn)練集擬合效果。其中 Niter 表示迭代次數(shù),Niter 越大,時(shí)間越久但是誤差越小。Niter = ∞表示標(biāo)準(zhǔn) SVD。表格中的平均誤差表示快速奇異值分解與標(biāo)準(zhǔn) SVD 得到的 A、B 之間的平均 L_1 距離。
總結(jié)與展望
本工作對(duì)預(yù)訓(xùn)練模型的權(quán)重進(jìn)行奇異值分解,通過(guò)將其中最重要的參數(shù)用于初始化一個(gè)名為 PiSSA 的適配器,微調(diào)這個(gè)適配器來(lái)近似微調(diào)完整模型的效果。實(shí)驗(yàn)表明,PiSSA 比 LoRA 收斂更快,最終效果更好,唯一的代價(jià)僅是需要幾秒的 SVD 初始化過(guò)程。
那么,您愿意為了更好的訓(xùn)練效果,多花幾秒鐘時(shí)間,一鍵更改 LoRA 的初始化為 PiSSA 嗎?
本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心
