【CVPR Oral】TensorFlow實現StarGAN代碼全部開源,1天訓練完

2021-02-20 新智元


  新智元編譯  

來源:github

作者:Junho Kim  編譯:肖琴

【新智元導讀】StarGAN 是去年 11 月由香港科技大學、新澤西大學和韓國大學等機構的研究人員提出的一個圖像風格遷移模型,是一種可以在同一個模型中進行多個圖像領域之間的風格轉換的對抗生成方法。近日,有研究人員將 StarGAN 在 TensorFlow 上實現的全部代碼開源,相關論文獲 CVPR 2018 Oral。

開源地址:https://github.com/taki0112/StarGAN-Tensorflow

StarGAN 是去年 11 月由香港科技大學、新澤西大學和韓國大學等機構的研究人員提出的一個圖像風格遷移模型,是一種可以在同一個模型中進行多個圖像領域之間的風格轉換的對抗生成方法。近日,有研究人員將 StarGAN 在 TensorFlow 上實現的全部代碼開源,相關論文獲 CVPR 2018 Oral。

開源地址:https://github.com/taki0112/StarGAN-Tensorflow

作者:Junho Kim 

看代碼之前,我們先來回顧一下 StarGAN 的原始論文。

圖像到圖像轉換(image-to-image translation)這個任務是指改變給定圖像的某一方面,例如,將人的面部表情從微笑改變為皺眉。在引入生成對抗網絡(GAN)之後,這項任務有了顯著的改進,包括可以改變頭髮顏色,改變風景圖像的季節等等。

給定來自兩個不同領域的訓練數據,這些模型將學習如何將圖像從一個域轉換到另一個域。我們將屬性(attribute)定義為圖像中固有的有意義的特徵,例如頭髮顏色,性別或年齡等,並且將屬性值(attribute value)表示為屬性的一個特定值,例如頭髮顏色的屬性值可以是黑色 / 金色 / 棕色,性別的屬性值是男性 / 女性。我們進一步將域(domain)表示為共享相同屬性值的一組圖像。例如,女性的圖像可以代表一個 domain,男性的圖像代表另一個 domain。

一些圖像數據集帶有多個標籤屬性。例如,CelebA 數據集包含 40 個與頭髮顏色、性別和年齡等面部特徵相關的標籤,RaFD 數據集有 8 個面部表情標籤,如 「高興」、「憤怒」、「悲傷」 等。這些設置使我們能夠執行更有趣的任務,即多域圖像到圖像轉換(multi-domain image-to-image translation),即根據來自多個域的屬性改變圖像。

圖 1:通過從 RaFD 數據集學習遷移知識,應用到 CelebA 的多域圖像到圖像轉換結果。第一列和第六列顯示輸入圖像,其餘列是產生的 StarGAN 圖像。注意,圖像是由一個單一模型網絡生成的,面部表情標籤如生氣、高興、恐懼是從 RaFD 學習的,而不是來自 CelebA。

在圖 1 中,前 5 列顯示了一個 CelebA 的圖像是如何根據 4 個域(「金髮」、「性別」、「年齡」 和 「白皮膚」)進行轉換。我們可以進一步擴展到訓練來自不同數據集的多個域,例如聯合訓練 CelebA 和 RaFD 圖像,使用在 RaFD 上訓練的特徵來改變 CelebA 圖像的面部表情,如圖 1 最右邊的列所示。

然而,現有模型在這種多域圖像轉換任務中既效率低,效果也不好。它們的低效性是因為在學習 k 個域之間的所有映射時,必須訓練 k(k-1)個生成器。圖 2 說明了如何訓練 12 個不同的生成器網絡以在 4 個不同的域中轉換圖像。

圖 2: StarGAN 模型與其他跨域模型的比較。(a)為處理多個域,應該在每兩個域之間都建立跨域模型。(b)StarGAN 用單個生成器學習多域之間的映射。該圖表示連接多個域的拓撲圖。

