雷鋒網 AI 科技評論按:近期,澳大利亞迪肯大學圖像識別和數據分析中心發表了一篇新的論文,由Tu Dinh Nguyen, Trung Le, Hung Vu, Dinh Phung編寫,該論文就生成對抗網絡(GAN)的模式崩潰問題進行了討論並給出了一種新的有效的解決方案 D2GAN,論文譯稿由雷鋒網 AI 科技評論編輯,原文連結請點擊。
這篇文章介紹了一種解決生成對抗網絡(GAN)模式崩潰問題的方法。這種方法很直觀但是證實有效,特別是當對GAN預先設置一些限制時。在本質上,它結合了Kullback-Leibler(KL)和反向KL散度的差異,生成一個目標函數,從而利用這些分支的互補統計特性捕捉多模式下分散預估密度。這種方法稱為雙鑑別器生成對抗網絡(Dual discriminator generative adversarial nets, D2GAN),顧名思義,與GAN不同的是,D2GAN有兩個鑑別器。這兩個鑑別器仍然與一個生成器一起進行極大極小的博弈,一個鑑別器會給符合分布的數據樣本給與高獎勵,而另外一個鑑別器卻更喜歡生成器生成的數據。生成器就要嘗試同時欺騙兩個鑑別器。理論分析表明,假設使用最強的鑑別器,優化D2GAN的生成器可以讓原始資料庫和生成器產生的數據間的KL和反向KL散度最小化,從而有效地避免模式崩潰的問題。作者進行了大量的合成和真實資料庫的實驗(MNIST,CIFAR-10,STL-10,ImageNet),對比D2GAN和最新的GAN變種的方法,並進行定性定量評估。實驗結果有效地驗證了D2GAN的競爭力和優越的性能,D2GAN生成樣本的質量和多樣性要比基準模型高得多,並可擴展到ImageNet資料庫。
簡介
生成式模型是研究領域的一大分支並且在最近幾年得到了飛速的成長,成功地部署到很多現代的應用中。一般的方法是通過解決密度預測問題,即學習模型分布Pmodel來預測置信度,在數據分布Pdata未知的情況下。這種方法的實現需要解決兩個基本問題。
首先,生成模型的學習表現基於訓練這些模型的目標函數的選擇。最為廣泛使用的目標,即事實標準目標,是遵循遵循最大似然估計原理,尋求模型參數以最大限度地提高訓練數據的似然性。這與最小化KL散度數據分布和模型分布上的差異的方法相似
。這種最小化會導致Pmodel覆蓋Pdata的多種模式,但是可能會引起一些完全看不到的和潛在的不希望的樣本。相反地,另外一種方法通過交換參數,最小化:
,一般稱其為反KL散度。觀察發現,對反KL散度準則優化模擬了模式搜索的過程,Pmodel集中在Pdata的單一模式,而忽略了其他模式,稱這種問題為模式崩潰。
第二個問題是密度函數Pmodel公式的選擇問題。一種方法是定義一個明確的密度函數,然後直接的根據最大似然框架進行參數估計。另外一種方法是使用一個不明確的密度函數記性數據分布估計,不需要使用Pmodel的解析形式。還有一些想法是借用最小包圍球的原理來訓練生成器,訓練和生成的數據,在被映射到特徵空間後,被封閉在同一個球體中。這種方法最為著名的先驅應用是生成對抗網絡(GAN),它是一種表達生成模型,具備生成自然場景的尖銳和真實圖像的能力。與大多數生成模型不同的是,GAN使用了一種激進的方法,模擬了遊戲中兩個玩家對抗的方法:一個生成器G通過從噪聲空間映射輸入空間來生成數據;鑑別器D則表現得像一個分類器,區分真實的樣本和生成器生成的偽圖像。生成器G和鑑別器D都是通過神經網絡參數化得來的,因此,這種方法可以歸類為深度生成模型或者生成神經模型。
GAN的優化實際上是一個極大極小問題,即給定一個最優的D,學習的目標變成尋找可以最小化Jensen-Shannon散度(JSD)的G:
。JSD最小化的行為已經被實踐證實相較於KL散度更近似於反KL散度。這,另一方面,也導致了之前提到的模式崩潰問題,在GAN的應用領域臭名昭著,即生成器只能生成相似的圖片,低熵分布,樣本種類匱乏。
近期的研究通過改進GAN的訓練方式來解決模式崩潰的問題。一個方法是使用mini-batch分辨法巧妙地讓鑑別器分辨與其他生成樣本非正常相似的圖片。儘管這種啟發方式可以幫助快速生成具有視覺吸引力的樣本,但是它的計算代價很高,因此,通常應用於鑑別器的最後一個隱藏層。另外一個方法是把鑑別器的優化通過幾個步驟展開,在訓練中產生一個代理目標來進行生成器的更新。第三種方法是訓練多個生成器,發現不同的數據模式。同期的,還有一些其他的方法,運用autoencoders進行正則化或者輔助損失來補償丟失的模式等。這些方法都可以在一定程度上改善模式崩潰的問題,但是由此帶來了更高的計算複雜度,從而無法擴展到ImageNet這種大規模的和具有挑戰性的視覺資料庫上。
應對這些挑戰,作者們在這篇論文中提出了一種新的方法,既可以高效地避免模式崩潰問題又可以擴展到龐大的資料庫(比如:ImageNet等)。通過結合KL和反KL散度生成一個統一的目標函數,從而利用了兩種散度的互補統計特性,有效地在多模式下分散預估密度。使用GAN的框架,量化這種思路,便形成了一種新穎的生成對抗架構:鑑別器D1(通過鑑別數據來自於Pdata而不在生成分布PG中獲取高分),鑑別器D2(相反地,來自於PG而不在Pdata中)和生成器G(嘗試欺騙D1、D2兩個鑑別器)。作者將這種方法命名為雙鑑別器生成對抗網絡(D2GAN)。
實驗證明,訓練D2GAN與訓練GAN會遇到同樣的極大極小問題,通過交替更新生成器和鑑別器可以得到解決。理論分析表明,如果G、D1和D2具有足夠的容量,如非參數的限制下,在最佳點,對KL和反KL散度而言,訓練標準確實導致了數據和模型分布之間的最小距離。這有助於模型在各種數據分布模式下進行公平的概率分布,使得生成器可一次完成數據分布恢復和生成多樣樣本。另外,作者還引入了超參數實現穩定地學習和各種散度影響的控制。
作者進行了大量的實驗,包括一個合成資料庫和具備不同特徵的四個真實大規模資料庫(MNIST、CIFAR10、STL-10、ImageNet)。眾所周知,評估生成模型是非常困難的,作者花費了很多時間,使用了各種評估辦法,定量的對比D2GAN和最新的基線方法。實驗結果表明,D2GAN可以在保持生成樣本質量的同時提高樣本的多樣性。更重要的是,這種方法可以擴展到更大規模的資料庫(ImageNet),並保持具有競爭力的多樣性結果和生成合理的高品質樣本圖片。
簡而言之,這種方法具有三個重要的貢獻:(i)一種新穎的生成對抗模型,提高生成樣本的多樣性;(ii)理論分析證實這種方法的目標是優化KL和反KL散度的最小差異,並在PG=Pdata時,實現全局最優;(iii)使用大量的定量標準和大規模資料庫對這種方法進行綜合評估。
作者們的實現方法如下:
生成對抗網絡
首先介紹一下生成對抗網絡(GAN),具有兩個玩家:鑑別器D和生成器G。鑑別器D(x),在數據空間中取一個點x,然後計算x在數據分布Pdata中而不是生成器G生成的概率。同時,生成器先向數據空間映射一個取自先導P(z)的噪聲向量z,獲取一個類似於訓練數據的樣本G(z),然後使用這個樣本來欺騙鑑別器。G(z)形成了一個在數據域的生成分布PG,和概率密度函數PG(x)。G和D都由神經網絡構成(見圖1a),並通過如下的極大極小優化得以學習:
學習遵循一個迭代的過程,其中鑑別器和生成器交替地更新。假設固定G,最大化D可以獲得最優鑑別器
,同時,固定最優D*,最小化G可以實現最小化Jensen-Shannon(JS)散度(數據和模型分布:
)。在博弈的納什均衡下,模型分布完全恢復了數據分布:PG=Pdata,從而鑑別器現在無法分辨真假數據:
。
由於JS散度通過大量的實驗數據證實與反KL散度的特性相同,GAN也會有模式崩潰的問題,因此,其生成的數據樣本多樣性很低。
雙鑑別器生成對抗網絡
為了解決GAN的模式崩潰問題,下方介紹了一種框架,尋求近似分布來有效地涵蓋多模式下的多模態數據。這種方法也是基於GAN,但是有三個組成部分,包括兩個不同的鑑別器D1、D2和一個生成器G。假定一個數據空間中的樣本x,如果x是數據分布Pdata中的,D1(x)獲得高分,如果是模式分布PG中的,則獲得低分。相反地,如果x是模式分布PG中的,D2(x)獲得高分,如果是數據分布Pdata中的,D2(x)獲得低分。與GAN不同的是,得分的表現形式為R+而不是[0,1]中的概率。生成器G的角色與GAN中的相似,即從噪聲空間中映射數據與真實數據進行合成後欺騙D1和D2兩個鑑別器。這三個部分都由神經網絡參數化而成,其中D1和D2不分享它們的參數。這種方法被稱為雙鑑別器生成對抗網絡(D2GAN),見上圖1b。D1、D2和G遵循如下的極大極小公式:
其中超參數
為了實現兩個目的。第一個是為了穩定化模型的學習過程。兩個鑑別器的輸出結果都是正的,D1(G(z))和D2(x)可能會變得很大並比LogD1(x)和LogD2(x)有指數性的影響,最終會導致學習的不穩定。為了克服這個問題,降低α和β的值。第二個目的是控制KL和反KL散度對優化的影響。後面介紹過優化方法後再對這個部分進行討論。
與GAN相似的是,通過交替更新D1、D2和G可以訓練D2GAN。
理論分析
通過理論分析發現,假設G、D1和D2具備足夠的容量,如非參數的限制下,在最佳點,G可以通過最小化模型和數據分布的KL和反KL散度恢復數據分布。首先,假設生成器是固定的,通過(w.r.t)鑑別器進行優化分析:
證明:根據誘導測度定理,兩個期望相等:
當
時,
。目標函數可以推演如下:
考慮到裡面的函數積分,給定x,通過兩個變量D1、D2最大化函數,得到D1*(x)和D2*(x)。將D1和D2設置為0,可以得到:
是非正數,則證明成立並得到了最大值。
接下來,
,計算生成器G的最優方案G*。
證明:將D1*和D2*代入極大極小方程,得到:
分別是KL和反KL散度。這些散度通常是非負的,並且只在PG*=Pdata時等於0。換言之,生成器生成的分布PG*與數據分布完全等同,這就意味著由於兩個分布的返回值都是1,兩個鑑別器在這種情況下就不能分辨真假樣本了。
如上公式中生成器的誤差表明提高α可以促進最小化KL散度(
)的優化,提高β可促進最小化反KL散度(
)的優化。通過調整α和β這兩個超參數,可以平衡KL散度和反KL散度的影響,從而有效地避免模式崩潰的問題。
實驗
在這個部分,作者進行了廣泛的實驗來驗證的提高模式覆蓋率和提出的方法應用在大規模資料庫上的能力。使用一個合成的2D資料庫進行視覺和數值驗證,並使用四個真實的資料庫(具有多樣性和大規模)進行數值驗證。同時,將D2GAN和最新的GAN的應用進行對比。
從大量的實驗得出結論:(i)鑑別器的輸出具有softplus activations:
,如正ReLU;(ii)Adam優化器,學習速率0.0002,一階動量0.5;(iii)64個樣本作為訓練生成器和鑑別器的minibatch訓練單元;(iv)0.2斜率的Leaky ReLU;(v)權重從各項同性的高斯(Gaussian)分布:
進行初始化,0偏差。實現的過程使用了TensorFlow,並且在文章發表後發布出來。下文將介紹實驗過程,首先是合成資料庫,然後是4個真實資料庫。
合成資料庫
在第一個實驗中,使用已經設計好的實驗方案對D2GAN處理多模態數據的能力進行評估。特別的是,從2D混合8個高斯分布和協方差矩陣0.02I獲取訓練數據,同時中位數分布在半徑2.0零質心的圓中。使用一個簡單的架構,包含一個生成器(兩個全連接隱藏層)和兩個鑑別器(一個ReLU激發層)。這個設定是相同的,因此保證了公平的對比。圖2c顯示了512個由D2GAN和基線生成的樣本。可以看出,常規的GAN產生的數據在數據分布的有效模式附近的一個單一模式上奔潰了。而unrolledGAN和D2GAN可以在8個混合部分分布數據,這就印證了能夠成功地學習多模態數據的能力。最後,D2GAN所截取的數據比unrolledGAN更精確,在各種模式下,unrolledGAN只能集中在模式質心附近的幾個點,而D2GAN產生的樣本全分布在所有模式附近,這就意味著D2GAN產生的樣本比unrolledGAN多得多。
下一步,定量的進行生成數據質量的對比。因為已知真實的分布Pdata,只需進行兩步測量,即對稱KL散度和Wasserstein距離。這些測量分別是對由D2GAN、unrolledGAN和GAN的10000個點歸一化直方與真實的Pdata之間的距離計算。圖2a/b再次清楚了表明了D2GAN相對於unrolled和GAN的優勢(距離越小越好);特別是Wasserstein度量,D2GAN離真實分布的距離基本上減小到0了。這些圖片也表達了D2GAN相對於GAN(綠色曲線)和unrolledGAN(藍色曲線)在訓練時的穩定性。
真實資料庫
下面,使用真實資料庫對D2GAN進行評估。在真實資料庫條件下,數據擁有更高的多樣性和更大的規模。對含有卷積層的網絡,根據DCGAN進行設計分析。鑑別器使用帶步長的卷積,生成器使用分步帶步長的卷積。每個層都進行批處理標準化,除了生成器輸出層和鑑別器的輸入層。鑑別器還使用Leaky ReLU 激發層,生成器使用ReLU層,除非其輸出是tanh,原因是各像素的強度值在反饋到D2GAN模型前已經變換到[-1,1]的區間內。唯一的區別是,在D2GAN下,當從N(0,0.01)初始化權重時,其表現比從N(0,0.02)初始化權重的效果好。架構的細節請看論文附件。
評估方式
評估生成對抗模型產生的樣本是很難的,原因有生成概率判斷標準繁多、缺乏有意義的圖像相似性度量標準。儘管生成器可以產生看似真實的圖像,但是如果這些圖像看起來非常近似,樣本依然不可使用。因此,為了量化各種模式下的圖像質量,同時生產高質量的樣本圖樣,使用各種不用的ad-hoc度量進行不同的實驗來進行D2GAN方法與各基線方法的效果對比。
首先,使用起始分值(Inception Score),計算通過:
,這裡P(y|x)是通過預訓練的初始模型的圖像x的條件標籤分布,P(y)是邊際分布:
。這種度量方式會給質量高的多樣的圖片給高分,但是有時候很容易被崩潰的模式欺騙,導致產生非常低質量的圖片。因此,這種方式不能用於測量模型是否陷入了錯誤的模式。為了解決這個問題,對有標籤的資料庫,使用MODE score:
這裡,
是訓練數據的預估標籤的經驗分布。MODE score的值可以充分的反應生成圖像的多樣性和視覺質量。
手寫數字圖像
這個部分使用手寫數字圖像-MNIST,資料庫包含有60,000張訓練圖像和10,000張測試灰度圖(28*28像素),數值區間從0到9。首先,假設MNIST有10個模式,代表了數據分支的連接部分,分為10個數字等級。然後使用不同的超參數配置進行擴展的網格搜索,使用兩個正則常數α和β,數值為{0.01,0.05,0.1,0.2}。為了進行公平的對比,對不同的架構使用相同的參數和全連接層。
評估部分,首先訓練一個簡單的但有效的3-layer卷積網絡(MNIST測試庫實現0.65%的誤差),然後將它應用於預估標籤的概率和生成樣本的MODE score計算中。圖3左顯示了3個模式下MODE score的分布。清晰的看到,D2GAN相對於標準GAN和Reg-GAN的巨大優越性,其分數的最大值基本落在區間【8.0-9.0】。值得注意的是,儘管提高網絡的複雜度,MODE score基本保持高水平。這幅圖片中只表現了最小網絡和最少層和隱藏單元的結果。
為了研究α和β的影響,在不同的α和β的數值下進行試驗(圖3右)。結果表明,給定α值,D2GAN可以在β達到一定數值時獲得更好的MODE score,當β數值繼續增大,MODE score降低。
MNIST-1K.假定10個模式的標準MNIST資料庫相當簡單。因此,基於這個資料庫,作者使用一個更具挑戰性的資料庫進行測試。沿用上述的方式,假定一個新的有1000個等級的MNIST資料庫(MNIST-1K),方法為用3個隨機數字組成一個RGB圖像。由此,可以組成1000個離散的模式,從000到999。
在這個實驗中,使用一個更強大的模型,鑑別器使用卷積層,生成器使用轉置卷積。通過測試模式的數量進行模型的性能評估,其中模型在25,600個樣本中至少產生一個模式,同時反KL散度分布介於模型分布(如從預訓練的MNIST分類器預測的標籤分布)和期望的數據分布之間。表1報告了D2GAN與GAN、unrolledGAN、GCGAN和Reg-GAN之間的對比。通過對比可以看出D2GAN具有極大的優勢,同時模型分布和數據分布之間的差距幾近為0。
自然場景圖像
下面是將D2GAN應用到更廣泛的自然場景圖像上,用於驗證其在大規模資料庫上的表現。使用三個經常被使用的資料庫:CIFAR-10,STL-10和ImageNet。CIFAR-10包含50,000張32*32的訓練圖片,有10個等級:飛機,摩託車,鳥,貓,鹿,狗,青蛙,馬,船和卡車(airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck)。STL-10,是ImageNet的子數據集,包含10,000張未被標記的96*96的圖片,相對於CIFAR-10更多樣,但是少於ImageNet。將所有的圖片向下縮小3倍至32*32解析度後,再對網絡進行訓練。ImageNet非常龐大,擁有120百萬自然圖片,包含1000個類別,通常ImageNet是深度網絡領域訓練使用的最為龐大和複雜的資料庫。使用這三個資料庫進行蓄念和計算,Inception score的結果如下圖和下方表格所示:
表格中和圖4中表示了Inception score在不同資料庫和不同模型上的不同值。值得注意的是,這邊的對比是在一個完美無監督的方法下,並且沒有標籤的信息。在CIFAR-10資料庫上使用的8個基線模型,而在STL-10和ImageNet資料庫上使用了DCGAN、DFM(denoising feature matching)作對比。在D2GAN的實現上使用了與DCGAN完全一致的網絡架構,以做公平的對比。在這三個實驗結果中,可以看出,D2GAN的表現低於DFM,但是在很大的程度高於其他任何一個基線模型。這種遜於DFM的結果印證了對高級別的特徵進行自動解碼是提高多樣性的一種有效方法。D2GAN可與這種方式兼容,因此融合這種做法可以為未來的研究帶來更好的效果。
最後,在圖5中展現了使用D2GAN生成的樣本圖片。這些圖片都是隨機產生的,而不是特別挑選的。從圖片中可以看出,D2GAN生成了可以視覺分辨的車,卡車,船和馬(在CIFAR-10資料庫的基礎上)。在STL-10的基礎上,圖片看起來相對比較難以辨認,但是飛機,車,卡車和動物的輪廓還是可以識別的;同時圖片還具備了多種背景,如天空,水下,山和森林(在ImageNet的基礎上)。這印證了使用D2GAN可以生成多樣性的圖片的結論。
結論
總結全文,作者介紹了一種全新的方法,結合KL(Kullback-Leibler)和反KL散度生成一個統一的目標函數來解決密度預測問題。這種方法利用了這兩種散度的互補統計特性來提高生成器產生的圖像的質量和樣本的多樣性。基於這個原理,作者引入了一種新的網絡,基於生成對抗網絡(GAN),由三方構成:兩個鑑別器和一個生成器,並命其為雙鑑別器生成對抗網絡(dual discriminator GAN, D2GAN)。如果設定兩個鑑別器是固定的,同時優化KL和反KL散度進行生成器的學習,通過這種方法可以幫助解決模式崩潰的問題(GAN的一大亟待突破的難點)。
作者通過大量的實驗對其提出的方法進行了驗證。這些實驗的結果證實了D2GAN的高效性和擴展性。實驗使用的資料庫包括合成資料庫和大規模真實圖片資料庫,即MNIST、CIFAR-10,STL-10和ImageNet。相較於最新的基線方法,D2GAN更具擴展性,可以應用於業內最為龐大和複雜的資料庫ImageNet,儘管取得了比融合DFM(denoising feature matching)和GAN的方法低的Inception score,但遠遠高於其他GAN應用的實驗結果。最後,作者指出,未來的研究可以借鑑融合DFM和GAN的做法,在現有的方法基礎上增加類似半監督學習、條件架構和自動編碼等的技術,更進一步的解決生成對抗網絡在應用中的問題。
論文地址:https://arxiv.org/abs/1709.03831
雷鋒網 AI 科技評論編譯