摘要
在過(guò)去幾年中,如何擴(kuò)展Transformer使之能夠處理更長(zhǎng)的序列一直是一個(gè)重要問(wèn)題,因?yàn)檫@能提高Transformer語(yǔ)言建模性能和高分辨率圖像理解能力,以及解鎖代碼、音頻和視頻生成等新應(yīng)用。然而增加序列長(zhǎng)度,注意力層是主要瓶頸,因?yàn)樗倪\(yùn)行時(shí)間和內(nèi)存會(huì)隨序列長(zhǎng)度的增加呈二次(平方)增加。FlashAttention利用GPU非勻稱(chēng)的存儲(chǔ)器層次結(jié)構(gòu),實(shí)現(xiàn)了顯著的內(nèi)存節(jié)?。◤钠椒皆黾愚D(zhuǎn)為線(xiàn)性增加)和計(jì)算加速(提速2-4倍),而且計(jì)算結(jié)果保持一致。但是,F(xiàn)lashAttention仍然不如優(yōu)化的矩陣乘法(GEMM)操作快,只達(dá)到理論最大FLOPs/s的25-40%。作者觀(guān)察到,這種低效是由于GPU對(duì)不同thread blocks和warps工作分配不是最優(yōu)的,造成了利用率低和不必要的共享內(nèi)存讀寫(xiě)。因此,本文提出了FlashAttention-2以解決這些問(wèn)題。
簡(jiǎn)介
如何擴(kuò)展Transformer使之能夠處理更長(zhǎng)的序列一直是一個(gè)挑戰(zhàn),**因?yàn)槠浜诵淖⒁饬拥倪\(yùn)行時(shí)間和內(nèi)存占用量隨輸入序列長(zhǎng)度成二次增加。**我們希望能夠打破2k序列長(zhǎng)度限制,從而能夠訓(xùn)練書(shū)籍、高分辨率圖像和長(zhǎng)視頻。此外,寫(xiě)作等應(yīng)用也需要模型能夠處理長(zhǎng)序列。過(guò)去一年中,業(yè)界推出了一些遠(yuǎn)超之前長(zhǎng)度的語(yǔ)言模型:GPT-4為32k,MosaicML的MPT為65k,以及Anthropic的Claude為100k。
雖然相比標(biāo)準(zhǔn)Attention,F(xiàn)lashAttention快了2~4倍,節(jié)約了10~20倍內(nèi)存,但是離設(shè)備理論最大throughput和flops還差了很多。本文提出了FlashAttention-2,它具有更好的并行性和工作分區(qū)。實(shí)驗(yàn)結(jié)果顯示,F(xiàn)lashAttention-2在正向傳遞中實(shí)現(xiàn)了約2倍的速度提升,達(dá)到了理論最大吞吐量的73%,在反向傳遞中達(dá)到了理論最大吞吐量的63%。在每個(gè)A100 GPU上的訓(xùn)練速度可達(dá)到225 TFLOPs/s。
本文主要貢獻(xiàn)和創(chuàng)新點(diǎn)為:
1. 減少了non-matmul FLOPs的數(shù)量(消除了原先頻繁rescale)。雖然non-matmul FLOPs僅占總FLOPs的一小部分,但它們的執(zhí)行時(shí)間較長(zhǎng),這是因?yàn)镚PU有專(zhuān)用的矩陣乘法計(jì)算單元,其吞吐量高達(dá)非矩陣乘法吞吐量的16倍。因此,減少non-matmul FLOPs并盡可能多地執(zhí)行matmul FLOPs非常重要。
2. 提出了在序列長(zhǎng)度維度上并行化。該方法在輸入序列很長(zhǎng)(此時(shí)batch size通常很?。┑那闆r下增加了GPU利用率。即使對(duì)于單個(gè)head,也在不同的thread block之間進(jìn)行并行計(jì)算。
3. 在一個(gè)attention計(jì)算塊內(nèi),將工作分配在一個(gè)thread block的不同warp上,以減少通信和共享內(nèi)存讀/寫(xiě)。
動(dòng)機(jī)
為了解決這個(gè)問(wèn)題,研究者們也提出了很多近似的attention算法,然而目前使用最多的還是標(biāo)準(zhǔn)attention。FlashAttention利用tiling、recomputation等技術(shù)顯著提升了計(jì)算速度(提升了2~4倍),并且將內(nèi)存占用從平方代價(jià)將為線(xiàn)性代價(jià)(節(jié)約了10~20倍內(nèi)存)。雖然FlashAttention效果很好,但是仍然不如其他基本操作(如矩陣乘法)高效。例如,其前向推理僅達(dá)到GPU(A100)理論最大FLOPs/s的30-50%(下圖);反向傳播更具挑戰(zhàn)性,在A(yíng)100上僅達(dá)到最大吞吐量的25-35%。相比之下,優(yōu)化后的GEMM(矩陣乘法)可以達(dá)到最大吞吐量的80-90%。通過(guò)觀(guān)察分析,這種低效是由于GPU對(duì)不同thread blocks和warps工作分配不是最優(yōu)的,造成了利用率低和不必要的共享內(nèi)存讀寫(xiě)。
Attention forward speed on A100 GPU. (Source: Figure 5 of the paper.)
背景知識(shí)
下面介紹一些關(guān)于GPU的性能和計(jì)算特點(diǎn),有關(guān)Attention和FlashAttention的詳細(xì)內(nèi)容請(qǐng)參考第一篇文章
FlashAttention圖解(如何加速Attention)
GPU
GPU performance characteristics.GPU主要計(jì)算單元(如浮點(diǎn)運(yùn)算單元)和內(nèi)存層次結(jié)構(gòu)。大多數(shù)現(xiàn)代GPU包含專(zhuān)用的低精度矩陣乘法單元(如Nvidia GPU的Tensor Core用于FP16/BF16矩陣乘法)。內(nèi)存層次結(jié)構(gòu)分為高帶寬內(nèi)存(High Bandwidth Memory, HBM)和片上SRAM(也稱(chēng)為shared memory)。以A100 GPU為例,它具有40-80GB的HBM,帶寬為1.5-2.0TB/s,每個(gè)108個(gè)streaming multiprocessors共享的SRAM為192KB,帶寬約為19TB/s。
這里忽略了L2緩存,因?yàn)椴荒苤苯颖挥?a href="http://wenjunhu.com/v/tag/1730/" target="_blank">程序員控制。
CUDA的軟件和硬件架構(gòu)
從Hardware角度來(lái)看:
Streaming Processor(SP):是最基本的處理單元,從fermi架構(gòu)開(kāi)始被叫做CUDA core。
Streaming MultiProcessor(SM):一個(gè)SM由多個(gè)CUDA core(SP)組成,每個(gè)SM在不同GPU架構(gòu)上有不同數(shù)量的CUDA core,例如Pascal架構(gòu)中一個(gè)SM有128個(gè)CUDA core。
SM還包括特殊運(yùn)算單元(SFU),共享內(nèi)存(shared memory),寄存器文件(Register File)和調(diào)度器(Warp Scheduler)等。register和shared memory是稀缺資源,這些有限的資源就使每個(gè)SM中active warps有非常嚴(yán)格的限制,也就限制了并行能力。
從Software(編程)角度來(lái)看:
CUDA軟件示例
thread是最基本的執(zhí)行單元(the basic unit of execution)。
warp是SM中最小的調(diào)度單位(the smallest scheduling unit on an SM),一個(gè)SM可以同時(shí)處理多個(gè)warp
thread block是GPU執(zhí)行的最小單位(the smallest unit of execution on the GPU)。
一個(gè)warp中的threads必然在同一個(gè)block中,如果block所含thread數(shù)量不是warp大小的整數(shù)倍,那么多出的那個(gè)warp中會(huì)剩余一些inactive的thread。也就是說(shuō),即使warp的thread數(shù)量不足,硬件也會(huì)為warp湊足thread,只不過(guò)這些thread是inactive狀態(tài),但也會(huì)消耗SM資源。
thread:一個(gè)CUDA并行程序由多個(gè)thread來(lái)執(zhí)行
warp:一個(gè)warp通常包含32個(gè)thread。每個(gè)warp中的thread可以同時(shí)執(zhí)行相同的指令,從而實(shí)現(xiàn)SIMT(單指令多線(xiàn)程)并行。
thread block:一個(gè)thread block可以包含多個(gè)warp,同一個(gè)block中的thread可以同步,也可以通過(guò)shared memory進(jìn)行通信。
grid:在GPU編程中,grid是一個(gè)由多個(gè)thread block組成的二維或三維數(shù)組。grid的大小取決于計(jì)算任務(wù)的規(guī)模和thread block的大小,通常根據(jù)計(jì)算任務(wù)的特點(diǎn)和GPU性能來(lái)進(jìn)行調(diào)整。
Hardware和Software的聯(lián)系:
SM采用的是Single-Instruction Multiple-Thread(SIMT,單指令多線(xiàn)程)架構(gòu),warp是最基本的執(zhí)行單元,一個(gè)warp包含32個(gè)并行thread,這些thread以不同數(shù)據(jù)資源執(zhí)行相同的指令。
當(dāng)一個(gè)kernel被執(zhí)行時(shí),grid中的thread block被分配到SM上,大量的thread可能被分到不同的SM上,但是一個(gè)線(xiàn)程塊的thread只能在一個(gè)SM上調(diào)度,SM一般可以調(diào)度多個(gè)block。每個(gè)thread擁有自己的程序計(jì)數(shù)器和狀態(tài)寄存器,并且可以使用不同的數(shù)據(jù)來(lái)執(zhí)行指令,從而實(shí)現(xiàn)并行計(jì)算,這就是所謂的Single Instruction Multiple Thread。
一個(gè)CUDA core可以執(zhí)行一個(gè)thread,一個(gè)SM中的CUDA core會(huì)被分成幾個(gè)warp,由warp scheduler負(fù)責(zé)調(diào)度。GPU規(guī)定warp中所有thread在同一周期執(zhí)行相同的指令,盡管這些thread執(zhí)行同一程序地址,但可能產(chǎn)生不同的行為,比如分支結(jié)構(gòu)。一個(gè)SM同時(shí)并發(fā)的warp是有限的,由于資源限制,SM要為每個(gè)block分配共享內(nèi)存,也要為每個(gè)warp中的thread分配獨(dú)立的寄存器,所以SM的配置會(huì)影響其所支持的block和warp并發(fā)數(shù)量。
GPU執(zhí)行模型小結(jié):
GPU有大量的threads用于執(zhí)行操作(an operation,也稱(chēng)為a kernel)。這些thread組成了thread block,接著這些blocks被調(diào)度在SMs上運(yùn)行。在每個(gè)thread block中,threads被組成了warps(32個(gè)threads為一組)。一個(gè)warp內(nèi)的threads可以通過(guò)快速shuffle指令進(jìn)行通信或者合作執(zhí)行矩陣乘法。在每個(gè)thread block內(nèi)部,warps可以通過(guò)讀取/寫(xiě)入共享內(nèi)存進(jìn)行通信。每個(gè)kernel從HBM加載數(shù)據(jù)到寄存器和SRAM中,進(jìn)行計(jì)算,最后將結(jié)果寫(xiě)回HBM中。
FlashAttention
FlashAttention應(yīng)用了tiling技術(shù)來(lái)減少內(nèi)存訪(fǎng)問(wèn),具體來(lái)說(shuō):
1. 從HBM中加載輸入數(shù)據(jù)(K,Q,V)的一部分到SRAM中
2. 計(jì)算這部分?jǐn)?shù)據(jù)的Attention結(jié)果
3. 更新輸出到HBM,但是無(wú)需存儲(chǔ)中間數(shù)據(jù)S和P
下圖展示了一個(gè)示例:首先將K和V分成兩部分(K1和K2,V1和V2,具體如何劃分根據(jù)數(shù)據(jù)大小和GPU特性調(diào)整),根據(jù)K1和Q可以計(jì)算得到S1和A1,然后結(jié)合V1得到O1。接著計(jì)算第二部分,根據(jù)K2和Q可以計(jì)算得到S2和A2,然后結(jié)合V2得到O2。最后O2和O1一起得到Attention結(jié)果。
值得注意的是,輸入數(shù)據(jù)K、Q、V是存儲(chǔ)在HBM上的,中間結(jié)果S、A都不需要存儲(chǔ)到HBM上。通過(guò)這種方式,F(xiàn)lashAttention可以將內(nèi)存開(kāi)銷(xiāo)降低到線(xiàn)性級(jí)別,并實(shí)現(xiàn)了2-4倍的加速,同時(shí)避免了對(duì)中間結(jié)果的頻繁讀寫(xiě),從而提高了計(jì)算效率。
FlashAttention-2
經(jīng)過(guò)鋪墊,正式進(jìn)入正文。我們先講述FlashAttention-2對(duì)FlashAttention的改進(jìn),從而減少了非矩陣乘法運(yùn)算(non-matmul)的FLOPs。然后說(shuō)明如何將任務(wù)分配給不同的thread block進(jìn)行并行計(jì)算,充分利用GPU資源。最后描述了如何在一個(gè)thread block內(nèi)部分配任務(wù)給不同的warps,以減少訪(fǎng)問(wèn)共享內(nèi)存次數(shù)。這些優(yōu)化方案使得FlashAttention-2的性能提升了2-3倍。
Algorithm
FlashAttention在FlashAttention算法基礎(chǔ)上進(jìn)行了調(diào)整,減少了非矩陣乘法運(yùn)算(non-matmul)的FLOPs。這是因?yàn)楝F(xiàn)代GPU有針對(duì)matmul(GEMM)專(zhuān)用的計(jì)算單元(如Nvidia GPU上的Tensor Cores),效率很高。以A100 GPU為例,其FP16/BF16矩陣乘法的最大理論吞吐量為312 TFLOPs/s,但FP32非矩陣乘法僅有19.5 TFLOPs/s,即每個(gè)no-matmul FLOP比mat-mul FLOP昂貴16倍。為了確保高吞吐量(例如超過(guò)最大理論TFLOPs/s的50%),我們希望盡可能將時(shí)間花在matmul FLOPs上。
Forward pass
通常實(shí)現(xiàn)Softmax算子為了數(shù)值穩(wěn)定性(因?yàn)橹笖?shù)增長(zhǎng)太快,數(shù)值會(huì)過(guò)大甚至溢出),會(huì)減去最大值:
這樣帶來(lái)的代價(jià)就是要對(duì)遍歷3次。
為了減少non-matmul FLOPs,本文在FlashAttention基礎(chǔ)上做了兩點(diǎn)改進(jìn):
簡(jiǎn)單示例的FlashAttention完整計(jì)算步驟(紅色部分表示V1和V2區(qū)別):
FlashAttention-2的完整計(jì)算步驟(紅色部分表示V1和V2區(qū)別):
有了上面分析和之前對(duì)FlashAttention的講解,再看下面?zhèn)未a就沒(méi)什么問(wèn)題了。
Causal masking是attention的一個(gè)常見(jiàn)操作,特別是在自回歸語(yǔ)言建模中,需要對(duì)注意力矩陣S應(yīng)用因果掩碼(即任何S ,其中 > 的條目都設(shè)置為?∞)。
1. 由于FlashAttention和FlashAttention-2已經(jīng)通過(guò)塊操作來(lái)實(shí)現(xiàn),對(duì)于所有列索引都大于行索引的塊(大約占總塊數(shù)的一半),我們可以跳過(guò)該塊的計(jì)算。這比沒(méi)有應(yīng)用因果掩碼的注意力計(jì)算速度提高了1.7-1.8倍。
2. 不需要對(duì)那些行索引嚴(yán)格小于列索引的塊應(yīng)用因果掩碼。這意味著對(duì)于每一行,我們只需要對(duì)1個(gè)塊應(yīng)用因果掩碼。
Parallelism
FlashAttention在batch和heads兩個(gè)維度上進(jìn)行了并行化:使用一個(gè)thread block來(lái)處理一個(gè)attention head,總共需要thread block的數(shù)量等于batch size × number of heads。每個(gè)block被調(diào)到到一個(gè)SM上運(yùn)行,例如A100 GPU上有108個(gè)SMs。當(dāng)block數(shù)量很大時(shí)(例如≥80),這種調(diào)度方式是高效的,因?yàn)閹缀蹩梢杂行Ю肎PU上所有計(jì)算資源。
但是在處理長(zhǎng)序列輸入時(shí),由于內(nèi)存限制,通常會(huì)減小batch size和head數(shù)量,這樣并行化成都就降低了。因此,F(xiàn)lashAttention-2還在序列長(zhǎng)度這一維度上進(jìn)行并行化,顯著提升了計(jì)算速度。此外,當(dāng)batch size和head數(shù)量較小時(shí),在序列長(zhǎng)度上增加并行性有助于提高GPU占用率。
Work Partitioning Between Warps
上一節(jié)討論了如何分配thread block,然而在每個(gè)thread block內(nèi)部,我們也需要決定如何在不同的warp之間分配工作。我們通常在每個(gè)thread block中使用4或8個(gè)warp,如下圖所示。
Work partitioning between different warps in the forward pass
論文中原話(huà)是”However, this is inefficient since all warps need to write their intermediate results out toshared memory, synchronize, then add up the intermediate results.”,說(shuō)的是shared memory而非HBM,但是結(jié)合下圖黃色框部分推斷,我認(rèn)為是HBM。
-
存儲(chǔ)器
+關(guān)注
關(guān)注
38文章
7509瀏覽量
163975 -
gpu
+關(guān)注
關(guān)注
28文章
4749瀏覽量
129034 -
矩陣
+關(guān)注
關(guān)注
0文章
423瀏覽量
34578
原文標(biāo)題:FlashAttention2詳解(性能比FlashAttention提升200%)
文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論