文章轉(zhuǎn)載于微信公眾號:GiantPandaCV
作者: Pui_Yeung
前言
??量化感知訓練(Quantization Aware Training )是在模型中插入偽量化模塊(fake/_quant module)模擬量化模型在推理過程中進行的舍入(rounding)和鉗位(clamping)操作,從而在訓練過程中提高模型對量化效應(yīng)的適應(yīng)能力,獲得更高的量化模型精度 。在這個過程中,所有計算(包括模型正反向傳播計算和偽量化節(jié)點計算)都是以浮點計算實現(xiàn)的,在訓練完成后才量化為真正的int8模型。??
Pytorch官方從1.3版本開始提供量化感知訓練API,只需修改少量代碼即可實現(xiàn)量化感知訓練。目前torch.quantization仍處于beta階段,不保證API前向、后向兼容性。以下介紹基于Pytorch 1.7,其他版本可能會有差異。
Pytorch量化感知訓練流程
??首先給出提供一個可運行demo,直觀了解量化感知訓練的6個步驟,再進行詳細的介紹
importtorchfromtorch.quantizationimportprepare_qat,get_default_qat_qconfig,convertfromtorchvision.modelsimportquantization# Step1:修改模型#這里直接使用官方修改好的MobileNetV2,下文會對修改點進行介紹model=quantization.mobilenet_v2()print("originalmodel:")print(model)# Step2:折疊算子#fuse_model()在training或evaluate模式下算子折疊結(jié)果不同,#對于QAT,需確保在training狀態(tài)下進行算子折疊assertmodel.trainingmodel.fuse_model()print("fusedmodel:")print(model)#Step3:指定量化方案#通過給模型實例增加一個名為"qconfig"的成員變量實現(xiàn)量化方案的指定#backend目前支持fbgemm和qnnpackBACKEND="fbgemm"model.qconfig=get_default_qat_qconfig(BACKEND)# Step4:插入偽量化模塊prepare_qat(model,inplace=True)print("modelwithobservers:")print(model)#正常的模型訓練,無需修改代碼# Step5:實施量化model.eval()#執(zhí)行convert函數(shù)前,需確保模型在evaluate模式model_int8=convert(model)print("quantizedmodel:")print(model_int8)# Step6:int8模型推理#指定與qconfig相同的backend,在推理時使用正確的算子torch.backends.quantized.engine=BACKEND#目前Pytorch的int8算子只支持CPU推理,需確保輸入和模型都在CPU側(cè)#輸入輸出仍為浮點數(shù)fp32_input=torch.randn(1,3,224,224)y=model_int8(fp32_input)print("output:")print(y)
Step1:修改模型
??Pytorch下需要適當修改模型才能進行量化感知訓練,以下以常用的MobileNetV2為例。官方已修改好的MobileNetV2的代碼,詳見這里(https://github.com/pytorch/vi...)
修改主要包括3點,以下摘取相應(yīng)的代碼進行介紹:
(1)在模型輸入前加入QuantStub(),在模型輸出后加入DeQuantStub()。目的是將輸入從fp32量化為int8,將輸出從int8反量化為fp32。模型的/_/_init/_/_()和forward()修改為:
classQuantizableMobileNetV2(MobileNetV2):def__init__(self,*args,**kwargs):"""MobileNetV2mainclassArgs:InheritsargsfromfloatingpointMobileNetV2"""super(QuantizableMobileNetV2,self).__init__(*args,**kwargs)self.quant=QuantStub()self.dequant=DeQuantStub()defforward(self,x):x=self.quant(x)x=self._forward_impl(x)x=self.dequant(x)returnx
(2)對加法等操作加入偽量化節(jié)點。因為int8數(shù)值進行加法運算容易超出數(shù)值范圍,所以不是直接進行計算,而是進行反量化->計算->量化的操作。以InvertedResidual的修改為例:
classQuantizableInvertedResidual(InvertedResidual):def__init__(self,*args,**kwargs):super(QuantizableInvertedResidual,self).__init__(*args,**kwargs)#加法的偽量化節(jié)點需要記錄所經(jīng)過該節(jié)點的數(shù)值的范圍,因此需要實例化一個對象self.skip_add=nn.quantized.FloatFunctional()defforward(self,x):ifself.use_res_connect:#普通版本MobileNetV2的加法#returnx+self.conv(x)#量化版本MobileNetV2的加法returnself.skip_add.add(x,self.conv(x))else:returnself.conv(x)
(3)將ReLU6替換為ReLU。MobileNet V2使用ReLU6的原因是對ReLU的輸出范圍進行截斷以緩解量化為fp16模型時的精度下降。因為int8量化本身就能確定截斷閾值,所以將ReLU6替換為ReLU以去掉截斷閾值固定為6的限制。官方的修改代碼在建立網(wǎng)絡(luò)后通過/_replace/_relu()將MobileNetV2中的ReLU6替換為ReLU:
model=QuantizableMobileNetV2(block=QuantizableInvertedResidual,**kwargs)_replace_relu(model)
Step2:算子折疊
??算子折疊是將模型的多個層合并成一個層,一般用來減少計算量和加速推理。對于量化感知訓練而言,算子折疊作用是將模型變“薄”,減少中間計算過程的誤差積累。
??以下比較有無算子折疊的結(jié)果(上:無算子折疊,下:有算子折疊,打印執(zhí)行prepare/_qat()后的模型)
?如果不進行算子折疊,每個Conv-BN-ReLU單元一共會插入4個FakeQuantize模塊。而進行算子折疊后,原來Conv2d()被ConvBnReLU2d()代替(3層合并到了第1層),BatchNorm2d()和ReLU()被Inentity()代替(僅作為占位),最終只插入了2個FakeQuantize模塊。
FakeQuantize模塊的減少意味著推理過程中進行的量化-反量化的次數(shù)減少,有利于減少量化帶來的性能損失。
??算子折疊由實現(xiàn)torch.quantization.fuse/_modules()。目前存在的比較遺憾的2點:
??算子折疊不能自動完成,只能由程序員手工指定要折疊的子模型。以torchvision.models.quantization.mobilenet/_v2()中實現(xiàn)的算子折疊函數(shù)為例:
deffuse_model(self):#遍歷模型內(nèi)的每個子模型,判斷類型并進行相應(yīng)的算子折疊forminself.modules():iftype(m)==ConvBNReLU:fuse_modules(m,['0','1','2'],inplace=True)iftype(m)==QuantizableInvertedResidual:#調(diào)用子模塊實現(xiàn)的fuse_model(),間接調(diào)用fuse_modules()m.fuse_model()
??能折疊的算子組合有限。目前支持的算子組合為:ConV + BN、ConV + BN + ReLU、Conv + ReLU、Linear + ReLU、BN + ReLU。如果嘗試折疊ConvTranspose2d、ReLU6等不支持的算子則會報錯。
Step3:指定量化方案
??目前支持fbgemm和qnnpack兩種backend方案。官方推薦x86平臺使用fbgemm方案,ARM平臺使用qnnpack方案。??量化方案通過如下方法指定
model.qconfig=get_default_qat_qconfig(backen)#或model.qconfig=get_default_qat_qconfig(backen)
??即通過給model增加一個名為qconfig為成員變量并賦值。
??量化方案可通過設(shè)置qconfig自定義,本文暫不討論。
Step4:插入偽量化模塊??
通過執(zhí)行prepare/_qat(),實現(xiàn)按qconfig的配置方案給每個層增加FakeQuantize()模塊?每個FakeQuantize()模塊內(nèi)包含相應(yīng)的Observer()模塊,在模型執(zhí)行forward()時自動記錄數(shù)值,供實施量化時使用。
Step5:實施量化??
完成訓練后,通過執(zhí)行convert()轉(zhuǎn)換為真正的int8量化模型。?完成轉(zhuǎn)換后,F(xiàn)akeQuantize()模塊被去掉,原來的ConvBNReLU2d()算子被替換為QuantizedConvReLU2d()算子。
Step6:int8模型推理
??int8模型的調(diào)用方法與普通的fp32模型的調(diào)用無異。需要注意的是,目前量化算子僅支持CPU計算,故須確保輸入和模型都在CPU側(cè)。
??若模型推理中出現(xiàn)報錯,一般是前面的步驟存在設(shè)置不當,參考常見問題第1點。
常見問題
(1) RuntimeError: Could not run XX with arguments from the YY backend. XX is only available for these backends ZZ??
雖然fp32模型和int8模型都能在CPU上推理,但fp32算子僅接受tensor作為輸入,int8算子僅接受quantedtensor作為輸入,輸入和算子的類型不一致導(dǎo)致上述錯誤。
??一般排查方向為:是否完成了模型修改,將加法等操作替換為量化版本;是否正確添加了QuantStub()和DeQuantStub();是否在執(zhí)行convert()前是否執(zhí)行了model.eval()(在traning模型下,dropout無int8實現(xiàn)但沒有被去掉,然而在執(zhí)行推理時會報錯)。
(2) 是否支持GPU訓練,是否支持DistributedDataParallel訓練???
支持。官方有一個完整的量化感知訓練的實現(xiàn),使用了GPU和DistributedDataParallel,可惜在文檔和教程中未提及,參考這里(https://github.com/pytorch/vi.../_quantization.py)。
(3) 是否支持混合精度模型(例如一部分fp32推理,一部分int8推理)???
官方?jīng)]有明確說明,但經(jīng)實踐是可以的。
??模型是否進行量化取決于是否帶qconfig。因此可以將模型定義修改為
classMixModel(nn.Module):def__init__(self):super(MixModel,self).__init__()self.fp32_part=Fp32Model()self.int8_part=Int8Model()defforward(self,x):x=self.int8_part(x)x=self.fp32(x)returnxmix_model=MixModel()mix_model.int8_part.qconfig=get_default_qat_qconfig(BACKEND)prepare_qat(mix_model,inplace=True)
??由此可實現(xiàn)所需的功能。注意將QuantStub()、Dequant()模塊移到Int8Model()中。
(4)精度保持效果如何,如何提升精度???
筆者進行的實驗不多,在做過的簡單的OCR任務(wù)中,可以做到文字檢測和識別模型的指標下降均不超過1個點(量化的int8模型對比正常訓練的fp32模型)。官方教程中提供了分類例子的效果和提升精度的技巧,可供參考。
總結(jié)
??Pytorch官方提供的量化感知訓練API,上手較為簡單,易于集成到現(xiàn)有訓練代碼中。但目前手動修改模型和算子折疊增加了一定的工作量,期待在未來版本的改進。
- END -
推薦閱讀
更多嵌入式AI技術(shù)干貨請關(guān)注嵌入式AI專欄。
審核編輯:符乾江
-
深度學習
+關(guān)注
關(guān)注
73文章
5503瀏覽量
121206 -
pytorch
+關(guān)注
關(guān)注
2文章
808瀏覽量
13235
發(fā)布評論請先 登錄
相關(guān)推薦
評論