0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

簡要介紹einsum表示法的概念,通過真實(shí)例子展示了einsum的表達(dá)力

zhKF_jqr_AI ? 來源:未知 ? 作者:李倩 ? 2018-10-04 08:50 ? 次閱讀

編者按:FAIR研究科學(xué)家Tim Rockt?schel簡要介紹了einsum表示法的概念,并通過真實(shí)例子展示了einsum的表達(dá)力。

當(dāng)我和同事聊天的時(shí)候,我意識到不是所有人都了解einsum,我開發(fā)深度學(xué)習(xí)模型時(shí)最喜歡的函數(shù)。本文打算改變這一現(xiàn)狀,讓所有人都了解它!愛因斯坦求和約定(einsum)在numpy和TensorFlow之類的深度學(xué)習(xí)庫中都有實(shí)現(xiàn),感謝Thomas Viehmann,最近PyTorch也實(shí)現(xiàn)了這一函數(shù)。關(guān)于einsum的背景知識,我推薦閱讀Olexa Bilaniuk的numpy的愛因斯坦求和約定以及Alex Riley的einsum基本指南。這兩篇文章介紹了numpy中的einsum,我的這篇文章則將演示在編寫優(yōu)雅的PyTorch/TensorFlow模型時(shí),einsum是多么有用(我將使用PyTorch作為例子,不過很容易就可以翻譯到TensorFlow)。

1. einsum記法

如果你像我一樣,發(fā)現(xiàn)記住PyTorch/TensorFlow中那些計(jì)算點(diǎn)積、外積、轉(zhuǎn)置、矩陣-向量乘法、矩陣-矩陣乘法的函數(shù)名字和簽名很費(fèi)勁,那么einsum記法就是我們的救星。einsum記法是一個(gè)表達(dá)以上這些運(yùn)算,包括復(fù)雜張量運(yùn)算在內(nèi)的優(yōu)雅方式,基本上,可以把einsum看成一種領(lǐng)域特定語言。一旦你理解并能利用einsum,除了不用記憶和頻繁查找特定庫函數(shù)這個(gè)好處以外,你還能夠更迅速地編寫更加緊湊、高效的代碼。而不使用einsum的時(shí)候,容易出現(xiàn)引入不必要的張量變形或轉(zhuǎn)置運(yùn)算,以及可以省略的中間張量的現(xiàn)象。此外,einsum這樣的領(lǐng)域特定語言有時(shí)可以編譯到高性能代碼,事實(shí)上,PyTorch最近引入的能夠自動生成GPU代碼并為特定輸入尺寸自動調(diào)整代碼的張量理解(Tensor Comprehensions)就基于類似einsum的領(lǐng)域特定語言。此外,可以使用opt einsum和tf einsum opt這樣的項(xiàng)目優(yōu)化einsum表達(dá)式的構(gòu)造順序。

比方說,我們想要將兩個(gè)矩陣A ∈ ?I × K和B ∈ ?K × J相乘,接著計(jì)算每列的和,最終得到向量c ∈ ?J。使用愛因斯坦求和約定,這可以表達(dá)為:

