最近一段時間在做商品理解的工作,主要內(nèi)容是從商品標(biāo)題里識別出商品的一些屬性標(biāo)簽,包括不限于品牌、顏色、領(lǐng)型、適用人群、尺碼等等。這類任務(wù)可以抽象成命名實體識別(Named Entity Recognition, NER)工作,一般用序列標(biāo)注(Sequence Tagging)的方式來做,是比較成熟的方向。
▲ 商品理解示例,品牌:佳豐;口味:蒜香味
本文主要記錄下做這個任務(wù)上遇到的問題,踩的坑,模型的效果等。
主要內(nèi)容:
- 怎么構(gòu)建命名實體識別(NER)任務(wù)的標(biāo)注數(shù)據(jù)
-
BertCRF 訓(xùn)練單標(biāo)簽識別過程及踩坑
- BertCRF 訓(xùn)練超多標(biāo)簽識別過程及踩坑
- CascadeBertCRF 訓(xùn)練超多標(biāo)簽識別過程及踩坑
NER任務(wù)標(biāo)注數(shù)據(jù)方法
其實對 NER 任務(wù)來說,怎么獲取標(biāo)注數(shù)據(jù)是比較重要、比較耗時費力的工作。針對商品理解任務(wù)來說,想要獲取大量的標(biāo)注數(shù)據(jù)一般可以分為 3 種途徑:
-
花錢外包,靠外包人肉打標(biāo),羨慕有錢的公司。
-
抓取其他平臺的數(shù)據(jù),這塊也可以分成兩種情況,第一種是既抓標(biāo)題又抓標(biāo)簽-標(biāo)簽值,比如 標(biāo)題:珍味來(zhenweilai)小黃魚(燒烤味),品牌:珍味來(zhenweilai),口味:燒烤味,得到的數(shù)據(jù)直接可以訓(xùn)練模型了;第二種是只抓 標(biāo)簽-標(biāo)簽值,把所有類目下所有常見的標(biāo)簽抓下來,不抓標(biāo)題,然通過一些手段把標(biāo)簽掛到自己平臺的標(biāo)題上,構(gòu)造訓(xùn)練數(shù)據(jù);第一種抓取得數(shù)據(jù)準(zhǔn),但很難找到資源給抓,即使找到了也非常容易被風(fēng)控;第二種因為請求量小,好抓一點,但掛標(biāo)簽這一步的準(zhǔn)確度會影響后面模型的效果。
- 用自己平臺的商品標(biāo)題去請求一些開放 NER 的 api,比如阿里云、騰訊云、百度 ai 等,有些平臺的 api 是免費的,有些 api 每天可以調(diào)用一定次數(shù),可以白嫖,對于電商領(lǐng)域,阿里云的 NER 效果比其他家好一些。
BertCRF單標(biāo)簽NER模型
這部分主要記錄 BertCRF 在做單一標(biāo)簽(品牌)識別任務(wù)時踩的一些坑。
先把踩的坑列一下:
-
怎么輕量化構(gòu)建 NER 標(biāo)注數(shù)據(jù)集。
-
bert tokenizer 標(biāo)題轉(zhuǎn) id 時,品牌值的 start idx、end idx 和原始的對不上,巨坑。
- 單一標(biāo)簽很容易過擬合,會把不帶品牌的標(biāo)題里識別出一些品牌,識別出來的品牌也不對。
2.1 輕量化構(gòu)建標(biāo)注數(shù)據(jù)集
上面講到構(gòu)建 NER 標(biāo)注數(shù)據(jù)的常見 3 種方法,先把第一種就排除,因為沒錢打標(biāo);對于第三種,我嘗試了福報廠的 NER api,分基礎(chǔ)版 和 高級版,但評估下來發(fā)現(xiàn)不是那么準(zhǔn)確,召回率沒有達(dá)到要求,也排除了;
那就剩第二種方案了,首先嘗試了第二種里的第一種情況,既抓標(biāo)題又抓標(biāo)簽,很快發(fā)現(xiàn)就被風(fēng)控了,不管用自己寫的腳本還是公司的采集平臺,都繞不過風(fēng)控,便放棄了;所以就只抓標(biāo)簽-標(biāo)簽值,后面再用規(guī)則的方法掛到商品標(biāo)題上。
只抓標(biāo)簽和標(biāo)簽值相當(dāng)于構(gòu)建類目下標(biāo)簽知識庫了,有了類目限定之后,通過規(guī)則掛靠在商品標(biāo)題上時,會提高掛靠的準(zhǔn)確率。比如“夏季清涼短款連衣裙”,其中包含標(biāo)簽“裙長”:“短款”,如果不做類目限定,就會用規(guī)則掛出多個標(biāo)簽“衣長”:“短款”,“褲長”:“短款”,“裙長”:“短款”等等,類目限制就可以把一些非此類目的標(biāo)簽排除掉。
通過規(guī)則掛靠出的數(shù)據(jù)也會存在一些 bad case,盡管做了類目限制,但也有一定的標(biāo)錯樣本;組內(nèi)其他同學(xué)在做大規(guī)模對比學(xué)習(xí)模型,于是用規(guī)則掛靠出的結(jié)果標(biāo)題——標(biāo)簽:標(biāo)簽值走一遍對比學(xué)習(xí)模型,把標(biāo)題向量和標(biāo)簽值向量相似得分高的樣本留下當(dāng)做優(yōu)質(zhì)標(biāo)注數(shù)據(jù)。
▲ 輕量化構(gòu)建NER標(biāo)注數(shù)據(jù)
通過以上步驟,不需要花費很多人力,自己一人就可以完成整個流程,減少了很多人工標(biāo)注、驗證的工作;得到的數(shù)據(jù)也足夠優(yōu)質(zhì)。
2.2 正確打標(biāo)label index
NER 任務(wù)和文本分類任務(wù)很像,文本分類任務(wù)是句子或整篇粒度,NER 是 token 或者 word 粒度的文本分類。
所以 NER 任務(wù)的訓(xùn)練數(shù)據(jù)和文本分類任務(wù)相似,但有一點點不同。對于文本分類任務(wù),一整個標(biāo)題有 1 個 label。
▲ 文本分類任務(wù)token和label對應(yīng)關(guān)系
對于 NER 任務(wù),一整個標(biāo)題有一串 label,每個 tokend 都有一個 label。在做品牌識別時,設(shè)定 label 有 3 種取值。
"UNK":0," B_brand":1,"I_brand":2,其中 B_brand 代表品牌的起始位置,I_brand 代表品牌的中間位置。
▲品牌NER任務(wù)token和label對應(yīng)關(guān)系
搞清了 NER 任務(wù)的 label 形式之后,接下來就是怎么正確的給每個樣本打上 label,一般先聲明個和 title 長度一樣的全 0 列表,遍歷,把相應(yīng)位置置 1 或者 2 就可以得到樣本 label,下面是一個基礎(chǔ)的例子
a={
"title":"潘頓特級初榨橄欖油",
"att_name":"品牌",
"att_value":"潘頓",
"start_idx":0,
"end_idx":2
}
defset_label(text):
title=text['title']
label=[0]*len(title)
foridxinrange(text['start_idx'],text['end_idx']):
ifidx==text['start_idx']:
label[idx]=1
else:
label[idx]=2
returnlabel
text_label=set_label(a)
print(text_label)
但這里需要把 title 進(jìn)行 tokenizer id 化,bert tokenizer 之后的 id 長度可能會和原來的標(biāo)題長度不一致,包含有些英文會拆成詞綴,空格也會被丟棄,導(dǎo)致原始的 start_idx 和 end_idx 發(fā)生偏移,label 就不對了。
這里先說結(jié)論:強(qiáng)烈建議使用 list(title)全拆分標(biāo)題,再使用 tokenizer.convert_tokens_to_ids 的方式 id 化?。?!
剛開始沒有使用上面那種方式,用的是 tokenizer(title)進(jìn)行 id 化再計算偏移量,重新對齊 label,踩了 2 個坑
-
tokenizer 拆分英文變成詞綴,start index 和 end index 會發(fā)生偏移,盡管有offset_mapping 可以記錄偏移的對應(yīng)關(guān)系,但真正回退偏移時還會遇到問題;
- 使用 tokenizer(title)的方式,預(yù)測的時候會遇到?jīng)]法把 id 變成 token;比如下面這個例子,
fromtransformersimportAutoTokenizer
tokenizer=AutoTokenizer.from_pretrained('../bert_pretrain_model')
input_id=tokenizer('呫頓')['input_ids']
token=[tokenizer.convert_ids_to_tokens(w)forwininput_id[1:-1]]
#['[UNK]','頓']
因為“呫”是生僻字,使用 convert_ids_to_tokens 是沒法知道原始文字是啥的,有人可能會說,預(yù)測出 index 之后,直接去標(biāo)題里拿字不就行了,不用 convert_ids_to_tokens;上面說過,預(yù)測出來的 index 和原始標(biāo)題的文字存在 offset,這樣流程就變成
▲ 使用tokenizer id化label對應(yīng)關(guān)系
所以,還是強(qiáng)烈建議使用 list(title)全拆分標(biāo)題,再使用 tokenizer.convert_tokens_to_ids 的方式 id 化?。?!
這樣就不存在偏移的問題,start idx 和 end idx 不會變化,預(yù)測的時候不需要使用 convert_ids_to_tokens,直接用 index 去列表里 token list 取字
正確打標(biāo) label 非常重要,不然訓(xùn)練的模型就會很詭異。建議在代碼里加上校驗語句,不管使用哪種方法,有考慮不全的地方,就會報錯
assertattribute_value==title[text['start_idx']:text['end_idx']]
2.3 BertCRF模型結(jié)構(gòu)
Pytorch 寫 BertCRF 很簡單,可能會遇到 CRF 包安裝問題,可以不安裝,直接把 crf.py 文件拷貝到項目里引用。
classBertCRF(nn.Module):
def__init__(self,num_labels):
super(BertCRF,self).__init__()
self.config=BertConfig.from_pretrained('../xxx/config.json')
self.bert=BertModel.from_pretrained('../xxx')
self.dropout=nn.Dropout(self.config.hidden_dropout_prob)
self.classifier=nn.Linear(self.config.hidden_size,num_labels)
self.crf=CRF(num_tags=num_labels,batch_first=True)
defforward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None,
):
outputs=self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
sequence_output=outputs[0]
sequence_output=self.dropout(sequence_output)
logits=self.classifier(sequence_output)
outputs=(logits,)
iflabelsisnotNone:
loss=self.crf(emissions=logits,tags=labels,mask=attention_mask)
outputs=(-1*loss,logits)
returnoutputs
2.4 緩解過擬合問題
只做一個標(biāo)簽(品牌)識別時,訓(xùn)練集是 標(biāo)題-品牌值 pair 對,每個樣本都有品牌值。由于品牌長尾現(xiàn)象嚴(yán)重,這里對熱門品牌的數(shù)據(jù)進(jìn)行了采樣,1 個品牌最少包含 100 個標(biāo)題,最多包含 300 個標(biāo)題,數(shù)據(jù)分布如下
模型關(guān)鍵參數(shù)
max_seq_length=50
train_batch_size=256
epochs=3
learning_rate=1e-5
crf_learning_rate=5e-5
第一版模型訓(xùn)練之后,驗證集 F1 0.98,通過分析驗證數(shù)據(jù)的 bad case,發(fā)現(xiàn)模型對包含品牌的標(biāo)題預(yù)測效果還不錯,但是對不包含品牌的標(biāo)題,幾乎全軍覆沒,都會抽出 1、2 個字出來,模型過擬合了。而且抽出的字一般都是標(biāo)題前 1、2 個字,這與商品品牌一般都在標(biāo)題前面有關(guān)。
針對過擬合問題及表現(xiàn)的現(xiàn)象,嘗試了 2 種方法:
-
既然對沒有品牌的標(biāo)題一般都抽出前 1、2 個字,那在訓(xùn)練的時候把品牌從前面隨機(jī)插入到標(biāo)題中間、尾部等位置,是不是可以緩解。
- 構(gòu)建訓(xùn)練集的時候加入一些負(fù)樣本,負(fù)樣本里 label 都是 0,不包含品牌,正負(fù)樣本比 1:1。
法 1 訓(xùn)練之后,沒有解決問題,而且過擬合問題更加嚴(yán)重了
法 2 訓(xùn)練之后,過擬合問題解決了,增加了近 1 倍樣本,訓(xùn)練時間翻倍。
BertCRF 模型訓(xùn)練完之后,通過分析 bad case,會發(fā)現(xiàn)有的數(shù)據(jù)模型預(yù)測是對的,標(biāo)注時標(biāo)錯了,模型有一定的糾錯能力,transformer 強(qiáng)??!
美國新安怡(fsoothielp)安撫奶嘴。標(biāo)注品牌:”soothie“;預(yù)測品牌:“新安怡(fsoothielp)”
美羚富奶羊羊羊粉 2 段。標(biāo)注品牌:“羊羊羊”,預(yù)測品牌:“美羚”
針對 BertCRF 在 Finetune 時有 2 種方式,一種是 linear probe,只訓(xùn)練 CRF 和線性層,凍結(jié) Bert 預(yù)訓(xùn)練參數(shù),這種方式訓(xùn)練飛快;另一種是不凍結(jié) Bert 參數(shù),模型所有參數(shù)都更新,訓(xùn)練很慢。
一般在 Bert 接下游任務(wù)時,我都會選擇第二種全部訓(xùn)練的方式,不凍結(jié)參數(shù),雖然訓(xùn)練慢,但擬合能力強(qiáng);尤其是用 bert-base 這類預(yù)訓(xùn)練模型時,這些模型在電商領(lǐng)域直接適配并不會很好,更新 bert 預(yù)訓(xùn)練參數(shù),能讓模型向電商標(biāo)題領(lǐng)域進(jìn)行遷移。
BertCRF多標(biāo)簽NER模型
這部分主要記錄 BertCRF 訓(xùn)練超多標(biāo)簽識別時,遇到的問題,模型的效果等。
先把踩的坑列一下:
-
爆內(nèi)存問題,因為要訓(xùn)練多標(biāo)簽,所以訓(xùn)練數(shù)據(jù)很多,千萬級別,dataloader 過程中內(nèi)存不夠。
-
爆顯存問題,CRF 的坑,下面會細(xì)說。
- 訓(xùn)練完的模型,預(yù)測時召回能力不強(qiáng),準(zhǔn)確率夠用。
多標(biāo)簽和單標(biāo)簽時,模型的結(jié)構(gòu)不變,和上面的代碼一模一樣。
3.1 爆內(nèi)存問題
和單標(biāo)簽一樣,也對每個標(biāo)簽值進(jìn)行了采樣,減少標(biāo)簽值的長尾分布現(xiàn)象。1 個標(biāo)簽值最少包含 100 個標(biāo)題,最多包含 300 個標(biāo)題。數(shù)據(jù)分布如下
一個標(biāo)簽有多個標(biāo)簽值,比如“顏色”:“紅”,“黃”,“綠”,...等。一個標(biāo)簽有 2 個 label 值,B 代表起始位置,I 代表終止位置,所以整體有 1212 + 1 個類別,1 代表 UNK。
單類別負(fù)采樣后訓(xùn)練數(shù)據(jù)總共 200w 左右,多類別時沒負(fù)采樣訓(xùn)練數(shù)據(jù) 900 多 w,數(shù)據(jù)量多了 4 倍,原有的 dataset 沒有優(yōu)化內(nèi)存,到多標(biāo)簽這里就爆內(nèi)存了。
把特征處理的模塊從__init__里轉(zhuǎn)移到__getitem__函數(shù)里,這樣就可以減少很多內(nèi)存使用了
舊版本的 dataset 函數(shù)
classMyDataset(Dataset):
def__init__(self,text_list,tokenizer,max_seq_len):
self.input_ids=[]
self.token_type_ids=[]
self.attention_mask=[]
self.labels=[]
self.input_lens=[]
self.len=len(text_list)
fortextintqdm(text_list):
input_ids,input_mask,token_type_ids,input_len,label_ids=feature_process(text,tokenizer,max_seq_len)
self.input_ids.append(input_ids)
self.token_type_ids.append(token_type_ids)
self.attention_mask.append(input_mask)
self.labels.append(label_ids)
self.input_lens.append(input_len)
def__getitem__(self,index):
tmp_input_ids=torch.tensor(self.input_ids[index]).to(device)
tmp_token_type_ids=torch.tensor(self.token_type_ids[index]).to(device)
tmp_attention_mask=torch.tensor(self.attention_mask[index]).to(device)
tmp_labels=torch.tensor(self.labels[index]).to(device)
tmp_input_lens=torch.tensor(self.input_lens[index]).to(device)
returntmp_input_ids,tmp_attention_mask,tmp_token_type_ids,tmp_input_lens,tmp_labels
def__len__(self):
returnself.len
新版本的 dataset 函數(shù)
classMyDataset(Dataset):
def__init__(self,text_list,tokenizer,max_seq_len):
self.text_list=text_list
self.len=len(text_list)
self.tokenizer=tokenizer
self.max_seq_len=max_seq_len
def__getitem__(self,index):
raw_text=self.text_list[index]
input_ids,input_mask,token_type_ids,input_len,label_ids=feature_process(raw_text,
self.tokenizer,
self.max_seq_len)
tmp_input_ids=torch.tensor(input_ids).to(device)
tmp_token_type_ids=torch.tensor(token_type_ids).to(device)
tmp_attention_mask=torch.tensor(input_mask).to(device)
tmp_labels=torch.tensor(label_ids).to(device)
tmp_input_lens=torch.tensor(input_len).to(device)
returntmp_input_ids,tmp_attention_mask,tmp_token_type_ids,tmp_input_lens,tmp_labels
def__len__(self):
returnself.len
可以看到新版本比舊版本減少了 5 個超大的 list,爆內(nèi)存的問題就解決了,雖然這塊會有一定的速度損失。
3.2 爆顯存問題
當(dāng)標(biāo)簽個數(shù)少時,BertCRF 模型最大 tensor 是 bert 的 input,包含 input_ids,attention_mask,token_type_ids三個tensor,維度是(batch size,sequence length,hidden_size=768),對于商品標(biāo)題數(shù)據(jù) sequence length=50,顯存占用大小取決于 batch size,僅做品牌識別,16G 顯存 batch size=300,32G 顯存 batch size=700。
但當(dāng)標(biāo)簽個數(shù)多時,BertCRF 模型最大 tensor 來自 CRF 這貨了,這貨具體原理不展開,后面會單獨寫一期,只講下這貨代碼里的超大 tensor。
CRF 在做 forward 時,函數(shù)_compute_normalizer 里的 next_score shape 是(batch_size, num_tags, num_tags),當(dāng)做多標(biāo)簽時,num_tags=1212,(batch_size, 1212, 1212)>>(batch_size, 50, 768),這個 tensor 遠(yuǎn)遠(yuǎn)大于 bert 的輸入了,多標(biāo)簽時,16G 顯存 batch size=32,32G 顯存 batch size=80
#shape:(batch_size,num_tags,num_tags)
next_score=broadcast_score+self.transitions+broadcast_emissions
排查到爆顯存的原因之后,也沒找到好的優(yōu)化辦法,CRF 這貨在多標(biāo)簽時太慢了,又占顯存。
3.3 模型效果
經(jīng)過近 4 天的顯卡火力全開之后,1k+ 類別的模型訓(xùn)練完成了。使用測試數(shù)據(jù)對模型進(jìn)行驗證,得到 3 個結(jié)論
-
模型沒有過擬合,盡管訓(xùn)練數(shù)據(jù)沒有負(fù)樣本
-
模型預(yù)測準(zhǔn)確率高,但召回能力不強(qiáng)
- 模型對單標(biāo)簽樣本預(yù)測效果好,多標(biāo)簽樣本預(yù)測不全,僅能預(yù)測 1~2 個,和 2 類似
先說一下模型為什么沒有出現(xiàn)單標(biāo)簽時的過擬合問題,因為在近 1k 個標(biāo)簽?zāi)P陀?xùn)練時,學(xué)習(xí)難度直接上去了,模型不會很快的收斂,單標(biāo)簽時任務(wù)過于簡單,容易出現(xiàn)過擬合。
驗證模型效果時,先定義怎么算正確:假設(shè)一個標(biāo)題包含 3 個標(biāo)簽,預(yù)測時要把這 3 個標(biāo)簽都識別出來,并且標(biāo)簽值也要對的上,才算正確;怎么算錯誤:識別的標(biāo)簽個數(shù)少于真實的標(biāo)簽個數(shù),識別的標(biāo)簽值和真實的對不上都算錯誤。
使用 105w 驗證數(shù)據(jù),整體準(zhǔn)確率 803388/1049268=76.5%,如果把預(yù)測不全,但預(yù)測對的樣本也算進(jìn)來的話,準(zhǔn)確率(803388+76589)/1049268=83.9%。
對 bad case 進(jìn)行分析,模型對于 1 個標(biāo)題中含有多個標(biāo)簽時,識別效果不好,表現(xiàn)現(xiàn)象是識別不全,一般只識別出 1 個標(biāo)簽,統(tǒng)計驗證數(shù)據(jù)里標(biāo)簽個數(shù)和樣本個數(shù)的關(guān)系,這個指標(biāo)算是標(biāo)簽個數(shù)維度的召回率
多標(biāo)簽樣本是指一個標(biāo)題中包含多個標(biāo)簽,比如下面這個商品包含 5 個標(biāo)簽。
標(biāo)題:“吊帶潮流優(yōu)雅純色氣質(zhì)收腰高腰五分袖喇叭袖連體褲 2018 年夏季”。
標(biāo)簽:袖長:五分袖;上市時間:2018年夏季;風(fēng)格:優(yōu)雅;圖案:純色;腰型:高腰。
可以看到對于標(biāo)簽數(shù)越多的標(biāo)題,模型的識別效果越不好,然后我分析了訓(xùn)練數(shù)據(jù)的標(biāo)簽個數(shù)個樣本數(shù)的關(guān)系,可以看到在訓(xùn)練數(shù)據(jù)里,近 90% 的樣本僅只有一個標(biāo)簽,模型對多標(biāo)簽識別效果不好主要和這個有關(guān)系。
所以在構(gòu)建數(shù)據(jù)集時,可以平衡一下樣本數(shù),多加一些多標(biāo)簽的樣本到訓(xùn)練集,這樣對多標(biāo)簽樣本的適配能力也會增強(qiáng)。
但多標(biāo)簽樣本本身收集起來會遇到困難,于是我又發(fā)現(xiàn)了一個新的騷操作
沒法獲得更多的多標(biāo)簽樣本提升模型的召回能力咋辦呢?模型不是對單標(biāo)簽樣本很牛 b 嘛,那在預(yù)測的時候,每次如果有標(biāo)簽提取出來,就從標(biāo)題里把已經(jīng)預(yù)測出的標(biāo)簽值刪掉,繼續(xù)預(yù)測,循環(huán)預(yù)測,直到預(yù)測是空終止。
第一次預(yù)測
input title:吊帶潮流優(yōu)雅純色氣質(zhì)收腰高腰五分袖喇叭袖連體褲2018年夏季
predict label:袖長:五分袖
把五分袖從標(biāo)題里刪除,進(jìn)行第二次預(yù)測
input title:吊帶潮流優(yōu)雅純色氣質(zhì)收腰高腰喇叭袖連體褲2018年夏季
predict label:上市時間:2018年夏季
把2018年夏季從標(biāo)題里刪除,進(jìn)行第三次預(yù)測
input title:吊帶潮流優(yōu)雅純色氣質(zhì)收腰高腰喇叭袖連體褲
predict label:風(fēng)格:優(yōu)雅
把優(yōu)雅從標(biāo)題里刪除,進(jìn)行第四次預(yù)測
input title:吊帶潮流純色氣質(zhì)收腰高腰喇叭袖連體褲
predict label:圖案:純色
把純色從標(biāo)題里刪除,進(jìn)行第五次預(yù)測
吊帶潮流氣質(zhì)收腰高腰喇叭袖連體褲
predict label:腰型:高腰
把高腰從標(biāo)題里刪除,進(jìn)行第六次預(yù)測
input title:吊帶潮流氣質(zhì)收腰喇叭袖連體褲
predict label:預(yù)測為空
可以看到,標(biāo)簽被一個接一個的準(zhǔn)確預(yù)測出,這種循環(huán)預(yù)測是比較耗時的,離線可以,在線吃不消;能找到更多 多標(biāo)簽數(shù)據(jù)補充到訓(xùn)練集里是正確的方向。
多標(biāo)簽 CRF 爆顯存,只能設(shè)定小 batch size 慢慢跑的問題不能解決嘛?當(dāng)然可以,卷友們提出了一種多任務(wù)學(xué)習(xí)的方法,CRF 只學(xué)習(xí) token 是不是標(biāo)簽實體,通過另一個任務(wù)區(qū)分 token 屬于哪個標(biāo)簽類別。
CascadeBertCRF多標(biāo)簽?zāi)P?/span>
4.1 模型結(jié)構(gòu)
在標(biāo)簽數(shù)目過多時,BertCRF 由于 CRF 這貨的問題,導(dǎo)致模型很耗顯存,訓(xùn)練也很慢,這種方式不太科學(xué),也會影響效果。
從標(biāo)簽過多這個角度出發(fā),卷友們提出把 NER 任務(wù)拆分成多任務(wù)學(xué)習(xí),一個任務(wù)負(fù)責(zé)識別 token 是不是實體,另一個任務(wù)判斷實體屬于哪個類別。
這樣 NER 任務(wù)的 lable 字典就只有"B"、"I"、"UNK"三個值了,速度嗖嗖的;而判斷實體屬于哪個類別用線性層就可,速度也很快,模型顯存占用很少。
▲ 左單任務(wù)NER模型;右多任務(wù)NER模型
Cascade 的意思是級聯(lián)。就是把 BERT 的 token 向量過一遍 CRF 之后,再過一遍 Dense 層分類。但這里面有一些細(xì)節(jié)。
訓(xùn)練時,BERT 的 tokenx 向量過一遍 Dense 層分類,但不是所有 token 都計算 loss,是把 CRF 預(yù)測是實體的 token 拿出來算 loss,CRF 預(yù)測不是實體的不計算 loss,一個實體有多個 token,每個 token 都計算 loss;預(yù)測時,把實體的每個 token 分類結(jié)果拿出來,設(shè)計了三種類別獲取方式。
比如“蒜香味”在模型的 CRF 分支預(yù)測出是實體,標(biāo)簽對應(yīng) "B"、"I"、"I";接下要解析這個實體屬于哪個類別,在 Dense 分支預(yù)測的結(jié)果可能會有四種
-
“蒜香味”對應(yīng)的 Dense 結(jié)果是 “unk”、“unk”、“unk”,沒識別出實體類別
-
“蒜香味”對應(yīng)的 Dense 結(jié)果是 “口味”、“口味”、“口味”,每個 token 都對
-
“蒜香味”對應(yīng)的 Dense 結(jié)果是 “unk”、"口味"、"口味",有的 token 對,有的token 沒識別出
- “蒜香味”對應(yīng)的 Dense 結(jié)果是 “unk”、“品牌”、“口味”,有的 token 對了,有的 token 沒識別出,有的 token 錯了
針對上面 4 中結(jié)果,可以看到 4、3、2 越來越嚴(yán)謹(jǐn)。在評估模型效果時,采用 2 是最嚴(yán)的,就是預(yù)測的 CRF 結(jié)果要對,Dense 結(jié)果中每個 token 都要對,才算完全正確;3 和 4 越來越寬松。
4.2 模型代碼
importtorch
fromcrfimportCRF
fromtorchimportnn
fromtorch.nnimportCrossEntropyLoss
fromtransformersimportBertModel,BertConfig
classCascadeBertCRF(nn.Module):
def__init__(self,bio_num_labels,att_num_labels):
super(CascadeBertCRF,self).__init__()
self.config=BertConfig.from_pretrained('../bert_pretrain_model/config.json')
self.bert=BertModel.from_pretrained('../bert_pretrain_model')
self.dropout=nn.Dropout(self.config.hidden_dropout_prob)
self.bio_classifier=nn.Linear(self.config.hidden_size,bio_num_labels)#crf預(yù)測字是不是標(biāo)簽
self.att_classifier=nn.Linear(self.config.hidden_size,att_num_labels)#預(yù)測標(biāo)簽屬于哪個類別
self.crf=CRF(num_tags=bio_num_labels,batch_first=True)
defforward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
bio_labels=None,
att_labels=None,
):
outputs=self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
sequence_output=outputs[0]
sequence_output=self.dropout(sequence_output)
bio_logits=self.bio_classifier(sequence_output)#(batchsize,sequencelength,bio_num_labels)
num_bio=bio_logits.shape[-1]
reshape_bio_logits=bio_logits.view(-1,num_bio)#(batchsize*sequencelength,bio_num_labels)
pred_bio=torch.argmax(reshape_bio_logits,dim=1)#ner預(yù)測的bio結(jié)果
no_zero_pred_bio_index=torch.nonzero(pred_bio)#取出ner結(jié)果非0的token
att_logits=self.att_classifier(sequence_output)#(batchsize,sequencelength,att_num_labels)
num_att=att_logits.shape[-1]#att_num_labels
att_logits=att_logits.view(-1,num_att)#(batchsize*sequencelength,att_num_labels)
outputs=(bio_logits,att_logits)
ifbio_labelsisnotNoneandatt_labelsisnotNone:
select_att_logits=torch.index_select(att_logits,0,no_zero_pred_bio_index.view(-1))
select_att_labels=torch.index_select(att_labels.contiguous().view(-1),0,no_zero_pred_bio_index.view(-1))
loss_fct=CrossEntropyLoss()
select_att_loss=loss_fct(select_att_logits,select_att_labels)
bio_loss=self.crf(emissions=bio_logits,tags=bio_labels,mask=attention_mask)
loss=-1*bio_loss+select_att_loss
outputs=(loss,-1*bio_loss,bio_logits,select_att_loss,att_logits)
returnoutputs
4.3 模型效果
上面提到評估 Dense 的結(jié)果會遇到 4 種情況,使用第 4 種方式進(jìn)行指標(biāo)評估;NER 的識別效果和上面一致。
使用 105w 驗證數(shù)據(jù),整體準(zhǔn)確率 792386/1049268=75.5%,比 BertCRF 低 1 個點;把預(yù)測不全,但預(yù)測對的樣本也算進(jìn)來的話,準(zhǔn)確率(147297+792386)/1049268=89.6%,比 BertCRF 高 5 個點;
標(biāo)簽個數(shù)和預(yù)測標(biāo)簽個數(shù)的對照關(guān)系:
CascadeBertCRF 模型的召回率比 BertCRF 要低,但模型的準(zhǔn)確率會高一些。CascadeBertCRF 相比 BertCRF,主要是提供了一種超多實體識別的訓(xùn)練思路,且模型的效果沒有損失,訓(xùn)練速度和推理速度有大幅提高。
把實體從標(biāo)題里刪掉訓(xùn)練預(yù)測的方法也同樣適用 CascadeBertCRF。
審核編輯 :李倩
-
數(shù)據(jù)集
+關(guān)注
關(guān)注
4文章
1208瀏覽量
24713 -
標(biāo)簽
+關(guān)注
關(guān)注
0文章
137瀏覽量
17892
原文標(biāo)題:NER | 商品標(biāo)題屬性識別探索與實踐
文章出處:【微信號:zenRRan,微信公眾號:深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論