編者按:幾個月前,Deepmind在ICML上發(fā)表了一篇論文《Neural Processes》,提出了一種兼具神經(jīng)網(wǎng)絡(luò)高效性和高斯過程靈活性的方法——神經(jīng)過程,被稱為是高斯過程的深度學(xué)習(xí)版本。雖然倍受關(guān)注,但目前真正能直觀解讀神經(jīng)過程的文章并不多,今天論智帶來的是牛津大學(xué)在讀PHDKaspar M?rtens的一篇可視化佳作。
在今年的ICML上,研究人員提出了不少有趣的工作,其中神經(jīng)過程(NPs)引起了許多人的注意,它基于神經(jīng)網(wǎng)絡(luò)概率模型,但又可以表示隨機過程的分布。這意味著NPs結(jié)合了兩個領(lǐng)域的元素:
深度學(xué)習(xí):神經(jīng)網(wǎng)絡(luò)是靈活的非線性函數(shù),可以直接訓(xùn)練
高斯過程:GP提供了一個概率框架,可用于學(xué)習(xí)非線性函數(shù)的分布
兩者都有各自的優(yōu)點和缺點。當(dāng)數(shù)據(jù)量有限時,由于本身具備概率性質(zhì)可以描述不確定性,GP是首選(這和非貝葉斯神經(jīng)網(wǎng)絡(luò)不同,后者只能捕捉單個函數(shù),而不是函數(shù)分布);而當(dāng)有大量數(shù)據(jù)時,訓(xùn)練神經(jīng)網(wǎng)絡(luò)比GP推斷更具擴展性,因此優(yōu)勢更大。
神經(jīng)過程的目標(biāo)就是實現(xiàn)神經(jīng)網(wǎng)絡(luò)和GP的優(yōu)勢融合。
什么是神經(jīng)過程?
NP是一種基于神經(jīng)網(wǎng)絡(luò)的方法,用于表示函數(shù)的分布。下圖展示了如何建立NP模型,以及訓(xùn)練模型背后的一般想法:
給定一系列觀察值(xi,yi),把它們分成“context points”和“target points”兩組?,F(xiàn)在,我們要根據(jù)“context points”中已知的輸入輸出對(xc,yc),其中c=1,…,C,和“target points”中的未知輸入x?t,其中t=1,…,T,預(yù)測其相應(yīng)的函數(shù)值y?t。
我們可以把NP看作是根據(jù)“context points”中的“target points”建模的模型,相關(guān)信息通過潛在空間z從左側(cè)流向右側(cè),從而提供新的預(yù)測。右側(cè)本質(zhì)上是從x映射到y(tǒng)的有限維嵌入,而z是個隨機變量,這就使NP成了概率模型,能捕捉函數(shù)的不確定性。一旦模型完成訓(xùn)練,我們就可以用z的近似后驗分布作為測試時進(jìn)行預(yù)測的先驗。
乍看之下,這種分“context points”和“target points”的做法有點類似把數(shù)據(jù)集分成訓(xùn)練集和測試集,但事實并非如此,因為“target points”集也是直接參與NP模型訓(xùn)練的——這意味著模型的(概率)損失函數(shù)在這個集上有明確意義。這樣做也有助于防止模型過擬合和提供更好的泛化性。在實踐中,我們還需要反復(fù)把訓(xùn)練數(shù)據(jù)通過隨機采樣分為“context points”中的“target points”,以獲得更全面的概括。
讓我們來思考以下兩種情況:
基于單個數(shù)據(jù)集推斷函數(shù)的分布
當(dāng)存在多個數(shù)據(jù)集且它們之間存在某種相關(guān)性時,推斷函數(shù)的分布
對于情況一,常規(guī)的(概率)監(jiān)督學(xué)習(xí)就能解決:給定一個包含N個樣本的數(shù)據(jù)集,比如(xi, yi),其中i=1,…,N。假設(shè)確實存在一個函數(shù)f,它能產(chǎn)生yi=f(xi),我們的目標(biāo)就是學(xué)習(xí)f的后驗分布,然后用它預(yù)測測試集上某點的函數(shù)值f(x?)。
對于情況二,我們則需要從元學(xué)習(xí)的角度去觀察。給定D個數(shù)據(jù)集,其中d=1,…,D,每個數(shù)據(jù)集包含Nd個數(shù)據(jù)對(xi(d), yi(d))。如果我們假設(shè)每個數(shù)據(jù)集都有自己的基函數(shù)fd,輸入xi后,它們有yi=fd(xi),那么在這種情況下,我們就可能想要了解每個fd的后驗分布,然后把經(jīng)驗推廣到新數(shù)據(jù)集d?上。
對于數(shù)據(jù)集很多但它們的樣本很少的情況,情況二的做法特別有用,因為這時模型學(xué)到的經(jīng)驗基于所有fd,它的內(nèi)核、超參數(shù)是這些函數(shù)共享的。當(dāng)給出新的數(shù)據(jù)集d?時,我們可以用后驗函數(shù)作為先驗函數(shù),然后執(zhí)行函數(shù)回歸。
之所以要舉著兩個例子,是因為一般來說,GP適用于情況一,即便N很小,這種做法也很有效。而NP背后的思路似乎主要來自元學(xué)習(xí)——在這種情況下,潛在的z可以被看作是用于不同數(shù)據(jù)集間信息共享的機制。但是,NP同樣具有概率模型的特征,事實上,它同時適用于以上兩種情況,具體分析請見下文。
NP模型是怎么實現(xiàn)的?
下面是NP生成模型的詳細(xì)圖解:
如果要逐步分解這個過程,就是:
首先,“context points”里的數(shù)據(jù)(xc,yc)通過神經(jīng)網(wǎng)絡(luò)h映射,獲得潛在表征rc
其次,這個向量rc經(jīng)聚合(操作:平均)獲得單個值r(和每個rc具有相同的維數(shù))
這個r的作用是使z的分布參數(shù)化,例如p(z|x1:C,y1:C)=N(μz(r),σ2z(r))
最后,為了預(yù)測輸入x?t后的函數(shù)值,對z采樣并將樣本與x?t組成數(shù)對,用神經(jīng)網(wǎng)絡(luò)g映射(z,x?t)獲得預(yù)測分布中的樣本y?t。
NP的推斷是在變分推斷(VI)框架中進(jìn)行的。具體來說,我們介紹了兩種近似分布:
讓q(z|context)去近似條件先驗p(z|context)
讓q(z|context,target)去近似于各自的p(z|context,target),其中context:=(x1:C,y1:C),target:=(x?1:T,y?1:T)
下圖是近似后驗q(z|·)的具體推斷過程。也就是說,我們用相同的神經(jīng)網(wǎng)絡(luò)h映射兩個數(shù)據(jù)集,獲得聚合的r,再把r映射到μz和σz,使后驗q(z|?)=N(μz,σz)被參數(shù)化。
變分下界包含兩個項(下式),其中第一項是target集上的預(yù)期對數(shù)似然,即先從z~q(z|context,target)上采樣(上圖左側(cè)),然后用這個z在target set上預(yù)測(上圖右側(cè))。
第二項是個正則項,它描述了q(z|context,target)和q(z|context)之間的KL散度。這和常規(guī)的KL(q||p)有點不同,因為我們的生成模型一開始就把p(z|context)當(dāng)做條件先驗,不是p(z),而這個條件先驗有依賴于神經(jīng)網(wǎng)絡(luò)h,這就是我們沒法得到確切值,只能用一個近似值q(z|context)。
實驗
NP作為先驗
我們先來看看把NP作為先驗的效果,也就是沒有觀察任何數(shù)據(jù),模型也沒有經(jīng)過訓(xùn)練。初始化權(quán)重后,對z~N(0,I)進(jìn)行采樣,然后通過x?值的生成先驗預(yù)測分布并繪制函數(shù)圖。
和具有可解釋內(nèi)核超參數(shù)的GP相反,NP先驗不太明確,它涉及各種架構(gòu)選擇(如多少隱藏層,用什么激活函數(shù)等),這些都會影響函數(shù)空間的先驗分布。
例如,如果我們用的激活函數(shù)是sigmoid,調(diào)整z的維數(shù)為{1, 2, 4, 8}。
如果用的是ReLU:
在一個小數(shù)據(jù)集上訓(xùn)練NP
假設(shè)我們只有5個數(shù)據(jù)點:
由于NP模型需要context set和target set兩個數(shù)據(jù)集,一種方法是選取固定大小的context set,另一種方法則是用不同大小的context set,然后多迭代幾次(1個點、2個點……以此類推)。一旦模型在這些隨機子集上完成訓(xùn)練,我們就可以用它作為所有數(shù)據(jù)的先驗和條件,然后根據(jù)預(yù)測結(jié)果繪制圖像。下圖展示了NP模型訓(xùn)練時的預(yù)測分布變化。
可以發(fā)現(xiàn),NP似乎已經(jīng)成功學(xué)習(xí)了這5個數(shù)據(jù)點的映射分布,那它的泛化性能如何呢?我們把這個訓(xùn)練好的模型放在另一個新的context set上,它的表現(xiàn)如下圖所示:
這個結(jié)果不足為奇,數(shù)據(jù)量太少了,模型過擬合可以理解。為了更好地提高模型泛化性,我們再來試試更大的函數(shù)集。
在一小類函數(shù)上訓(xùn)練NP
上文已經(jīng)用單個(固定)數(shù)據(jù)集探索了模型的訓(xùn)練情況,為了讓NP像GP一樣通用,我們需要在更大的一類函數(shù)上進(jìn)行訓(xùn)練。但在準(zhǔn)備復(fù)雜函數(shù)前,我們先來看看模型在簡單場景下的表現(xiàn),也就是說,這里觀察的不是單個函數(shù),而是一小類函數(shù),比如它們都包含a?sin(x),其中a∈[?1,1]。
我們的目標(biāo)是探究:
NP能不能捕捉這些函數(shù)?
NP能不能概括這類函數(shù)以外的函數(shù)?
下面是具體步驟:
設(shè)a滿足均勻分布:a~U(?2,2)
設(shè)xi~U(?3,3)
定義yi:=f(xi),其中f(x)=a?sin(x)
把數(shù)據(jù)對(xi,yi)隨機分成context set和target set兩個數(shù)據(jù)集,并進(jìn)行優(yōu)化
重復(fù)上述步驟
為了方便可視化,這里我們用了二維z,具體圖像如下所示:
從左往右看,模型似乎編碼了參數(shù)a,如果這幅圖不夠直觀,下面是調(diào)整某一潛在維度(z1或z2)的動態(tài)可視化:
需要注意的是,這里我們沒有用任何context set里的數(shù)據(jù),只是為了可視化指定了具體的(z1, z2)值。接下來,就讓我們用這個模型進(jìn)行預(yù)測。
如下左圖所示,當(dāng)context set數(shù)據(jù)集里只包含(0, 0)一個點時,模型覆蓋了一個較寬的范圍,包含不同a取值下a?sin(x)的值域(雖然a∈[?2,2],但訓(xùn)練時并沒有完全用到)。
往context set數(shù)據(jù)集里添加第二個點(1,sin(1))后,可視化如中圖所示,相比左圖,它不再包含a為負(fù)數(shù)的情況。右圖是繼續(xù)添加f(x)=1.0sin(x)的點后的情況,這時模型后驗開始接近函數(shù)的真實分布情況。
這之后,我們就可以開始探究NP模型的泛化性,以2.5sin(x)和|sin(x)|為例,前者需要在a?sin(x)的基礎(chǔ)上做一些推斷,而后者的值始終是個正數(shù)。
如上圖所示,模型的值域還是和訓(xùn)練期間一樣,但它在兩種情況下都出現(xiàn)了符合函數(shù)分布的一些預(yù)期。需要注意的是,這里我們并沒有給NP提供足夠多的不確定性,所以它預(yù)測不準(zhǔn)確也情有可原,畢竟比起易于解釋的模型,這種自帶黑盒特性的模型更難衡量。
之后,作者又比較了GP和NP的預(yù)測分布情況,發(fā)現(xiàn)兩者性能非常接近,只是隨著給出的數(shù)據(jù)點越來越多時,NP會因為架構(gòu)選擇(神經(jīng)網(wǎng)絡(luò)過小、低緯度z)出現(xiàn)性能急劇下降。對此,以下幾個改進(jìn)方法可以幫助解決問題:
2維z適合用于學(xué)習(xí)理解,在實際操作中,可視情況采用更高的維度
讓神經(jīng)網(wǎng)絡(luò)h和g變得更深,擴大隱藏層
在訓(xùn)練期間使用更多樣化的函數(shù)(更全面地訓(xùn)練NP超參數(shù)),可提高NP模型泛化性
結(jié)論
雖然NP號稱結(jié)合了神經(jīng)網(wǎng)絡(luò)和GP,能預(yù)測函數(shù)的分布,但它從本質(zhì)上看還是更接近神經(jīng)網(wǎng)絡(luò)模型——只需優(yōu)化架構(gòu)和訓(xùn)練過程,模型性能就可以大幅提高。但是,這些變化都是隱含的,使得NP更難被解釋為先驗。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4773瀏覽量
100861 -
數(shù)據(jù)集
+關(guān)注
關(guān)注
4文章
1208瀏覽量
24727 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5506瀏覽量
121255
原文標(biāo)題:函數(shù)分布視角下的神經(jīng)過程
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論