人工智慧目前的核心目標應該是賦予機器自主理解我們所在世界的能力。對於人類來說,我們對這個世界所了解的知識可能很快就會忘記,比如我們所處的三維環境中,物體能夠交互,移動,碰撞;什麼動物會飛,什麼動物吃草等等。這些巨大的並且不斷擴大的信息現在是很容易被機器獲取的,問題的關鍵是怎麼設計模型和算法讓機器更好的去分析和理解這些數據中所蘊含的寶藏。
Generative models(生成模型)現在被認為是能夠實現這一目標的最有前景的方法之一。Generative models通過輸入一大堆特定領域的數據進行訓練(比如圖像,句子,聲音等)來使得模型能夠產生和輸入數據相似的輸出。這一直覺的背後可以由下面名言闡述。
「What I cannot create, I do not understand.」 —Richard Feynman
生成模型由一個參數數量比訓練數據少的多神經網絡構成,所以生成模型為了能夠產生和訓練數據相似的輸出就會迫使自己去發現數據中內在的本質內容。訓練Generative models的方法有幾種,在這裡我們主要闡述其中的Adversarial Training(對抗訓練)方法。
Adversarial Training上文說過Adversarial Training是訓練生成模型的一種方法。為了訓練生成模型,Adversarial Training提出一種Discriminative Model(判別模型)來和生成模型產生對抗,下面來說說Generative models G(z) 和 Discriminative Model D(x) 是如何相互作用的。
其中生成模型和判別模型合起來的框架被稱為GAN網絡。通過下圖我們來理清判別模型和生成模型之間的輸入輸出關係:生成模型通過輸入隨機噪聲 z(z 屬於 p_z) 產生合成樣本;而判別模型通過分別輸入真實的訓練數據和生成模型的訓練數據來判斷輸入的數據是否真實。
描述了GAN的網絡結構,但它的優化目標是什麼?怎麼就可以通過訓練使得生成模型能夠產生和真實數據相似的輸出?優化的目標其實很簡單,簡單來說就是:
下面用形式化說明下如果訓練GAN網絡, 先定義一些參數:
參數含義p_z輸入隨機噪聲 z 的分布p_{data}未知的輸入樣本的數據分布p_g生成模型的輸出樣本的數據分布,GAN的目標就是要p_g=p_{data}訓練判別模型 D(x) 的目標:
對每一個輸入數據 x 屬於 p_{data} 要使得 D(x) 最大;
對每一個輸入數據 x 不屬於 p_{data} 要使得 D(x) 最小。
訓練生成模型 G(z) 的目標是來產生樣本來欺騙判別模型 D, 因此目標為最大化 D(G(z)),也就是把生成模型的輸出輸入到判別模型,然後要讓判別模型預測其為真實數據。同時,最大化 D(G(z)) 等同於最小化 1-D(G(z)),因為 D 的輸出是介於0到1之間的,真實數據努力預測為1,否則為0。
所以把生成模型和判別模型的訓練目標結合起來,就得到了GAN的優化目標:
總結一下上面的內容,GAN啟發自博弈論中的二人零和博弈,在二人零和博弈中,兩位博弈方的利益之和為零或一個常數,即一方有所得,另一方必有所失。GAN模型中的兩位博弈方分別由生成模型和判別模型充當。生成模型G捕捉樣本數據的分布,判別模型是一個二分類器,估計一個樣本來自於訓練數據(而非生成數據)的概率。G和D一般都是非線性映射函數,例如多層感知機、卷積神經網絡等。生成模型的輸入是一些服從某一簡單分布(例如高斯分布)的隨機噪聲z,輸出是與訓練圖像相同尺寸的生成圖像。向判別模型D輸入生成樣本,對於D來說期望輸出低概率(判斷為生成樣本),對於生成模型G來說要儘量欺騙D,使判別模型輸出高概率(誤判為真實樣本),從而形成競爭與對抗。
GAN實現一個簡單的一維數據GAN網絡的tensorflow實現:genadv_tutorial其一維訓練數據分布如下所示,是一個均值-1, sigma =1 的正態分布。
我們結合代碼和上面的理論內容來分析下GAN的具體實現,判別模型的優化目標為最大化下式,其中 D_1(x) 表示判別真實數據, D_2(G(z)) 表示對生成的數據進行判別, 其中 D_1 和 D_2 是共享參數的, 也就是說是同一個判別模型。
對應的python代碼如下:
batch=tf.Variable(0)
obj_d=tf.reduce_mean(tf.log(D1)+tf.log(1-D2))
opt_d=tf.train.GradientDescentOptimizer(0.01)
.minimize(1-obj_d,global_step=batch,var_list=theta_d)
為了優化 G, 我們想要最大化 D_2(x')(成功欺騙 D ),因此 G 的優化函數為:
對應的python代碼:
batch=tf.Variable(0)
obj_g=tf.reduce_mean(tf.log(D2))
opt_g=tf.train.GradientDescentOptimizer(0.01)
.minimize(1-obj_g,global_step=batch,var_list=theta_g)
定義好優化目標後,下面就是訓練的主要代碼了:
for i in range(TRAIN_ITERS):
x= np.random.normal(mu,sigma,M)
z= np.random.random(M)
sess.run(opt_d, {x_node: x, z_node: z})
z= np.random.random(M)
sess.run(opt_g, {z_node: z})
下面是實驗的結果,左圖是訓練之間的數據,可以看到生成數據的分布和訓練數據相差甚遠;右圖是訓練後的數據分析,生成數據和訓練數據的分布接近了很多,且此時判別模型的輸出分布在0.5左右,說明生成模型順利的欺騙到判別模型。
DCGANGAN的一個改進模型就是DCGAN。這個網絡的生成模型的輸入為一個100個符合均勻分布的隨機數(通常被稱為code),然後產生輸出為64x64x3的輸出圖像(下圖中 G(z) ), 當code逐漸遞增時,生成模型輸出的圖像也逐漸變化。下圖中的生產模型主要由反卷積層構成, 判別模型就由簡單的卷積層組成,最後輸出一個判斷輸入圖片是否為真實數據的概率 P(x) 。
下圖為隨著迭代次數,DCGAN產生圖像的變化過程。
訓練好網絡之後,其中的生成模型和判別模型都有其他的作用。一個訓練好的判別模型能夠用來對數據提取特徵然後進行分類任務。通過輸入隨機向量生成模型可以產生一些非常有意思的的圖片,如下圖所示,當輸入空間平滑變化時,輸出的圖片也在平滑轉變。
GAN的訓練及其改進上面使用GAN產生的圖像雖然效果不錯,但其實GAN網絡的訓練過程是非常不穩定的。通常在實際訓練GAN中所碰到的一個問題就是判別模型的收斂速度要比生成模型的收斂速度要快很多,通常的做法就是讓生成模型多訓練幾次來趕上生成模型,但是存在的一個問題就是通常生成模型和判別模型的訓練是相輔相成的,理想的狀態是讓生成模型和判別模型在每次的訓練過程中同時變得更好。判別模型理想的minimum loss應該為0.5,這樣才說明判別模型分不出是真實數據還是生成模型產生的數據。
Improved GANsImproved techniques for training GANs這篇文章提出了很多改進GANs訓練的方法,其中提出一個想法叫Feature matching,之前判別模型只判別輸入數據是來自真實數據還是生成模型。現在為判別模型提出了一個新的目標函數來判別生成模型產生圖像的統計信息是否和真實數據的相似。讓 f(x) 表示判別模型中間層的輸出, 新的目標函數被定義為:
其實就是要求真實圖像和合成圖像在判別模型中間層的距離要最小。這樣可以防止生成模型在當前判別模型上過擬合。
InfoGAN到這可能有些同學會想到,我要是想通過GAN產生我想要的特定屬性的圖片改怎麼辦?普通的GAN輸入的是隨機的噪聲,輸出也是與之對應的隨機圖片,我們並不能控制輸出噪聲和輸出圖片的對應關係。這樣在訓練的過程中也會倒置生成模型傾向於產生更容易欺騙判別模型的某一類特定圖片,而不是更好的去學習訓練數據的分布,這樣對模型的訓練肯定是不好的。InfoGAN的提出就是為了解決這一問題,通過對輸入噪聲添加一些類別信息以及控制圖像特徵(如mnist數字的角度和厚度)的隱含變量來使得生成模型的輸入不在是隨機噪聲。雖然現在輸入不再是隨機噪聲,但是生成模型可能會忽略這些輸入的額外信息還是把輸入當成和輸出無關的噪聲,所以需要定義一個生成模型輸入輸出的互信息,互信息越高,說明輸入輸出的關聯越大。
下面三張圖片展示了通過分別控制輸入噪聲的類別信息,數字角度信息,數字筆畫厚度信息產生指定輸出的圖片,可以看出InfoGAN產生圖片的效果還是很好的。
其他應用GAN網絡還有很多其他的有趣應用,比如下圖所示的根據一句話來產生對應的圖片,可能大家都有了解karpathy大神的看圖說話, 但是GAN有能力把這個過程給反過來。
還有下面這個「圖像補全」, 根據圖像剩餘的信息來匹配最佳的補全內容。
還有下面這個圖像增強的例子,有點去馬賽克的意思,效果還是挺不錯的:-D。
總結顏樂存說過,2016年深度學習領域最讓他興奮技術莫過於對抗學習。對抗學習確實是解決非監督學習的一個有效方法,而無監督學習一直都是人工智慧領域研究者所孜孜追求的「終極目標」之一。
參考Generative Adversarial Networks(https://arxiv.org/abs/1406.2661)(https://arxiv.org/abs/1511.06434)
(https://arxiv.org/abs/1606.03657)