0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

Faster Transformer v2.1版本源碼解讀

jf_pmFSk4VX ? 來源:后來遇見AI ? 2023-09-19 11:39 ? 次閱讀

寫在前面:本文將對(duì) Faster Transformer v2.1 版本源碼進(jìn)行解讀,重點(diǎn)介紹該版本基于 v1.0 和 v2.0 所做的優(yōu)化內(nèi)容,剖析源碼作者優(yōu)化意圖。

1 v2.1 版本發(fā)布背景

在 FasterTransformer v1.0 中,Nvidia 提供了一個(gè)高度優(yōu)化的 BERT Transformer Encoder 模塊,主要應(yīng)用于序列標(biāo)注推理場(chǎng)景,筆者針對(duì)源碼的實(shí)現(xiàn)邏輯和優(yōu)化技巧進(jìn)行了深度剖析,有興趣的讀者可以移步——【CUDA編程】Faster Transformer v1.0 源碼詳解。
在 FasterTransformer v2.0 中,Nvidia 添加了一個(gè)高度優(yōu)化的 Decoder 模塊和一套推理方案 Decoding 模型。其中,Decoder 相當(dāng)于我們常說的 decoder layer;而 Decoding 則包含了整個(gè)解碼的流程,包括詞嵌入、位置編碼、解碼層和束搜索等過程,相當(dāng)于我們常說的 decoder model。同樣,筆者針對(duì) v2.0 版本新增的內(nèi)容進(jìn)行了優(yōu)化解讀,有興趣的讀者可以移步——【CUDA編程】Faster Transformer v2.0 源碼詳解。
在 FasterTransformer v2.1 中,官方主要添加了 3 塊優(yōu)化內(nèi)容。第一點(diǎn)是考慮到 PyTorch 的用戶越來越多,官方添加了對(duì) PyTorch 的支持,這點(diǎn)不在本文的討論范疇。第二個(gè)特點(diǎn)是支持 Effective Transformer,該優(yōu)化思路來自字節(jié)跳動(dòng)算法團(tuán)隊(duì),計(jì)算模型中去除了 encoder 輸入的無效填充,從而降低了計(jì)算開銷。第三,除了使用束搜索進(jìn)行解碼外,還提供了基于采樣的解碼策略。除此之外,Nvidia 還對(duì) Encoder、Decoder 和 beam search 等諸多模塊的內(nèi)核進(jìn)行了優(yōu)化,進(jìn)一步提高了 FasterTransformer 的推理速度。因此本文的解讀也主要聚焦于 3 個(gè)方面:Effective Transformer、sampling decoding、內(nèi)核優(yōu)化,針對(duì)其他未發(fā)生變更的內(nèi)容,請(qǐng)讀者閱讀筆者的前兩篇文章。

2 整體架構(gòu)

同前面兩個(gè)版本一樣,v2.1 的底層由 CUDA 和 cuBLAS 實(shí)現(xiàn),提供 C++ APITensorFlow/PyThorch OP。用戶可以將它們集成到 TensorFlow、PyTorch 或其他在本機(jī) C++ 中構(gòu)建的推理服務(wù)代碼中。此外官方還提供了一些簡(jiǎn)單的示例代碼來演示如何使用 Encoder、Decoder 以及在 C++、TensorFlow 和 PyTorch 中執(zhí)行 Decoding 過程。下面是整體架構(gòu)圖:
5c3ec9b2-5698-11ee-939d-92fbcf53809c.png

源碼地址如下,有興趣的讀者可以前往下載

https://github.com/NVIDIA/FasterTransformer/tree/v2.1/

3 Effective Transformer

關(guān)于 Transformer Encoder 的邏輯筆者在之前的文章中有詳細(xì)闡述,這里筆者不打算再重復(fù)講解,解讀重心會(huì)放在“Effective”的部分。當(dāng)使用 Transformer 對(duì)一批輸入序列進(jìn)行編碼時(shí),我們通常將輸入序列視為一個(gè)矩陣,其列數(shù)等于所有序列的最大長度。Faster Transformer 可以非常有效地處理所有序列長度大致相同的情況。然而,如果同一批中序列的長度變化很大,將它們填充到相同的長度意味著對(duì)內(nèi)存和計(jì)算資源的巨大浪費(fèi)??紤]下面一個(gè)例子:

bert_input = [["Hi"], ["Picking"], ["The", "seed", "of", "Job's", "tears"]]
bert_tokens = [[1], [2], [3,4,5,6,7]]
bert_tokens_padded = [[1, 0, 0, 0, 0], [2, 0, 0, 0, 0], [3, 4, 5, 6, 7]]
bert_tokens_mask = [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [1, 1, 1, 1, 1]]

對(duì)于輸入的 3 個(gè)樣本來說,實(shí)際有效的 word 只有 1+1+5=7 個(gè),但是我們要用一個(gè) 3 * 5 的矩陣來計(jì)算,中間其實(shí)有一半的元素是無效的,這些無效元素既浪費(fèi)了內(nèi)存又占用了計(jì)算資源。所以我們?cè)谙?,能不能就?[1,2,3,4,5,6,7] 這 7 個(gè)元素來參與計(jì)算?
在 Effective Transformer 中,會(huì)根據(jù)不同的計(jì)算階段,動(dòng)態(tài)刪除和恢復(fù)填充值,從而減少資源占用。

3.1 計(jì)算偏移量

通常上游傳過來的 tensor 一般是 padded 之后的規(guī)則矩陣,假設(shè)形狀為 [batch_size, seq_len, hidden_units] 記為 tensor_padded,刪除填充值后的形狀為 [valid_word_num, hidden_units] 記為 tensor,要想動(dòng)態(tài)的刪除和恢復(fù)填充值,也就是說要找到 tensor_padded 和 tensor 的對(duì)應(yīng)關(guān)系,源碼中用 build_sequence_length_padding_offset_kernelLauncher 函數(shù)解決這個(gè)問題。

/**
 * @brief 計(jì)算偏移量
 * 
 * @param sequence_length         [batch_size,]
 * @param batch_size        
 * @param max_seq_len             
 * @param valid_word_num          [1,]
 * @param tmp_mask_offset         [valid_word_num]
 * @return __global__ 
 */
__global__ void build_sequence_length_padding_offset(const int* sequence_length, 
  const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset)
{
  // do cumulated sum
  int total_seq_len = 0;
  int cum_offset = 0;
  int index = 0;
  for(int i = 0; i < batch_size; i++) 
  {
    const int seq_len = sequence_length[i];
    for(int j = 0; j < seq_len; j++)
    {
      tmp_mask_offset[index] = cum_offset;
      index++;
    }
    cum_offset += max_seq_len - seq_len;
    total_seq_len += seq_len;
  }
  valid_word_num[0] = total_seq_len;
}

void build_sequence_length_padding_offset_kernelLauncher(const int* sequence_length, 
  const int batch_size, const int max_seq_len, int* valid_word_num, int* tmp_mask_offset,
  cudaStream_t stream)
{
  build_sequence_length_padding_offset<<<1, 1, 0, stream>>>(sequence_length, 
    batch_size, max_seq_len, valid_word_num, tmp_mask_offset);
}

可以看到函數(shù)體內(nèi)部執(zhí)行了一個(gè) 1*1 的核函數(shù),也就是說這個(gè)核函數(shù)完全沒有并行,全部邏輯在一個(gè)線程里完成??赡苡凶x者要問,既然不并行,那為啥不在主機(jī)端直接用 C++ 完成,這是因?yàn)?tmp_mask_offset 和 valid_word_num 這都是設(shè)備端的變量,如果在主機(jī)端計(jì)算還需要一次內(nèi)存拷貝操作,而 host-device 內(nèi)存拷貝是比較耗時(shí)的,所以干脆就在設(shè)備端開一個(gè)線程算了。
核函數(shù)內(nèi)的計(jì)算邏輯比較簡(jiǎn)單,直接看下面的圖就可以了。
5c57bb84-5698-11ee-939d-92fbcf53809c.png根據(jù)樣本長度 sequence_length 計(jì)算了兩個(gè)指標(biāo) valid_word_num 和 id_offset,用于后面動(dòng)態(tài)刪除和恢復(fù)矩陣,tensor 的行索引加上 id_offset 就得到了對(duì)應(yīng)行在 tensor_padded 中的行索引。

3.2 刪除填充值

以形狀為 [batch_size, seq_len, hidden_units] 的 tensor_padded 矩陣為例,如果計(jì)算過程只是在最后一個(gè)維度,比如右乘一個(gè)形如 [hidden_uints, new_units] 的矩陣,那其實(shí)完全可以去掉 tensor_padded 中的填充值之后再計(jì)算。這里官方提供了兩個(gè)核函數(shù)進(jìn)行刪除填充值的操作:第一個(gè)就是單純的刪除填充值函數(shù) remove_sequence_length_padding_kernelLauncher,第二個(gè)是和 transpose 操作融合后的 transpose_rebuild_padding,讀者看起來可能會(huì)一臉懵逼,為什么刪除函數(shù)命名要用 rebuild?沒錯(cuò),這里應(yīng)該是源碼作者筆誤,導(dǎo)致了掛羊頭賣狗肉的行為。

3.2.1 remove_sequence_length_padding_kernelLauncher

單純地刪除填充值的計(jì)算邏輯很簡(jiǎn)單,就是按兩步,找到行索引,按索引拷貝元素即可,直接看代碼。

template
__global__ void remove_sequence_length_padding(const T* src, T* tgt,
                                              const int* tmp_mask_offset,
                                              int* mask_offset,
                                              const int n)
{
  const int tid = threadIdx.x;
  const int bid = blockIdx.x;
  mask_offset[bid] = tmp_mask_offset[bid];
  const int src_seq_id = bid + mask_offset[bid];
  const int tgt_seq_id = bid;
  for(int i = tid; i < n; i += blockDim.x)
  {
    tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i];
  }
}

template
void remove_sequence_length_padding_kernelLauncher(const T* src, T* tgt, 
                                                  const int* tmp_mask_offset, 
                                                  int* mask_offset,
                                                  const int m, const int n, cudaStream_t stream)
{
  // src: [batch_size*max_seq_len, hidden_dim]
  // tgt: [valid_word_num, hidden_dim]
  // m: valid_word_num
  // n: hidden_dim
  remove_sequence_length_padding<<>>(src, tgt, tmp_mask_offset, mask_offset, n);
}

remove_sequence_length_padding 核函數(shù)的 grid_size 設(shè)置為 valid_word_num,也就是 tensor 的行數(shù),每個(gè) block 處理一行元素,block_size 取 256,每個(gè)線程處理對(duì)應(yīng)行的一個(gè)或多個(gè)元素,步長為 256??梢钥吹?,源矩陣中的行索引 src_seq_id 就等于目標(biāo)矩陣的行索引 tgt_seq_id 加上 offset。

