本篇文章包含以下內容
介紹歷史直觀解釋訓練過程GAN在MNIST數據集上的KERAS實現介紹
生成式敵對網絡通常也稱為GANs,用於生成圖像而不需要很少或沒有輸入。GANs允許我們生成由神經網絡生成的圖像。在我們深入討論這個理論之前,我想向您展示GANs構建您興奮感的能力。把馬變成斑馬(反之亦然)。
歷史
生成式對抗網絡(GANs)是由Ian Goodfellow (GANs的GAN Father)等人於2014年在其題為「生成式對抗網絡」的論文中提出的。它是一種可替代的自適應變分編碼器(VAEs)學習圖像的潛在空間,以生成合成圖像。它的目的是創造逼真的人工圖像,幾乎無法與真實的圖像區分。
GAN的直觀解釋
生成器和鑑別器網絡:
生成器網絡的目的是將隨機圖像初始化並解碼成一個合成圖像。
鑑別器網絡的目的是獲取這個輸入,並預測這個圖像是來自真實的數據集還是合成的。
正如我們剛才看到的,這實際上就是GANs,兩個相互競爭的敵對網絡。
GAN的訓練過程
GANS的訓練是出了名的困難。在CNN中,我們使用梯度下降來改變權重以減少損失。
然而,在GANs中,每一次重量的變化都會改變整個動態系統的平衡。
在GAN的網絡中,我們不是在尋求將損失最小化,而是在我們對立的兩個網絡之間找到一種平衡。
我們將過程總結如下
輸入隨機生成的噪聲圖像到我們的生成器網絡中生成樣本圖像。我們從真實數據中提取一些樣本圖像,並將其與一些生成的圖像混合在一起。將這些混合圖像輸入到我們的鑑別器中,鑑別器將對這個混合集進行訓練並相應地更新它的權重。然後我們製作更多的假圖像,並將它們輸入到鑑別器中,但是我們將它們標記為真實的。這樣做是為了訓練生成器。我們在這個階段凍結了鑑別器的權值(鑑別器學習停止),並且我們使用來自鑑別器的反饋來更新生成器的權值。這就是我們如何教我們的生成器(製作更好的合成圖像)和鑑別器更好地識別贗品的方法。流程圖如下
對於本文,我們將使用MNIST數據集生成手寫數字。GAN的架構是:
使用KERAS實現GANS
首先,我們加載所有必要的庫
import os os.environ["KERAS_BACKEND"] = "tensorflow" import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt from keras.layers import Input from keras.models import Model, Sequential from keras.layers.core import Reshape, Dense, Dropout, Flatten from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import Convolution2D, UpSampling2D from keras.layers.normalization import BatchNormalization from keras.datasets import mnist from keras.optimizers import Adam from keras import backend as K from keras import initializers K.set_image_dim_ordering('th') # Deterministic output. # Tired of seeing the same results every time? Remove the line below. np.random.seed(1000) # The results are a little better when the dimensionality of the random vector is only 10. # The dimensionality has been left at 100 for consistency with other GAN implementations. randomDim = 100
現在我們加載數據集。這裡使用MNIST數據集,所以不需要單獨下載和處理。
(X_train, y_train), (X_test, y_test) = mnist.load_data() X_train = (X_train.astype(np.float32) - 127.5)/127.5 X_train = X_train.reshape(60000, 784)
接下來,我們定義生成器和鑑別器的結構
# Optimizer adam = Adam(lr=0.0002, beta_1=0.5)#generator generator = Sequential() generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02))) generator.add(LeakyReLU(0.2)) generator.add(Dense(512)) generator.add(LeakyReLU(0.2)) generator.add(Dense(1024)) generator.add(LeakyReLU(0.2)) generator.add(Dense(784, activation='tanh')) generator.compile(loss='binary_crossentropy', optimizer=adam)#discriminator discriminator = Sequential() discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02))) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(512)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(256)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(1, activation='sigmoid')) discriminator.compile(loss='binary_crossentropy', optimizer=adam)
現在我們把發生器和鑑別器結合起來同時訓練。
# Combined network discriminator.trainable = False ganInput = Input(shape=(randomDim,)) x = generator(ganInput) ganOutput = discriminator(x) gan = Model(inputs=ganInput, outputs=ganOutput) gan.compile(loss='binary_crossentropy', optimizer=adam) dLosses = [] gLosses = []
三個函數,每20個epoch繪製並保存結果,並保存模型。
# Plot the loss from each batch def plotLoss(epoch): plt.figure(figsize=(10, 8)) plt.plot(dLosses, label='Discriminitive loss') plt.plot(gLosses, label='Generative loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.savefig('images/gan_loss_epoch_%d.png' % epoch) # Create a wall of generated MNIST images def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)): noise = np.random.normal(0, 1, size=[examples, randomDim]) generatedImages = generator.predict(noise) generatedImages = generatedImages.reshape(examples, 28, 28) plt.figure(figsize=figsize) for i in range(generatedImages.shape[0]): plt.subplot(dim[0], dim[1], i+1) plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r') plt.axis('off') plt.tight_layout() plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch) # Save the generator and discriminator networks (and weights) for later use def saveModels(epoch): generator.save('models/gan_generator_epoch_%d.h5' % epoch) discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)
訓練函數
def train(epochs=1, batchSize=128): batchCount = X_train.shape[0] / batchSize print 'Epochs:', epochs print 'Batch size:', batchSize print 'Batches per epoch:', batchCount for e in xrange(1, epochs+1): print '-'*15, 'Epoch %d' % e, '-'*15 for _ in tqdm(xrange(batchCount)): # Get a random set of input noise and images noise = np.random.normal(0, 1, size=[batchSize, randomDim]) imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)] # Generate fake MNIST images generatedImages = generator.predict(noise) # print np.shape(imageBatch), np.shape(generatedImages) X = np.concatenate([imageBatch, generatedImages]) # Labels for generated and real data yDis = np.zeros(2*batchSize) # One-sided label smoothing yDis[:batchSize] = 0.9 # Train discriminator discriminator.trainable = True dloss = discriminator.train_on_batch(X, yDis) # Train generator noise = np.random.normal(0, 1, size=[batchSize, randomDim]) yGen = np.ones(batchSize) discriminator.trainable = False gloss = gan.train_on_batch(noise, yGen) # Store loss of most recent batch from this epoch dLosses.append(dloss) gLosses.append(gloss) if e == 1 or e % 20 == 0: plotGeneratedImages(e) saveModels(e) # Plot losses from every epoch plotLoss(e)
至此一個簡單的GAN已經完成了,完整的代碼在這裡找到
github/bhaveshgoyal27/mediumblogs/blob/master/KerasMNISTGAN.py
作者:Bhavesh Goyal
deephub翻譯組