不到200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

2020-12-11 雷鋒網

生成對抗網絡(Generative Adversarial Networks,GAN)最早由 Ian Goodfellow 在 2014 年提出,是目前深度學習領域最具潛力的研究成果之一。它的核心思想是:同時訓練兩個相互協作、同時又相互競爭的深度神經網絡(一個稱為生成器 Generator,另一個稱為判別器 Discriminator)來處理無監督學習的相關問題。在訓練過程中,兩個網絡最終都要學習如何處理任務。

通常,我們會用下面這個例子來說明 GAN 的原理:將警察視為判別器,製造假幣的犯罪分子視為生成器。一開始,犯罪分子會首先向警察展示一張假幣。警察識別出該假幣,並向犯罪分子反饋哪些地方是假的。接著,根據警察的反饋,犯罪分子改進工藝,製作一張更逼真的假幣給警方檢查。這時警方再反饋,犯罪分子再改進工藝。不斷重複這一過程,直到警察識別不出真假,那麼模型就訓練成功了。

雖然 GAN 的核心思想看起來非常簡單,但要搭建一個真正可用的 GAN 網絡卻並不容易。因為畢竟在 GAN 中有兩個相互耦合的深度神經網絡,同時對這兩個網絡進行梯度的反向傳播,也就比一般場景困難兩倍。

為此,本文將以深度卷積生成對抗網絡(Deep Convolutional GAN,DCGAN)為例,介紹如何基於 Keras 2.0 框架,以 Tensorflow 為後端,在 200 行代碼內搭建一個真實可用的 GAN 模型,並以該模型為基礎自動生成 MNIST 手寫體數字。

  判別器

判別器的作用是判斷一個模型生成的圖像和真實圖像比,有多逼真。它的基本結構就是如下圖所示的卷積神經網絡(Convolutional Neural Network,CNN)。對於 MNIST 數據集來說,模型輸入是一個 28x28 像素的單通道圖像。Sigmoid 函數的輸出值在 0-1 之間,表示圖像真實度的概率,其中 0 表示肯定是假的,1 表示肯定是真的。與典型的 CNN 結構相比,這裡去掉了層之間的 max-pooling,而是採用了步進卷積來進行下採樣。這裡每個 CNN 層都以 LeakyReLU 為激活函數。而且為了防止過擬合和記憶效應,層之間的 dropout 值均被設置在 0.4-0.7 之間。具體在 Keras 中的實現代碼如下。

self.D = Sequential()
depth = 64
dropout = 0.4
# In: 28 x 28 x 1, depth = 1
# Out: 10 x 10 x 1, depth=64
input_shape = (self.img_rows, self.img_cols, self.channel)
self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape,\
padding='same', activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*2, 5, strides=2, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*4, 5, strides=2, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*8, 5, strides=1, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
# Out: 1-dim probability
self.D.add(Flatten())
self.D.add(Dense(1))
self.D.add(Activation('sigmoid'))
self.D.summary()

  生成器

生成器的作用是合成假的圖像,其基本機構如下圖所示。圖中,我們使用了卷積的倒數,即轉置卷積(transposed convolution),從 100 維的噪聲(滿足 -1 至 1 之間的均勻分布)中生成了假圖像。如在 DCGAN 模型中提到的那樣,去掉微步進卷積,這裡我們採用了模型前三層之間的上採樣來合成更逼真的手寫圖像。在層與層之間,我們採用了批量歸一化的方法來平穩化訓練過程。以 ReLU 函數為每一層結構之後的激活函數。最後一層 Sigmoid 函數輸出最後的假圖像。第一層設置了 0.3-0.5 之間的 dropout 值來防止過擬合。具體代碼如下。

self.G = Sequential()
dropout = 0.4
depth = 64+64+64+64
dim = 7
# In: 100
# Out: dim x dim x depth
self.G.add(Dense(dim*dim*depth, input_dim=100))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Reshape((dim, dim, depth)))
self.G.add(Dropout(dropout))
# In: dim x dim x depth
# Out: 2*dim x 2*dim x depth/2
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
# Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix
self.G.add(Conv2DTranspose(1, 5, padding='same'))
self.G.add(Activation('sigmoid'))
self.G.summary()
return self.G

  生成 GAN 模型

下面我們生成真正的 GAN 模型。如上所述,這裡我們需要搭建兩個模型:一個是判別器模型,代表警察;另一個是對抗模型,代表製造假幣的犯罪分子。

判別器模型