3.2.2 transpose_rebuild_padding

關(guān)于這個(gè)函數(shù)名的問題筆者前面已經(jīng)吐槽過了,到此為止,這并不影響實(shí)際應(yīng)用,不然也沒法通過測(cè)試上線。。。這個(gè)函數(shù)內(nèi)部實(shí)現(xiàn)了兩個(gè)操作:transpose、刪除填充值,下面來看一下代碼。

template
__global__
void transpose_rebuild_padding(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head,
  const int* mask_offset)
{
  // TODO: optimize this kernel? 
  // do remove_sequence_length_padding
  const int tid = threadIdx.x; // batch * seq_len or valid_word_num
  const int bid = blockIdx.x; // head_num * size_per_head

  const int src_batch_id = (bid + mask_offset[bid]) / seq_len;
  const int src_seq_id = (bid + mask_offset[bid]) % seq_len;

  const int dst_seq_id = bid;

  const int head_id = tid / size_per_head;
  const int hidden_id = tid % size_per_head;
  dst[dst_seq_id * head_num * size_per_head + tid] = src[ src_batch_id * head_num * seq_len * size_per_head +
    head_id * seq_len * size_per_head + src_seq_id * size_per_head + hidden_id];
}

transpose_rebuild_padding<<>>(transpose_dst_, dst, 
            batch_size, seq_len, head_num, size_per_head, param_.sequence_id_offset);

函數(shù)體內(nèi)部有兩行關(guān)于 bid 和 tid 的注釋,這個(gè)是源碼作者的注釋,應(yīng)該是寫混了,不用去看,請(qǐng)聽筆者一一解釋。
首先這個(gè)函數(shù)的應(yīng)用場(chǎng)景是在 之后,也就是說這里的源矩陣是轉(zhuǎn)置之后的矩陣,形狀為 [batch_size, head_num, seq_len, size_per_head],計(jì)算目的有兩個(gè):transpose、和刪除填充值。gird_size 設(shè)置為 valid_word_num,block_size 設(shè)置為 head_num * size_per_head,一個(gè) thread 處理一個(gè)元素。函數(shù)內(nèi)部通過偏移量確定源矩陣的 batch_id 和 seq_id,有了索引后,再按照矩陣維度順序把形如 [batch_size, head_num, seq_len, size_per_head] 的源矩陣,轉(zhuǎn)換為形如 [valid_word_num, head_num, size_per_head] 的無填充矩陣。

3.4 恢復(fù)填充值

我們知道,attention 操作通常包括:Q/K/V 線性變換、、transpose 以及 attention out 線性變換等幾個(gè)步驟。其中 中有 2 個(gè) Strided Batched Gemm 操作,這兩個(gè)矩陣乘法是涉及 seq_len 維度的,因?yàn)橐?jì)算 word 與 word 間的相似度以及 scores 的加權(quán)平均值,所以在計(jì)算前要先恢復(fù)填充值矩陣。這里有一點(diǎn)要說明,并不是說這一步計(jì)算必須得有填充值,是因?yàn)橛辛颂畛渲悼梢哉{(diào)用矩陣乘法 API,從而實(shí)現(xiàn)更好的并行化計(jì)算,如果舍棄矩陣乘法寫一個(gè)函數(shù)根據(jù)偏移量逐個(gè) word 計(jì)算,也是可行的,但是性能極差,所以還不如動(dòng)態(tài)恢復(fù)矩陣。
關(guān)于恢復(fù)填充值的操作,源碼提供了兩個(gè) kernel,第一個(gè)就是單純的恢復(fù)填充值函數(shù) rebuild_sequence_length_padding_kernelLauncher,第二個(gè)是和 add bias 操作融合后的 add_QKV_bias_rebuild_padding。

3.4.1 rebuild_sequence_length_padding_kernelLauncher

恢復(fù)填充值的操作和刪除填充值的操作是互逆的,只需要把握一點(diǎn)即可:

template
__global__ void rebuild_sequence_length_padding(const T* src, T* tgt,
                                            const int* mask_offset,
                                            const int n)
{
  const int tid = threadIdx.x;
  const int bid = blockIdx.x;
  const int tgt_seq_id = bid + mask_offset[bid];
  const int src_seq_id = bid;

  for(int i = tid; i < n; i += blockDim.x)
  {
    tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i];
  }
}

template
void rebuild_sequence_length_padding_kernelLauncher(const T* src, T* tgt, 
                                                  const int* mask_offset, const int m, 
                                                  const int n, cudaStream_t stream)
{
  // src: [valid_word_num, hidden_dim]
  // tgt: [batch_size*max_seq_len, hidden_dim]
  rebuild_sequence_length_padding<<>>(src, tgt, mask_offset, n);
}

核函數(shù)執(zhí)行配置參數(shù)與刪除填充值的核函數(shù)一致,都是一個(gè) block 處理一行元素。可以看到,源矩陣中的行索引 tgt_seq_id 就等于目標(biāo)矩陣的行索引 src_seq_id 加上 offset。

3.4.2 add_QKV_bias_rebuild_padding

顧名思義,核函數(shù)內(nèi)部實(shí)現(xiàn)了兩個(gè)計(jì)算邏輯:add Q/K/V bias 和恢復(fù)填充值,我們來看下源碼。

template
__global__
void add_QKV_bias_rebuild_padding(T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, const T* bias_V, T* q_buf_, T* k_buf_, T* v_buf_, 
  const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int* mask_offset)
{
  const int tid = threadIdx.x;
  const int bid = blockIdx.x;
  const int bdim = blockDim.x;

  const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
  const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
  const int tgt_head_id = tid / size_per_head;
  const int tgt_hidden_id = tid % size_per_head;

  const int src_id = bid * bdim + tid;
  const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + 
                    tgt_head_id * seq_len * size_per_head + 
                    tgt_seq_id * size_per_head + 
                    tgt_hidden_id;
  
  q_buf_[tgt_id] = Q[src_id] + bias_Q[tid];
  k_buf_[tgt_id] = K[src_id] + bias_K[tid];
  v_buf_[tgt_id] = V[src_id] + bias_V[tid];
}

核函數(shù) grid_size 設(shè)置為 valid_word_num,一個(gè) block 內(nèi)部處理 hidden_uints = head_num*size_per_head 個(gè)元素。由于模型整體超參數(shù)要求 hidden_uints 不大于 1024,所以這里 block_size 直接就設(shè)置為 hidden_units,一個(gè) thread 就處理一個(gè)元素。
核函數(shù)內(nèi)部最重要的邏輯就是 tgt_id 的計(jì)算,有兩點(diǎn)需要把握。首先是 tgt_batch_id 和 tgt_seq_id 的確定,通過源矩陣的索引加偏移量后計(jì)算得到。然后是 tgt_id 的確定,按照 [batch_size, head_num, seq_len, size_per_head] 的順序計(jì)算即可。
不熟悉 attention 計(jì)算邏輯的讀者可能會(huì)問,為什么就加了一個(gè) add bias 的操作,核函數(shù)變得如此復(fù)雜?這其實(shí)是因?yàn)檫@里面其實(shí)還隱藏了一個(gè) transpose 操作,把原來形狀如 [valid_word_num, head_num, size_per_head] 的矩陣轉(zhuǎn)換成了形狀如 [batch_size, head_num, seq_len, size_per_head] 的矩陣,筆者在前面的文章說過,這是“多頭”獨(dú)立計(jì)算的邏輯使然。

4 采樣解碼

4.1 采樣解碼原理

關(guān)于解碼策略,筆者在上一篇文章中介紹了貪心搜索(greedy search)和束搜索(beam search)兩種方法,這兩種方法統(tǒng)稱為基于搜索的解碼策略,其解碼目標(biāo)都是最大化生成概率,即高概率的 word 相比于低概率 word 有壓倒性優(yōu)勢(shì),在解碼過程中低概率的 word 絕不可能被選中。
除了基于搜索的解碼策略以外,基于概率采樣的解碼策略也被廣泛應(yīng)用,相比基于搜索的解碼方法,通過采樣生成的文本通常具有更高的多樣性,同時(shí)也在一定程度上緩解了生成通用和重復(fù)文本的問題。常用的基于概率采樣的解碼策略分為以下四種:隨機(jī)采樣帶溫度的隨機(jī)采樣、Top-k 采樣、Top-p 采樣

4.1.1 隨機(jī)采樣

在解碼時(shí)每個(gè) step 都從當(dāng)前概率分布 中按照概率隨機(jī)采樣一個(gè)詞,即。
相比于按概率“掐尖”,這樣會(huì)增大所選詞的范圍,引入更多的隨機(jī)性。這個(gè)方法是谷歌開放式聊天機(jī)器人 Meena[DialoGPT、Meena] 采用的方式。當(dāng)時(shí)那篇論文的結(jié)論就是這種隨機(jī)采樣的方法遠(yuǎn)好于 Beam Search。但這種隨機(jī)采樣也是有局限性的,容易產(chǎn)生上下文無關(guān)前后不一致的問題。而在開放閑聊領(lǐng)域,生成文本的長度都比較短,這種問題就被自然的淡化了。

4.1.2 帶溫度的隨機(jī)采樣

盡管隨機(jī)采樣在一定程度上能避免生成重復(fù)的文本,但是,由于從整個(gè)詞表中采樣可能會(huì)采到與上下文無關(guān)的詞,因此,隨機(jī)采樣得到的文本上下文常常不連貫。為了使得模型盡可能避免采樣到低概率的詞,一個(gè)有效的辦法是設(shè)置一個(gè)名為“溫度”(temperature)的參數(shù)來控制概率分布的彌散程度,該參數(shù)用 表示,是一個(gè)大于 的實(shí)數(shù)。形式化地說,生成過程中需要將概率分布的計(jì)算方式修改為:

當(dāng) 時(shí),即為原始的概率分布;當(dāng) 時(shí),得到的概率分布將更加尖銳,彌散程度更小,采樣的隨機(jī)性降低;當(dāng)時(shí),使用隨機(jī)采樣解碼的效果近似于貪心搜素;當(dāng) 時(shí),得到的概率分布彌散程度更小,采樣的隨機(jī)性升高;當(dāng) 時(shí),使用隨機(jī)采樣解碼的效果則近似于從均勻分布中隨機(jī)采樣。因此,合理設(shè)置 可以避免隨機(jī)采到概率較小的詞。

4.1.3 Top-k 采樣

除了設(shè)置溫度來調(diào)整概率分布的彌散程度,Top-k 采樣近來也被廣泛使用。具體來說,在每個(gè) step,解碼器首先選擇概率最高的 k 個(gè) word 作為候選 word 構(gòu)成一個(gè)集合,然后將這個(gè)子集中 word 的概率再歸一化,最后從新的概率分布中采樣。這個(gè)辦法據(jù)說可以獲得比 Beam Search 好很多的效果,但也有一個(gè)問題,就是這個(gè) k 值不太好選。因?yàn)閷?shí)際應(yīng)用中概率分布變化比較大,有時(shí)候可能很均勻,有的時(shí)候比較集中。對(duì)于集中的情況還好說,當(dāng)分布均勻時(shí),一個(gè)較小的 k 容易丟掉很多優(yōu)質(zhì)候選詞。但如果 k 定的太大,這個(gè)方法又會(huì)退化回普通采樣。

