對抗訓(xùn)練本質(zhì)是為了提高模型的魯棒性,一般情況下在傳統(tǒng)訓(xùn)練的基礎(chǔ)上,添加了對抗訓(xùn)練是可以進(jìn)一步提升效果的,在比賽打榜、調(diào)參時是非常重要的一個trick。對抗訓(xùn)練在CV領(lǐng)域內(nèi)非常常用,那么在NLP領(lǐng)域如何使用呢?本文簡單總結(jié)幾種常用的對抗訓(xùn)練方法。
公式理解:
最大化擾動:挑選一個能使得模型產(chǎn)生更大損失(梯度較大)的擾動量,作為攻擊;
最小化損失:根據(jù)最大的擾動量,添加到輸入樣本后,朝著最小化含有擾動的損失(梯度下降)方向更新參數(shù);
這個被構(gòu)造出來的“對抗樣本”并不能具體對應(yīng)到某個單詞,因此,反過來在推理階段是沒有辦法通過修改原始輸入得到這樣的對抗樣本。
對抗訓(xùn)練有兩個作用,一是 提高模型對惡意攻擊的魯棒性 ,二是 提高模型的泛化能力 。
在CV任務(wù),根據(jù)經(jīng)驗性的結(jié)論,對抗訓(xùn)練往往會使得模型在非對抗樣本上的表現(xiàn)變差,然而神奇的是,在NLP任務(wù)中,模型的泛化能力反而變強(qiáng)了。
常用的幾種對抗訓(xùn)練方法有FGSM、FGM、PGD、FreeAT、YOPO、FreeLB、SMART。本文暫時只介紹博主常用的3個方法,分別是 FGM 、 PGD 和 FreeLB 。
具體實現(xiàn)時,不同的對抗方法會有差異,但是 從訓(xùn)練速度和代碼編輯難易程度的角度考慮,推薦使用FGM和迭代次數(shù)較少的PGD 。
一、FGM算法
FGM的代碼量很少,只需要自行實現(xiàn)簡單的類即可:
importtorch classFGM(): def__init__(self,model): self.model=model self.backup={}#用于保存模型擾動前的參數(shù) defattack( self, epsilon=1., emb_name='word_embeddings'#emb_name表示模型中embedding的參數(shù)名 ): ''' 生成擾動和對抗樣本 ''' forname,paraminself.model.named_parameters():#遍歷模型的所有參數(shù) ifparam.requires_gradandemb_nameinname:#只取wordembedding層的參數(shù) self.backup[name]=param.data.clone()#保存參數(shù)值 norm=torch.norm(param.grad)#對參數(shù)梯度進(jìn)行二范式歸一化 ifnorm!=0andnottorch.isnan(norm):#計算擾動,并在輸入?yún)?shù)值上添加擾動 r_at=epsilon*param.grad/norm param.data.add_(r_at) defrestore( self, emb_name='word_embeddings'#emb_name表示模型中embedding的參數(shù)名 ): ''' 恢復(fù)添加擾動的參數(shù) ''' forname,paraminself.model.named_parameters():#遍歷模型的所有參數(shù) ifparam.requires_gradandemb_nameinname:#只取wordembedding層的參數(shù) assertnameinself.backup param.data=self.backup[name]#重新加載保存的參數(shù)值 self.backup={}
在訓(xùn)練時,只需要額外添加5行代碼:
fgm=FGM(model)#(#1)初始化 forbatch_input,batch_labelindata: loss=model(batch_input,batch_label)#正常訓(xùn)練 loss.backward()#反向傳播,得到正常的grad #對抗訓(xùn)練 fgm.attack()#(#2)在embedding上添加對抗擾動 loss_adv=model(batch_input,batch_label)#(#3)計算含有擾動的對抗樣本的loss loss_adv.backward()#(#4)反向傳播,并在正常的grad基礎(chǔ)上,累加對抗訓(xùn)練的梯度 fgm.restore()#(#5)恢復(fù)embedding參數(shù) #梯度下降,更新參數(shù) optimizer.step() model.zero_grad()
二、PGD算法
Project Gradient Descent(PGD)是一種迭代攻擊算法,相比于普通的FGM 僅做一次迭代,PGD是做多次迭代,每次走一小步,每次迭代都會將擾動投射到規(guī)定范圍內(nèi)。形式化描述為:
代碼實現(xiàn)如下所示:
importtorch classPGD(): def__init__(self,model): self.model=model self.emb_backup={} self.grad_backup={} defattack(self,epsilon=1.,alpha=0.3,emb_name='word_embeddings',is_first_attack=False): forname,paraminself.model.named_parameters(): ifparam.requires_gradandemb_nameinname: ifis_first_attack: self.emb_backup[name]=param.data.clone() norm=torch.norm(param.grad) ifnorm!=0andnottorch.isnan(norm): r_at=alpha*param.grad/norm param.data.add_(r_at) param.data=self.project(name,param.data,epsilon) defrestore(self,emb_name='word_embeddings'): forname,paraminself.model.named_parameters(): ifparam.requires_gradandemb_nameinname: assertnameinself.emb_backup param.data=self.emb_backup[name] self.emb_backup={} defproject(self,param_name,param_data,epsilon): r=param_data-self.emb_backup[param_name] iftorch.norm(r)>epsilon: r=epsilon*r/torch.norm(r) returnself.emb_backup[param_name]+r defbackup_grad(self): forname,paraminself.model.named_parameters(): ifparam.requires_grad: self.grad_backup[name]=param.grad.clone() defrestore_grad(self): forname,paraminself.model.named_parameters(): ifparam.requires_grad: param.grad=self.grad_backup[name]
pgd=PGD(model) K=3 forbatch_input,batch_labelindata: #正常訓(xùn)練 loss=model(batch_input,batch_label) loss.backward()#反向傳播,得到正常的grad pgd.backup_grad() #累積多次對抗訓(xùn)練——每次生成對抗樣本后,進(jìn)行一次對抗訓(xùn)練,并不斷累積梯度 fortinrange(K): pgd.attack(is_first_attack=(t==0))#在embedding上添加對抗擾動,firstattack時備份param.data ift!=K-1: model.zero_grad() else: pgd.restore_grad() loss_adv=model(batch_input,batch_label) loss_adv.backward()#反向傳播,并在正常的grad基礎(chǔ)上,累加對抗訓(xùn)練的梯度 pgd.restore()#恢復(fù)embedding參數(shù) #梯度下降,更新參數(shù) optimizer.step() model.zero_grad()
三、FreeLB算法
很明顯找到FreeLB與PGD的區(qū)別在于累積的方式:
FreeLB:通過對 K K K 次梯度的平均累積作為擾動更新
PGD:只取最后一次的梯度進(jìn)行更新
實現(xiàn)流程如下圖所示:
審核編輯:劉清
-
算法
+關(guān)注
關(guān)注
23文章
4626瀏覽量
93157 -
nlp
+關(guān)注
關(guān)注
1文章
489瀏覽量
22066
原文標(biāo)題:煉丹之道 | NLP中的對抗訓(xùn)練
文章出處:【微信號:zenRRan,微信公眾號:深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論