下面代碼展示了如何在 Keras 框架下生成判別器模型。上文定義的判別器是為模型訓練定義的損失函數。這裡由於判別器的輸出為 Sigmoid 函數,因此採用了二進位交叉熵為損失函數。在這種情況下,以 RMSProp 作為優化算法可以生成比 Adam 更逼真的假圖像。這裡我們將學習率設置在 0.0008,同時還設置了權值衰減和clipvalue等參數來穩定後期的訓練過程。如果你需要調節學習率,那麼也必須同步調節其他相關參數。

optimizer = RMSprop(lr=0.0008, clipvalue=1.0, decay=6e-8)
self.DM = Sequential()
self.DM.add(self.discriminator())
self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
metrics=['accuracy'])

對抗模型

如圖所示,對抗模型的基本結構是判別器和生成器的疊加。生成器試圖騙過判別器,同時從其反饋中提升自己。如下代碼中演示了如何基於 Keras 框架實現這一部分功能。其中,除了學習速率的降低和相對權值衰減之外,訓練參數與判別器模型中的訓練參數完全相同。

optimizer = RMSprop(lr=0.0004, clipvalue=1.0, decay=3e-8)
self.AM = Sequential()
self.AM.add(self.generator())
self.AM.add(self.discriminator())
self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
metrics=['accuracy'])

訓練

搭好模型之後,訓練是最難實現的部分。這裡我們首先用真實圖像和假圖像對判別器模型單獨進行訓練,以判斷其正確性。接著,對判別器模型和對抗模型輪流展開訓練。如下圖展示了判別器模型訓練的基本流程。在 Keras 框架下的實現代碼如下所示。

images_train = self.x_train[np.random.randint(0,
self.x_train.shape[0], size=batch_size), :, :, :]
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
images_fake = self.generator.predict(noise)
x = np.concatenate((images_train, images_fake))
y = np.ones([2*batch_size, 1])
y[batch_size:, :] = 0
d_loss = self.discriminator.train_on_batch(x, y)
y = np.ones([batch_size, 1])
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
a_loss = self.adversarial.train_on_batch(noise, y)

訓練過程中需要非常耐心,這裡列出一些常見問題和解決方案:

問題1:最終生成的圖像噪點太多。

解決:嘗試在判別器和生成器模型上引入 dropout,一般更小的 dropout 值(0.3-0.6)可以產生更逼真的圖像。

問題2:判別器的損失函數迅速收斂為零,導致發生器無法訓練。

解決:不要對判別器進行預訓練。而是調整學習率,使判別器的學習率大於對抗模型的學習率。也可以嘗試對生成器換一個不同的訓練噪聲樣本。

問題3:生成器輸出的圖像仍然看起來像噪聲。

解決:檢查激活函數、批量歸一化和 dropout 的應用流程是否正確。

問題4:如何確定正確的模型/訓練參數。

解決:嘗試從一些已經發表的論文或代碼中找到參考,調試時每次只調整一個參數。在進行 2000 步以上的訓練時,注意觀察在 500 或 1000 步左右參數值調整的效果。

  輸出情況

下圖展示了在訓練過程中,整個模型的輸出變化情況。可以看到,GAN 在自己學習如何生成手寫體數字。

完整代碼地址:

https://github.com/roatienza/Deep-Learning-Experiments/blob/master/Experiments/Tensorflow/GAN/dcgan_mnist.py 

來源:medium,雷鋒網編譯

雷鋒網(公眾號:雷鋒網(公眾號:雷鋒網))相關閱讀:

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基於 PyTorch)

生成對抗網絡(GANs )為什麼這麼火?盤點它誕生以來的主要技術進展

雷鋒網版權文章,未經授權禁止轉載。詳情見轉載須知。

