不到200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

2020-12-05 雷鋒網

生成對抗網絡(Generative Adversarial Networks,GAN)最早由 Ian Goodfellow 在 2014 年提出,是目前深度學習領域最具潛力的研究成果之一。它的核心思想是:同時訓練兩個相互協作、同時又相互競爭的深度神經網絡(一個稱為生成器 Generator,另一個稱為判別器 Discriminator)來處理無監督學習的相關問題。在訓練過程中,兩個網絡最終都要學習如何處理任務。

通常,我們會用下面這個例子來說明 GAN 的原理:將警察視為判別器,製造假幣的犯罪分子視為生成器。一開始,犯罪分子會首先向警察展示一張假幣。警察識別出該假幣,並向犯罪分子反饋哪些地方是假的。接著,根據警察的反饋,犯罪分子改進工藝,製作一張更逼真的假幣給警方檢查。這時警方再反饋,犯罪分子再改進工藝。不斷重複這一過程,直到警察識別不出真假,那麼模型就訓練成功了。

雖然 GAN 的核心思想看起來非常簡單,但要搭建一個真正可用的 GAN 網絡卻並不容易。因為畢竟在 GAN 中有兩個相互耦合的深度神經網絡,同時對這兩個網絡進行梯度的反向傳播,也就比一般場景困難兩倍。

為此,本文將以深度卷積生成對抗網絡(Deep Convolutional GAN,DCGAN)為例,介紹如何基於 Keras 2.0 框架,以 Tensorflow 為後端,在 200 行代碼內搭建一個真實可用的 GAN 模型,並以該模型為基礎自動生成 MNIST 手寫體數字。

  判別器

判別器的作用是判斷一個模型生成的圖像和真實圖像比,有多逼真。它的基本結構就是如下圖所示的卷積神經網絡(Convolutional Neural Network,CNN)。對於 MNIST 數據集來說,模型輸入是一個 28x28 像素的單通道圖像。Sigmoid 函數的輸出值在 0-1 之間,表示圖像真實度的概率,其中 0 表示肯定是假的,1 表示肯定是真的。與典型的 CNN 結構相比,這裡去掉了層之間的 max-pooling,而是採用了步進卷積來進行下採樣。這裡每個 CNN 層都以 LeakyReLU 為激活函數。而且為了防止過擬合和記憶效應,層之間的 dropout 值均被設置在 0.4-0.7 之間。具體在 Keras 中的實現代碼如下。

self.D = Sequential()
depth = 64
dropout = 0.4
# In: 28 x 28 x 1, depth = 1
# Out: 10 x 10 x 1, depth=64
input_shape = (self.img_rows, self.img_cols, self.channel)
self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape,\
padding='same', activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*2, 5, strides=2, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*4, 5, strides=2, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*8, 5, strides=1, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
# Out: 1-dim probability
self.D.add(Flatten())
self.D.add(Dense(1))
self.D.add(Activation('sigmoid'))
self.D.summary()

  生成器

生成器的作用是合成假的圖像,其基本機構如下圖所示。圖中,我們使用了卷積的倒數,即轉置卷積(transposed convolution),從 100 維的噪聲(滿足 -1 至 1 之間的均勻分布)中生成了假圖像。如在 DCGAN 模型中提到的那樣,去掉微步進卷積,這裡我們採用了模型前三層之間的上採樣來合成更逼真的手寫圖像。在層與層之間,我們採用了批量歸一化的方法來平穩化訓練過程。以 ReLU 函數為每一層結構之後的激活函數。最後一層 Sigmoid 函數輸出最後的假圖像。第一層設置了 0.3-0.5 之間的 dropout 值來防止過擬合。具體代碼如下。

self.G = Sequential()
dropout = 0.4
depth = 64+64+64+64
dim = 7
# In: 100
# Out: dim x dim x depth
self.G.add(Dense(dim*dim*depth, input_dim=100))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Reshape((dim, dim, depth)))
self.G.add(Dropout(dropout))
# In: dim x dim x depth
# Out: 2*dim x 2*dim x depth/2
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
# Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix
self.G.add(Conv2DTranspose(1, 5, padding='same'))
self.G.add(Activation('sigmoid'))
self.G.summary()
return self.G

  生成 GAN 模型

下面我們生成真正的 GAN 模型。如上所述,這裡我們需要搭建兩個模型:一個是判別器模型,代表警察;另一個是對抗模型,代表製造假幣的犯罪分子。

判別器模型