4.1.4 Top-p 采樣

相比于 Top-k 方法從概率最高的 k 個(gè)候選詞中采樣,它不再取一個(gè)固定的 k,而是固定候選集合的概率密度和在整個(gè)概率分布中的比例。也就是構(gòu)造一個(gè)最小候選集,使得

Top-p 采樣根據(jù)生成概率從高到低在詞表上選擇累積概率恰好超過 的候選 word 作為采樣集合,再從這些候選 word 中采樣出最終的結(jié)果。選出來這個(gè)集合之后也和 Top-k 采樣一樣,重新歸一化集合內(nèi) word 的概率,并把集合外詞的概率設(shè)為。

4.2 調(diào)用鏈

sampling decoding 模塊的核心邏輯封裝在 decoding_sampling.h 文件的 DecodingSampling 類中,計(jì)算邏輯都在 forward 函數(shù)里,具體調(diào)用鏈如下:

DecodingSampling->DecodingSampling()  // 構(gòu)造函數(shù)
DecodingSampling->forward()
    ->init_kernelLauncher
    ->loop for step
        ->embedding_lookup_sine_position_encoding_kernel_launcher
        ->loop for decoder layer
            ->decoder->initialize
            ->decoder->forward
        ->decoder->decoder_norm1
        ->cublasGemmEx
        ->sampling_kernel_kernelLauncher

4.3 DecodingSampling 構(gòu)造函數(shù)

構(gòu)造函數(shù)內(nèi)部首先進(jìn)行了 candidate_num_ 和 probability_threshold_ 的判斷,不能同時(shí)為 0 或同時(shí)不為 0,這兩個(gè)參數(shù)分別代表 Top-k 采樣的 k 和 Top-p 采樣的 p,意思是源碼提供了兩種采樣解碼策略,在初始化的時(shí)候必須確定使用哪一種。
接下來就是一些內(nèi)存分配的工作,和 v2.0 版本基本一致,筆者根據(jù)源碼繪制了一份內(nèi)存分布圖如下。5c5b78c8-5698-11ee-939d-92fbcf53809c.png首先在構(gòu)造函數(shù)內(nèi)部初始化了 2 個(gè)二級(jí)指針 K_cache_ 和 V_cache,這個(gè)指針是用來存儲(chǔ)解碼過程中每個(gè) step 對(duì)應(yīng)的輸入 tensor 經(jīng)過 Dense 變換后的 key 矩陣和 value 矩陣,用于 self-attention 的??梢钥吹竭@兩個(gè)指針申請(qǐng)的內(nèi)存大小和之前 v2.0 版本 DecodingOpenNMT 類中有所不同,DecodingOpenNMT 中有兩個(gè)元素,DecodingSampling 中只有一個(gè),這是因?yàn)?DecodingOpenNMT 的解碼策略只有一個(gè) beam search,beam search 每輪結(jié)束后取 TopK 的時(shí)候會(huì)打亂順序,需要一個(gè)元素暫存當(dāng)前每個(gè) beam 的 Key 和 Value,等 TopK 確定后再根據(jù) parent_ids 更新 Key 和 Value。而使用采樣解碼策略就不存在這個(gè)問題,所以一個(gè)元素足矣。

K_cache_ = new DataType_ *[1];
V_cache_ = new DataType_ *[1];

然后就是一系列 buffer size 的計(jì)算,用于內(nèi)存申請(qǐng)和分配的,結(jié)合筆者整理的內(nèi)存分布圖可以非常容易的理解。

4.4 forward 函數(shù)

這個(gè)函數(shù)的計(jì)算邏輯過于復(fù)雜,不適合單列一節(jié),大致過程見調(diào)用鏈,當(dāng)筆者把 forward 內(nèi)部調(diào)用鏈中子模塊講清楚的時(shí)候,forward 也就清晰了。

4.5 init_kernelLauncher

關(guān)于初始化函數(shù),源碼針對(duì) Top-k 和 Top-p 兩種不同的解碼策略分別給了一個(gè)核函數(shù),我們分別來研究一下。

if(args_.candidate_num_ != 0)
{
    init_kernelLauncher(finished_buf_, decoding_params.sequence_length, word_ids_buf_, cum_log_buf_,
        args_.start_id_, args_.batch_size_, 1, decoding_params.stream);
}
else if(args_.probability_threshold_ != 0.0)
{
    topp_initialization_kernelLauncher(finished_buf_,
                                        decoding_params.sequence_length, 
                                        word_ids_buf_,
                                        topp_id_vals_buf_,
                                        topp_offset_buf_,
                                        args_,
                                        decoding_params.stream);
}

4.5.1 Top-k 采樣初始化

Top-k 采樣初始化時(shí)依然調(diào)用的是 v2.0 版本的 DecodingOpenNMT 中的 init_kernel 核函數(shù),只是把 beam_width 設(shè)置為 1 表示不使用 beam search,這個(gè)函數(shù)主要實(shí)現(xiàn)以下幾個(gè)功能:

decoding_params.sequence_length 初始化為 0

finished_buf_ 初始化為 false

word_ids 初始化為 start_id

cum_log_probs 初始化為 0

4.5.2 Top-p 采樣初始化

Top-p 采樣初始化主要做了以下工作:

decoding_params.sequence_length 初始化為 0

finished_buf_ 初始化為 false

word_ids 初始化為 start_id

topp_offset_buf 初始化為 [0, vocab_size, ..., batch_size * vocab_size]

topp_id_val_buf 初始化為 [[0, 1, ..., vocab_size-1], [0, 1, ..., vocab_size-1], ..., [0, 1, ..., vocab_size-1]],其實(shí)就是 batch_size 個(gè)索引向量。

/**
* @brief top-p 初始化
* 
* @param finished                  [batch_size,]
* @param sequence_length           [batch_size,]
* @param word_ids                  [batch_size,]
* @param topp_id_val_buf           [batch_size, vocab_size]
* @param topp_offset_buf           [batch_size + 1 向上取 4 的倍數(shù), ]
* @param batch_size 
* @param vocab_size 
* @param start_id 
* @return __global__ 
*/
__global__ void topp_initialization_kernel(bool* finished,
                                        int* sequence_length, 
                                        int* word_ids,
                                        int* topp_id_val_buf,
                                        int* topp_offset_buf,
                                        const int batch_size, 
                                        const int vocab_size,
                                        const int start_id)
{
    int tid = threadIdx.x;
    int bid = blockIdx.x;

    if(bid == 0)
    {
        for(int i = tid; i < batch_size + 1; i+= blockDim.x)
        {
            topp_offset_buf[i] = i * vocab_size;
        }
        
        for(int i = tid; i < batch_size; i+= blockDim.x)
        {
            finished[i] = false;
            sequence_length[i] = 0;
            word_ids[i] = start_id; 
        }
    }

    int index = tid + bid * blockDim.x;
    while(index < batch_size * vocab_size)
    {
        topp_id_val_buf[index] = index % vocab_size;
        index += blockDim.x * gridDim.x;
    }
}

topp_initialization_kernel<<<32, 512, 0, stream>>>(finished, sequence_length, word_ids, 
                                                    topp_id_val_buf, topp_offset_buf,
                                                    args.batch_size_, args.vocab_size_,
                                                    args.start_id_);

4.6 embedding_lookup_sine_position_encoding_kernel_launcher

該核函數(shù)做了兩項(xiàng)工作:詞嵌入(embedding lookup)、位置嵌入(sine_position),在 v2.0 版本這倆功能是通過兩個(gè)核函數(shù)實(shí)現(xiàn)的,v2.1 版本把這兩個(gè)核函數(shù)進(jìn)行了融合,這塊內(nèi)容本來筆者是計(jì)劃放在第 5 節(jié)來介紹的,但是為了保證采樣模塊的完整性,就先在這里說了。

  template 
  __global__ void embedding_lookup_sine_position_encoding_kernel(T* from_tensor,
                                                                const T* embedding_table, 
                                                                const T* position_encoding,
                                                                const int* word_ids,
                                                                const int hidden_units)
  {
      const int tid = threadIdx.x;
      const int bid = blockIdx.x;
      const int write_pos = tid + bid * blockDim.x;
      // 1. lookup the table
      // 2. multiply hidden_dim**0.5
      // 3. add the position encoding
      from_tensor[write_pos] = embedding_table[word_ids[bid] * hidden_units + tid] * 
                                (T)sqrtf(float(hidden_units)) + position_encoding[tid];
  }

  template 
  void embedding_lookup_sine_position_encoding_kernel_launcher(T* from_tensor,
                                                              const T* embedding_table, 
                                                              const T* position_encoding,
                                                              const int* word_ids,
                                                              const int batch_size,
                                                              const int hidden_units, 
                                                              cudaStream_t stream)
  {
      assert(hidden_units <= 1024);
      dim3 grid(batch_size);
      dim3 block(hidden_units);
      embedding_lookup_sine_position_encoding_kernel<<>>(from_tensor,
                                                                                  embedding_table,
                                                                                  position_encoding,
                                                                                  word_ids,
                                                                                  hidden_units);
  }

核函數(shù)的 grid_size 和 block_size 分別設(shè)置為 batch_size 和 hidden_units,在函數(shù)內(nèi)部做了以下三件事:

根據(jù) embedding_table 查表賦值,把 word_id 轉(zhuǎn)化為詞向量

詞向量的值乘以一個(gè)修正系數(shù) hidden_uints ** 0.5,達(dá)到縮放效果,這一步在 v2.0 版本中是在 sine_position_encoder_kernel 核函數(shù)中進(jìn)行的。

加上位置編碼向量,在 v2.0 版本的 sine_position_encoder_kernel 中是直接計(jì)算了一個(gè)位置編碼值加上去的,但是這里為了節(jié)省計(jì)算時(shí)間,直接通過查表實(shí)現(xiàn),這就要求函數(shù)入?yún)⒌臅r(shí)候把提前計(jì)算好的 position_encoding 傳進(jìn)來,這種做法其實(shí)挺好的,因?yàn)楸旧砦恢镁幋a就是與數(shù)據(jù)無關(guān)的,完全可以提前算好,缺點(diǎn)就是會(huì)增加內(nèi)存占用,用空間換時(shí)間。

4.7 Top-k 采樣解碼

