官方資源帖!手把手教你在TF2.0中實現CycleGAN,推特上百贊

2021-01-19 量子位

銅靈 發自 凹非寺量子位 出品| 公眾號 QbitAI

CycleGAN,一個可以將一張圖像的特徵遷移到另一張圖像的酷算法,此前可以完成馬變斑馬、冬天變夏天、蘋果變桔子等一顆賽艇的效果。

這行被頂會ICCV收錄的研究自提出後,就為圖形學等領域的技術人員所用,甚至還成為不少藝術家用來創作的工具。

也是目前大火的「換臉」技術的老前輩了。

如果你還沒學會這項厲害的研究,那這次一定要抓緊上車了。

現在,TensorFlow開始手把手教你,在TensorFlow 2.0中CycleGAN實現大法。

這個官方教程貼幾天內收穫了滿滿人氣,獲得了Google AI工程師、哥倫比亞大學數據科學研究所Josh Gordon的推薦,推特上已近600贊。

有國外網友稱讚太棒,表示很高興看到TensorFlow 2.0教程中涵蓋了最先進的模型。

這份教程全面詳細,想學CycleGAN不能錯過這個:

詳細內容

在TensorFlow 2.0中實現CycleGAN,只要7個步驟就可以了。

1、設置輸入Pipeline

安裝tensorflow_examples包,用於導入生成器和鑑別器。

!pip install -q git+https://github.com/tensorflow/examples.git

!pip install -q tensorflow-gpu==2.0.0-beta1import tensorflow as tf

from __future__ import absolute_import, division, print_function, unicode_literalsimport tensorflow_datasets as tfdsfrom tensorflow_examples.models.pix2pix import pix2piximport osimport timeimport matplotlib.pyplot as pltfrom IPython.display import clear_outputtfds.disable_progress_bar()AUTOTUNE = tf.data.experimental.AUTOTUNE

2、輸入pipeline

在這個教程中,我們主要學習馬到斑馬的圖像轉換,如果想尋找類似的數據集,可以前往:

https://www.tensorflow.org/datasets/datasets#cycle_gan

在CycleGAN論文中也提到,將隨機抖動( Jitter )和鏡像應用到訓練集中,這是避免過度擬合的圖像增強技術。

和在Pix2Pix中的操作類似,在隨機抖動中嗎,圖像大小被調整成286×286,然後隨機裁剪為256×256。

在隨機鏡像中嗎,圖像隨機水平翻轉,即從左到右進行翻轉。

dataset, metadata = tfds.load('cycle_gan/horse2zebra',with_info=True, as_supervised=True)train_horses, train_zebras = dataset['trainA'], dataset['trainB']test_horses, test_zebras = dataset['testA'], dataset['testB']

BUFFER_SIZE = 1000BATCH_SIZE = 1IMG_WIDTH = 256IMG_HEIGHT = 256

def random_crop(image):cropped_image = tf.image.random_crop( image, size=[IMG_HEIGHT, IMG_WIDTH, 3]) return cropped_image

# normalizing the images to [-1, 1]def normalize(image):image = tf.cast(image, tf.float32) image = (image / 127.5) - 1 return image

def random_jitter(image):# resizing to 286 x 286 x 3 image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) # randomly cropping to 256 x 256 x 3 image = random_crop(image) # random mirroring image = tf.image.random_flip_left_right(image) return image

def preprocess_image_train(image, label):image = random_jitter(image) image = normalize(image) return image

def preprocess_image_test(image, label):image = normalize(image) return image

train_horses = train_horses.map(preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)train_zebras = train_zebras.map( preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)test_horses = test_horses.map( preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)test_zebras = test_zebras.map( preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)

sample_horse = next(iter(train_horses))sample_zebra = next(iter(train_zebras))

plt.subplot(121)plt.title('Horse')plt.imshow(sample_horse[0] * 0.5 + 0.5)plt.subplot(122)plt.title('Horse with random jitter')plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

