在H100發(fā)布之際,英偉達(dá)還帶來一個“重磅產(chǎn)品”——Transformer Engine。在Transformer大火之際推出這么一個產(chǎn)品,無疑是煉丹師福音。
當(dāng)時我還在猜測它會以怎么樣的一種形式呈現(xiàn)給用戶,直到最近公開了倉庫 NVIDIA/TransformerEngine
這其實就是PyTorch的一個拓展,為了利用FP8的特性,針對Transformer里面的Kernel進(jìn)行了重寫,包含了一系列LayerNorm, GeLU, ScaledSoftmax等。
使用方式也是比較簡單,使用該拓展額外包的一層Module來搭建網(wǎng)絡(luò),即可,最后再包一層混合精度訓(xùn)練作用域:
importtorch importtransformer_engine.pytorchaste fromtransformer_engine.commonimportrecipe #Setdimensions. in_features=768 out_features=3072 hidden_size=2048 #Initializemodelandinputs. model=te.Linear(in_features,out_features,use_bias=True) inp=torch.randn(hidden_size,in_features,device="cuda") #創(chuàng)建FP8訓(xùn)練的配置 fp8_recipe=recipe.DelayedScaling(margin=0,interval=1,fp8_format=recipe.Format.E4M3) #FP8的autocast withte.fp8_autocast(enabled=True,fp8_recipe=fp8_recipe): out=model(inp) loss=out.sum() loss.backward()
本篇博客就簡單介紹下Transformer Engine及其對應(yīng)實現(xiàn)原理
官方文檔:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Transfromer Engine 是干啥的?
在各種以Transformer為基礎(chǔ)的語言模型如GPT3大火后,語言模型的參數(shù)量還在以指數(shù)形式增長:
那么優(yōu)化Transformer性能就顯得格外重要了,其中混合精度訓(xùn)練是一個很實用的技巧
在FP16下,其數(shù)據(jù)范圍還是足夠大的,因此在AMP下,我們只在最后的Loss做了一個scaling,這個步驟足以保證在整個模型運算過程中不會產(chǎn)生溢出
而FP8相比FP16減少了更多有效位,因此不能簡單地復(fù)用FP16下的策略,需要給每個FP8 Tensor單獨設(shè)置一個合適的scale factor。Transformer Engine 需要動態(tài)地對輸入范圍進(jìn)行調(diào)整,如圖所示:
上圖來自H100白皮書內(nèi)(當(dāng)時我還天真的以為有一個專門的硬件做這個處理。。。)
下面我們簡單看下其代碼和實現(xiàn)原理
Kernel實現(xiàn)
具體到每一個算子實現(xiàn)動態(tài)范圍調(diào)整的原理其實很簡單,通過記錄歷史的abs max值,來去調(diào)整最終縮放的范圍。
其主要的Kernel實現(xiàn)都放在了 common 目錄下,我們以gelu這個kernel為例,最終它會調(diào)用到 vectorized_pointwise.h這個文件,我們主要看 unary_kernel
unary_kernel
這個核函數(shù)模板跟常規(guī)的elementwise向量化模板是類似的。
首先會讓每個線程獲取到scale值
ComputeTypes=0; ifconstexpr(is_fp8::value){ //獲取scale值 if(scale!=nullptr)s=*scale; //將scale取倒數(shù)寫回scale_inv if(blockIdx.x==0&&threadIdx.x==0&&scale_inv!=nullptr){ reciprocal (scale_inv,s); } }
其中在循環(huán)里,線程會不斷更新他運算結(jié)果的最大值,并且最終運算結(jié)果要乘上scale值:
//實際運算發(fā)生 ComputeTypetemp=OP(val,p); ifconstexpr(is_fp8::value){ __builtin_assume(max>=0); max=fmaxf(fabsf(temp),max); //縮放 temp=temp*s; }
當(dāng)Kernel主體運算完畢后,再也warp為單位做一個reduce_max,獲取到線程束內(nèi)的最大值,再通過atomicMax原子操作,不斷更新全局最大值:
ifconstexpr(is_fp8::value){ /*warptileamaxreduce*/ max=reduce_max (max,warp_id); if(threadIdx.x==0&&amax!=nullptr){ static_assert(std::is_same ::value); //更新全局最大值 atomicMaxFloat(amax,max); } }
其他layernorm等Kernel也是諸如類似的邏輯,這里就不再展開了
(1) DelayedScaling
從前面的示例代碼我們可以看到一個比較重要的API是DelayedScaling,我們可以根據(jù)官方文檔查看各個參數(shù)含義:
margin 計算scale的偏移量
interval 控制計算scale factor的頻率
fp8_format 使用FP8的格式,F(xiàn)P8有E4M3和E5M2,但是現(xiàn)在不支持純E5M2的格式訓(xùn)練
amax_history_len 記錄abs maxval的歷史窗口大小
amax_compute_algo 在窗口里選擇absmax的算法,'max'則是選擇歷史窗口里最大值,'most_recent'則是選擇近期的值,當(dāng)然你也可以傳一個自定義的函數(shù)
相關(guān)代碼為:
@torch.jit.script def_default_get_amax( amax_history:torch.Tensor, amax_compute_algo:str, )->Tuple[torch.Tensor,torch.Tensor]: """Defaultfunctiontoobtainamaxfromhistory.""" ifamax_compute_algo=="max": amax=torch.max(amax_history,dim=0).values else:#amax_compute_algo=="most_recent" amax=amax_history[0] amax_history=update_amax_history(amax_history) returnamax_history,amax
scaling_factor_compute_algo 計算scale factor的算法
@torch.jit.script def_default_sf_compute( amax:torch.Tensor, scale:torch.Tensor, fp8_max:float, margin:int, )->torch.Tensor: """Defaultfunctiontoconvertamaxtoscalingfactor.""" exp=torch.floor(torch.log2(fp8_max/amax))-margin sf=torch.round(torch.pow(2,torch.abs(exp))) sf=torch.where(amax>0.0,sf,scale) sf=torch.where(torch.isfinite(amax),sf,scale) sf=torch.where(exp0,?1?/?sf,?sf) ????return?sf
override_linear_precision 由3個bool值,分別控制fprop前向,dgrad,wgrad三個矩陣乘是否用更高的精度來計算,默認(rèn)都為False
(2) TransformerEngineBaseModule
相關(guān)的Kernel除了要完成自己的計算任務(wù),也得實時維護(hù)amax這些值,因此也需要對應(yīng)修改nn.Module,這里TransformerEngine繼承了nn.Module,并且增加了一些buffer維護(hù)的機制,這些buffer用于存儲動態(tài)scale的信息:
classTransformerEngineBaseModule(torch.nn.Module,ABC): def__init__(self)->None: ... self.fp8=False self.fp8_meta={} self.fp8_meta["fp8_group"]=None self.fp8_meta["recipe"]=get_default_fp8_recipe() deffp8_init(self,num_gemms:int=1)->None: """Initializefp8relatedmetadataandtensorsduringfprop.""" #Iffp8isn'tenabled,turnoffandreturn. ifnotis_fp8_enabled(): self.fp8=False return #FP8isalreadyenabledandrecipeisthesame,don'tdoanything. ifself.fp8andget_fp8_recipe()==self.fp8_meta["recipe"]: return #SetFP8,recipe,andotherFP8metadata self.fp8=True self.fp8_meta["recipe"]=get_fp8_recipe() self.fp8_meta["num_gemms"]=num_gemms self.fp8_meta["fp8_group"]=get_fp8_group() #SetFP8_MAXpertensoraccordingtorecipe self.fp8_meta["fp8_max_fwd"]=self.fp8_meta["recipe"].fp8_format.value.max_fwd self.fp8_meta["fp8_max_bwd"]=self.fp8_meta["recipe"].fp8_format.value.max_bwd #Allocatescalesandamaxes self.init_fp8_meta_tensors()
而相關(guān)Module如LayerNormMLP繼承該Module,并且傳入fp8_meta信息更新:
classLayerNormMLP(TransformerEngineBaseModule): defforward(...): out=_LayerNormMLP.apply( ..., self.fp8, self.fp8_meta, )
總結(jié)
大致瀏覽完其實思路不復(fù)雜,但感覺還是FP8技術(shù)的不穩(wěn)定,整個項目還是加入了很多限制。得益于PyTorch靈活的外部擴(kuò)展形式,只要不去觸碰框架底層運行機制,僅僅在算子層面上的修改還是相當(dāng)簡單。雖然不具備通用性,但是運算主體就這幾個算子,為了性能也是可以接受的
審核編輯:湯梓紅
-
NVIDIA
+關(guān)注
關(guān)注
14文章
5055瀏覽量
103372 -
英偉達(dá)
+關(guān)注
關(guān)注
22文章
3821瀏覽量
91510 -
Transformer
+關(guān)注
關(guān)注
0文章
145瀏覽量
6026 -
pytorch
+關(guān)注
關(guān)注
2文章
808瀏覽量
13289 -
H100
+關(guān)注
關(guān)注
0文章
31瀏覽量
294
原文標(biāo)題:詳解 NVIDIA H100 TransformerEngine
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論