前面說過 Top-k 采樣解碼是先選取概率最大的 k 個(gè) word 再進(jìn)行采樣,所以需要先計(jì)算概率,計(jì)算概率必然要先根據(jù) logits 值計(jì)算 Softmax,但是我們知道 Softmax 函數(shù)是單調(diào)的,其實(shí)就相當(dāng)于一個(gè)指數(shù)映射后的歸一化操作。那既然是單調(diào)函數(shù),我們完全可以直接根據(jù) logits 直接選出 TopK,然后再計(jì)算 Softmax,這樣可以把 Softmax 的規(guī)模從 vocab_size 縮減到 k,這是一個(gè)非??捎^的縮減量。

4.7.1 update_logits_without_softmax

這個(gè)核函數(shù)完成了 logits 的 add bias 操作,其實(shí)是 decoder out 的線性變換的內(nèi)容,前面只進(jìn)行了矩陣乘法,在這個(gè)核函數(shù)中把偏置項(xiàng)加上,另外核函數(shù)內(nèi)部還加了一個(gè)停止符判斷。

template 
__global__ void update_logits_kernel_without_softmax(T* logits, const T* bias, const int end_id, const bool* finished, const int n)
{
  int bid = blockIdx.x;
  bool finish = finished[bid];
  int offset = bid * n;

  for(int tid = threadIdx.x; tid < n; tid += blockDim.x)
  {
    if(finish)
      logits[offset + tid] = (tid == end_id) ? FLT_MAX : -1 * FLT_MAX;
    else
      logits[offset + tid] += bias[tid];
  }
}

void update_logits_without_softmax(float* logits, const float* bias, const int end_id, const bool* finished, 
  const int m, const int n, cudaStream_t stream)
{
  dim3 grid(m);
  dim3 block(min(n, 1024));
  /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */
  update_logits_kernel_without_softmax<<>>(logits, bias, end_id, finished, n);
}

這里加完偏置項(xiàng)就可以直接用于 TopK 采樣了。

4.7.2 topK_sampling_kernel_kernelLauncher

根據(jù) topK_sampling_kernel_kernelLauncher 函數(shù)邏輯可以看出,采樣過程由兩個(gè)核函數(shù)完成:beam_topK_kernel 和 sampling。函數(shù)內(nèi)部首先判斷了 candidate_num 的值,貌似目前只支持 1、2、4 三種情況,這里源碼為什么要用宏的模式,因?yàn)榫幾g期要對(duì)模板進(jìn)行實(shí)例化,要求 K(candidate_num) 在編譯期就得確定,而源碼中的 candidate_num 顯然是一個(gè)運(yùn)行期才確定的參數(shù),所以只好犧牲編譯期,多實(shí)例化幾個(gè)模板(如 1、2、4,分別對(duì)應(yīng) 1 個(gè)函數(shù)),等到運(yùn)行期的時(shí)候匹配真實(shí)的 candidate_num,去執(zhí)行對(duì)應(yīng)的模板函數(shù)。

#define CASE_K(K) 
  case K : 
    beam_topK_kernel<<>>(log_probs, 
        topk_tmp_id_buf, topk_tmp_val_buf, vocab_size, 0.0f); 
  break; 

template 
void topK_sampling_kernel_kernelLauncher(T* log_probs,
                                        int* topk_tmp_id_buf,
                                        T* topk_tmp_val_buf,
                                        int* ids,
                                        int* sequence_length,
                                        bool* finished_buf,
                                        int random_num,
                                        DecodingSamplingArguments args,
                                        cudaStream_t stream)
{
    const int batch_size = args.batch_size_;
    const int vocab_size = args.vocab_size_;
    const int candidate_num = args.candidate_num_;
    const int end_id = args.end_id_;
    const int block_size = 256;
    switch(candidate_num)
    {
        CASE_K(1);
        CASE_K(2);
        CASE_K(4);
        default:
            printf("[ERROR] Topk kernel does not support candidate_num = %d 
", candidate_num);
            exit(0);
            break;
    }
    sampling <<< batch_size, candidate_num, 0, stream>>> (topk_tmp_id_buf, topk_tmp_val_buf, 
                                                            ids, sequence_length, finished_buf,
                                                            candidate_num, random_num, end_id, vocab_size);
}

4.7.2.1 求TopK

在 v2.0 版本的 beam search 中也有求 TopK 的操作,不過當(dāng)時(shí)那個(gè)計(jì)算思路就很粗糙,簡(jiǎn)單粗暴,總共分為兩個(gè) kernel,在第一個(gè) kernel 里面,先是用塊內(nèi)規(guī)約求出當(dāng)前線程對(duì)應(yīng)值的最大值,把最大值存起來,然后變量賦值為極小值,然后線程內(nèi)部直接循環(huán) K 次,最后獲得了 grid_size 個(gè) TopK,然后再第二個(gè) kernel 中把這個(gè)范圍再縮小到 TopK??梢钥吹竭@是一種 native 的求 TopK 思路,在 v2.1 版本,求 TopK 的思路有所優(yōu)化。
TopK 問題是一個(gè)經(jīng)典算法問題,通常我們通過維護(hù)一個(gè)小根堆,堆里存了 K 個(gè)數(shù)據(jù),每次新數(shù)據(jù)跟堆頂數(shù)據(jù)比較,大于堆頂元素就替換掉堆頂元素,然后重新建堆,遍歷完所有元素后,堆中元素就是 TopK。這里源碼中也使用了這個(gè)思路,但是并沒有使用堆結(jié)構(gòu),而是定義了一個(gè)結(jié)構(gòu)體 TopK,應(yīng)該是作者嫌麻煩,因?yàn)?K 實(shí)在太小,就不折騰了,我們來看一下這個(gè)結(jié)構(gòu)體。

template
struct TopK
{
    int p[MAX_K];
    T u[MAX_K];

    __device__ __forceinline__ void insert(T elem, int elem_id)
    {
        // 把插入元素跟最后一個(gè)元素比較,如果插入元素更大,則替換掉最后一個(gè)元素
        if (elem > u[MAX_K-1] || (p[MAX_K-1] == -1) || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1])))
        //if (elem > u[MAX_K-1] || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1])))
        {
            u[MAX_K-1] = elem;
            p[MAX_K-1] = elem_id;
        }
        // 冒泡排序,把 TopK 中的元素進(jìn)行排序
        for(int k = MAX_K - 2; k >= 0; --k)
        {
            if ((u[k+1] > u[k]) || (p[k] == -1) || ((u[k+1] == u[k])&&(p[k+1] < p[k])))
            //if ((u[k+1] > u[k]) || ((u[k+1] == u[k])&&(p[k+1] < p[k])))
            {
                T u2 = u[k];
                int p2 = p[k]; 
                u[k] = u[k+1];
                p[k] = p[k+1];
                u[k+1] = u2;
                p[k+1] = p2;
            }
        }
    }

    __device__ __forceinline__ void init()
    {
      #pragma unroll
      for(int i = 0; i < MAX_K; i++)
      {
        p[i] = -1;
        u[i] = -FLT_MAX;
      }
    }
};

可以看到,結(jié)構(gòu)體中有兩個(gè)長度為 MAX_K 的數(shù)組變量,p 用來存索引,u 用來存值,一一對(duì)應(yīng)并按值降序排列。為啥弄兩個(gè)數(shù)組?是因?yàn)檫@里我們還需要元素的位置,也就是 word_id,這兩個(gè)數(shù)組同步更新。除了成員變量以外還有兩個(gè)成員函數(shù),一個(gè)是初始化函數(shù) init 主要用來初始化 p 和 u,另一個(gè)是 insert 函數(shù)用來“插入元素”和“建堆”。insert 函數(shù)中首先比較最后一個(gè)元素和新插入元素,滿足以下任意條件后,將用新插入的元素替換掉 TopK 中最后一個(gè)元素。

插入元素大于最后一個(gè)元素

最后一個(gè)元素是初始化的標(biāo)識(shí),也就是數(shù)組沒有滿

插入元素等于最后一個(gè)元素,但是插入元素的索引更小

插入元素后,還得“建堆”保證堆頂元素最小,前面說過這里直接用排序代替“建堆”,所以源碼就提供了一個(gè)冒泡排序,排序完成后,數(shù)組中的元素恢復(fù)降序排列。
TopK 結(jié)構(gòu)介紹完之后,下面就是如何使用 TopK 結(jié)構(gòu)完成對(duì) logits 的求 TopK 操作。源碼中使用 beam_topK_kernel 核函數(shù)來求 TopK,grid_size 設(shè)置為 batch_size,block_size 設(shè)置為 256,也就是說一個(gè) block 內(nèi)要處理 vocab_size 個(gè)元素,從中選出 TopK,每個(gè)線程處理 vocab_size / 256 個(gè)元素,步長為 256。

template
__launch_bounds__(THREADBLOCK_SIZE)
__global__
void beam_topK_kernel(const T* log_probs, 
                        int* topk_tmp_id_buf,
                        T* topk_tmp_val_buf,
                        const int vocab_size,
                        T diversity_rate)
{
    typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;

    int thread_id = threadIdx.x;
    int block_id = blockIdx.x;
    TopK partial;
    
    #pragma unroll
    for(int i = 0; i < MAX_K; ++i)
    {
        partial.p[i] = -1;
        partial.u[i] = -FLT_MAX;
    }

    #pragma unroll
    for(int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE)
    {
        int index = elem_id + block_id * vocab_size;        
        partial.insert(log_probs[index], index);
    }

    TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op);

    if (thread_id == 0)
    {
        int index = block_id * MAX_K;
        
        #pragma unroll
        for(int i = 0; i < MAX_K; ++i)
        {
            topk_tmp_id_buf[index + i] = total.p[i];
            topk_tmp_val_buf[index + i] = total.u[i] + diversity_rate * (T)i;
        }
    }
}

核函數(shù)內(nèi)部首先使用 cub 庫進(jìn)行了塊內(nèi)規(guī)約前的準(zhǔn)備,這個(gè)我們暫且不去看,之后內(nèi)部定義了一個(gè)寄存器變量 partial,partial 存儲(chǔ)了當(dāng)前線程處理元素的 TopK,相當(dāng)于當(dāng)前線程下的小根堆,隨后對(duì) partial 進(jìn)行初始化,這塊其實(shí)可以直接調(diào)用成員函數(shù) init 的,但是作者估計(jì)忘記還有這個(gè)函數(shù)了就又手寫了一遍。然后就是對(duì)當(dāng)前線程待處理的元素進(jìn)行遍歷,讓 partial 來 insert 待處理元素,全部 insert 一遍后的 partial 其實(shí)就存儲(chǔ)了當(dāng)前線程處理的所有元素的 TopK。但是我們的目標(biāo)是要獲取整個(gè) block 內(nèi)的全局 TopK,所以我們還需要進(jìn)行一次“大合并”,把所有的 TopK 合并成一個(gè),這實(shí)際相當(dāng)于一次塊內(nèi)規(guī)約操作,只是我們還需要定義一個(gè)操作函數(shù),顯然這個(gè)操作函數(shù)的輸入是兩個(gè) TopK 類型的變量,輸出是 TopK 類型,其計(jì)算邏輯就是把兩個(gè) TopK 合并成 1 個(gè) TopK。源碼提供了一個(gè) reduce_topk_op 函數(shù)來完成這個(gè)任務(wù)。