plt.subplot(121)plt.title('Zebra')plt.imshow(sample_zebra[0] * 0.5 + 0.5)plt.subplot(122)plt.title('Zebra with random jitter')plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)

3、導入並重新使用Pix2Pix模型

通過安裝tensorflow_examples包,從Pix2Pix中導入生成器和鑑別器。

這個教程中使用的模型體系結構與Pix2Pix中很類似,但也有一些差異,比如Cyclegan使用的是實例規範化而不是批量規範化,比如Cyclegan論文使用的是修改後的resnet生成器等。

我們訓練兩個生成器(G和F)和兩個鑑別器(X和Y)。生成器G架構圖像X轉換為圖像Y,生成器F將圖像Y轉換為圖像X。

鑑別器D_X區分圖像X和生成的圖像X(F(Y)),辨別器D_Y區分圖像Y和生成的圖像Y(G(X))。

OUTPUT_CHANNELS = 3generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

to_zebra = generator_g(sample_horse)to_horse = generator_f(sample_zebra)plt.figure(figsize=(8, 8))contrast = 8plt.subplot(221)plt.title('Horse')plt.imshow(sample_horse[0] * 0.5 + 0.5)plt.subplot(222)plt.title('To Zebra')plt.imshow(to_zebra[0] * 0.5 * contrast + 0.5)plt.subplot(223)plt.title('Zebra')plt.imshow(sample_zebra[0] * 0.5 + 0.5)plt.subplot(224)plt.title('To Horse')plt.imshow(to_horse[0] * 0.5 * contrast + 0.5)plt.show()

plt.figure(figsize=(8, 8))plt.subplot(121)plt.title('Is a real zebra?')plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')plt.subplot(122)plt.title('Is a real horse?')plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')plt.show()

4、損失函數

在CycleGAN中,因為沒有用於訓練的成對數據,因此無法保證輸入X和目標Y在訓練期間是否有意義。因此,為了強制學習正確的映射,CycleGAN中提出了「循環一致性損失」(cycle consistency loss)。

鑑別器和生成器的損失與Pix2Pix中的類似。

LAMBDA = 10

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):real_loss = loss_obj(tf.ones_like(real), real) generated_loss = loss_obj(tf.zeros_like(generated), generated) total_disc_loss = real_loss + generated_loss return total_disc_loss * 0.5

def generator_loss(generated):return loss_obj(tf.ones_like(generated), generated)

循環一致性意味著結果接近原始輸入。

例如將一個句子和英語翻譯成法語,再將其從法語翻譯成英語後,結果與原始英文句子相同。

在循環一致性損失中,圖像X通過生成器傳遞C產生的圖像Y^,生成的圖像Y^通過生成器傳遞F產生的圖像X^,然後計算平均絕對誤差X和X^。

前向循環一致性損失為:

反向循環一致性損失為:

def calc_cycle_loss(real_image, cycled_image):loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image)) return LAMBDA * loss1

初始化所有生成器和鑑別器的的優化:

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

5、檢查點

checkpoint_path = "./checkpoints/train"ckpt = tf.train.Checkpoint(generator_g=generator_g,generator_f=generator_f, discriminator_x=discriminator_x, discriminator_y=discriminator_y, generator_g_optimizer=generator_g_optimizer, generator_f_optimizer=generator_f_optimizer, discriminator_x_optimizer=discriminator_x_optimizer, discriminator_y_optimizer=discriminator_y_optimizer)ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)# if a checkpoint exists, restore the latest checkpoint.if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) print ('Latest checkpoint restored!!')

6、訓練

注意:為了使本教程的訓練時間合理,本示例模型迭代次數較少(40次,論文中為200次),預測效果可能不如論文準確。

EPOCHS = 40

def generate_images(model, test_input):prediction = model(test_input) plt.figure(figsize=(12, 12)) display_list = [test_input[0], prediction[0]] title = ['Input Image', 'Predicted Image'] for i in range(2): plt.subplot(1, 2, i+1) plt.title(title[i]) # getting the pixel values between [0, 1] to plot it. plt.imshow(display_list[i] * 0.5 + 0.5) plt.axis('off') plt.show()