這一表達(dá)式指明了c中的每個(gè)元素ci是如何計(jì)算的,列向量Ai:乘以行向量B:j,然后求和。注意,在愛因斯坦求和約定中,我們省略了求和符號Sigma,因?yàn)槲覀冸[式地累加重復(fù)的下標(biāo)(這里是k)和輸出中未指明的下標(biāo)(這里是i)。當(dāng)然,einsum也能表達(dá)更基本的運(yùn)算。比如,計(jì)算兩個(gè)向量a, b ∈ ?J的點(diǎn)積可以表達(dá)為:

在深度學(xué)習(xí)中,我經(jīng)常碰到的一個(gè)問題是,變換高階張量到向量。例如,我可能有一個(gè)張量,其中包含一個(gè)batch中的N個(gè)訓(xùn)練樣本,每個(gè)樣本是一個(gè)長度為T的K維詞向量序列,我想把詞向量投影到一個(gè)不同的維度Q。如果將這個(gè)張量記作T ∈ ?N × T × K,將投影矩陣記作W ∈ ?K × Q,那么所需計(jì)算可以用einsum表達(dá)為:

最后一個(gè)例子,比方說有一個(gè)四階張量T ∈ ?N × T × K × M,我們想要使用之前的投影矩陣將第三維投影至Q維,并累加第二維,然后轉(zhuǎn)置結(jié)果中的第一維和最后一維,最終得到張量C ∈ ?M × Q × N。einsum可以非常簡潔地表達(dá)這一切:

注意,我們通過交換下標(biāo)n和m(Cmqn而不是Cnqm),轉(zhuǎn)置了張量構(gòu)造結(jié)果。

2. Numpy、PyTorch、TensorFlow中的einsum

einsum在numpy中實(shí)現(xiàn)為np.einsum,在PyTorch中實(shí)現(xiàn)為torch.einsum,在TensorFlow中實(shí)現(xiàn)為tf.einsum,均使用一致的簽名einsum(equation, operands),其中equation是表示愛因斯坦求和約定的字符串,而operands則是張量序列(在numpy和TensorFlow中是變長參數(shù)列表,而在PyTorch中是列表)。例如,我們的第一個(gè)例子,cj= ∑i∑kAikBkj寫成equation字符串就是ik,kj -> j。注意這里(i, j, k)的命名是任意的,但需要一致。

PyTorch和TensorFlow像numpy支持einsum的好處之一是einsum可以用于神經(jīng)網(wǎng)絡(luò)架構(gòu)的任意計(jì)算圖,并且可以反向傳播。典型的einsum調(diào)用格式如下:

上式中?是占位符,表示張量維度。上面的例子中,arg1和arg3是矩陣,arg2是二階張量,這一einsum運(yùn)算的結(jié)果(result)是矩陣。注意einsum處理的是可變數(shù)量的輸入。在上面的例子中,einsum指定了三個(gè)參數(shù)之上的操作,但它同樣可以用在牽涉一個(gè)參數(shù)、兩個(gè)參數(shù)、三個(gè)以上參數(shù)的操作上。學(xué)習(xí)einsum的最佳途徑是通過學(xué)習(xí)一些例子,所以下面我們將展示一下,在許多深度學(xué)習(xí)模型中常用的庫函數(shù),用einsum該如何表達(dá)(以PyTorch為例)。

2.1 矩陣轉(zhuǎn)置

import torch

a = torch.arange(6).reshape(2, 3)

torch.einsum('ij->ji', [a])

tensor([[ 0., 3.],

[ 1., 4.],

[ 2., 5.]])

2.2 求和

a = torch.arange(6).reshape(2, 3)

torch.einsum('ij->', [a])

tensor(15.)

2.3 列求和

a = torch.arange(6).reshape(2, 3)

torch.einsum('ij->j', [a])

tensor([ 3., 5., 7.])

2.4 行求和

a = torch.arange(6).reshape(2, 3)

torch.einsum('ij->i', [a])

tensor([ 3., 12.])

2.5 矩陣-向量相乘

a = torch.arange(6).reshape(2, 3)

b = torch.arange(3)

torch.einsum('ik,k->i', [a, b])

tensor([ 5., 14.])

2.6 矩陣-矩陣相乘

a = torch.arange(6).reshape(2, 3)

b = torch.arange(15).reshape(3, 5)

torch.einsum('ik,kj->ij', [a, b])

tensor([[ 25., 28., 31., 34., 37.],

[ 70., 82., 94., 106., 118.]])

2.7 點(diǎn)積

向量:

a = torch.arange(3)

b = torch.arange(3,6) # [3, 4, 5]

torch.einsum('i,i->', [a, b])

tensor(14.)

矩陣:

a = torch.arange(6).reshape(2, 3)

b = torch.arange(6,12).reshape(2, 3)

torch.einsum('ij,ij->', [a, b])

tensor(145.)

2.8 哈達(dá)瑪積

a = torch.arange(6).reshape(2, 3)

b = torch.arange(6,12).reshape(2, 3)

torch.einsum('ij,ij->ij', [a, b])

tensor([[ 0., 7., 16.],

[ 27., 40., 55.]])

2.9 外積

a = torch.arange(3)

b = torch.arange(3,7)

torch.einsum('i,j->ij', [a, b])

tensor([[ 0., 0., 0., 0.],

[ 3., 4., 5., 6.],

[ 6., 8., 10., 12.]])

2.10 batch矩陣相乘

a = torch.randn(3,2,5)

b = torch.randn(3,5,3)

torch.einsum('ijk,ikl->ijl', [a, b])

tensor([[[ 1.0886, 0.0214, 1.0690],

[ 2.0626, 3.2655, -0.1465]],

[[-6.9294, 0.7499, 1.2976],

[ 4.2226, -4.5774, -4.8947]],

[[-2.4289, -0.7804, 5.1385],

[ 0.8003, 2.9425, 1.7338]]])

2.11 張量縮約

batch矩陣相乘是張量縮約的一個(gè)特例。比方說,我們有兩個(gè)張量,一個(gè)n階張量A ∈ ?I1× ? × In,一個(gè)m階張量B ∈ ?J1× ? × Jm。舉例來說,我們?nèi) = 4,m = 5,并假定I2= J3且I3= J5。我們可以將這兩個(gè)張量在這兩個(gè)維度上相乘(A張量的第2、3維度,B張量的3、5維度),最終得到一個(gè)新張量C ∈ ?I1× I4× J1× J2× J4,如下所示:

a = torch.randn(2,3,5,7)

b = torch.randn(11,13,3,17,5)

torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape

torch.Size([2, 7, 11, 13, 17])

2.12 雙線性變換

如前所述,einsum可用于超過兩個(gè)張量的計(jì)算。這里舉一個(gè)這方面的例子,雙線性變換。

a = torch.randn(2,3)

b = torch.randn(5,3,7)

c = torch.randn(2,7)

torch.einsum('ik,jkl,il->ij', [a, b, c])

tensor([[ 3.8471, 4.7059, -3.0674, -3.2075, -5.2435],

[-3.5961, -5.2622, -4.1195, 5.5899, 0.4632]])

3. 案例

3.1 TreeQN

我曾經(jīng)在實(shí)現(xiàn)TreeQN( arXiv:1710.11417)的等式6時(shí)使用了einsum:給定網(wǎng)絡(luò)層l上的低維狀態(tài)表示zl,和激活a上的轉(zhuǎn)換函數(shù)Wa,我們想要計(jì)算殘差鏈接的下一層狀態(tài)表示。

在實(shí)踐中,我們想要高效地計(jì)算大小為B的batch中的K維狀態(tài)表示Z ∈ ?B × K,并同時(shí)計(jì)算所有轉(zhuǎn)換函數(shù)(即,所有激活A(yù))。我們可以將這些轉(zhuǎn)換函數(shù)安排為一個(gè)張量W ∈ ?A × K × K,并使用einsum高效地計(jì)算下一層狀態(tài)表示。

import torch.nn.functional as F

def random_tensors(shape, num=1, requires_grad=False):

tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)]

return tensors[0] if num == 1else tensors

# 參數(shù)

# -- [激活數(shù) x 隱藏層維度]

b = random_tensors([5, 3], requires_grad=True)

# -- [激活數(shù) x 隱藏層維度 x 隱藏層維度]

W = random_tensors([5, 3, 3], requires_grad=True)

def transition(zl):

# -- [batch大小 x 激活數(shù) x 隱藏層維度]

return zl.unsqueeze(1) + F.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b)

# 隨機(jī)取樣仿造輸入

# -- [batch大小 x 隱藏層維度]

zl = random_tensors([2, 3])

transition(zl)

3.2 注意力

讓我們再看一個(gè)使用einsum的真實(shí)例子,實(shí)現(xiàn)注意力機(jī)制的等式11-13(arXiv:1509.06664):

用傳統(tǒng)寫法實(shí)現(xiàn)這些可要費(fèi)不少力氣,特別是考慮batch實(shí)現(xiàn)。einsum是我們的救星!

# 參數(shù)

# -- [隱藏層維度]

bM, br, w = random_tensors([7], num=3, requires_grad=True)

# -- [隱藏層維度 x 隱藏層維度]

WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)

# 注意力機(jī)制的單次應(yīng)用

def attention(Y, ht, rt1):

# -- [batch大小 x 隱藏層維度]

tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr])

Mt = F.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM)

# -- [batch大小 x 序列長度]

at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w]))

# -- [batch大小 x 隱藏層維度]

rt = torch.einsum("ijk,ij->ik", [Y, at]) + F.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br)

# -- [batch大小 x 隱藏層維度], [batch大小 x 序列維度]

return rt, at

# 取樣仿造輸入

# -- [batch大小 x 序列長度 x 隱藏層維度]

Y = random_tensors([3, 5, 7])

# -- [batch大小 x 隱藏層維度]

ht, rt1 = random_tensors([3, 7], num=2)

rt, at = attention(Y, ht, rt1)

4. 總結(jié)

einsum是一個(gè)函數(shù)走天下,是處理各種張量操作的瑞士軍刀。話雖如此,“einsum滿足你一切需要”顯然夸大其詞了。從上面的真實(shí)用例可以看到,我們?nèi)匀恍枰趀insum之外應(yīng)用非線性和構(gòu)造額外維度(unsqueeze)。類似地,分割、連接、索引張量仍然需要應(yīng)用其他庫函數(shù)。

使用einsum的麻煩之處是你需要手動實(shí)例化參數(shù),操心它們的初始化,并在模型中注冊這些參數(shù)。不過我仍然強(qiáng)烈建議你在實(shí)現(xiàn)模型時(shí),考慮下有哪些情況適合使用einsum.

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報(bào)投訴
  • 函數(shù)
    +關(guān)注

    關(guān)注

    3

    文章

    4331

    瀏覽量

    62630
  • 深度學(xué)習(xí)
    +關(guān)注

    關(guān)注

    73

    文章

    5503

    瀏覽量

    121175

