0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線(xiàn)課程
  • 觀(guān)看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

ICLR 2019論文解讀:膠囊圖神經(jīng)網(wǎng)絡(luò)的PyTorch實(shí)現(xiàn)

DPVg_AI_era ? 來(lái)源:lp ? 2019-03-29 10:11 ? 次閱讀

膠囊圖神經(jīng)網(wǎng)絡(luò)(CapsGNN)是在GNN啟發(fā)下誕生了基于圖片分類(lèi)的新框架。CapsGNN在10個(gè)數(shù)據(jù)集中的6個(gè)的表現(xiàn)排名位居前兩名。與所有其他端到端架構(gòu)相比,CapsGNN在所有社交數(shù)據(jù)集中均名列首位。

本日Reddit上熱議的一個(gè)話(huà)題是名為“膠囊圖神經(jīng)網(wǎng)絡(luò)”(CapsGNN)的新框架。從名字不難看出,它是受圖神經(jīng)網(wǎng)絡(luò)(GNN)的啟發(fā),在其基礎(chǔ)上改進(jìn)而來(lái)的成果。

CapsGNN框架的作者為新加坡南洋理工大學(xué)電氣與電子工程學(xué)院的Zhang Xinyi和Lihui Chen,該研究的論文將在ICLR 2019上發(fā)表。

目前,從圖神經(jīng)網(wǎng)絡(luò)(GNN)中學(xué)到的高質(zhì)量節(jié)點(diǎn)嵌入已經(jīng)應(yīng)用于各種基于節(jié)點(diǎn)的應(yīng)用程序中,其中一些程序已經(jīng)實(shí)現(xiàn)了最先進(jìn)的性能。不過(guò),當(dāng)應(yīng)用程序用GNN學(xué)習(xí)的節(jié)點(diǎn)嵌入來(lái)生成圖形嵌入時(shí),標(biāo)量節(jié)點(diǎn)表示可能不足以有效地保留節(jié)點(diǎn)或圖形的完整屬性,從而導(dǎo)致圖形嵌入的性能達(dá)不到最優(yōu)。

膠囊圖神經(jīng)網(wǎng)絡(luò)(CapsGNN)受到了膠囊神經(jīng)網(wǎng)絡(luò)的啟發(fā),利用膠囊的概念來(lái)解決現(xiàn)有基于GNN的圖嵌入算法的缺點(diǎn)。CapsGNN以膠囊形式對(duì)節(jié)點(diǎn)特征進(jìn)行提取,利用路由機(jī)制來(lái)捕獲圖形級(jí)別的重要信息。因此,模型會(huì)為每個(gè)圖生成多個(gè)嵌入,從多個(gè)不同方面捕獲圖的屬性。

CapsGNN中包含的注意力模塊可用于處理各種尺寸的圖,讓模型能夠?qū)W⑻幚韴D的關(guān)鍵部分。通過(guò)對(duì)10個(gè)圖結(jié)構(gòu)數(shù)據(jù)集的廣泛評(píng)估表明,CapsGNN具有強(qiáng)大的機(jī)制,可通過(guò)數(shù)據(jù)驅(qū)動(dòng)捕獲整個(gè)圖的宏觀(guān)屬性。在幾個(gè)圖分類(lèi)任務(wù)上的性能優(yōu)于其他SOTA技術(shù)。

膠囊圖神經(jīng)網(wǎng)絡(luò)基本架構(gòu)

上圖所示為CapsGNN的簡(jiǎn)化版本。它由三個(gè)關(guān)鍵模塊組成:1)基本節(jié)點(diǎn)膠囊提取模塊:GNN用于提取具有不同感受野的局部頂點(diǎn)特征,然后在該模塊中構(gòu)建主節(jié)點(diǎn)膠囊。 2)高級(jí)圖膠囊提取模塊:融合了注意力模塊和動(dòng)態(tài)路由,以生成多個(gè)圖膠囊。 3)圖分類(lèi)模塊:再次利用動(dòng)態(tài)路由,生成用于圖分類(lèi)的類(lèi)膠囊。