儘管訓練起來很複雜,但基本的步驟只有四個,分別為:獲取預測、計算損失、使用反向傳播計算梯度、將梯度應用於優化程序。

@tf.functiondef train_step(real_x, real_y):# persistent is set to True because gen_tape and disc_tape is used more than # once to calculate the gradients. with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape( persistent=True) as disc_tape: fake_y = generator_g(real_x, training=True) cycled_x = generator_f(fake_y, training=True) fake_x = generator_f(real_y, training=True) cycled_y = generator_g(fake_x, training=True) disc_real_x = discriminator_x(real_x, training=True) disc_real_y = discriminator_y(real_y, training=True) disc_fake_x = discriminator_x(fake_x, training=True) disc_fake_y = discriminator_y(fake_y, training=True) # calculate the loss gen_g_loss = generator_loss(disc_fake_y) gen_f_loss = generator_loss(disc_fake_x) # Total generator loss = adversarial loss + cycle loss total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x) total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y) disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x) disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y) # Calculate the gradients for generator and discriminator generator_g_gradients = gen_tape.gradient(total_gen_g_loss, generator_g.trainable_variables) generator_f_gradients = gen_tape.gradient(total_gen_f_loss, generator_f.trainable_variables) discriminator_x_gradients = disc_tape.gradient( disc_x_loss, discriminator_x.trainable_variables) discriminator_y_gradients = disc_tape.gradient( disc_y_loss, discriminator_y.trainable_variables) # Apply the gradients to the optimizer generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables)) generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables)) discriminator_x_optimizer.apply_gradients( zip(discriminator_x_gradients, discriminator_x.trainable_variables)) discriminator_y_optimizer.apply_gradients( zip(discriminator_y_gradients, discriminator_y.trainable_variables))

for epoch in range(EPOCHS):start = time.time() n = 0 for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)): train_step(image_x, image_y) if n % 10 == 0: print ('.', end='') n+=1 clear_output(wait=True) # Using a consistent image (sample_horse) so that the progress of the model # is clearly visible. generate_images(generator_g, sample_horse) if (epoch + 1) % 5 == 0: ckpt_save_path = ckpt_manager.save() print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path)) print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

7、使用測試集生成圖像

# Run the trained model on the test datasetfor inp in test_horses.take(5):generate_images(generator_g, inp)

8、進階學習方向

在上面的教程中,我們學習了如何從Pix2Pix中實現的生成器和鑑別器進一步實現CycleGAN,接下來的學習你可以嘗試使用TensorFlow中的其他數據集。

你還可以用更多次的迭代改善結果,或者實現論文中修改的ResNet生成器,進行知識點的進一步鞏固。

傳送門

教程地址:

連結因政策隱去,請求助搜尋引擎。關鍵字:TensorFlow,Tutorial,CycleGAN。

GitHub地址:

連結因政策隱去,請求助搜尋引擎。關鍵字:GitHub,TensorFlow,Tutorial,CycleGAN。