為了解決這類問題,我們提出了 StarGAN,這是一個能夠學習多個域之間映射的生成對抗網絡。如圖 2(b) 所示,我們的模型接受多個域的訓練數據,僅使用一個生成器就可以學習所有可用域之間的映射。

這個想法很簡單。我們的模型不是學習固定的轉換(例如,將黑頭髮變成金色頭髮),而是將圖像和域信息作為輸入,學習將輸入的圖像靈活地轉換為相應的域。我們使用一個標籤來表示域信息。在訓練過程中,我們隨機生成一個目標域標籤,並訓練模型將輸入圖像轉換為目標域。這樣,我們可以控制域標籤並在測試階段將圖像轉換為任何想要的域。

我們還介紹了一種簡單但有效的方法,通過在域標籤中添加一個掩碼向量(mask vector)來實現不同數據集域之間的聯合訓練。我們提出的方法可以確保模型忽略未知的標籤,並關注特定數據集提供的標籤。這樣,我模型就可以很好地完成任務,比如利用從 RaFD 中學到的特徵合成 CelebA 圖像的面部表情,如圖 1 最右邊的列所示。據我們所知,這是第一個在不同的數據集上成功地完成多域圖像轉換的工作。

總結而言,這個研究的貢獻如下:

提出 StarGAN,這是一個新的生成對抗網絡,只使用一個生成器和一個鑑別器來學習多個域之間的映射,能有效地利用所有域的圖像進行訓練。

演示了如何通過使用 mask vector 來學習多個數據集之間的多域圖像轉換,使 StarGAN 能夠控制所有可用的域標籤。

使用 StarGAN 在面部屬性轉換和面部表情合成任務提供了定性和定量的結果,優於 baseline 模型

圖 3:StarGAN 的概觀,包含兩個模塊:一個鑑別器 D 和一個生成器 G。(a)D 學習區分真實圖像和假圖像,並將真實圖像分類到相應的域。(b)G 接受圖像和目標域標籤作為輸入並生成假圖像。 (c)G 嘗試在給定原始域標籤的情況下,從假圖像中重建原始圖像。(d)G 嘗試生成與真實圖像非常像的假圖像,並通過 D 將其分類為目標域。

圖4:CelebA 數據集上面部屬性轉換的結果對凱勒巴數據集。第1列顯示輸入圖像,後4列顯示單個屬性轉換的結果,最右邊的列顯示多個屬性的轉換結果。H:頭髮的顏色;G:性別;A:年齡

圖5:RaFD 數據集上面部表情合成的結果

圖6:StarGAN-SNG 和 StarGAN-JNT 在 CelebA 數據集上的面部表情合成結果。

要求:

> python download.py celebA

下載數據集

> python download.py celebA

訓練

測試

預訓練模型

結果 (128x128, wgan-gp)

女性

男性

預訓練權重:https://drive.google.com/open?id=1ezwtU1O_rxgNXgJaHcAynVX8KjMt0Ua-

訓練時間:少於 1 天

硬體:GTX 1080Ti

閱讀更多:【明星自動大變臉】最新 StarGAN 對抗生成網絡實現多領域圖像變換(附代碼)


【加入社群】

新智元 AI 技術 + 產業社群招募中,歡迎對 AI 技術 + 產業落地感興趣的同學,加小助手微信號: aiera2015_3  入群;通過審核後我們將邀請進群,加入社群後務必修改群備註(姓名 - 公司 - 職位;專業群審核較嚴,敬請諒解)。