注意力模塊

在CapsGNN中,基于每個(gè)節(jié)點(diǎn)提取主膠囊,即主膠囊的數(shù)量取決于輸入圖的大小。在這種情況下,如果直接應(yīng)用路由機(jī)制,則生成的高級(jí)別的膠囊的值將高度依賴(lài)于主膠囊的數(shù)量(圖大?。?,這種情況并不理想。因此,實(shí)驗(yàn)引入一個(gè)注意力模塊來(lái)解決這個(gè)問(wèn)題。

注意力模塊架構(gòu)。首先壓平主膠囊,利用兩層全連接神經(jīng)網(wǎng)絡(luò)產(chǎn)生每個(gè)膠囊的注意力值。利用基于節(jié)點(diǎn)的歸一化(對(duì)每行進(jìn)行歸一化)來(lái)生成最終注意力值。 將標(biāo)準(zhǔn)化值與主膠囊相乘來(lái)計(jì)算標(biāo)度膠囊。

實(shí)驗(yàn)設(shè)置與結(jié)果

我們驗(yàn)證了從CapsGNN中提取的圖嵌入與大量SOTA方法的性能,與一些經(jīng)典方法的最優(yōu)性能做了對(duì)比。此外還進(jìn)行了實(shí)驗(yàn)研究,評(píng)估膠囊對(duì)圖編碼特征效率的影響。我們對(duì)生成的圖/類(lèi)膠囊進(jìn)行了簡(jiǎn)要分析。實(shí)驗(yàn)結(jié)果和分析如下所示。

表1為生物數(shù)據(jù)集的實(shí)驗(yàn)結(jié)果,表2為社會(huì)數(shù)據(jù)集的實(shí)驗(yàn)結(jié)果。對(duì)于每個(gè)數(shù)據(jù)集,以粗體突出顯示前2個(gè)準(zhǔn)確度。

與所有其他算法相比,CapsGNN在10個(gè)數(shù)據(jù)集中的6個(gè)的表現(xiàn)排名位居前兩名,并且在其他數(shù)據(jù)集上也實(shí)現(xiàn)了基本相當(dāng)?shù)慕Y(jié)果。與所有其他端到端架構(gòu)相比,CapsGNN在所有社交數(shù)據(jù)集中均名列首位。

表1:生物數(shù)據(jù)集的實(shí)驗(yàn)結(jié)果

表2:社交數(shù)據(jù)集的實(shí)驗(yàn)結(jié)果

膠囊的效率

在膠囊的效率測(cè)試實(shí)驗(yàn)中,GNN的層數(shù)設(shè)置為L(zhǎng) = 3,每層的通道數(shù)都設(shè)置為Cl = 2。通過(guò)調(diào)整節(jié)點(diǎn)的維度(dn)、圖(dg)、膠囊和圖形、膠囊的數(shù)量(P)來(lái)構(gòu)造不同的CapsGNN。

表3:膠囊效率評(píng)估實(shí)驗(yàn)中經(jīng)過(guò)測(cè)試的體系結(jié)構(gòu)詳細(xì)信息

圖3:特征表示效率的比較。橫軸表示測(cè)試架構(gòu)的設(shè)置,縱軸表示NCI1的分類(lèi)精度。

圖膠囊的可視化

分類(lèi)膠囊的可視化

膠囊圖網(wǎng)絡(luò):基于GNN的高效快捷的新框架

CapsGNN是一個(gè)新框架,將膠囊理論融合到GNN中,來(lái)實(shí)現(xiàn)更高效的圖表示學(xué)習(xí)。該框架受CapsNet的啟發(fā),在原體系結(jié)構(gòu)中引入了膠囊的概念,在從GNN提取的節(jié)點(diǎn)特征的基礎(chǔ)上,以向量的形式提取特征。