相關焦點

  • 22考研|點擊查看,手把手教你查資料~
    沒關係,考研幫手把手教你查詢各方考研信息~
  • 手把手教你NumPy來實現Word2vec
    在實際的訓練中,你應該隨機初始化這些權重(比如使用np.random.uniform())。想要這麼做,把第九第十行注釋掉,把11和12行取消注釋就好。這是通過對y_pred 與在w_c 中的每個上下文詞之間的差的加合來實現的。
  • 手把手教你做一顆氫彈
    氫彈今天小編手把手教你做一顆氫彈,來看具體操作。首先,你最好住在海邊,因為海水裡可以提煉出大量的氘和氚 ,它們都是氫元素的同位素。(氘有兩個中子,氚有三個中子,還有一個中子的氕,自然界中最多的就是氕,水分子裡的氫元素絕大部分都是氕,含有較多氘或氚元素組成的水叫做重水,極其稀少)然後把上一集教大家做好的原子彈拿過來,要是你沒看上集,沒做好原子彈就沒辦法了,趕快回去補課把原子彈做出來。
  • 手把手教你學ELISA、PCR、免疫組化
    1.手把手教你學ELISAELISA的基礎是抗原或抗體的固相化及抗原或抗體的酶標記。結合在固相載體 表面的抗原或抗體仍保持其免疫學活性,酶標記的抗原或抗體既保留其免疫學活性,又保留酶的活性。在這種測定方法中有三個必要的試劑:(1)固相的抗菌素原或抗體,即"免疫吸附劑"(immunosorbent);(2)酶標記的抗原或抗體,稱為"結合物"(conjugate);(3)酶反應的底物。根據試劑的來源和標本的情況以及檢測的具體條件,可設計出各種不同類型的檢測方法。
  • 手把手教你在ppt中設置超連結
    手把手教你在ppt中設置超連結時間:2017-08-05 13:37   來源:三聯   責任編輯:沫朵 川北在線核心提示:原標題:ppt中超連結怎麼添加? 手把手教你在ppt中設置超連結 ppt中超連結怎麼添加?
  • Excel表格中如何繪製稜錐圖 手把手教你在excel2007中插入稜錐圖
    Excel表格中如何繪製稜錐圖 手把手教你在excel2007中插入稜錐圖時間:2017-07-02 14:14   來源:三聯   責任編輯:沫朵 川北在線核心提示:原標題:Excel表格中如何繪製稜錐圖 手把手教你在excel2007中插入稜錐圖 1、打開一個Excel的文件,選中一個單元格,然後滑鼠左鍵單擊菜單【插入】
  • PPT教程:手把手教你高逼格的PPT動畫
    原標題:PPT教程:手把手教你高逼格的PPT動畫 今天河南中公優就業IT培訓小編給大家分享的是一個比賽用的PPT,為了說明水浮蓮(水葫蘆)對水域的汙染十分嚴重,其中一頁PPT裡展示了一組數據
  • 學姐分享|手把手教你免費安裝、激活Office
    前言:請大家支持正版,儘可能選擇官方渠道下載、購買激活碼。以下僅供個人學習、交流使用,請勿傳播。每次要安裝各種大型軟體比如Office時,頭就大了,作為一個工科女,這點小事怎麼難得到我呢?來,小九學姐手把手教你免費安裝、激活office各版本!
  • 手把手教你用PyTorch實現圖像分類器(第一部分)
    通過3篇短文,介紹如何實現圖像分類器的概念基礎——這是一種能夠理解圖像內容的算法。本文的目標不是提供手把手的指導,而是幫助理解整個過程。如果你正在考慮學習機器學習或人工智慧,你將不得不做類似的項目,並理解本系列文章中介紹的概念。文章主要進行概念上的解釋,不需要知道如何編寫代碼。
  • 麥克講堂—手把手教你進行吸附熱分析(20200616)
    本次講堂將手把手教你搞定吸附熱數據。查看「手把手教你進行吸附熱分析」視頻spm=a2hzp.8244740.0.0麥克儀器公司是提供材料表徵解決方案的全球領先廠商,在密度、比表面積及孔隙度、粒度及粒形、粉體表徵、催化劑表徵及工藝開發等五個核心領域擁有一流的儀器和應用技術。
  • 手把手教你如何在WPS表格中求標準差
    手把手教你如何在WPS表格中求標準差時間:2017-08-08 14:42   來源:系統天堂   責任編輯:沫朵 川北在線核心提示:原標題:wps如何求標準差? 手把手教你如何在WPS表格中求標準差 wps如何求標準差?怎麼求一系列數據的標準方差呢?對於很多網友來說,這個還是很難的問題,所以今天小編為大家帶來相關介紹。
  • 手把手教你如何隱藏電腦文件夾
    手把手教你如何隱藏電腦文件夾時間:2017-08-03 19:24   來源:三聯   責任編輯:沫朵 川北在線核心提示:原標題:電腦怎麼隱藏文件夾? 手把手教你如何隱藏電腦文件夾 朋友會向你借電腦,但是電腦上有一些文件又不想讓其他人看到。該怎麼辦呢?有的人把它們放到U盤或移動硬碟,貼身保管;有的人則用軟體進行加密。
  • 手把手教你認石斑——赤點石斑魚
    本期的手把手教你認石斑系列,要教大家認識的是,在咱們餐桌上經常見到的一種石斑魚,說不定你也曾經品嘗過它。   在中國,紅色是代表喜慶的顏色,過年、過節、結婚、賀壽等,人們都喜歡用大紅顏色來裝點,名字中帶有「紅」字的食物自然也是餐桌上的常客。
  • 手把手教你製作ppt日記本
    手把手教你製作ppt日記本時間:2017-07-16 15:06   來源:三聯   責任編輯:沫朵 川北在線核心提示:原標題:ppt怎麼製作筆記本? 手把手教你製作ppt日記本 ppt怎麼製作筆記本?本文介紹了使用ppt製作日記本的方法,製作方法簡單,一起來學習吧!
  • 咔哇熊攜手黑獅子私人健身會所利用微贊直播實現品牌共贏
    打造了一場【猛男教練教你keep fit產後身材】的專題直播,成功吸引了高達10014人次的觀看。直播運營者十分用心地準備了直播頁面中的各類宣傳資料,不放過一絲可以利用的資源。其中,在【詳情】中插入了黑獅子品牌的詳細介紹、限時優惠、聯繫方式等信息。
  • 《最強蝸牛》該隱打法攻略 手把手教你該隱怎麼打
    《最強蝸牛》該隱打法攻略 手把手教你該隱怎麼打時間:2020-10-01 15:47   來源:遊俠網   責任編輯:沫朵 川北在線核心提示:原標題:《最強蝸牛》該隱打法攻略 手把手教你該隱怎麼打 最強蝸牛該隱怎麼打?
  • 我的***煙花炮竹圖文教程 手把手教你煙花炮竹怎麼做
    :原標題:我的****煙花炮竹圖文教程 手把手教你煙花炮竹怎麼做 除夕時我們都會用炮竹驅趕年魔,也叫夕,那麼如果想在我的****中建造一個煙花炮竹應該怎麼做呢?想必各位玩家對此存在許多疑惑,接下來我們一起來跟隨諸葛教科書看看我的****煙花炮竹教程吧。
  • PSoC 4 手把手教你成為魔方大神
    小小立方體,轉動,打亂,再將其復原,玩家在這樣的過程中不斷精進速度,使得解魔方甚至成為了一項著名的競技運動。截至2020年,三階魔方還原官方世界紀錄由中國的杜宇生保持(2018年11月24日於蕪湖賽),單次3.475秒。在美國加州,一款新型魔方「 HEYKUBE 」吸引了我們的注意。
  • 手把手教你如何使用斐波那契回調線
    原標題:手把手教你如何使用斐波那契回調線 總的來說,通過對該數列的探索可以推導出兩組重要的數列——0.191、0.382、0.5、0.618、0.809;1、1.382、1.5、1.618、2、2.382、2.618。這兩組數列中最為重要的是0.382、0.5、0.618、1、1.618五個數字,它們在黃金外匯分析中使用十分廣泛而且效果極佳。
  • 剪映素材庫在哪裡 手把手教你剪映素材庫怎麼用
    剪映素材庫在哪裡 手把手教你剪映素材庫怎麼用時間:2020-07-01 17:09   來源:騰牛網    責任編輯:沫朵 川北在線核心提示:原標題:剪映素材庫在哪裡 手把手教你剪映素材庫怎麼用 剪映是非常好用的剪輯工具,現在很多人都愛用,裡面有很多功能,可以製作出很好的視頻。