下面代碼展示了如何在 Keras 框架下生成判別器模型。上文定義的判別器是為模型訓練定義的損失函數。這裡由於判別器的輸出為 Sigmoid 函數,因此採用了二進位交叉熵為損失函數。在這種情況下,以 RMSProp 作為優化算法可以生成比 Adam 更逼真的假圖像。這裡我們將學習率設置在 0.0008,同時還設置了權值衰減和clipvalue等參數來穩定後期的訓練過程。如果你需要調節學習率,那麼也必須同步調節其他相關參數。

optimizer = RMSprop(lr=0.0008, clipvalue=1.0, decay=6e-8)
self.DM = Sequential()
self.DM.add(self.discriminator())
self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
metrics=['accuracy'])

對抗模型

如圖所示,對抗模型的基本結構是判別器和生成器的疊加。生成器試圖騙過判別器,同時從其反饋中提升自己。如下代碼中演示了如何基於 Keras 框架實現這一部分功能。其中,除了學習速率的降低和相對權值衰減之外,訓練參數與判別器模型中的訓練參數完全相同。

optimizer = RMSprop(lr=0.0004, clipvalue=1.0, decay=3e-8)
self.AM = Sequential()
self.AM.add(self.generator())
self.AM.add(self.discriminator())
self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
metrics=['accuracy'])

訓練

搭好模型之後,訓練是最難實現的部分。這裡我們首先用真實圖像和假圖像對判別器模型單獨進行訓練,以判斷其正確性。接著,對判別器模型和對抗模型輪流展開訓練。如下圖展示了判別器模型訓練的基本流程。在 Keras 框架下的實現代碼如下所示。

images_train = self.x_train[np.random.randint(0,
self.x_train.shape[0], size=batch_size), :, :, :]
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
images_fake = self.generator.predict(noise)
x = np.concatenate((images_train, images_fake))
y = np.ones([2*batch_size, 1])
y[batch_size:, :] = 0
d_loss = self.discriminator.train_on_batch(x, y)
y = np.ones([batch_size, 1])
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
a_loss = self.adversarial.train_on_batch(noise, y)

訓練過程中需要非常耐心,這裡列出一些常見問題和解決方案:

問題1:最終生成的圖像噪點太多。

解決:嘗試在判別器和生成器模型上引入 dropout,一般更小的 dropout 值(0.3-0.6)可以產生更逼真的圖像。

問題2:判別器的損失函數迅速收斂為零,導致發生器無法訓練。

解決:不要對判別器進行預訓練。而是調整學習率,使判別器的學習率大於對抗模型的學習率。也可以嘗試對生成器換一個不同的訓練噪聲樣本。

問題3:生成器輸出的圖像仍然看起來像噪聲。

解決:檢查激活函數、批量歸一化和 dropout 的應用流程是否正確。

問題4:如何確定正確的模型/訓練參數。

解決:嘗試從一些已經發表的論文或代碼中找到參考,調試時每次只調整一個參數。在進行 2000 步以上的訓練時,注意觀察在 500 或 1000 步左右參數值調整的效果。

  輸出情況

下圖展示了在訓練過程中,整個模型的輸出變化情況。可以看到,GAN 在自己學習如何生成手寫體數字。

完整代碼地址:

https://github.com/roatienza/Deep-Learning-Experiments/blob/master/Experiments/Tensorflow/GAN/dcgan_mnist.py 

來源:medium,雷鋒網(公眾號:雷鋒網)編譯

雷鋒網(公眾號:雷鋒網)相關閱讀:

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基於 PyTorch)

生成對抗網絡(GANs )為什麼這麼火?盤點它誕生以來的主要技術進展

雷鋒網版權文章,未經授權禁止轉載。詳情見轉載須知。