原文標(biāo)題:einsum滿足你一切需要:深度學(xué)習(xí)中的愛因斯坦求和約定

文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子

    Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子)part1.rar
    發(fā)表于 12-20 12:09

    Proteus仿真實(shí)例大全分享

    [url=Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子) https://bbs.elecfans.com/jishu_339609_1_1.html (出處: 中國電子技術(shù)論壇)]Proteus
    發(fā)表于 06-20 03:27

    簡單介紹CAN總線的相關(guān)概念

    基于STM32的CAN總線通信學(xué)習(xí)筆記本文主要簡單介紹CAN總線的相關(guān)概念,以及通信協(xié)議等知識,和使用STM32自帶的bxCAN外設(shè)進(jìn)行CAN總線編程實(shí)驗(yàn),以及編程心得。1. CAN總線簡要
    發(fā)表于 08-19 07:23

    教您如何清理筆記本鍵盤(真實(shí)例子)

    教您如何清理筆記本鍵盤(真實(shí)例子) 前些日子的某天晚上,不小心把一大杯桔子汁碰倒,小銀男的鍵盤里被灌了一些。老婆聞訊趕來幫我把里面的水
    發(fā)表于 01-19 11:20 ?1027次閱讀

    關(guān)于ARM的22個(gè)常用概念介紹

    本文簡要介紹ARM的22個(gè)常用的概念。
    發(fā)表于 06-18 14:35 ?2972次閱讀

    Simulink建模仿真實(shí)例快速入門

    Simulink建模仿真實(shí)例詳解Simulink建模仿真實(shí)例詳解Simulink建模仿真實(shí)例詳解Simulink建模仿真實(shí)例詳解
    發(fā)表于 12-28 18:15 ?0次下載

    Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子)part4

    Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子)。
    發(fā)表于 04-18 10:02 ?226次下載

    Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子)part3

    Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子)。
    發(fā)表于 04-18 10:02 ?223次下載

    Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子)part1

    Proteus單片機(jī)仿真實(shí)例大全(幾百個(gè)例子)。
    發(fā)表于 04-18 10:02 ?726次下載

    簡要介紹操作系統(tǒng)虛擬化的概念,以及實(shí)現(xiàn)操作系統(tǒng)虛擬化的技術(shù)

    本文簡要介紹操作系統(tǒng)級虛擬化的概念,并簡要闡述實(shí)現(xiàn)操作系統(tǒng)虛擬化所用到的技術(shù)Namespac
    的頭像 發(fā)表于 01-10 15:00 ?1.3w次閱讀
    <b class='flag-5'>簡要</b><b class='flag-5'>介紹</b><b class='flag-5'>了</b>操作系統(tǒng)虛擬化的<b class='flag-5'>概念</b>,以及實(shí)現(xiàn)操作系統(tǒng)虛擬化的技術(shù)

    差分吸收介紹

    本文首先介紹差分吸收概念,然后解釋差分吸收的原理,然后分析
    的頭像 發(fā)表于 08-05 11:31 ?8088次閱讀
    差分吸收<b class='flag-5'>法</b><b class='flag-5'>介紹</b>

    C語言指針的表達(dá)實(shí)例程序說明

    本文檔的主要內(nèi)容詳細(xì)介紹的是C語言指針的表達(dá)實(shí)例程序說明。
    發(fā)表于 11-05 17:07 ?4次下載
    C語言指針的<b class='flag-5'>表達(dá)</b>式<b class='flag-5'>實(shí)例</b>程序說明

    MATLAB根軌跡仿真實(shí)例

    本文檔的主要內(nèi)容詳細(xì)介紹的是MATLAB根軌跡仿真實(shí)例
    發(fā)表于 07-03 08:00 ?1次下載
    MATLAB根軌跡仿<b class='flag-5'>真實(shí)例</b>

    電磁仿真實(shí)例教程

    電磁仿真實(shí)例教程免費(fèi)下載。
    發(fā)表于 04-21 10:57 ?41次下載

    基于TPU-MLIR:詳解EinSum的完整處理過程!

    EinSum介紹EinSum(愛因斯坦求和)是一個(gè)功能強(qiáng)大的算子,能夠簡潔高效地表示出多維算子的乘累加過程,對使用者非常友好。本質(zhì)上,EinSum
    的頭像 發(fā)表于 02-19 13:08 ?699次閱讀
    基于TPU-MLIR:詳解<b class='flag-5'>EinSum</b>的完整處理過程!