前言
這篇文章的主要內(nèi)容是,解讀 AlphaTensor 這篇論文的主要思想,如何通過強化學習來探索發(fā)現(xiàn)更高效的矩陣乘算法。
1、二進制加法和乘法
這一節(jié)簡單介紹一下計算機是怎么實現(xiàn)加法和乘法的。
以 2 + 5 和 2 * 5 為例。
我們知道數(shù)字在計算機中是以二進制形式表示的。
整數(shù)2的二進制表示為:0010
整數(shù)5的二進制表示為:0101
1.1、二進制加法
二進制加法很簡單,也就是兩個二進制數(shù)按位相加,如下圖所示:
當然具體到硬件實現(xiàn)其實是包含了異或運算和與運算,具體細節(jié)可以閱讀文末參考的資料。
1.2、二進制乘法
二進制乘法其實也是通過二進制加法來實現(xiàn)的,如下圖所示:
乘法在硬件上的實現(xiàn)本質(zhì)是移位相加。
對于二進制數(shù)來說乘數(shù)和被乘數(shù)的每一位非0即1。
所以相當于乘數(shù)中的每一位從低位到高位,分別和被乘數(shù)的每一位進行與運算并產(chǎn)生其相應(yīng)的局部乘積,再將這些局部乘積左移一位與上次的和相加。
從乘數(shù)的最低位開始:
若為1,則復(fù)制被乘數(shù),并左移一位與上一次的和相加;
若為0,則直接將0左移一位與上一次的和相加;
如此循環(huán)至乘數(shù)的最高位。
從二進制乘法的實現(xiàn)也可以看出來,加法比乘法操作要快。
1.3、用加法替換乘法的簡單例子
上面這個公式相信大家都很熟悉了,式子兩邊是等價的
左邊包含了2次乘法和1次加法(減法也可以看成加法)
右邊則包含了1次乘法和2次加法
可以看到通過數(shù)學上的等價變換,增加了加法的次數(shù)同時減少了乘法的次數(shù)。
2、矩陣乘算法
對于兩個大小分別為 Q x R 和 R x P 的矩陣相乘,通用的實現(xiàn)就需要 Q * P * R 次乘法操作(輸出矩陣大小 Q x P,總共 Q * P 個元素,每個元素計算需要 R 次乘法操作)。
根據(jù)前面 1.2內(nèi)容可知,乘法比加法慢,所以如果能減少的乘法次數(shù)就能有效加速矩陣乘的運算。
2.1、通用矩陣乘算法
首先來看一下通用的矩陣乘算法:
如上圖所示,兩個大小為2x2矩陣做乘法,總共需要8次乘法和4次加法。
2.2、Strassen 矩陣乘算法
上圖所示即為 Strassen 矩陣乘算法,和通用矩陣乘算法不一樣的地方是,引入了7個中間變量 m,只有在計算這7個中間變量才會用到乘法。
簡單用 c1 驗證一下:
可以看到 Strassen 算法總共包含7次乘法和18次加法,通過數(shù)學上的等價變換減少了1次乘法同時增加了14次加法。
3、AlphaTensor 核心思想解讀
3.1、將矩陣乘表示為3維張量
首先來看下論文中的一張圖
圖中下方是3維張量,每個立方體表示3維張量一個坐標點。
其中張量每個位置的值只能是 0 或者 1,透明的立方體表示 0,紫色的立方體表示 1。
現(xiàn)在將圖簡化一下,以[a,b,c]這樣的維度順序,將張量以維度a平攤開,這樣更容易理解:
這個3維張量怎么理解呢?
比如對于 c1,我們知道 c1 的計算需要用到 a1,a2,b1,b3,對應(yīng)到3維張量就是:
而從上圖可知,對于兩個 2 x 2 的矩陣相乘,3維張量大小為 4 x 4 x 4。
一般的,對于兩個 n x n 的矩陣相乘,3維張量大小為 n^2 x n^2 x n^2。
更一般的,對于兩個 n x m 和 m x p 的矩陣相乘,3維張量大小為 n*m x m*p x n*p。
然后論文中為了簡化理解,都是以 n x n 矩陣乘來講解的,論文中以
表示 n x n 矩陣乘的3維張量,下文中為了方便寫作以 Tn 來表示。
3.2、3維張量分解
然后論文中提出了一個假設(shè):
如果能將3維張量 Tn 分解為 R 個秩1的3維張量(R rank-one terms)的和的話,那么對于任意的 n x n 矩陣乘計算就只需要 R 次乘法。
如上圖公式所示,就是表示的這個分解,其中的
就表示的一個秩1的3維張量,是由 u^(r) 、 v^(r) 和 ?w^(r) 這3個一維向量做外積得到的。
這具體怎么什么理解呢?我們回去看一下 Strassen 矩陣乘算法:
上圖左邊就是 Strassen 矩陣乘算法的計算過程,右邊的 U,V 和 W 3個矩陣,各自分別對應(yīng)左邊 U -> a, V -> b 和 W -> m。
具體又怎么理解這三個矩陣呢?
我們在圖上加一些標注來解釋,其中 U , V 和 W 矩陣每一列從左到右按順序,就對應(yīng)上文提到的,u^(r) 、 v^(r) 和 ?w^(r) 這3個一維向量。
然后矩陣 U 每一列和 [a1,a2,a3,a4] 做內(nèi)積,矩陣 V 每一列和 [b1,b2,b3,b4] 做內(nèi)積,然后內(nèi)積結(jié)果相乘就得到 [m1,m2,m3,m4,m5,m6,m7]了。
最后矩陣 W 每一行和 [m1,m2,m3,m4,m5,m6,m7] 做內(nèi)積就得到 [c1,c2,c3,c4]。
接著再看一下的 U,V 和 W 這三個矩陣第一列的外積結(jié)果
如下圖所示:
可以看到 U,V 和 W 三個矩陣每一列對應(yīng)的外積的結(jié)果就是一個3維張量,那么這些3維張量全部加起來就會得到 Tn 么?下面我們來驗證一下:
?
可以看到這些外積的結(jié)果全部加起來就恰好等于 Tn:
?
所以也就證實了開頭的假設(shè):
如果能將表示矩陣乘的3維張量 Tn 分解為 R 個秩1的3維張量(R rank-one terms)的和,那么對于任意的 n x n 矩陣乘計算就只需要 R 次乘法。
因此也就很自然的可以想到,如果能找到更優(yōu)的張量分解,也就是讓 R 更小的話,那么就相當于找到乘法次數(shù)更小的矩陣乘算法了。
通過強化學習探索更優(yōu)的3維張量分解
將探索3維張量分解過程變成游戲
論文中是采用了強化學習這個框架,來探索對3維張量Tn的更優(yōu)的分解。強化學習的環(huán)境是一個單玩家的游戲(a single-player game, TensorGame)。
首先定義這個游戲進行 t 步之后的狀態(tài)為 St:
然后初始狀態(tài) S0 就設(shè)置為要分解的3維張量 Tn:
?
對于游戲中的每一步t,玩家(就是本論文提出的 AlphaTensor)會根據(jù)當前的狀態(tài)選擇下一步的行動,也就是通過生成新的三個一維向量從而得到新的秩1張量:
?
接著更新狀態(tài) St減去這個秩1張量:
?
玩家的目標就是,讓最終狀態(tài) St=0同時盡量的減少游戲的步數(shù)。
當?shù)竭_最終狀態(tài) St=0 之后,也就找到了3維張量Tn的一個分解了:
?
還有些細節(jié)是,對于玩家每一步的選擇都是給一個 -1 的分數(shù)獎勵,其實也很容易理解,也就是玩的步數(shù)越多,獎勵越低,從而鼓勵玩家用更少的步數(shù)完成游戲。
而且對于一維向量的生成,也做了限制
?
就是生成這些一維向量的值,只限定在比如 [?2,??1,?0,?1,?2] 這5個離散值之內(nèi)。
AlphaTensor 簡要解讀
論文中是怎么說的,在游戲過程中玩家 AlphaTensor 是通過一個深度神經(jīng)網(wǎng)絡(luò)來指導蒙特卡洛樹搜索(MonteCarlo tree search)。關(guān)于這個蒙特卡洛樹搜索,我不是很了解這里就不做解讀了,有興趣的讀者可以閱讀文末參考資料。
首先看下深渡神經(jīng)網(wǎng)絡(luò)部分:
?
深度神經(jīng)網(wǎng)絡(luò)的輸入是當前的狀態(tài) St也就是需要分解的張量(上圖中的最右邊的粉紅色立方體)。輸出包含兩個部分,分別是 Policy head 和 Value head。
其中 Policy head 的輸出是對于當前狀態(tài)可以采取的潛在下一步行動,也就是一維向量(u(t),?v(t),?w(t)) 的候選分布,然后通過采樣得到下一步的行動。
然后 Value head 應(yīng)該是對于給定的當前的狀態(tài) St ,估計游戲完成之后的最終獎勵分數(shù)的分布。
接下來簡要解讀一下整個游戲的流程,還有深度神經(jīng)網(wǎng)絡(luò)是如何訓練的:
先看流程圖的上方 Acting 那個方框內(nèi),表示的是用訓練好的網(wǎng)絡(luò)做推理玩游戲的過程。
可以看到最左邊綠色的立方體,也就是待分解的3維張量 Tn變換到粉紅色立方體,論文中提到是作了基的變換,但是這塊感覺如果不是去復(fù)現(xiàn)就不用了解的那么深入,而且我也沒去細看這塊就跳過吧。
然后從最初待分解的 Tn 開始,輸入到神經(jīng)網(wǎng)絡(luò),通過蒙特卡洛樹搜索得到秩1張量,然后減去該張量之后,繼續(xù)將相減的結(jié)果輸入到網(wǎng)路中,繼續(xù)這個過程直到張量相減的結(jié)果為0。
將游戲過程記錄下來,就是流程圖最右邊的 Played game。
然后流程圖下方的 Learning 方框表示的就是訓練過程,訓練數(shù)據(jù)有兩個部分,一個是已經(jīng)玩過的游戲記錄 Played games buffer 還有就是通過人工生成的數(shù)據(jù)。
人工怎么生成訓練數(shù)據(jù)呢?
論文中提到,盡管張量分解是個 NP-hard 的問題,給定一個 Tn 要找其分解很難。但是我們可以反過來用秩1張量來構(gòu)造出一個待分解的張量嘛!簡單來說就是采樣R個秩1張量,然后加起來就能的到分解的張量了。
因為對于強化學習這塊我不是了解的并不深入,所以也就只能作粗淺的解讀。
實驗結(jié)果
最后看一下實驗結(jié)果
表格最左邊一列表示矩陣乘的規(guī)模,最右邊三列表示矩陣乘算法乘法次數(shù)。
第一列表示目前為止,數(shù)學家找到的最優(yōu)乘法次數(shù)。
第2和3列就是 AlphaTensor 找到的最優(yōu)乘法次數(shù)。
可以看到其中有5個規(guī)模,AlphaTensor 能找到更優(yōu)的乘法次數(shù)(標紅的部分):
兩個 4 x 4 和 4 x 4 的矩陣乘,AlphaTensor 搜索出47次乘法;
兩個 5 x 5 和 5 x 5 的矩陣乘,AlphaTensor 搜索出96次乘法;
兩個 3 x 4 和 4 x 5 的矩陣乘,AlphaTensor 搜索出47次乘法;
兩個 4 x 4 和 4 x 5 的矩陣乘,AlphaTensor 搜索出63次乘法;
兩個 4 x 5 和 5 x 5 的矩陣乘,AlphaTensor 搜索出76次乘法;
審核編輯:劉清
評論
查看更多