利用CapsGNN,一個(gè)圖可以表示為多個(gè)嵌入,每個(gè)嵌入都可以捕獲不同方面的圖屬性。生成的圖形和類(lèi)封裝不僅可以保留與分類(lèi)相關(guān)的信息,還可以保留關(guān)于圖屬性的其他信息,這些信息可能在后續(xù)流程中用到。CapsGNN是一種新穎、高效且強(qiáng)大的數(shù)據(jù)驅(qū)動(dòng)方法,可以表示圖形等高維數(shù)據(jù)。

與其他SOTA算法相比,CapsGNN模型在10個(gè)圖表分類(lèi)任務(wù)中有6個(gè)成功實(shí)現(xiàn)了更好或相當(dāng)?shù)男阅?,在社交?shù)據(jù)集上的表現(xiàn)尤其顯眼。與其他類(lèi)似的基于標(biāo)量的體系結(jié)構(gòu)相比,CapsGNN在編碼特征方面更有效,這對(duì)于處理大型數(shù)據(jù)集非常有用。

關(guān)于開(kāi)源代碼和模型的一些補(bǔ)充信息

運(yùn)行環(huán)境

代碼庫(kù)在Python 3.5.2中實(shí)現(xiàn)。用于開(kāi)發(fā)的軟件包版本如下:

networkx 1.11tqdm 4.28.1numpy 1.15.4pandas 0.23.4texttable 1.5.0scipy 1.1.0argparse 1.1.0torch 0.4.1torch-scatter 1.1.2torch-sparse 0.2.2torch-cluster 1.2.4torch-geometric 1.0.3torchvision 0.2.1

數(shù)據(jù)集

代碼會(huì)從input文件夾中獲取訓(xùn)練圖,圖存儲(chǔ)形式為JSON。用于測(cè)試的圖也存儲(chǔ)為JSON文件。每個(gè)節(jié)點(diǎn)id和節(jié)點(diǎn)標(biāo)簽必須從0開(kāi)始索引。字典的鍵是存儲(chǔ)的字符串,以使JSON能夠序列化排布。

每個(gè)JSON文件都具有以下的鍵值結(jié)構(gòu):

{"edges": [[0, 1],[1, 2],[2, 3],[3, 4]], "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"}, "target": 1}

邊緣鍵(edgeskey)具有邊緣列表值,用于描述連接結(jié)構(gòu)。標(biāo)簽鍵具有每個(gè)節(jié)點(diǎn)的標(biāo)簽,這些標(biāo)簽存儲(chǔ)為字典- 在此嵌套字典中,標(biāo)簽是值,節(jié)點(diǎn)標(biāo)識(shí)符是鍵。目標(biāo)鍵具有整數(shù)值,該值代表了類(lèi)成員資格。

輸出

預(yù)測(cè)結(jié)果保存在output目錄中。每個(gè)嵌入都有一個(gè)標(biāo)題和一個(gè)帶有圖標(biāo)識(shí)符的列。最后,預(yù)測(cè)會(huì)按標(biāo)識(shí)符列排序。

訓(xùn)練CapsGNN模型由src /main.py腳本處理,該腳本提供以下命令行參數(shù)。

輸入和輸出選項(xiàng)

--training-graphs STR Training graphs folder. Default is `dataset/train/`. --testing-graphs STR Testing graphs folder. Default is `dataset/test/`. --prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.

模型選項(xiàng)

--epochs INT Number of epochs. Default is 10. --batch-size INT Number fo graphs per batch. Default is 32. --gcn-filters INT Number of filters in GCNs. Default is 2. --gcn-layers INT Number of GCNs chained together. Default is 5. --inner-attention-dimension INT Number of neurons in attention. Default is 20. --capsule-dimensions INT Number of capsule neurons. Default is 8. --number-of-capsules INT Number of capsules in layer. Default is 8. --weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6. --lambd FLOAT Regularization parameter. Default is 1.0. --learning-rate FLOAT Adam learning rate. Default is 0.01.

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀(guān)點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 神經(jīng)網(wǎng)絡(luò)

    關(guān)注

    42

    文章

    4773

    瀏覽量

    100862
  • 數(shù)據(jù)集
    +關(guān)注

    關(guān)注

    4

    文章

    1208

    瀏覽量

    24727
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    808

    瀏覽量

    13244
  • GNN
    GNN
    +關(guān)注

    關(guān)注

    1

    文章

    31

    瀏覽量

    6355

