FasterTransformer BERT
FasterTransformer BERT 包含優(yōu)化的 BERT 模型、高效的 FasterTransformer 和 INT8 量化推理。
模型結(jié)構(gòu)
標(biāo)準(zhǔn)的 BERT 和 高效的 FasterTransformer
FasterTransformer 編碼器支持以下配置。
- Batch size (B1): 批量大小 <= 4096
- Sequence length (S): 序列長(zhǎng)度 <= 4096。對(duì)于 INT8 模型,當(dāng) S > 384 時(shí) S 需要是 32 的倍數(shù)。
- Size per head (N): 小于 128 的偶數(shù)。
- Head number (H): 在 FP32 下滿足 H * N <= 1024 或在 FP16 下滿足 H * N <= 2048 的任何數(shù)字。
- Data type: FP32, FP16, BF16, INT8 and FP8 (Experimental).
- 如果內(nèi)存足夠,任意層數(shù)(N1)
在 FasterTransformer v1.0 中,我們提供了高度優(yōu)化的 BERT 等效編碼器模型。接下來,基于Effective Transformer的思想,我們?cè)? FasterTransformer v2.1 中通過去除無用的 padding 來進(jìn)一步優(yōu)化BERT推理,并提供 Effective FasterTransformer。在 FasterTransformer v3.0 中,我們提供了 INT8 量化推理以獲得更好的性能。
在 FasterTransformer v3.1 中,我們優(yōu)化了 INT8 Kernel 以提高 INT8 推理的性能,并將 TensorRT 的多頭注意力插件集成到 FasterTransformer 中。在 FasterTransformer v4.0 中,我們添加了多頭注意力 Kernel 支持 V100 的 FP16 模式和 T4, A100 的 INT8 模式。下圖演示了除 INT8 外的這些優(yōu)化的流程圖。在FasterTransformer v5.0中,我們重構(gòu)了代碼,將 mask building 和 padding 移動(dòng)到 Bert 的 forward 函數(shù)中,并在 Ampere GPU 上基于稀疏特性來加速GEMM。在 FasterTransformer v5.1 中,我們支持對(duì) Bert FP16 進(jìn)行進(jìn)行多節(jié)點(diǎn)多 GPU 推理。
BERT 模型是 google 在2018年提出的。FasterTransformer 的encoder 相當(dāng)于 BERT 模型,但是做了很多優(yōu)化。圖 1 最左邊的流程顯示了 FasterTransformer 中的優(yōu)化。經(jīng)過優(yōu)化后,F(xiàn)asterTransformer 僅使用 8 或 6 個(gè) gemms(藍(lán)色塊)和 6 個(gè)自定義 CUDA kernel(綠色塊)來實(shí)現(xiàn)一個(gè) transformer 塊。
對(duì)于 Effective FasterTransformer,主要思想是去除句子的填充以防止計(jì)算無用的標(biāo)記。當(dāng)一個(gè) Batch 的平均序列長(zhǎng)度遠(yuǎn)小于最大序列長(zhǎng)度時(shí),此方法可以節(jié)省大量時(shí)間。圖 2 顯示了我們使用的想法和偏移量(橙色)。要實(shí)現(xiàn) Effective FasterTransformer,我們需要考慮兩個(gè)問題。首先,我們需要去除 BERT 之前的 padding,離開 BERT 之后重建 padding 以保持結(jié)果的形狀。這很簡(jiǎn)單,帶來的開銷基本可以忽略。第二個(gè)問題是多頭注意力的計(jì)算。一個(gè)天真的解決方案是在多頭注意力之前重建填充并在多頭注意力之后移除填充,如圖 1 的第二個(gè)流程圖所示。因?yàn)槲覀兛梢詫⑦@些重建/移除融合到其他 kernel 中,額外的開銷也是可以忽略的。
為了進(jìn)一步提高多頭注意力的性能,我們集成了 TensorRT 的多頭注意力,將整個(gè)注意力計(jì)算融合到一個(gè) kernel 中。源代碼在這里。該 kernel 同時(shí)支持 Effective FasterTransformer 和標(biāo)準(zhǔn) BERT 模型。圖 1 中的第三個(gè)和第四個(gè)流程圖顯示了工作流程。有了這樣的 kernel ,我們就不用擔(dān)心多頭注意力的填充問題了。這個(gè) kernel 需要另一個(gè)偏移量,如圖 2 所示。
第一個(gè)偏移量 [0, 0, 1, 3, 3, 3]比較好理解,直接和[0, 1, 2, 3, 4, 5]迭代就可以得到原始的位置了。第二個(gè)偏移量是從0位置開始,記錄連續(xù)的原始token個(gè)數(shù),比如我們將[0, 2, 3, 6]做差分,得到[2, 1, 3]也對(duì)應(yīng)了原始的數(shù)據(jù)中每行做的padding的tokn數(shù)目。
此外,我們發(fā)現(xiàn) padding 會(huì)影響某些任務(wù)的準(zhǔn)確性,盡管它們應(yīng)該是無用的。因此,我們建議刪除下游任務(wù)最終輸出中的填充。
編碼器的參數(shù)、輸入和輸出:
-
Constructor of BERT
-
Input of BERT
-
Output of BERT
上面聲明了 Bert 模型的輸入?yún)?shù),以及輸入和輸出Tensor的shape。
此外,注意到 TensorRT 的多頭注意力Kernel雖然功能很強(qiáng)大但是也有一些限制。首先,這個(gè)kernel需要 Turing 或者更新架構(gòu)的 GPU,并且每個(gè)頭的大小必須是64。當(dāng)條件不滿足時(shí),我們使用FasterTransformer的原始多頭注意力實(shí)現(xiàn)。其次,它需要一個(gè)額外的序列長(zhǎng)度偏移量,如Figure2所示,更多的細(xì)節(jié)在這里 。當(dāng)輸入有 padding 時(shí),序列長(zhǎng)度偏移的形狀為 。假設(shè)這里有3個(gè)序列,長(zhǎng)度分別為 , , ,然后 padding 之后的序列長(zhǎng)度為 。那么序列長(zhǎng)度偏移時(shí) 。即,序列長(zhǎng)度偏移記錄了每個(gè)句子的序列長(zhǎng)度。當(dāng)我們有 padding 時(shí),我們將 padding 視為一些獨(dú)立的句子。
在 FasterTransformer v4.0 中,我們實(shí)現(xiàn)了兩條 INT8 推理的流水線,如圖 3 所示。對(duì)于 int8_mode == 1 (int8v1),我們不量化殘差連接,使用 int32 作為 int8 gemms 的輸出,并對(duì)權(quán)重采用逐通道的量化方式。對(duì)于 int8_mode == 2 (int8v2),我們量化殘差連接,使用 int8 作為 int8 gemms 的輸出,并對(duì)權(quán)重采用逐張量的量化。一般來說,int8_mode == 1 的精度更高,而 int8_mode == 2 的性能更好。
Figure 3
對(duì)于 INT8 推理,需要量化模型。我們提供了 TensorFlow 量化工具和示例代碼,同時(shí)還提供了帶有 TensorRT 量化工具的 PyTorch 示例代碼。請(qǐng)先參考bert-quantization/bert-tf-quantization
和examples/pytorch/bert/bert-quantization-sparsity
中的README
。
在 FasterTransformer v5.0 中,我們支持稀疏 gemm 以利用 Ampere GPU 的稀疏特性。我們還提供了一個(gè)關(guān)于 PyTorch 的示例。
在 FasterTransformer v5.1 中,我們支持 BERT 模型的多 GPU 多節(jié)點(diǎn)推理。
優(yōu)化點(diǎn)解讀
優(yōu)化主要是針對(duì) Figure 1 也就是 BERT 的編碼器模塊的各個(gè)組件來講(我這里忽略了 Figure1 的和 padding 相關(guān)的組建的講解,感興趣的讀者可以自己看看 FasterTransformer)。
import torch.nn as nn
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
def forward(self, query, key, value, mask=None, dropout=None):
scores = torch.matmul(query, key.transpose(-2, -1)) \\
/ math.sqrt(query.size(-1))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value),
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, h, d_model, dropout=0.1):
super().__init__()
assert d_model % h == 0
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention()
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linear_layers, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch.
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
return self.output_linear(x)
Compute Q, K, V by three GEMMs or one Batch GEMM
add_QKV_bias 優(yōu)化
這個(gè)是針對(duì)上面forward函數(shù)中 (1) 這部分存在的分別對(duì) Q, K, V進(jìn)行bias_add以及transpose的優(yōu)化,將其融合成一個(gè)cuda kernel。
對(duì)于FP32,F(xiàn)asterTransformer是啟動(dòng) batch_size *seq_len *3 個(gè) Block, 每個(gè) Block 里面啟動(dòng) head_num *size_per_head 個(gè)線程只處理一個(gè)token(對(duì)應(yīng) head_num *size_per_head 次計(jì)算)的 bias_add 計(jì)算。我們注意到這里還將輸入的shape進(jìn)行了改變,也就是將原始的[batch_size, seq_length, head_num * size_per_head] -> [batch_size, seq_length, head_num, size_per_head](對(duì)應(yīng) .view(batch_size, -1, self.h, self.d_k))->[batch_size, head_num, seq_length, size_per_head](對(duì)應(yīng).transpose(1, 2))。
而對(duì)于FP16模式,F(xiàn)asterTransformer是啟動(dòng) batch_size *seq_len 個(gè) Block,,每個(gè) Block 里面啟動(dòng) head_num *size_per_head 個(gè)線程同時(shí)處理QKV的同一個(gè)token(對(duì)應(yīng)head_num * size_per_head次計(jì)算)并使用了half2相關(guān)的數(shù)學(xué)函數(shù)。這樣不僅僅可以達(dá)到2倍于half的訪存帶寬和計(jì)算吞吐,還可以極大地減少指令的發(fā)射數(shù)量。
高效的softmax kernel
這里我沒有怎么看,因?yàn)閛neflow已經(jīng)有一個(gè)比FasterTransformer更好的softmax kernel實(shí)現(xiàn)了。
transpose kernel
這個(gè) kernel 是對(duì)應(yīng)上面 BERT 的 Encoder 部分的:
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
這里的 x 的 shape 仍然和之前的 q 的 shape 一致, 為[batch_size, head_num, seq_length, size_per_head]。因?yàn)锳ttetion 層不會(huì)改變輸入的形狀,因?yàn)?Attention 的計(jì)算過程是:q *k 轉(zhuǎn)置(.transpose(2, 3)),除以 d_k **0.5,輸出維度是 [b, head_num , seq_length, seq_length] 即單詞和單詞直接的相似性 ,然后對(duì)最后一個(gè)維度進(jìn)行 softmax 操作得到 [b, head_num, seq_length, seq_length] , 最后和 v(shape 也是 [batch_size, head_num, seq_length, size_per_head]) 做一個(gè)矩陣乘法,結(jié)果的 shape 和輸入的 shape 形狀都是:[batch_size, head_num, seq_length, size_per_head] 。因此這里的 x.transpose(1, 2) 就是把 shape 為 [batch_size, head_num, seq_length, size_per_head] 的 x 重新排列為 [batch_size, head_num, size_per_head, seq_length]。然后 x.contiguous().view(batch_size, -1, self.h *self.d_k) 進(jìn)一步將 shape 重新排列為 [batch_size, seq_length, head_num * size_per_head] 。
對(duì)于 FP32 模式,啟動(dòng) batch_size *head_num *seq_length 個(gè) Block , 然后每個(gè) Block 啟動(dòng) size_per_head 個(gè)線程處理一個(gè)序列(一個(gè)序列對(duì)應(yīng) size_per_head 個(gè)元素)。如下:
const int seq_per_block = 1; grid.x = batch_size * head_num * seq_len / seq_per_block; block.x = seq_per_block * size_per_head; transpose
而 transpose 的kernel實(shí)現(xiàn)也比較簡(jiǎn)單,根據(jù)blockIdx.x計(jì)算下batch_id和seq_id以及head_id(輸入 x 的 shape 為 [batch_size, head_num, seq_length, size_per_head]):
`template
global
void transpose(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head)
{
int batch_id = blockIdx.x / (head_num * seq_len);
int seq_id = blockIdx.x % seq_len;
int head_id = (blockIdx.x % (head_num * seq_len))/ seq_len;
dst[batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head
head_id * size_per_head + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}
`
對(duì)于 half 來說,采用和 add_QKV_bias 一樣的優(yōu)化方式,每個(gè) block 處理 4 個(gè)sequence。具體來說,就是現(xiàn)在啟動(dòng) batch_size *head_num *seq_len / 4 個(gè) Block, 每個(gè) Block 使用 2 *size_per_head 個(gè)線程處理 4 個(gè)序列。為什么 2 *size_per_head 個(gè)線程可以處理 4 個(gè)序列(一個(gè)序列對(duì)應(yīng) size_per_head 個(gè)元素),原因是因?yàn)槭褂昧?half2 來做數(shù)據(jù)讀取。half 類型的 kernel 實(shí)現(xiàn)如下:
` inline device
int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
{
return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4;
}
template<>
global
void transpose(__half* src, __half* dst,
const int batch_size, const int seq_len, const int head_num, const int size_per_head)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int batch_id = tid / (head_num * seq_len * size_per_head);
int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head);
int seq_id = (tid % (seq_len * size_per_head)) / size_per_head;
int id = tid % size_per_head;
int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head);
half2 * src_ptr = (half2* )src;
half2 * dst_ptr = (half2* )dst;
dst_ptr[target_id] = src_ptr[tid];
}
`
trt_add_QKV_bias 和 TensorRT fused multi-head attention kernel
實(shí)際上從 Figure1 也可以看出我們上面講到的 batch GEMM,softmax, GEMM,transpose 等操作都可以被合成一個(gè)超大的 cuda kernel,進(jìn)一步進(jìn)行優(yōu)化,也就是這里的 TensorRT fused multi-head attention kernel。這個(gè)是將 TensorRT 的這個(gè)插件作為第三方倉(cāng)庫(kù)引入到 FasterTransformer 進(jìn)行加速的,具體的代碼我沒有研究過,這里就不展開了。
現(xiàn)在 MultiHeadAttention 部分涉及到的優(yōu)化其實(shí)就講完了,我們接著看一下FasterTransformer 對(duì) BERT Encoder 的其它部分的優(yōu)化。我們這里貼一下 Transformer 的結(jié)構(gòu)圖:
在 MultiHeadAttention 的后面接了一個(gè) Add & Norm,這里的 Add 其實(shí)就是殘差,Norm 就是 LayerNorm。所以 Encoder 部分的兩個(gè) Add & Norm 可以總結(jié)為:
add_bias_input_layernorm
對(duì)于 softmax 和 layernorm 我還沒看 FasterTransformer 的源碼,后續(xù)研究了之后再分享。
總的來說就是 add_bias_input_layernorm 這個(gè)優(yōu)化把殘差連接和LayerNorm fuse到一起了,性能更好并且降低了kernel launch的開銷。
add_bias_act
在上圖的 Feed Forward 的實(shí)現(xiàn)中,還有一個(gè) bias_add 和 gelu 激活函數(shù)挨著的 pattern ,所以 FasterTransformer 實(shí)現(xiàn)了這個(gè) add_bias_act kernel 將兩個(gè)操作融合起來,常規(guī)操作。
-
編碼器
+關(guān)注
關(guān)注
45文章
3645瀏覽量
134578 -
gpu
+關(guān)注
關(guān)注
28文章
4742瀏覽量
128972 -
CUDA
+關(guān)注
關(guān)注
0文章
121瀏覽量
13641
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論