template
__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, const TopK& b)
{
    TopK res = a;
    for(int i = 0; i < MAX_K; ++i)
        res.insert(b.u[i], b.p[i]);
    return res;
}

可以看到,reduce_topk_op 是通過遍歷一個(gè) TopK 變量 b 的元素,不斷 insert 到另一個(gè) TopK 變量 a 的拷貝 res 中實(shí)現(xiàn)的合并工作。
有了操作函數(shù)以后,直接調(diào)用 cub 庫的塊內(nèi)規(guī)約 API 就完成了塊內(nèi)規(guī)約,獲取了整個(gè) block 內(nèi)的全局 TopK total。當(dāng) thread_id == 0 時(shí),把這 k 個(gè)元素對(duì)應(yīng)的 logit 和 word_id 寫入 topk_tmp_val_buf 和 topk_tmp_id_buf 中。這里還有個(gè) diversity_rate 參數(shù),這應(yīng)該是一個(gè)修正系數(shù),但是筆者發(fā)現(xiàn)源碼中實(shí)際設(shè)置為 0.0f 并沒有啟用。

4.7.2.2 采樣

前面介紹過采樣原理,獲取 TopK 之后,計(jì)算每個(gè) word 的概率,然后在 TopK 中歸一化,最后根據(jù)歸一化后的概率采樣。其實(shí)就是先 Softmax 后采樣,我們來看一下源碼。

// Sampling kernels
template
__global__ void sampling(int* topk_tmp_id_buf, 
                        T* topk_tmp_val_buf, 
                        int* ids, 
                        int* sequence_length, 
                        bool* finished_buf,
                        const int candidate_num, 
                        int random_num,
                        const int end_id,
                        const int vocab_size)
{
    int tid = threadIdx.x;
    int bid = blockIdx.x;
    __shared__ T sum;
    __shared__ T rand_num;

    if(tid < candidate_num)
    {
        T max_val = topk_tmp_val_buf[bid * candidate_num];
        topk_tmp_val_buf[bid * candidate_num + tid] = __expf(topk_tmp_val_buf[bid * candidate_num + tid] - max_val);
    }
    
    if(tid == 0)
    {
        sum = 0.0f;
        for(int i = 0; i < candidate_num; i++)
        {
            sum = sum + topk_tmp_val_buf[bid * candidate_num + i];
        }
        
        curandState_t local_state;
        curand_init((T)random_num, bid, 0, &local_state);
        rand_num = (T)curand_uniform(&local_state) * sum;

        ids[bid] = topk_tmp_id_buf[bid * candidate_num + candidate_num - 1] % vocab_size;
        for(int i = 0; i < candidate_num; i++)
        {
            rand_num = rand_num - topk_tmp_val_buf[bid * candidate_num + i];
            if(rand_num <= 0.0f){
                ids[bid] = topk_tmp_id_buf[bid * candidate_num + i] % vocab_size;
                break;
            }
        }

        sequence_length[bid] = finished_buf[bid] ? sequence_length[bid] : sequence_length[bid] + 1;
        finished_buf[bid] = ids[bid] == end_id ? 1 : 0;
    }
}

核函數(shù)中 grid_size 和 block_size 分別設(shè)置為 batch_size 和 candidate_num,當(dāng)前線程就只處理對(duì)應(yīng)一個(gè)元素,先根據(jù)索引從 topk_tmp_val_buf 中獲取 TopK 中的最大值,然后讓當(dāng)前元素減去最大值然后求指數(shù),再存入 topk_tmp_val_buf。在 0 號(hào)線程內(nèi)循環(huán)求規(guī)約和,得到 sum,這時(shí)候其實(shí)已經(jīng)可以開始采樣了,沒有必要非得歸一化。源碼中調(diào)用 cuda 隨機(jī)數(shù)生成庫的 API 從均勻分布中隨機(jī)一個(gè) 0~1 之間的數(shù)再乘以 sum,得到一個(gè) 0~sum 之間的數(shù) rand_num,要知道 TopK 中各元素是降序排列的,我可以把他當(dāng)成 k 個(gè)相互連接的組合線段記作(其中每個(gè)子線段記作),把 rand_num 當(dāng)成一根長度為 rand_num 的線段記作,并將其與 的最左側(cè)對(duì)齊,那么 的右端點(diǎn)落在 的哪個(gè)子線段中就認(rèn)為采樣選中了哪個(gè) word,筆者給出如下示意圖。
5c6db13c-5698-11ee-939d-92fbcf53809c.png

隨后根據(jù)采樣選中的 word_id 對(duì) sequence_length 和 finished_buf 進(jìn)行更新,至此當(dāng)前 step 的采樣解碼就完成了。

4.8 Top-p 采樣解碼

Top-p 采樣與 Top_k 采樣有所不同,不再從固定 k 個(gè)候選詞中采樣,而是根據(jù)生成概率從高到低在詞表上選擇累積概率恰好超過 的候選 word 作為采樣集合,從這個(gè)集合中采樣。所以采樣前必須先計(jì)算每個(gè) word 對(duì)應(yīng)的概率并進(jìn)行排序,即要計(jì)算 Softmax,再按概率值排序。

4.8.1 update_logits_without_log

源碼中使用 update_logits_kernel_without_log 核函數(shù)來計(jì)算 Softmax,順帶加一個(gè)上一步?jīng)]進(jìn)行的 add bias 操作。這個(gè)核函數(shù)比較簡(jiǎn)單是個(gè)老生常談的 Softmax kernel,只需要注意一點(diǎn),計(jì)算完 softmax 后不要取對(duì)數(shù)即可,具體計(jì)算邏輯筆者就不啰嗦了,讀者有興趣可以看筆者前面的文章。

4.8.2 topP_sampling_kernel_kernelLauncher

Softmax 后拿到詞表內(nèi)每個(gè) word 的概率,在進(jìn)行采樣前還要進(jìn)行排序。

4.8.2.1 排序

這里排序是個(gè)大工程,因?yàn)?vocab_size 通常會(huì)很大,源碼中使用了 cub 庫中的 API 進(jìn)行排序。

template
void topP_sampling_kernel_kernelLauncher(const T* log_probs,
                                        const int* id_vals,
                                        T* sorted_log_probs,
                                        int* sorted_id_vals, 
                                        int* topp_offset_buf,
                                        void* temp_storage,
                                        bool* finished_buf,
                                        int step,
                                        DecodingSamplingArguments args,
                                        int* output_ids, 
                                        int* sequence_length, 
                                        cudaStream_t stream)
{
    // sort_kernel<<>>(log_probs, 
    //                                             id_vals,
    //                                             sorted_log_probs,
    //                                             sorted_id_vals,
    //                                             vocab_size);
    cub::SortPairsDescending(temp_storage, 
                                                        args.temp_storage_size_,
                                                        log_probs, 
                                                        sorted_log_probs,
                                                        id_vals, 
                                                        sorted_id_vals, 
                                                        args.vocab_size_ * args.batch_size_,
                                                        args.batch_size_, 
                                                        topp_offset_buf, topp_offset_buf + 1);
                                                        
    top_p_sampling<<<1, args.batch_size_, 0, stream>>>(sorted_log_probs, 
                                                        sorted_id_vals,
                                                        output_ids + (step - 1) * args.batch_size_,
                                                        sequence_length,
                                                        finished_buf,
                                                        args.vocab_size_, 
                                                        step,
                                                        args.probability_threshold_,
                                                        args.end_id_);
}

下面我們對(duì) cub::SortPairsDescending 函數(shù)的主要參數(shù)進(jìn)行介紹:

d_temp_storage:設(shè)備可以訪問的臨時(shí)內(nèi)存,當(dāng)設(shè)置為 NULL 時(shí),所需的分配大小將寫入 temp_storage_bytes,并且不執(zhí)行任何工作。所以在真正執(zhí)行函數(shù)前,我們需要先傳一下 NULL 獲取 temp_storage_bytes 然后再開始真正的執(zhí)行排序

temp_storage_bytes:臨時(shí)內(nèi)存的大小

d_keys_in:指向排序過程中的比較依據(jù),也就是說排序是根據(jù)這個(gè)指針指向的數(shù)據(jù)的來進(jìn)行的,這里我們將它設(shè)置為概率值 log_probs

d_keys_out:排序后的輸出,這里我們用 sorted_log_probs 來接收

d_values_in:與 key 一一對(duì)應(yīng),這里我們把他設(shè)置為概率值對(duì)應(yīng)的索引 id_vals,其實(shí)就是 word_id

d_values_out:排序后的輸出,這里我們用 sorted_id_vals 來接收

num_items:待排序的元素?cái)?shù)目,這里應(yīng)該是 batch_size * vocab_size

num_segments:待排序的批次,也就是分為多少個(gè)組,這里是對(duì)每個(gè)樣本單獨(dú)排序,所以取 batch_size

d_begin_offsets:每個(gè)分組的起始索引,為了方便 end_offset 的設(shè)置,這個(gè)變量對(duì)應(yīng)的元素?cái)?shù)量通常是 num_segments + 1,前面 num_segments 個(gè)元素都是分組的起始索引,最后一個(gè)元素設(shè)為 num_items,這里我們?cè)O(shè)置為 topp_offset_buf,前面已經(jīng)完成初始化

d_end_offsets:每個(gè)分組的結(jié)束索引,注意這里是“顧頭不顧尾”的模式,所以直接可以設(shè)置為 d_begin_offsets + 1,這里我們?cè)O(shè)置為 topp_offset_buf + 1

參數(shù)意義介紹完畢后,其實(shí)函數(shù)的作用也就清晰了,就是分組降序排序,每一組對(duì)應(yīng) batch 內(nèi)的一個(gè)樣本,也就是 vocab_size 個(gè)元素,最終我們獲取到了 batch 內(nèi)每個(gè)樣本下排序后的待采樣 word 的概率值 sorted_log_probs和 sorted_id_vals。

4.8.2.2 采樣

根據(jù)采樣原理,拿到排序結(jié)果后,我們需要根據(jù) p 值進(jìn)行候選集的確定,然后在候選集的內(nèi)部進(jìn)行采樣。
源碼中提供了核函數(shù) top_p_sampling 進(jìn)行采樣工作,grid_size 設(shè)置為 1,block_size 設(shè)置為 bacth_size,即在一個(gè) block 內(nèi)完成計(jì)算,每個(gè)線程承擔(dān)一個(gè)樣本的計(jì)算任務(wù)。