相關焦點

  • 生成式對抗網絡GAN的高級議題
    生成對抗網絡(GAN)就是一個幫你實現的新朋友。"GAN是過去10年機器學習中最有趣的想法。" - Facebook AI人工智慧研究總監Yann LeCun最近引入了生成對抗網作為訓練生成模型的新方法,即創建能夠生成數據的模型。它們由兩個"對抗"模式:生成模型G獲得數據和判別模型D來估計訓練數據提供的樣本的準確性。G和D可能是一個非線性映射函數,如多層感知。在生成對抗網絡(GAN)中,我們有兩個神經網絡在零和遊戲中相互對抗,其中第一個網絡,即生成器,其任務是欺騙第二個網絡,即鑑別器。
  • 用Keras搭建GAN:圖像去模糊中的應用(附代碼)
    這篇文章主要介紹在Keras中搭建GAN實現圖像去模糊。所有的Keras代碼可點擊這裡。可點擊查看原始出版文章和Pytorch實現。快速回憶生成對抗網絡GAN中兩個網絡的訓練相互競爭。這些只是對生成對抗網絡的一個簡單回顧,如果還是不夠明白的話,可以參考完整介紹。數據Ian Goodfellow首次使用GAN模型是生成MNIST數據。 而本篇文章是使用生成對抗網絡進行圖像去模糊。因此生成器的輸入不是噪聲,而是模糊圖像。
  • 5分鐘入門GANS:原理解釋和keras代碼實現
    GANs,用於生成圖像而不需要很少或沒有輸入。GANs允許我們生成由神經網絡生成的圖像。在我們深入討論這個理論之前,我想向您展示GANs構建您興奮感的能力。把馬變成斑馬(反之亦然)。歷史生成式對抗網絡(GANs)是由Ian Goodfellow (GANs的GAN Father)等人於2014年在其題為「生成式對抗網絡」的論文中提出的。
  • 這些資源你肯定需要!超全的GAN PyTorch+Keras實現集合
    機器之心編譯參與:劉曉坤、思源、李澤南生成對抗網絡一直是非常美妙且高效的方法,自 14 年 Ian Goodfellow 等人提出第一個生成對抗網絡以來,各種變體和修正版如雨後春筍般出現,它們都有各自的特性和對應的優勢。
  • 一篇文章教你用11行Python代碼實現神經網絡
    聲明:本文是根據英文教程 (用 11 行 Python 代碼實現的神經網絡)學習總結而來,關於更詳細的神經網絡的介紹可以參考我的另一篇博客:。A Neural Network in 11 lines of Python從感知機到人工神經網絡如果你讀懂了下面的文章,你會對神經網絡有更深刻的認識,有任何問題,請多指教。
  • 只需130 行代碼,用 GAN 生成二維樣本的小例子
    PyTorch 平臺用 50 行代碼實現 GAN(生成對抗網絡),詳情參見:《GAN 很複雜?如如何用不到 50 行代碼訓練 GAN》。近期,針對文中介紹的「50 行代碼 GAN 模型」,有開發者指出了局限性,並基於此模型給出了改進版本,也就是本文將要介紹的「130 行代碼實現 GAN 二維樣本」。本文原載於知乎專欄,作者達聞西,雷鋒網經授權發布。
  • GAN快速入門資料推薦:17種變體的Keras開原始碼,附相關論文
    夏乙 編譯整理量子位 出品 | 公眾號 QbitAI圖片來源:Kaggle blog從2014年誕生至今,生成對抗網絡(GAN)始終廣受關注,已經出現了200多種有名有姓的變體。在論文中,研究人員給出了用MNIST和多倫多人臉數據集 (TFD)訓練的模型所生成的樣本。
  • GAN(生成對抗網絡)萬字長文綜述
    GAN的基本介紹生成對抗網絡(GAN,Generative Adversarial Networks)作為一種優秀的生成式模型,引爆了許多圖像生成的有趣應用。GAN的基本概念GAN(Generative Adversarial Networks)從其名字可以看出,是一種生成式的,對抗網絡。再具體一點,就是通過對抗的方式,去學習數據分布的生成式模型。所謂的對抗,指的是生成網絡和判別網絡的互相對抗。
  • 萬字綜述之生成對抗網絡(GAN)
    文章目錄如下:GAN的基本介紹生成對抗網絡(GAN,Generative Adversarial Networks)作為一種優秀的生成式模型,引爆了許多圖像生成的有趣應用。GAN的基本概念GAN(Generative Adversarial Networks)從其名字可以看出,是一種生成式的,對抗網絡。再具體一點,就是通過對抗的方式,去學習數據分布的生成式模型。所謂的對抗,指的是生成網絡和判別網絡的互相對抗。
  • 手把手教你用Keras進行多標籤分類(附代碼)
    或者至少訓練一個神經網絡來完成三項分類任務? 我不想在if / else代碼的級聯中單獨應用它們,這些代碼使用不同的網絡,具體取決於先前分類的輸出。 謝謝你的幫助Switaj提出了一個美妙的問題:Keras深度神經網絡是否有可能返回多個預測?
  • Keras結合Keras後端搭建個性化神經網絡模型(不用原生Tensorflow)
    它幫我們實現了一系列經典的神經網絡層(全連接層、卷積層、循環層等),以及簡潔的迭代模型的接口,讓我們能在模型層面寫代碼,從而不用仔細考慮模型各層張量之間的數據流動。但是,當我們有了全新的想法,想要個性化模型層的實現,Keras的高級API是不能滿足這一要求的,而換成Tensorflow又要重新寫很多輪子,這時,Keras的後端就派上用場了。
  • 谷歌推出新框架:只需5行代碼,就能提高模型準確度和魯棒性
    曉查 發自 凹非寺量子位 出品 | 公眾號 QbitAI今天,谷歌推出了新開源框架——神經結構學習(NSL),它使用神經圖學習方法,來訓練帶有圖(Graph)和結構化數據的神經網絡,可以帶來更強大的模型。現在,通過TensorFlow就能獲取和使用。NSL有什麼用?
  • 資源|17類對抗網絡經典論文及開原始碼(附源碼)
    原標題:資源|17類對抗網絡經典論文及開原始碼(附源碼) 全球人工智慧 文章來源:Github 對抗網絡專題文獻集convolutional networks)(ICLR) [Paper]https://arxiv.org/abs/1511.06434 [Code]https://github.com/jacobgil/keras-dcgan
  • 生成對抗網絡的最新研究進展
    生成對抗網絡的工作原理給定一組目標樣本,生成器試圖生成一些能夠欺騙判別器、使判別器相信它們是真實的樣本。判別器試圖從假(生成)樣本中解析真實(目標)樣本。然而,他們的對偶形式(用下確界代替上確界或者用上確界代替下確界)可能易於優化。對偶原則為將一種形式轉換為另一種形式奠定了框架。關於這一點的詳細解釋,你可以查看這篇博客。4.Lipschitz 連續性一個 Lipschitz 連續函數的變化速度是有限的。
  • Keras入門系列教程:兩分鐘構建你的第一個神經網絡模型
    要開始使用tf.keras, 請將其作為TensorFlow程序的一部分導入:import tensorflow as tffrom tensorflow import kerastf.keras 可以運行任何與Keras兼容的代碼
  • Keras和TensorFlow究竟哪個會更好?
    -10 數據集上訓練兩個單獨的卷積神經網絡 (CNN),方案如下: 方法 1 :以 TensorFlow 作為後端的 Keras 模型 方法 2 :使用 tf.keras 中 Keras 子模塊 在介紹的過程中我還會展示如何把自定義的 TensorFlow 代碼寫入你的 Keras 模型中。
  • 圖像分類入門,輕鬆拿下90%準確率|教你用Keras搞Fashion-MNIST
    原作 Margaret Maynard-Reid王小新 編譯自 TensorFlow的Medium量子位 出品 | 公眾號 QbitAI這篇教程會介紹如何用TensorFlow裡的tf.keras函數,對Fashion-MNIST數據集進行圖像分類。
  • Keras 之父講解 Keras:幾行代碼就能在分布式環境訓練模型 |...
    但現在,我們把 Keras API 直接整合入 TensorFlow 項目中,這樣能與你的已有工作流無縫結合。至此,Keras 成為了 TensorFlow 內部的一個新模塊:tf.keras,它包含完整的 Keras API。「對於 TensorFlow 用戶,這意味著你獲得了一整套易於使用的深度學習組件,並能與你的工作流無縫整合。
  • 張東升,我知道是你!如何使用GAN做一個禿頭生產器
    看過這部劇後,我突然很想知道自己禿頭是什麼樣子,於是查了一下飛槳官網,果然它有圖片生成的模型庫。那麼,我們如何使用PaddlePaddle做出一個禿頭生成器呢。  生成對抗網絡介紹  說到圖像生成,就必須說到GAN,它是一種非監督學習的方式,通過讓兩個神經網絡相互博弈的方法進行學習,該方法由lan Goodfellow等人在2014年提出。生成對抗網絡由一個生成網絡和一個判別網絡組成,生成網絡從潛在的空間(latent space)中隨機採樣作為輸入,其輸出結果需要儘量模仿訓練集中的真實樣本。
  • 通過Keras 構建基於 LSTM 模型的故事生成器
    LSTM 的使用背景當你讀這篇文章的時候,你可以根據你對前面所讀單詞的理解來理解上下文。 你不會從一開始或者從中間部分閱讀就能夠直接理解文本意義,而是隨著你閱讀的深入,你的大腦才最終形成上下文聯繫,能夠理解文本意義。