銅靈 發自 凹非寺量子位 出品| 公眾號 QbitAI
CycleGAN,一個可以將一張圖像的特徵遷移到另一張圖像的酷算法,此前可以完成馬變斑馬、冬天變夏天、蘋果變桔子等一顆賽艇的效果。
這行被頂會ICCV收錄的研究自提出後,就為圖形學等領域的技術人員所用,甚至還成為不少藝術家用來創作的工具。
也是目前大火的「換臉」技術的老前輩了。
如果你還沒學會這項厲害的研究,那這次一定要抓緊上車了。
現在,TensorFlow開始手把手教你,在TensorFlow 2.0中CycleGAN實現大法。
這個官方教程貼幾天內收穫了滿滿人氣,獲得了Google AI工程師、哥倫比亞大學數據科學研究所Josh Gordon的推薦,推特上已近600贊。
有國外網友稱讚太棒,表示很高興看到TensorFlow 2.0教程中涵蓋了最先進的模型。
這份教程全面詳細,想學CycleGAN不能錯過這個:
詳細內容
在TensorFlow 2.0中實現CycleGAN,只要7個步驟就可以了。
1、設置輸入Pipeline
安裝tensorflow_examples包,用於導入生成器和鑑別器。
!pip install -q git+https://github.com/tensorflow/examples.git
!pip install -q tensorflow-gpu==2.0.0-beta1import tensorflow as tf
from __future__ import absolute_import, division, print_function, unicode_literalsimport tensorflow_datasets as tfdsfrom tensorflow_examples.models.pix2pix import pix2piximport osimport timeimport matplotlib.pyplot as pltfrom IPython.display import clear_outputtfds.disable_progress_bar()AUTOTUNE = tf.data.experimental.AUTOTUNE
2、輸入pipeline
在這個教程中,我們主要學習馬到斑馬的圖像轉換,如果想尋找類似的數據集,可以前往:
https://www.tensorflow.org/datasets/datasets#cycle_gan
在CycleGAN論文中也提到,將隨機抖動( Jitter )和鏡像應用到訓練集中,這是避免過度擬合的圖像增強技術。
和在Pix2Pix中的操作類似,在隨機抖動中嗎,圖像大小被調整成286×286,然後隨機裁剪為256×256。
在隨機鏡像中嗎,圖像隨機水平翻轉,即從左到右進行翻轉。
dataset, metadata = tfds.load('cycle_gan/horse2zebra',with_info=True, as_supervised=True)train_horses, train_zebras = dataset['trainA'], dataset['trainB']test_horses, test_zebras = dataset['testA'], dataset['testB']
BUFFER_SIZE = 1000BATCH_SIZE = 1IMG_WIDTH = 256IMG_HEIGHT = 256
def random_crop(image):cropped_image = tf.image.random_crop( image, size=[IMG_HEIGHT, IMG_WIDTH, 3]) return cropped_image
# normalizing the images to [-1, 1]def normalize(image):image = tf.cast(image, tf.float32) image = (image / 127.5) - 1 return image
def random_jitter(image):# resizing to 286 x 286 x 3 image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) # randomly cropping to 256 x 256 x 3 image = random_crop(image) # random mirroring image = tf.image.random_flip_left_right(image) return image
def preprocess_image_train(image, label):image = random_jitter(image) image = normalize(image) return image
def preprocess_image_test(image, label):image = normalize(image) return image
train_horses = train_horses.map(preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)train_zebras = train_zebras.map( preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)test_horses = test_horses.map( preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)test_zebras = test_zebras.map( preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))sample_zebra = next(iter(train_zebras))
plt.subplot(121)plt.title('Horse')plt.imshow(sample_horse[0] * 0.5 + 0.5)plt.subplot(122)plt.title('Horse with random jitter')plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
plt.subplot(121)plt.title('Zebra')plt.imshow(sample_zebra[0] * 0.5 + 0.5)plt.subplot(122)plt.title('Zebra with random jitter')plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
3、導入並重新使用Pix2Pix模型
通過安裝tensorflow_examples包,從Pix2Pix中導入生成器和鑑別器。
這個教程中使用的模型體系結構與Pix2Pix中很類似,但也有一些差異,比如Cyclegan使用的是實例規範化而不是批量規範化,比如Cyclegan論文使用的是修改後的resnet生成器等。
我們訓練兩個生成器(G和F)和兩個鑑別器(X和Y)。生成器G架構圖像X轉換為圖像Y,生成器F將圖像Y轉換為圖像X。
鑑別器D_X區分圖像X和生成的圖像X(F(Y)),辨別器D_Y區分圖像Y和生成的圖像Y(G(X))。
OUTPUT_CHANNELS = 3generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)to_horse = generator_f(sample_zebra)plt.figure(figsize=(8, 8))contrast = 8plt.subplot(221)plt.title('Horse')plt.imshow(sample_horse[0] * 0.5 + 0.5)plt.subplot(222)plt.title('To Zebra')plt.imshow(to_zebra[0] * 0.5 * contrast + 0.5)plt.subplot(223)plt.title('Zebra')plt.imshow(sample_zebra[0] * 0.5 + 0.5)plt.subplot(224)plt.title('To Horse')plt.imshow(to_horse[0] * 0.5 * contrast + 0.5)plt.show()
plt.figure(figsize=(8, 8))plt.subplot(121)plt.title('Is a real zebra?')plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')plt.subplot(122)plt.title('Is a real horse?')plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')plt.show()
4、損失函數
在CycleGAN中,因為沒有用於訓練的成對數據,因此無法保證輸入X和目標Y在訓練期間是否有意義。因此,為了強制學習正確的映射,CycleGAN中提出了「循環一致性損失」(cycle consistency loss)。
鑑別器和生成器的損失與Pix2Pix中的類似。
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):real_loss = loss_obj(tf.ones_like(real), real) generated_loss = loss_obj(tf.zeros_like(generated), generated) total_disc_loss = real_loss + generated_loss return total_disc_loss * 0.5
def generator_loss(generated):return loss_obj(tf.ones_like(generated), generated)
循環一致性意味著結果接近原始輸入。
例如將一個句子和英語翻譯成法語,再將其從法語翻譯成英語後,結果與原始英文句子相同。
在循環一致性損失中,圖像X通過生成器傳遞C產生的圖像Y^,生成的圖像Y^通過生成器傳遞F產生的圖像X^,然後計算平均絕對誤差X和X^。
前向循環一致性損失為:
反向循環一致性損失為:
def calc_cycle_loss(real_image, cycled_image):loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image)) return LAMBDA * loss1
初始化所有生成器和鑑別器的的優化:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
5、檢查點
checkpoint_path = "./checkpoints/train"ckpt = tf.train.Checkpoint(generator_g=generator_g,generator_f=generator_f, discriminator_x=discriminator_x, discriminator_y=discriminator_y, generator_g_optimizer=generator_g_optimizer, generator_f_optimizer=generator_f_optimizer, discriminator_x_optimizer=discriminator_x_optimizer, discriminator_y_optimizer=discriminator_y_optimizer)ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)# if a checkpoint exists, restore the latest checkpoint.if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) print ('Latest checkpoint restored!!')
6、訓練
注意:為了使本教程的訓練時間合理,本示例模型迭代次數較少(40次,論文中為200次),預測效果可能不如論文準確。
EPOCHS = 40
def generate_images(model, test_input):prediction = model(test_input) plt.figure(figsize=(12, 12)) display_list = [test_input[0], prediction[0]] title = ['Input Image', 'Predicted Image'] for i in range(2): plt.subplot(1, 2, i+1) plt.title(title[i]) # getting the pixel values between [0, 1] to plot it. plt.imshow(display_list[i] * 0.5 + 0.5) plt.axis('off') plt.show()
儘管訓練起來很複雜,但基本的步驟只有四個,分別為:獲取預測、計算損失、使用反向傳播計算梯度、將梯度應用於優化程序。
@tf.functiondef train_step(real_x, real_y):# persistent is set to True because gen_tape and disc_tape is used more than # once to calculate the gradients. with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape( persistent=True) as disc_tape: fake_y = generator_g(real_x, training=True) cycled_x = generator_f(fake_y, training=True) fake_x = generator_f(real_y, training=True) cycled_y = generator_g(fake_x, training=True) disc_real_x = discriminator_x(real_x, training=True) disc_real_y = discriminator_y(real_y, training=True) disc_fake_x = discriminator_x(fake_x, training=True) disc_fake_y = discriminator_y(fake_y, training=True) # calculate the loss gen_g_loss = generator_loss(disc_fake_y) gen_f_loss = generator_loss(disc_fake_x) # Total generator loss = adversarial loss + cycle loss total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x) total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y) disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x) disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y) # Calculate the gradients for generator and discriminator generator_g_gradients = gen_tape.gradient(total_gen_g_loss, generator_g.trainable_variables) generator_f_gradients = gen_tape.gradient(total_gen_f_loss, generator_f.trainable_variables) discriminator_x_gradients = disc_tape.gradient( disc_x_loss, discriminator_x.trainable_variables) discriminator_y_gradients = disc_tape.gradient( disc_y_loss, discriminator_y.trainable_variables) # Apply the gradients to the optimizer generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables)) generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables)) discriminator_x_optimizer.apply_gradients( zip(discriminator_x_gradients, discriminator_x.trainable_variables)) discriminator_y_optimizer.apply_gradients( zip(discriminator_y_gradients, discriminator_y.trainable_variables))
for epoch in range(EPOCHS):start = time.time() n = 0 for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)): train_step(image_x, image_y) if n % 10 == 0: print ('.', end='') n+=1 clear_output(wait=True) # Using a consistent image (sample_horse) so that the progress of the model # is clearly visible. generate_images(generator_g, sample_horse) if (epoch + 1) % 5 == 0: ckpt_save_path = ckpt_manager.save() print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path)) print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))
7、使用測試集生成圖像
# Run the trained model on the test datasetfor inp in test_horses.take(5):generate_images(generator_g, inp)
8、進階學習方向
在上面的教程中,我們學習了如何從Pix2Pix中實現的生成器和鑑別器進一步實現CycleGAN,接下來的學習你可以嘗試使用TensorFlow中的其他數據集。
你還可以用更多次的迭代改善結果,或者實現論文中修改的ResNet生成器,進行知識點的進一步鞏固。
傳送門
教程地址:
連結因政策隱去,請求助搜尋引擎。關鍵字:TensorFlow,Tutorial,CycleGAN。
GitHub地址:
連結因政策隱去,請求助搜尋引擎。關鍵字:GitHub,TensorFlow,Tutorial,CycleGAN。