template
__global__ void top_p_sampling(T* sorted_log_probs, 
                                int* sorted_id_vals,
                                int* ids,
                                int* sequence_length,
                                bool* finished_buf,
                                const int vocab_size,
                                const int random_num,
                                const float prob_threshold, 
                                const int end_id)
{
    int tid = threadIdx.x;
    curandState_t local_state;
    curand_init((T)random_num, tid, 0, &local_state);
    T rand_num = (T)curand_uniform(&local_state) * prob_threshold;
    ids[tid] = sorted_id_vals[vocab_size - 1];

    for(int i = tid * vocab_size; i < tid * vocab_size + vocab_size; i++)
    {
        rand_num = rand_num - sorted_log_probs[i];
        if(rand_num <= 0)
        {
            ids[tid] = sorted_id_vals[i];
            break;
        }
    }

    sequence_length[tid] = finished_buf[tid] ? sequence_length[tid] : sequence_length[tid] + 1;
    finished_buf[tid] = ids[tid] == end_id ? 1 : 0;
}

top_p_sampling<<<1, args.batch_size_, 0, stream>>>(sorted_log_probs, 
                                                  sorted_id_vals,
                                                  output_ids + (step - 1) * args.batch_size_,
                                                  sequence_length,
                                                  finished_buf,
                                                  args.vocab_size_, 
                                                  step,
                                                  args.probability_threshold_,
                                                  args.end_id_);

采樣過程和前面 Top-k 的過程大同小異,有一點(diǎn)區(qū)別就是,不用真的先確定候選集再進(jìn)行采樣,可以直接一步進(jìn)行。先使用 cuda 隨機(jī)數(shù)生成庫的 API 從均勻分布中隨機(jī)一個(gè) 0~1 之間的數(shù)再乘以 p 值(probability_threshold_),這其實(shí)就相當(dāng)于把采樣的概率點(diǎn)縮放到了 p 值范圍內(nèi),然后遍歷 sorted_log_probs 判斷采樣點(diǎn)落在哪個(gè)區(qū)間,就選中了哪個(gè) word,示意圖如下:
5c77af98-5698-11ee-939d-92fbcf53809c.png

采樣完成后把采樣結(jié)果更新到 ids,然后對(duì) sequence_length 和 finished_buf 進(jìn)行更新,至此,當(dāng)前 step 的 Top-p 采樣解碼就完成了。

5 內(nèi)核優(yōu)化

5.1 批量矩陣乘法優(yōu)化

首先我們來看 Encoder 部分的優(yōu)化點(diǎn),Encoder 的計(jì)算較為簡(jiǎn)單,主要集中在 self-attention。在介紹 OpenMultiHeadAttention 之前我們不妨先來看一下其內(nèi)部 buffer 的內(nèi)存分布情況,通過內(nèi)存分布情況的變化,可以看出具體新增和刪減了哪些變量,從這些變量入手可以有助于弄懂具體優(yōu)化邏輯。

buf_ = (DataType_*) allocator_.malloc(sizeof(DataType_) * (buf_size * 7 + qk_buf_size) + sizeof(DataType_*) * 9);
query_buf_ = buf_;
key_buf_ = buf_ + buf_size;
value_buf_ = buf_ + 2 * buf_size;
q_buf_ = buf_ + 3 * buf_size;
k_buf_ = buf_ + 4 * buf_size;
v_buf_ = buf_ + 5 * buf_size;
qk_buf_ = buf_ + 6 * buf_size;
transpose_dst_ = qk_buf_ + qk_buf_size;
qkv_kernel_ = (DataType_**)(transpose_dst_ + buf_size);
qkv_input_ = qkv_kernel_ + 3;
qkv_buf_ = qkv_input_ + 3;

從 OpenMultiHeadAttention 構(gòu)造函數(shù)中可以發(fā)現(xiàn),v2.1 版本多申請(qǐng)了 sizeof(DataType_*) * 9 大小的設(shè)備內(nèi)存,也就是說多了 9 個(gè)二級(jí)指針。從下面內(nèi)存分配可以看出,這 9 個(gè)二級(jí)指針分別分給了 3 個(gè)變量:qkv_kernel_、qkv_input_、qkv_buf_。這三個(gè)變量在 initialize 進(jìn)行初始化,其中 qkv_kernel_ 對(duì)應(yīng) self-attention 操作中對(duì)輸入 tensor 進(jìn)行線性變換的 3 個(gè)權(quán)重參數(shù);qkv_input_ 對(duì)應(yīng) 3 個(gè)輸入 tensor,在 self-attention 中全部都是 from_tenosr;qkv_buf_ 對(duì)應(yīng)的是 3 個(gè) buffer,用于存儲(chǔ) QKV 的中間計(jì)算結(jié)果。

const DataType_* hA[] {param_.self_attention.query_weight.kernel, 
                            param_.self_attention.key_weight.kernel, 
                            param_.self_attention.value_weight.kernel,
                            param_.from_tensor, param_.from_tensor, param_.from_tensor,
                            query_buf_, key_buf_, value_buf_};
cudaMemcpyAsync((void*)qkv_kernel_, hA, sizeof(DataType_*) * 9, cudaMemcpyHostToDevice, param_.stream);

仔細(xì)一想,這三個(gè)新增變量最終指向的數(shù)據(jù)其實(shí)都是前面已經(jīng)存在的變量,為什么要單獨(dú)搞這幾個(gè)二級(jí)指針呢,而且這幾個(gè)變量統(tǒng)統(tǒng)都和 self-attention 相關(guān)?這應(yīng)該是 self-attention 中使用了某個(gè) API,通過這個(gè) API 可以獲得加速效果,其輸入?yún)?shù)要求是二級(jí)指針。
果不其然,forward 函數(shù)中當(dāng) is_fuse_QKV 為 true 時(shí),調(diào)用了 cuBLAS 中的 cublasGemmBatchedEx 函數(shù)進(jìn)行矩陣乘法,該函數(shù)可以對(duì) batch 級(jí)別的矩陣進(jìn)行乘法運(yùn)算,要求輸入的參數(shù)為 Array of pointers to array 也就是二級(jí)指針的形式,這里源碼把 query、key、value 三個(gè)矩陣當(dāng)成一個(gè) batch 內(nèi)的三個(gè)矩陣,使用 API 一次完成 3 個(gè)矩陣的乘法運(yùn)算,相比與原來的先后調(diào)用三次 cublasGemmEx 函數(shù)計(jì)算乘法,節(jié)省了一定的運(yùn)算時(shí)間。

if(is_fuse_QKV == true)
{
  check_cuda_error(cublasGemmBatchedEx(param_.cublas_handle, 
                      CUBLAS_OP_N, CUBLAS_OP_N, 
                      n, m, k, 
                      &alpha, 
                      (const void* const*) qkv_kernel_, AType_, n,
                      (const void* const*) qkv_input_, BType_, k,
                      &beta,
                      (void* const*)qkv_buf_, CType_, n,
                      3, 
                      computeType_,
                      static_cast(cublasAlgo_[3])));
}

關(guān)于批量矩陣乘法這個(gè)優(yōu)化除了 Encoder 以外,在 Decoder 的 self-attention 中也有應(yīng)用,具體各位讀者可以自行閱讀。

5.2 Decoder Attention Opt

Decoder 的兩個(gè)核函數(shù) masked_attention_kernel_opt、 cross_attention_kernel_opt 的優(yōu)化是 decoder 中的主要優(yōu)化內(nèi)容,筆者將以 masked_attention_kernel_opt 為例介紹優(yōu)化技巧,關(guān)于 attention 的原理等內(nèi)容不再贅述。這部分代碼實(shí)現(xiàn)過程說實(shí)話有些繁瑣,會(huì)導(dǎo)致初讀的時(shí)候一臉懵逼,總之優(yōu)化思路就一句話:向量化數(shù)據(jù)訪問提升帶寬

5.2.1 向量化數(shù)據(jù)類型

首先作者定義了一個(gè)數(shù)據(jù)類型 Copy_t,這個(gè)類型的定義過程也比較繁瑣,其內(nèi)存占用的大小是根據(jù) ELEMENTS_PER_WARP_LOAD 動(dòng)態(tài)調(diào)整的,具體代碼如下:

template 
using Copy_half_t =
    typename std::conditional::type
        >::type
    >::type;

template 
using Copy_t = Copy_half_t;

源碼中使用的 std::conditional 是 C++11 引入的類模板,表示的是一種編譯期的分支邏輯,當(dāng)?shù)谝粋€(gè)非類型模板參數(shù)的值為 true 時(shí),type 的類型為第一個(gè)類型模板參數(shù)的類型,為 false 時(shí) type 的類型為第二個(gè)類型模板參數(shù)的值。那么以上代碼的含義就是在一個(gè) warp 內(nèi)處理 ELEMENTS_PER_WARP_LOAD 個(gè) T 類型的元素,Copy_t 類型占用的大小為 sizeof(T) * ELEMENTS_PER_WARP_LOAD / 32。
這里我們假設(shè)數(shù)據(jù)類型 T 以 FP32 為例,ELEMENTS_PER_WARP_LOAD 設(shè)置為 size_per_head 取 64,這樣的話 Copy_t 實(shí)際就是 int2 類型,不要去糾結(jié)為什么是 int2,這里寫 int2 僅僅是因?yàn)樗剂?8 個(gè)字節(jié),寫 float2 等任意占用 8 個(gè)字節(jié)的類型也是一樣的。帶著這個(gè)向量化訪問的思想,我們?cè)賮砜匆幌潞撕瘮?shù) masked_attention_kernel_opt 的代碼,代碼太繁瑣我就不完整貼了,下面只對(duì)主要內(nèi)容進(jìn)行介紹。

typedef Copy_t copy_t;
const int elems_per_thread = size_per_head / WARP_SIZE;

union Access_t
{
  copy_t v;
  T x[elems_per_thread]; // supported size 1,2,4
};
typedef struct Float_n_t
{
  T x[elems_per_thread]; // supported size 1,2,4
} float_n_t;

首先提一下核函數(shù)的 gird_size 和 block_size 分別設(shè)置為 batch_size * head_num 和 256,也就是一個(gè) block 內(nèi)處理一行數(shù)據(jù)(size_per_head個(gè)元素)。在核函數(shù)內(nèi)部定義了一個(gè)類型 copy_t,從模板參數(shù)可以看出,這里是想要一個(gè) Warp 內(nèi)部直接處理 size_per_head 個(gè)元素的,也就是說在這一個(gè) block 內(nèi)一個(gè) warp 就完成了當(dāng)前 step 的計(jì)算任務(wù),其他 warp 在干嘛?后面將會(huì)講到。然后定義了一個(gè)變量 elems_per_thread 表示每個(gè)線程處理的元素?cái)?shù)量。接著定義了一個(gè)聯(lián)合體 Access_t 用來存儲(chǔ) elems_per_thread 個(gè) T 類型的元素,和一個(gè)結(jié)構(gòu)體 Float_n_t 用來存儲(chǔ) n 個(gè) T 類型的元素。