相關焦點

  • 手把手教你用keras搭建GAN
    在遙遠的九月份,我開始做了keras的系列教程,現在我主要的研究方向轉到了生成對抗網絡,生成對抗網絡的代碼實現和訓練機制比分類模型都要複雜和難入門.之前一段時間時間一直在幫璇姐跑cvpr的實驗代碼,做了蠻多的對比實驗,其中我就發現了,keras的代碼實現和可閱讀性很好,搭生成對抗網絡網絡GAN就好像搭樂高積木一樣有趣哦。
  • 【專知薈萃11】GAN生成式對抗網絡知識資料全集(理論/報告/教程/綜述/代碼等)
    今天專知為大家呈送第十一篇專知主題薈萃-生成式對抗網絡GAN知識資料大全集薈萃 (理論/報告/教程/綜述/代碼等),請大家查看!]The GAN Zoo千奇百怪的生成對抗網絡,都在這裡了。/soumith/ganhacks]OpenAI生成模型參考連結:[https://blog.openai.com/generative-models/]用Keras實現MNIST生成對抗模型參考連結:[https://oshearesearch.com/index.PHP/2016/07/01/mnist-generative-adversarial-model-in-keras
  • 輕鬆構建 PyTorch 生成對抗網絡(GAN)
    『模仿手寫字體』,為了完成這個課題,您將親手體驗生成對抗網絡的設計和實現。『模仿手寫字體』與人像生成的基本原理和工程流程基本是一致的,雖然它們的複雜性和精度要求有一定差距,但是通過解決『模仿手寫字體』問題,可以為生成對抗網絡的原理和工程實踐打下基礎,進而可以逐步嘗試和探索更加複雜先進的網絡架構和應用場景。《生成對抗網絡》(GAN)由 Ian Goodfellow 等人在 2014年提出,它是一種深度神經網絡架構,由一個生成網絡和一個判別網絡組成。
  • GAN對抗網絡入門教程
    A Beginner's Guide to Generative Adversarial Networks (GANs) https://skymind.ai/wiki/generative-adversarial-network-gan生成對抗網絡(英語:Generative Adversarial Network,簡稱GAN)是非監督式學習的一種方法,通過讓兩個神經網絡相互博弈的方式進行學習
  • 5分鐘入門GANS:原理解釋和keras代碼實現
    GANs,用於生成圖像而不需要很少或沒有輸入。GANs允許我們生成由神經網絡生成的圖像。在我們深入討論這個理論之前,我想向您展示GANs構建您興奮感的能力。把馬變成斑馬(反之亦然)。歷史生成式對抗網絡(GANs)是由Ian Goodfellow (GANs的GAN Father)等人於2014年在其題為「生成式對抗網絡」的論文中提出的。
  • 教程 | 在Keras上實現GAN:構建消除圖片模糊的應用
    生成對抗網絡簡介在生成對抗網絡中,有兩個網絡互相進行訓練。生成器通過生成逼真的虛假輸入來誤導判別器,而判別器會分辨輸入是真實的還是人造的。GAN 訓練流程訓練過程中有三個關鍵步驟:請注意,判別器的權重在第三步中被凍結。
  • GAN生成式對抗網絡及應用詳解
    從本節開始,我們將討論如何將生成對抗網絡(GAN)應用於深度學習的某個領域。
  • 生成對抗網絡GANs學習路線
    導讀MLmindset作者發布了一篇生成對抗網絡合集文章,整合了各類關於GAN的資源,如GANs文章、模型、代碼、應用、課程、書籍
  • 科普 | ​生成對抗網絡(GAN)的發展史
    Ian Goodfellow等人在「Generative Adversarial Networks」中提出了生成對抗網絡。學術界和工業界都開始接受並歡迎GAN的到來。GAN的崛起不可避免。首先,GAN最厲害的地方是它的學習性質是無監督的。GAN也不需要標記數據,這使GAN功能強大,因為數據標記的工作非常枯燥。其次,GAN的潛在用例使它成為交談的中心。
  • 【前沿】NIPS2017貝葉斯生成對抗網絡TensorFlow實現(附GAN資料下載)
    導讀今年五月份康奈爾大學的 Andrew Gordon Wilson 和 Permutation Venture 的 Yunus Saatchi 提出了一個貝葉斯生成對抗網絡(Bayesian GAN),結合貝葉斯和對抗生成網絡,提出了一個實用的貝葉斯公式框架,用GAN來進行無監督學習和半監督式學習。
  • GAN快速入門資料推薦:17種變體的Keras開原始碼,附相關論文
    夏乙 編譯整理量子位 出品 | 公眾號 QbitAI圖片來源:Kaggle blog從2014年誕生至今,生成對抗網絡(GAN)始終廣受關注,已經出現了200多種有名有姓的變體。在論文中,研究人員給出了用MNIST和多倫多人臉數據集 (TFD)訓練的模型所生成的樣本。
  • 教程 | 一招教你使用 tf.keras 和 eager execution 解決複雜問題
    和 eager execution)解決了四類複雜問題:文本生成、生成對抗網絡、神經網絡機器翻譯、圖片標註。(文本生成)能生成一張貓的圖片嗎?(生成對抗網絡)能翻譯句子嗎?(神經網絡機器翻譯)能根據圖片生成標題嗎?(圖片標註)在暑期實習期間,我使用 TensorFlow 的兩個最新 API(tf.keras 和 eager execution)開發了這些示例,以下是分享內容。希望你們能覺得它們有用,有趣!
  • 只需130 行代碼,用 GAN 生成二維樣本的小例子
    PyTorch 平臺用 50 行代碼實現 GAN(生成對抗網絡),詳情參見:《GAN 很複雜?如如何用不到 50 行代碼訓練 GAN》。近期,針對文中介紹的「50 行代碼 GAN 模型」,有開發者指出了局限性,並基於此模型給出了改進版本,也就是本文將要介紹的「130 行代碼實現 GAN 二維樣本」。本文原載於知乎專欄,作者達聞西,雷鋒網經授權發布。
  • 雲計算必備知識-基於PyTorch機器學習構建生成對抗網絡
    『模仿手寫字體』與人像生成的基本原理和工程流程基本是一致的,雖然它們的複雜性和精度要求有一定差距,但是通過解決『模仿手寫字體』問題,可以為生成對抗網絡的原理和工程實踐打下基礎,進而可以逐步嘗試和探索更加複雜先進的網絡架構和應用場景。《生成對抗網絡》(GAN)由 Ian Goodfellow 等人在 2014年提出,它是一種深度神經網絡架構,由一個生成網絡和一個判別網絡組成。
  • GAN(生成對抗網絡)萬字長文綜述
    GAN的基本介紹生成對抗網絡(GAN,Generative Adversarial Networks)作為一種優秀的生成式模型,引爆了許多圖像生成的有趣應用。GAN的基本概念GAN(Generative Adversarial Networks)從其名字可以看出,是一種生成式的,對抗網絡。再具體一點,就是通過對抗的方式,去學習數據分布的生成式模型。所謂的對抗,指的是生成網絡和判別網絡的互相對抗。生成網絡儘可能生成逼真樣本,判別網絡則儘可能去判別該樣本是真實樣本,還是生成的假樣本。示意圖如下:
  • GAN(生成對抗網絡)的最新應用狀況
    G 的 loss 包含 content loss 部分,因此 G 並非完全的非監督,它也用到了監督信息:它強制要求生成圖像提取的特徵與真實圖像提取的特徵要匹配,文中用到的特徵提取網絡為 VGG,content loss 定義如下:
  • 【GAN】四、CGAN論文詳解與代碼詳解
    在之前我們介紹了DCGAN與原始GAN的相關理論,並給出了DCGAN生成手寫數字圖像的代碼。CGAN生成手寫數字的keras代碼請移步:https://github.com/Daipuwei/CGAN-mnist。一、 GAN回顧為了兼顧CGAN的相關理論介紹,我們首先回顧GAN相關細節。GAN主要包括兩個網絡,一個是生成器和判別器,生成器的目的就是將隨機輸入的高斯噪聲映射成圖像(「假圖」),判別器則是判斷輸入圖像是否來自生成器的概率,即判斷輸入圖像是否為假圖的概率。
  • 萬字綜述之生成對抗網絡(GAN)
    文章目錄如下:GAN的基本介紹生成對抗網絡(GAN,Generative Adversarial Networks)作為一種優秀的生成式模型,引爆了許多圖像生成的有趣應用。GAN的基本概念GAN(Generative Adversarial Networks)從其名字可以看出,是一種生成式的,對抗網絡。再具體一點,就是通過對抗的方式,去學習數據分布的生成式模型。所謂的對抗,指的是生成網絡和判別網絡的互相對抗。
  • 深度 | 生成對抗網絡初學入門:一文讀懂GAN的基本原理(附資源)
    在這篇文章中,我們將對生成對抗網絡(GAN)背後的一般思想進行全面的介紹,並向你展示一些主要的架構以幫你很好地開始學習,另外我們還將提供一些有用的技巧,可以幫你顯著改善你的結果。GAN 的發明生成模型的基本思想是輸入一個訓練樣本集合,然後形成這些樣本的概率分布的表徵。常用的生成模型方法是直接推斷其概率密度函數。
  • 資源|帶自注意力機制的生成對抗網絡,實現效果怎樣?
    在前一段時間,Han Zhang 和 Goodfellow 等研究者提出添加了自注意力機制的生成對抗網絡,這種網絡可使用全局特徵線索來生成高解析度細節。本文介紹了自注意力生成對抗網絡的 PyTorch 實現,讀者也可以嘗試這一新型生成對抗網絡。