相關焦點

  • TensorFlow 2.1指南:keras模式、渴望模式和圖形模式(附代碼)
    Keras模式import numpy as npimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras.layers import Input, Dense, Flatten, Conv2Dfrom tensorflow.keras
  • 令人困惑的 TensorFlow!(II)
    絕大多數情況下,名稱會自動創建;例如,一個常量節點會以 Const 命名,當創建更多常量節點時,其名稱將是 Const_1,Const_2 等。還可以通過 name=的屬性設置節點名稱,列舉後綴仍會自動添加:代碼:import tensorflow as tfa = tf.constant(0.)b = tf.constant(1.)
  • 詳解深度強化學習展現TensorFlow 2.0新特性(代碼)
    本文完整代碼資源連結:GitHub:https://github.com/inoryy/tensorflow2-deep-reinforcement-learningGoogle Colab:https://colab.research.google.com/drive/12QvW7VZSzoaF-Org-u-N6aiTdBN5ohNA
  • 生成「貓狗版」川普,造假臉工具StarGANv2被玩壞,算法現已開源 | CVPR 2020
    最近他們在GitHub上公布了官方實現代碼,很快就被網友玩壞了。StarGANv2有兩個訓練數據集,一個針對人臉,一個針對動物臉。兩者本來是「井水不犯河水」,但是有人偏偏要混用,拿川普的臉輸入到動物臉預訓練模型裡。結果川普的臉轉換成貓就成了這樣:
  • 【TF秘籍】令人困惑的 TensorFlow!(II)
    絕大多數情況下,名稱會自動創建;例如,一個常量節點會以 Const 命名,當創建更多常量節點時,其名稱將是 Const_1,Const_2 等。還可以通過 name=的屬性設置節點名稱,列舉後綴仍會自動添加:代碼:import tensorflow as tfa = tf.constant(0.)b = tf.constant(1.)
  • 令人困惑的TensorFlow!
    但讓我沒想到的是,學習曲線相當的陡峭,甚至在加入該項目幾個月後,我還偶爾對如何使用 TensorFlow 代碼來實現想法感到困惑。我把這篇博文當作瓶中信寫給過去的自己:一篇我希望在學習之初能被給予的入門介紹。我希望這篇博文也能幫助到其他人。以往的教程缺少了那些內容?自 TensorFlow 發布的三年以來,其已然成為深度學習生態系統中的一塊基石。
  • TensorFlow 2.X,會是它走下神壇的開始嗎?
    這樣的開原始碼,即使到現在,很多最新的前沿模型,尤其是谷歌大腦的各項研究,仍然採用的 1.X 的寫法與 API。比如說,預訓練語言模型 T5、Albert、Electra 或者圖像處理模型 EfficientNet 等等。他們實際上還是用 1.X 那一套方法寫的,只不過能兼容 TensorFlow 2.X。
  • TensorFlow Recommenders 現已開源,讓推薦系統更上一層樓!
    import tensorflow as tfimport tensorflow_datasets as tfdsimport tensorflow_recommenders as tfrs# Ratings data.
  • TensorFlow2.0-Keras入門
    最近在知乎上寫一些學習tensorflow2.0的筆記心得,整理成中文教程,希望幫助想學習tensorflow2的朋友更好的了解tensorflow2的同時,也是倒逼自己更好的學習。我始終相信:最好的學習是輸出知識,最好的成長是共同成長。希望可以通過這個公眾號和廣大的深度學習愛好者,一起學習成長。
  • 使用Amazon SageMaker 運行基於 TensorFlow 的中文命名實體識別
    本文將使用預訓練語言模型ALBERT做中文命名實體識別,該項目基於開源的代碼修改而來(本文代碼見參考資料1,原始代碼見參考資料2),使用TensorFlow框架開發,在下一節,我們將展示如何在Amazon SageMaker中進行該模型的訓練。
  • TensorFlow 2.1.0-rc2發布
    pip install tensorflow安裝的TF默認也有gpu支持了(windows、linux),pip install tensorflow-gpu仍舊可以用,如果考慮到包的大小,也可以用只支持CPU的:tensorflow-cpuWindows用戶:為了使用新的/d2ReducedOptimizeHugeFunctions
  • TensorFlow 2.4來了:上線對分布式訓練和混合精度的新功能支持
    像單工作器的 MirroredStrategy 一樣,MultiWorkerMirroredStrategy 通過同步數據並行實現分布式訓練,顧名思義,藉助 MultiWorkerMirroredStrategy 可以在多臺機器上進行訓練,每臺機器都可能具有多個 GPU。
  • 【TF專輯 1】TensorFlow在Win10上的安裝教程和簡單示例
    5.按照提示,激活之:activate tensorflow 想切換到哪個環境就 activate哪個~ 這篇文章既然是安裝tensorflow的,當然要avtivate tensorflow! 三、TensorFlow安裝 1.按照官網的指示: 安裝CPU版本輸入pipinstall --ignore-installed --upgrade tensorflow安裝GPU版本輸入pipinstall --ignore-installed --upgrade tensorflow-gpu
  • 入門 | 關於TensorFlow,你應該了解的9件事
    使用 TensorFlow.js 在瀏覽器中執行實時人體姿態估計。打開你的相機試一下?1.6 秒計算時間?是的!香蕉識別率超過 97%?是的!#7:專用硬體更強勁如果你已經厭倦了在訓練神經網絡過程中需要等待 CPU 完成數據處理,那麼現在你可以使用專門為 Cloud TPU 設計的硬體,T 即 tensor。就像 TensorFlow……巧合嗎?我認為不是!不久前,谷歌在 alpha 版中發布了第三版 TPU。
  • mnist tensorflow 預測專題及常見問題 - CSDN
    實驗大致步驟如下,加載MNIST數據集,同時初始化網絡超參; 建立計算圖; 建立Session會話,執行計算圖進行AlexNet模型訓練和結果預測(訓練模型和評估模型)。實現代碼如下, 1 #coding=utf-8 2 from __future__ import print_function 3 4 from tensorflow.examples.tutorials.mnist import input_data 5 mnist = input_data.read_data_sets("/
  • 使用Tensorflow實現RNN-LSTM的菜鳥指南
    Tensorflow和其他各種庫(Theano,Torch,PyBrain)為用戶提供了設計模型的工具,而沒有深入了解實現神經網絡,優化或反向傳播算法的細節。Danijar概述了組織Tensorflow模型的好方法,您可能希望稍後使用它來整理代碼。出於本教程的目的,我們將跳過這一點,並專注於編寫正常工作的代碼。首先導入所需的包。
  • 如何在PyTorch和TensorFlow中訓練圖像分類模型
    數據集的標準拆分用於評估和比較模型,其中60,000張圖像用於訓練模型,而單獨的10,000張圖像集用於測試模型。現在,我們也了解了數據集。因此,讓我們在PyTorch和TensorFlow中使用CNN構建圖像分類模型。我們將從PyTorch中的實現開始。我們將在google colab中實現這些模型,該模型提供免費的GPU以運行這些深度學習模型。
  • 能看破並說破一切的TensorFlow
    預訓練模型中使用的各種架構如下表所示:MobileNet-SSDSSD框架是一個單一卷積網絡,可學習預測邊界框位置並對其一次性分類。因此,可以端到端地訓練SSD。k個邊界框具有各自的預設形狀,這些形狀是在訓練前就設定好的。例如,在上圖中有4個框,那麼k=4。
  • 如何用Tensorflow object-detection API訓練模型,找到聖誕老爺爺?
    本文將教會你如何通過Tensorflow object-detection API訓練自己的目標檢測模型(object detector),來找到聖誕老人。本文的代碼可見於github:https://github.com/turnerlabs/character-finder代碼產生的模型可被延伸用於抓取其他的動畫或者真實人物。
  • 【致敬周杰倫】基於TensorFlow讓機器生成周董的歌詞(附源碼)
    建立模型主要分為三步:確定好編碼器和解碼器中cell的結構,即採用什麼循環單元,多少個神經元以及多少個循環層;將輸入數據轉化成tensorflow的seq2seq.rnn_decoder需要的格式,並得到最終的輸出以及最後一個隱含狀態;將輸出數據經過softmax層得到概率分布,並且得到誤差函數,確定梯度下降優化器;由於tensorflow