5.2.2 add query bias

核函數(shù)內(nèi)定義了兩個(gè)變量 sq 和 logits,用來存儲(chǔ) attention 的中間結(jié)果。

__shared__ float_n_t sq[block_sz];
__shared__ float logits[1024]; // only use [0 ~ step-1], the step should be smaller than 1024

在 add query bias 之前先計(jì)算了當(dāng)前線程的各類索引以及偏移量 qkv_id,結(jié)合線程網(wǎng)格這個(gè)很好理解,然后根據(jù)偏移量更改了 QKV 相關(guān)的各個(gè)變量的地址方便后續(xù)索引。計(jì)算 add query bias 的過程分為兩步:從 query_buf 和 self_Q_bias 中向量化取值、結(jié)構(gòu)體內(nèi)循環(huán)計(jì)算。

// each warp will have its own copy of sq
query_buf_r.v = *((copy_t *)query_buf + lane_id);
key_buf_r.v = *((copy_t *)key_buf + lane_id);
bias_r.v = *((copy_t *)self_Q_bias + lane_id);
float qb_r[elems_per_thread];
for (int i = 0; i < elems_per_thread; ++i)
{
  qb_r[i] =  (float)query_buf_r.x[i] + (float)bias_r.x[i];
}

可以看到聯(lián)合體 Access_t 中的成員 v,其實(shí)就起到一個(gè)方便占位的作用,當(dāng)然如果沒有的話,筆者認(rèn)為也可以使用下面的方式取值。

Access_t *qbuf = reinterpret_cast(query_buf)
query_buf_r.v = qbuf[lane_id];

5.2.3 add key bias & softmax

我們知道 attention 中 softmax 計(jì)算的對(duì)象是 query 和 key 的乘積,query 我們已經(jīng)拿到了,存在每個(gè) thread 的 qb_r 中。key 需要從 key_cache 中獲取,對(duì)于當(dāng)前 step 而言,query 是固定的,與 from_tensor 對(duì)應(yīng),key 與前面 step 的 from_tensor 也一一對(duì)應(yīng),因此這一步完全是可以并行的,所以作者在這里設(shè)計(jì)成一個(gè) block 內(nèi)總共處理 warp_num 個(gè) step 的計(jì)算,這也回應(yīng)了前面的疑問,明明一個(gè) warp 內(nèi)就能處理一個(gè) step 的計(jì)算,其他 warp 干啥去了。其他 warp 處理其他 step 的計(jì)算去了??偨Y(jié)一下,對(duì)于所有 warp 而言 query 都是一樣的,所以放在寄存器變量 qb_r 中,同時(shí)每個(gè) warp 有各自的 key,通過 ite * offset 計(jì)算偏移量獲取。

//offset for each step
int offset = batch_size * head_num * size_per_head;
bias_r.v = *((copy_t *) self_K_bias + lane_id);
for(int ite = warp_id; ite < step; ite += warp_num)
{
  key_val_r.v = *((copy_t *)&key_cache[ite * offset] + lane_id);
  //for the last step, we should update K + bias_K to the cache
  if(ite == step - 1)
  {
    for (int i = 0; i < elems_per_thread; i++)
    {
      key_val_r.x[i] = (float)key_buf_r.x[i] + (float)bias_r.x[i];
    }
    *((copy_t *)&key_cache[ite * offset] + lane_id) = key_val_r.v;
  }
  float val = 0.f;
  for (int i = 0; i < elems_per_thread; i++)
  {
    val = val +  (float)key_val_r.x[i] * qb_r[i] * (float)scalar;
  }
  float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val);
  if (lane_id == 0)
  {
    logits[ite] = qk; 
  }
}
__syncthreads();

拿到 key 之后將其與 query 相乘然后調(diào)用 cub 庫的束內(nèi)規(guī)約 API 計(jì)算出每個(gè) step 下 query 和 key 的向量相似度也就是 attention scores,根據(jù) step 的索引 ite 將其存入 logits 中。

__shared__ float s_max_val, s_sum;

float local_i = -1e20f;
for(int i = tid; i < step; i += blockDim.x)
  local_i = max(local_i, logits[i]);

float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max());
if(tid == 0)
  s_max_val = max_val;
__syncthreads();


float local_o = 0.0f;
for(int i = tid; i < step; i += blockDim.x)
{
  logits[i] = __expf(logits[i] - s_max_val);
  local_o += logits[i];
}
float val = BlockReduce(block_temp_storage).Sum(local_o);

if(tid == 0)
  s_sum = val + 1e-6;
__syncthreads();

float s_sum_inverse = __fdividef(1.0f, s_sum);
for(int i = tid; i < step; i += blockDim.x)
{
  logits[i] = logits[i] * s_sum_inverse;
}
__syncthreads(); 

計(jì)算 Softmax 的過程比較常規(guī),就是三個(gè)步驟:reduceMax、broadcast、reduceSum、broadcast,沒什么好說的,有疑問的讀者可以閱讀筆者上一篇文章。

5.2.4 計(jì)算 attention

得到 attention scores 后,右乘一個(gè) value 矩陣就得到 attention out,其實(shí)就是算加權(quán)平均值。這里計(jì)算思路也和前面計(jì)算 大致相同,都是一個(gè) block 內(nèi)計(jì)算 warp_num 個(gè) step,計(jì)算完成后 sum_r 中存儲(chǔ)了每個(gè)線程負(fù)責(zé)的 step 的對(duì)應(yīng)元素的加權(quán)和,最終的是需要所有 step 的加權(quán)總和,所以作者把 sum_r 放入共享內(nèi)存變量 sq 中暫存,然后再進(jìn)行兩層循環(huán)把其他線程束計(jì)算的 sum_r 全都加到 warp_id == 0 的線程對(duì)應(yīng)的 sum_r 中,此時(shí) warp_id == 0 的線程束內(nèi)各 thread 中的 sum_r 存儲(chǔ)的即為完成加權(quán)求和之后的 attention out,最后將 attention out 更新到 context_buf_ptr 中完成計(jì)算。

// This optimization introduces discrepancy because of different order in FP32 summation
float sum_r[elems_per_thread] = {0.f};
bias_r.v = *((copy_t *) self_V_bias + lane_id);
value_buf_r.v = *((copy_t *)value_buf + lane_id);

for(int ite = warp_id; ite < step; ite += warp_num)
{
    value_val_r.v = *((copy_t *)&value_cache[ite * offset] + lane_id);
    //for the last step, we should update K + bias_K to the cache
    if(ite == step - 1)
    {
        for (int i = 0; i < elems_per_thread; i++)
        {
            value_val_r.x[i] = (float)value_buf_r.x[i] + (float)bias_r.x[i];
        }
        *((copy_t *)&value_cache[ite * offset] + lane_id) = value_val_r.v;
    }
    for (int i = 0; i < elems_per_thread; ++i)
    {
        sum_r[i] += (float)value_val_r.x[i] * logits[ite]; 
    }
}
for (int i = 0; i < elems_per_thread; i++)
{
    sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i];
}
__syncthreads();
if (warp_id == 0)
{
    #pragma unroll
    for (int j = 1; j < warp_num; j++)
    {
        for (int i = 0; i < elems_per_thread; ++i)
        {
            sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + tid].x[i];
        }
    }
}
__syncthreads();
#pragma unroll
for (int i = 0; i < elems_per_thread; i++)
{
    value_val_r.x[i] = sum_r[i];
}
if (warp_id == 0)
{
    *((copy_t *)context_buf + lane_id) = value_val_r.v;
}

這里有一點(diǎn)需要注意,在把其他線程束計(jì)算的 sum_r 更新到 0 號(hào)線程束內(nèi)時(shí),源碼中使用 sq[j * WARP_SIZE + tid] 進(jìn)行取值,這里容易造成誤解,雖然在 0 號(hào)線程束內(nèi) lane_id 和 tid 的值相等,但是為了便于理解這里建議還是使用 sq[j * WARP_SIZE + lane_id] 取值較好,以免對(duì)不熟悉的讀者造成困擾。

5.3 topK kernel 優(yōu)化

關(guān)于 Top-k 采樣解碼前面已經(jīng)介紹,這里說的 topK 特指 beam search 過程中的求 topK 操作,在 v2.0 版本中,固定設(shè)置 block_size 為 1024,通過第一個(gè) topK kernel 把 ids 的形狀縮小到 [batch_size, grid_size_1st, beam_width],再經(jīng)過第二個(gè) topK kernel 求出最終每個(gè)樣本的 topK。其中關(guān)于 batch_size 維度的每個(gè)樣本的計(jì)算過程是通過循環(huán)實(shí)現(xiàn)的,并行程度不高。在 v2.1 版本,作者更新了 topK kernel,依然通過兩個(gè) kernel (topk_stage_1_opt3 和 topk_stage_2_opt3)完成 topK 計(jì)算。

在 topk_stage_1_opt3 中把 gird_size 設(shè)置為 batch_size * K * BLOCKS_PER_BEAM_,也就是說對(duì)于每一行 vocab_size 個(gè)元素,要使用 BLOCKS_PER_BEAM_ 個(gè) block 參與計(jì)算。

/**
 * @brief 
 * // grid_size = batch_size * K * BLOCKS_PER_BEAM_
 * @tparam T 
 * @tparam BLOCK_SIZE_ 
 * @tparam BLOCKS_PER_BEAM_ 
 * @param log_probs                 [batch_size, beam_width, vocab_size]
 * @param tmp_log_probs             [batch_size, beam_width, vocab_size]
 * @param topk_tmp_id_buf           [batch_size, beam_width, BLOCKS_PER_BEAM_, K]
 * @param topk_tmp_val_buf          [batch_size, beam_width, BLOCKS_PER_BEAM_, K]
 * @param k                          beam_width
 * @param vocab_size 
 * @return __global__ 
 */
template
__global__ void topk_stage_1_opt3(
    const T* __restrict log_probs,
    T* tmp_log_probs,
    int* topk_tmp_id_buf,
    T* topk_tmp_val_buf,
    const int k,    // beam_width
    const int vocab_size
)
{
    typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;
    
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;
    const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs
    const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam 
    const int tmp_log_buf_index = row_id * vocab_size; 
    const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k;
    TopK_2 partial;

    for(int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_)
    {
        int index = elem_id + tmp_log_buf_index;
        tmp_log_probs[index] = log_probs[index]; 
    }


    for(int ite = 0; ite < k; ite++)
    {
        partial.init();
        #pragma unroll
        for(int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_)
        {
            int index = elem_id + tmp_log_buf_index;
            partial.insert(tmp_log_probs[index], index);
        }

        TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2);

        if (tid == 0)
        {
            const int index = tmp_topk_buf_index + ite;
            topk_tmp_id_buf[index] = total.p;
            topk_tmp_val_buf[index] = total.u;
            tmp_log_probs[total.p] = -FLT_MAX;
        }
        __syncthreads();
    }
}