原文標(biāo)題:基于GNN,強(qiáng)于GNN:膠囊圖神經(jīng)網(wǎng)絡(luò)的PyTorch實(shí)現(xiàn) | ICLR 2019

文章出處:【微信號(hào):AI_era,微信公眾號(hào):新智元】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    labview BP神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)

    請(qǐng)問(wèn):我在用labview做BP神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)故障診斷,在NI官網(wǎng)找到了機(jī)器學(xué)習(xí)工具包(MLT),但是里面沒(méi)有關(guān)于這部分VI的幫助文檔,對(duì)于”BP神經(jīng)網(wǎng)絡(luò)分類(lèi)“這個(gè)范例有很多不懂的地方,比如
    發(fā)表于 02-22 16:08

    人工神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)方法有哪些?

    人工神經(jīng)網(wǎng)絡(luò)(Artificial Neural Network,ANN)是一種類(lèi)似生物神經(jīng)網(wǎng)絡(luò)的信息處理結(jié)構(gòu),它的提出是為了解決一些非線(xiàn)性,非平穩(wěn),復(fù)雜的實(shí)際問(wèn)題。那有哪些辦法能實(shí)現(xiàn)人工神經(jīng)
    發(fā)表于 08-01 08:06

    matlab實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò) 精選資料分享

    習(xí)神經(jīng)神經(jīng)網(wǎng)絡(luò),對(duì)于神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)是如何一直沒(méi)有具體實(shí)現(xiàn)一下:現(xiàn)看到一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)模型用于訓(xùn)
    發(fā)表于 08-18 07:25

    一種新型神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu):膠囊網(wǎng)絡(luò)

    膠囊網(wǎng)絡(luò)是 Geoffrey Hinton 提出的一種新型神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),為了解決卷積神經(jīng)網(wǎng)絡(luò)(ConvNets)的一些缺點(diǎn),提出了膠囊
    的頭像 發(fā)表于 02-02 09:25 ?5886次閱讀

    基于PyTorch的深度學(xué)習(xí)入門(mén)教程之使用PyTorch構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò)

    PyTorch的自動(dòng)梯度計(jì)算 Part3:使用PyTorch構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò) Part4:訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò)分類(lèi)器 Part5:數(shù)據(jù)并行化 本文是關(guān)于Part3的內(nèi)容。 Part3:使
    的頭像 發(fā)表于 02-15 09:40 ?2113次閱讀

    PyTorch教程8.1之深度卷積神經(jīng)網(wǎng)絡(luò)(AlexNet)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程8.1之深度卷積神經(jīng)網(wǎng)絡(luò)(AlexNet).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 10:09 ?0次下載
    <b class='flag-5'>PyTorch</b>教程8.1之深度卷積<b class='flag-5'>神經(jīng)網(wǎng)絡(luò)</b>(AlexNet)

    PyTorch教程之循環(huán)神經(jīng)網(wǎng)絡(luò)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程之循環(huán)神經(jīng)網(wǎng)絡(luò).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 09:52 ?0次下載
    <b class='flag-5'>PyTorch</b>教程之循環(huán)<b class='flag-5'>神經(jīng)網(wǎng)絡(luò)</b>

    PyTorch教程之從零開(kāi)始的遞歸神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程之從零開(kāi)始的遞歸神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 09:55 ?0次下載
    <b class='flag-5'>PyTorch</b>教程之從零開(kāi)始的遞歸<b class='flag-5'>神經(jīng)網(wǎng)絡(luò)</b><b class='flag-5'>實(shí)現(xiàn)</b>

    PyTorch教程9.6之遞歸神經(jīng)網(wǎng)絡(luò)的簡(jiǎn)潔實(shí)現(xiàn)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程9.6之遞歸神經(jīng)網(wǎng)絡(luò)的簡(jiǎn)潔實(shí)現(xiàn).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 09:56 ?0次下載
    <b class='flag-5'>PyTorch</b>教程9.6之遞歸<b class='flag-5'>神經(jīng)網(wǎng)絡(luò)</b>的簡(jiǎn)潔<b class='flag-5'>實(shí)現(xiàn)</b>

    PyTorch教程10.3之深度遞歸神經(jīng)網(wǎng)絡(luò)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程10.3之深度遞歸神經(jīng)網(wǎng)絡(luò).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 15:12 ?0次下載
    <b class='flag-5'>PyTorch</b>教程10.3之深度遞歸<b class='flag-5'>神經(jīng)網(wǎng)絡(luò)</b>

    PyTorch教程16.2之情感分析:使用遞歸神經(jīng)網(wǎng)絡(luò)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程16.2之情感分析:使用遞歸神經(jīng)網(wǎng)絡(luò).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 10:55 ?0次下載
    <b class='flag-5'>PyTorch</b>教程16.2之情感分析:使用遞歸<b class='flag-5'>神經(jīng)網(wǎng)絡(luò)</b>

    使用PyTorch構(gòu)建神經(jīng)網(wǎng)絡(luò)

    PyTorch是一個(gè)流行的深度學(xué)習(xí)框架,它以其簡(jiǎn)潔的API和強(qiáng)大的靈活性在學(xué)術(shù)界和工業(yè)界得到了廣泛應(yīng)用。在本文中,我們將深入探討如何使用PyTorch構(gòu)建神經(jīng)網(wǎng)絡(luò),包括從基礎(chǔ)概念到高級(jí)特性的全面解析。本文旨在為讀者提供一個(gè)完整的
    的頭像 發(fā)表于 07-02 11:31 ?729次閱讀

    PyTorch神經(jīng)網(wǎng)絡(luò)模型構(gòu)建過(guò)程

    PyTorch,作為一個(gè)廣泛使用的開(kāi)源深度學(xué)習(xí)庫(kù),提供了豐富的工具和模塊,幫助開(kāi)發(fā)者構(gòu)建、訓(xùn)練和部署神經(jīng)網(wǎng)絡(luò)模型。在神經(jīng)網(wǎng)絡(luò)模型中,輸出層是尤為關(guān)鍵的部分,它負(fù)責(zé)將模型的預(yù)測(cè)結(jié)果以合適的形式輸出。以下將詳細(xì)解析
    的頭像 發(fā)表于 07-10 14:57 ?515次閱讀

    pytorch中有神經(jīng)網(wǎng)絡(luò)模型嗎

    當(dāng)然,PyTorch是一個(gè)廣泛使用的深度學(xué)習(xí)框架,它提供了許多預(yù)訓(xùn)練的神經(jīng)網(wǎng)絡(luò)模型。 PyTorch中的神經(jīng)網(wǎng)絡(luò)模型 1. 引言 深度學(xué)習(xí)是一種基于人工
    的頭像 發(fā)表于 07-11 09:59 ?712次閱讀

    PyTorch如何實(shí)現(xiàn)多層全連接神經(jīng)網(wǎng)絡(luò)

    PyTorch實(shí)現(xiàn)多層全連接神經(jīng)網(wǎng)絡(luò)(也稱(chēng)為密集連接神經(jīng)網(wǎng)絡(luò)或DNN)是一個(gè)相對(duì)直接的過(guò)程,涉及定義網(wǎng)絡(luò)結(jié)構(gòu)、初始化參數(shù)、前向傳播、損失
    的頭像 發(fā)表于 07-11 16:07 ?1236次閱讀