核函數(shù)內(nèi)引入了一個(gè)新的數(shù)據(jù)結(jié)構(gòu) TopK_2,這個(gè)數(shù)據(jù)結(jié)構(gòu)中有兩個(gè)成員屬性,p 和 u,分別代表存儲(chǔ)了概率值和其對(duì)應(yīng)的 word_id;有兩個(gè)成員方法,init 和 insert,分別進(jìn)行初始化和更新,insert 方法非常簡(jiǎn)單,就是單純的把更大的值和索引更新到對(duì)象中。接下來,我們看一下 topk_stage_1_opt3 函數(shù)。

首先計(jì)算了當(dāng)前線程對(duì)應(yīng) log_probs 的各種索引,這個(gè)根據(jù)線程網(wǎng)格不難理解,根據(jù)索引將 log_probs 的值更新到 tmp_log_probs 中,注意這里每個(gè)線程處理元素的步長為 BLOCK_SIZE_ * BLOCKS_PER_BEAM_。

隨后對(duì) K 進(jìn)行循環(huán),在循環(huán)中首先對(duì)該線程處理的所有元素進(jìn)行遍歷,不斷將數(shù)據(jù) insert 到 partial 中,這樣就得到了每個(gè)線程處理元素的最大值,然后再對(duì) partial 進(jìn)行塊內(nèi)規(guī)約,得到每個(gè)線程塊內(nèi)的最大值 total。在 tid == 0 的線程內(nèi)把 total 更新到 topk_tmp_id_buf,再把 tmp_log_probs 中的值置為極小值,循環(huán) K 次上述過程就得到每個(gè)線程塊內(nèi)的 topK,最終一行元素被處理成了 BLOCKS_PER_BEAM_ 個(gè) topK,topk_tmp_val_buf 的形狀為 [batch_size, beam_width, BLOCKS_PER_BEAM_, K]。在第二個(gè) kernel 中我們需要將其縮小到 [batch_size, K],下面來看一下代碼。

/**
 * @brief 
 * // grid_size = batch_size
 * @tparam T 
 * @tparam BLOCK_SIZE_ 
 * @tparam BLOCKS_PER_BEAM_ 
 * @param topk_tmp_id_buf               [batch_size, beam_width, BLOCKS_PER_BEAM_, K]
 * @param topk_tmp_val_buf 
 * @param ids 
 * @param k 
 * @return __global__ 
 */
template
__global__ void topk_stage_2_opt3(
    const int* __restrict topk_tmp_id_buf,
    T* topk_tmp_val_buf,
    int* ids,
    const int k)
{
    const int size = k * k * BLOCKS_PER_BEAM_; 
    const int tid = threadIdx.x;
    const int batch_id = blockIdx.x;

    typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;
    extern __shared__ char array[];
    T *s_val = topk_tmp_val_buf + batch_id * size;
    int *s_id = (int*)(array);
    
    TopK_2 partial;

    for(int ite = 0; ite < k; ite++)
    {
        partial.init();
        #pragma unroll
        for(int i = tid; i < size; i+= BLOCK_SIZE_)
        {
            partial.insert(s_val[i], i);
        }
    
        TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2);
    
        if(tid == 0) 
        {
            s_id[ite] = total.p;
            s_val[total.p] = -FLT_MAX;
        }
        __syncthreads();
    }
    if(tid < k) ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]];
}

topk_stage_2_opt3 的 grid_size 直接設(shè)置為 batch_size,block_size 設(shè)置為 BLOCK_SIZE_,也就是說,我們?cè)谝粋€(gè) block 內(nèi)求出 topK 就可以完成計(jì)算任務(wù)。計(jì)算思路和 topk_stage_2_opt3 大致相同,都是每次求最大值后把原值置小,循環(huán) K 次即可。定義了一個(gè)共享內(nèi)存變量 s_id 用來存儲(chǔ) topK,最終在 tid < k 的線程分別把 s_id 更新到 ids 完成計(jì)算。
總的來說,更新后的 topK kernel 計(jì)算思路更加清晰,便于理解,是一個(gè)較好的思路,但是筆者還是更推薦使用 Top-k 采樣解碼中的思路來計(jì)算 topK 問題,猜測(cè)這兩種解碼方式的代碼不是同一個(gè)作者編寫的,否則完全可以復(fù)用代碼。

6 小結(jié)

總的來說,v2.1 版本的 Faster Transformer 相比與 v2.0 版本細(xì)節(jié)改動(dòng)還是比較多的,但是整體大框架沒有改變,仍然還是 3 個(gè)主要模塊:Encoder、Decoder、Decoding,新增了 Effective Transoformer 和 sample Decoding 兩個(gè)子模塊。現(xiàn)對(duì)本文總結(jié)如下:

新增了 Effective Transoformer,通過動(dòng)態(tài)刪除和恢復(fù)填充值的方式一定程度上可以節(jié)約 Encoder 部分的計(jì)算資源,當(dāng)一個(gè) batch 內(nèi)樣本長度變化越大,性能提升越明顯。在實(shí)際應(yīng)用中,訓(xùn)練模型階段,其實(shí)在處理數(shù)據(jù)時(shí)一般會(huì)刻意地將一個(gè) batch 的文本長度控制在一個(gè)較小的變化范圍,但在推理階段通常不會(huì)這么干,所以這個(gè) Effective Transoformer 就有用武之地了。

在 Top-k 采樣中,源碼給出了一個(gè)很好的求 topK 的思路,值得學(xué)習(xí)借鑒。

在 Top-p 采樣中,源碼示范了如何調(diào)用 cub 庫 API 進(jìn)行分組排序,值得學(xué)習(xí)借鑒。

在計(jì)算 self-attention 過程中,源碼示范了 cuBLAS 庫的 cublasGemmBatchedEx 函數(shù)調(diào)用方法,將三次串行調(diào)用矩陣乘法 API 縮減為 1 次。

在 Decoder 中,源碼首次引入向量化數(shù)據(jù)類型,提升訪存效率。

審核編輯:彭菁

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 模塊
    +關(guān)注

    關(guān)注

    7

    文章

    2706

    瀏覽量

    47468
  • 源碼
    +關(guān)注

    關(guān)注

    8

    文章

    641

    瀏覽量

    29207
  • 函數(shù)
    +關(guān)注

    關(guān)注

    3

    文章

    4331

    瀏覽量

    62604
  • 代碼
    +關(guān)注

    關(guān)注

    30

    文章

    4788

    瀏覽量

    68601
  • Transformer
    +關(guān)注

    關(guān)注

    0

    文章

    143

    瀏覽量

    6005

原文標(biāo)題:【CUDA編程】Faster Transformer v2.1 源碼詳解

文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    海康威視發(fā)布iVMS-8700安防綜合管理平臺(tái)V2.1版本

    康威視發(fā)布iVMS-8700安防綜合管理平臺(tái)V2.1版本,以全新的風(fēng)格面世。
    發(fā)表于 07-30 10:20 ?2.8w次閱讀

    請(qǐng)問怎么分辨TI套件TMDSHVMTRPFCKIT屬于1.7、v2.0、v2.1中的哪個(gè)版本

    怎么分辨TI套件TMDSHVMTRPFCKIT屬于1.7、v2.0、v2.1中的哪個(gè)版本
    發(fā)表于 10-17 15:01

    FET335xD V2.1版本V2.2版本有什么區(qū)別嗎

    FET335xD V2.1版本V2.2版本有什么區(qū)別?一開始買的配2.1版本核心板的開發(fā)板能用
    發(fā)表于 01-12 06:22

    ST-Link V2.1如何制作?怎么使用?

    前言ST-Link V2.1簡(jiǎn)介Mass StorageVirtual COM portDebug PortMCOST-Link V2.1原理圖ST-Link Bootloader程序ST-Link
    發(fā)表于 02-18 06:13

    ST-link V2.1版本具有哪些優(yōu)勢(shì)?

    ST-link V2.1版本具有哪些優(yōu)勢(shì)?
    發(fā)表于 02-21 06:29

    SComAssistant V2.1

    SComAssistant V2.1是一個(gè)串口調(diào)試軟件。
    發(fā)表于 03-15 12:19 ?250次下載
    SComAssistant <b class='flag-5'>V2.1</b>

    YESTON盈通 G41T V2.1主板

    YESTON盈通 G41T V2.1主板 主板驅(qū)動(dòng)安裝程序
    發(fā)表于 09-11 12:02 ?28次下載

    LED點(diǎn)陣多功能數(shù)字時(shí)鐘V2.1

    0 0730LED點(diǎn)陣多功能數(shù)字時(shí)鐘V2.1版 20140314.zip
    發(fā)表于 12-30 14:03 ?0次下載

    串口通信助手v2.1

    串口通信助手v2.1,MATLAB源代碼,感興趣的小伙伴們可以看看。
    發(fā)表于 07-25 10:45 ?10次下載

    BlueSkyC51不完全手冊(cè)V2.1

    BlueSkyC51不完全手冊(cè)V2.1
    發(fā)表于 03-14 10:25 ?5次下載

    ST-Link V2.1 制作使用

    ST-Link V2.1 制作使用1、前言新的STM32單片機(jī)可以通過常見的 J-Link, ST-Link, 開源的DAP-Link等設(shè)備下載程序, 淘寶搜索STM32下載器出來各種各樣的玩意
    發(fā)表于 12-22 19:47 ?26次下載
    ST-Link <b class='flag-5'>V2.1</b> 制作使用

    ST-Link V2.1 制作使用

    前言ST-Link V2.1簡(jiǎn)介Mass StorageVirtual COM portDebug PortMCOST-Link V2.1原理圖ST-Link Bootloader程序ST-Link
    發(fā)表于 12-23 19:00 ?33次下載
    ST-Link <b class='flag-5'>V2.1</b> 制作使用

    OBTCAN-IP核用戶手冊(cè)(V2.1)

    OBTCAN-IP核用戶手冊(cè)(V2.1)
    發(fā)表于 06-08 15:29 ?6次下載
    OBTCAN-IP核用戶手冊(cè)(<b class='flag-5'>V2.1</b>)

    外置BFO V2.1貼片版開源分享

    電子發(fā)燒友網(wǎng)站提供《外置BFO V2.1貼片版開源分享.zip》資料免費(fèi)下載
    發(fā)表于 07-25 10:03 ?0次下載
    外置BFO <b class='flag-5'>V2.1</b>貼片版開源分享

    Faster Transformer v1.0源碼詳解

    解讀的內(nèi)容僅限 Faster Transformer v1.0 版本,更高版本
    的頭像 發(fā)表于 09-08 10:20 ?974次閱讀
    <b class='flag-5'>Faster</b> <b class='flag-5'>Transformer</b> <b class='flag-5'>v</b>1.0<b class='flag-5'>源碼</b>詳解