CycleGAN 論文及源碼解讀

2021-01-20 Xerrors

解決的痛點問題是配對的圖像不好找,所以嘗試找到一個映射函數 G,可以將 X 域上的圖像映射到 Y 域上,由於映射關係沒有約束,很容易出現訓練上的問題,所以訓練了兩個映射函數,另外一個映射函數將 F 域上的圖像映射到 X 上,最終實現的效果就是 F(G(X)) 近似於 X。

原論文首先是提出了一個假設:

We assume there is some underlying relationship between the domains – for example, that they are two different renderings of the same underlying scene – and seek to learn that relationship.

我們假設在不同域之間存在某種潛在的關係(例如:它們是同一基本場景的兩種不同的呈現形式),並試圖了解這種關係。

所以作者想要訓練一個映射關係 G,G 可以將 X 域映射到等同於 Y 分布的域 G(X),但是問題出現了,沒法保證這樣的轉換會將輸入 x 和輸出 y 匹配,同時也沒有辦法進行優化;這裡的問題應該就是使用非配對數據都會面臨的問題;同時在訓練的過程中經常會出現「mode collapse」:會將所有的圖像映射到一個輸出上面。

為了保證每次的輸出都能夠跟輸入有關係,這裡可以理解為保留原本的「潛在關係」,所以作者引入了「循環一致性」原理,也就是要保證 x 經過映射關係 G 之後所得到的 G(x) 可以通過另外一個映射關係 F 映射回來,也就是要滿足F(G(x)) = x,G(F(y))=y。

1. 相關工作

這一部分主要是介紹前人的工作和自己的理解,這裡稍微記錄一下;

作者認為 GANs 成功的關鍵在於「對抗性損失」的理念,同時作者也基於此提出了自己的「循環一致性損失」。

The key to GANs』 success is the idea of an adversarial loss that forces the generated images to be, in principle, indistinguishable from real photos.

GANs成功的關鍵在於「對抗性損失」的理念,這種理念迫使生成的圖像原則上與真實的照片無法區分。

image-20201208115526873

在總結前人在 image-to-image 上的工作的時候,說明自己的網絡跟前人相比,不需要特定的與訓練或者特定的匹配關係,同時也沒有假定輸入和輸出的低維映射空間是一樣的。(我覺得這跟他之前的假設不是相悖的嗎?)

Unlike the above approaches, our formulation does not rely on any task-specific, predefined similarity function between the input and output, nor do we assume that the input and output have to lie in the same low-dimensional embedding space.

與上述方法不同,我們的方法不依賴於輸入和輸出之間任何針對特定任務定製的、預定義的相似性函數,也不假設輸入和輸出必須處於相同的低維嵌入空間。

在涉及到「循環一致性」的時候,作者說明,使用傳遞性作為優化方法的概念由來已久,也已經在很多領域比如翻譯、3D模型匹配有所應用,同時使用也已經有人將循環一致性損失運用在模型訓練中,知識跟我們所不同的是,他們知識利用傳遞性來監督 CNN 的訓練。

Of these, Zhou et al. and Godard et al.  are most similar to our work, as they use a cycle consistency loss as a way of using transitivity to supervise CNN training.

其中Zhou和Godard等人與我們的工作最為相似,他們將循環一致性損失作為一種利用傳遞性來監督CNN訓練的方式。

同時也出現了一個巧合 DualGAN:

Concurrent with our work, in these same proceedings, Yi et al. independently use a similar objective for unpaired image-to-image translation, inspired by dual learning in machine translation.

與我們的工作同時,在同一篇論文中,受機器翻譯中的雙重學習啟發,Yi等人獨立地使用了一個類似的目標用於圖像到圖像的非配對翻譯。

同時作者還對比了 Neural Style Transfer,儘管呈現結果相似,但是 Cycle GAN 所注重的是兩個圖像集之間的映射,所提取的是外觀之外了更高級的特徵。所以該模型也更容易應用到其他的任務上。

2. 數學理論

想要理解這裡的數學概念,主要還是先了解這個圖中的每個部分的功能;

image-20201208115526873

網絡中包含兩個映射函數以及兩個判別器,在訓練映射關係 G 和判別器 DY 的時候,目標函數如下所示(F 和 DX 同理):

損失 cycle consistency loss:

結合在一起可以得到,目標函數如下,其中

整個算法的就是為了解決下面這樣一個公式:

除此之外對於風格轉換,還有一個損失函數 identity loss,作者發現引入額外的損失來激勵生成器映射,可以很好的保留輸入和輸出之間的顏色成分:

3. 網絡結構

網絡架構方面,看論文是沒怎麼看懂;倒是下面這張圖片看的比較明白,圖中展示的是一次單向訓練的過程。

生成器由三個部分完成:

編碼-轉換-解碼」第一部分可以理解為特徵提取(編碼),提取原有圖像的抽象特徵,第二部分是轉換,將特徵從圖像域 A 轉換到圖像域 B,之後通過還原(解碼)變成域 B 上的圖片。所以一般採用兩個頂端相對的提醒來表示生成器的模型;

Generator生成器
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block (256x256x3 -> 256x256x64)
        ## https://zhuanlan.zhihu.com/p/66989411 詳解 nn.ReflectionPad2d
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling, Encoding (256x256x64 -> 128x128x128 -> 64x64x256)
        ## 採用兩個卷積層進行特徵提取
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks, Transformation (64x64x256 -> 64x64x256)
        ## 添加 6 個(默認)殘差模塊進行風格轉換
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling, Decoding (64x64x256 -> 128x128x128 -> 256x256x64)
        ## 將圖像特徵還原到圖像域 B 上,使用兩層逆卷積操作
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer (2556x256x64 -> 256x256x)
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

殘差模塊

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

判別器

就是一個簡單的分類器

discriminator

判別器:

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

4. 訓練過程

對生成器進行判斷的時候,使用了三種損失函數,GAN loss、Identity loss、Cycle loss;一圖勝千言,我畫的太棒了!

訓練生成器的過程

從圖中可以看到對於生成器的訓練的輸入是取自域 A 和域 B 的兩個圖片 Real A、Real B,整個生成器的訓練分為了三個部分,分別是橙色的區域來計算 identity loss,黃色的區域用來計算 Cycle loss,紫色的其餘來計算 GAN loss;對於兩個生成器所生成的 6 個 loss,進行累加作為生成器的 loss,之後進行反向傳播,後面的代碼也能說明這一點。

判別器的訓練方法:

訓練判別器的過程

代碼可以參考這個大佬實現的易讀的 Pytorch 版本:PyTorch-CycleGAN/blob/master/train

參考資料Tensorflow implementation of CycleGAN - GitHub https://github.com/hardikbansal/CycleGAN代碼參考:A clean and readable Pytorch implementation of CycleGAN - GitHub https://github.com/aitorzip/PyTorch-CycleGAN/

相關焦點

  • 異父異母的三胞胎:CycleGAN, DiscoGAN, DualGAN
    如果大家看過去年一篇cvpr的論文,叫pix2pix的話,對這個任務就比較熟悉。就算你們不知道pix2pix,之前有一個很火的,可以把線條畫變成貓的網頁應用,就是用的pix2pix的算法。然而pix2pix的模型是在成對的數據上訓練的,也就是說,對於線條到貓的應用,我們訓練的時候就需要提供一對一對的數據:一個線條畫,和對應的真實的貓圖片。
  • 論文解讀: CycleGAN-VC3
    論文解讀:《CycleGAN-VC3: Examining and Improving CycleGAN-VCs for Mel-spectrogram
  • 成對抗網絡圖像處理工具 CycleGAN
    部分內容來自:品玩下載地址:https://www.oschina.net/p/cyclegan(可點原文連結之美進入)維權聲明:凡是機器人2025發布的文章都會找原文作者授權並給予白名單,若沒有授權到的文章,如涉及版權等問題,請及時聯繫運營者(微信:xiaoejiqiren)我們將第一時間處理,謝謝!
  • 源碼資本張宏江博士連續三年居全球頂級計算機科學家榜單中國大陸...
    源碼資本投資合伙人張宏江博士居中國大陸科學家之首,華人排名第5,全球排名第36。源碼資本投資合伙人 張宏江博士張宏江博士2011年底加入金山,任集團執行董事及執行長(CEO),兼任金山雲的執行長以及獵豹移動、迅雷及世紀互聯的董事,並於2016年12月從金山退休,之後加盟源碼資本任投資合伙人。
  • 萬眾期待:Hinton團隊開源CapsNet源碼
    【導讀】前幾天,Hinton團隊的膠囊網絡論文第一作者Sara Sabour將其源碼在GitHub上開源,其實,該論文「Dynamic Routing Between Capsules」早在去年10月份就已經發表,直到今日,其官方實現終於開源。此前,Hinton一再強調,當前的反向傳播和CNN網絡存在很大的局限性,表明AI的下一代研究方向是「無監督學習」。
  • 系列(二)CycleGAN
    issuu.com/stanislaschaillou/docs/stanislas_chaillou_thesis_(3) Cycle Generative Adversarial Network (CycleGAN)https://www.geeksforgeeks.org/cycle-generative-adversarial-network-cyclegan
  • 源碼張宏江博士連續三年居全球頂級計算機科學家榜單中國大陸第一
    源碼資本投資合伙人張宏江博士居中國大陸科學家之首,華人排名第5,全球排名第36。(CEO),兼任金山雲的執行長以及獵豹移動、迅雷及世紀互聯的董事,並於2016年12月從金山退休,之後加盟源碼資本任投資合伙人。
  • 源碼張宏江博士連續三年居全球頂級計算機科學家榜單中國大陸第一
    源碼資本投資合伙人張宏江博士居中國大陸科學家之首,華人排名第5,全球排名第36。源碼資本投資合伙人 張宏江博士張宏江博士2011年底加入金山,任集團執行董事及執行長(CEO),兼任金山雲的執行長以及獵豹移動、迅雷及世紀互聯的董事,並於2016年12月從金山退休,之後加盟源碼資本任投資合伙人。
  • 源碼資本王星石:消費升級大潮下,如何挖掘獨角獸? | 獵雲網
    源碼資本王星石認為,需求分化、此消彼長是消費升級的主要行為模式,「此消」、「彼長」兩個方向都有機會出現獨角獸。獨角獸來自三個象限,即大眾升級需求、大眾降級需求、小眾升級需求。未來大眾升級需求裡面的偉大公司,很可能隱藏在小眾升級需求市場裡面。文章轉自:源碼資本(ID:sourcecodecapital),作者:王星石。王星石:投資部副總裁,2014年加入源碼資本。
  • 這個項目利用CycleGAN生成不同屬性的神奇寶貝
    惡屬性 → 其他屬性參考連結:https://www.rileynwong.com/blog/2019/5/22/pokemon2pokemon-using-cyclegan-to-generate-pokemon-as-different-elemental-types
  • CVPR2017精彩論文解讀:結合序列學習和交叉形態卷積的3D生物醫學...
    雷鋒網(公眾號:雷鋒網) AI科技評論按:雖然CVPR 2017已經落下帷幕,但對精彩論文的解讀還在繼續下文是宜遠智能的首席科學家劉凱對此次大會收錄的《結合序列學習和交叉形態卷積的3D生物醫學圖像分割》(Joint Sequence Learning and Cross-Modality Convolution for 3D Biomedical Segmentation)一文進行的解讀。
  • 仿微信的IM聊天時間顯示格式(含iOS/Android/Web實現)[圖文+源碼]
    ,效果可媲美微信 [附件下載]》《高仿Android版手機QQ可拖拽未讀數小氣泡源碼 [附件下載]》《Android聊天界面源碼:實現了聊天氣泡、表情圖標(可翻頁) [附件下載]》《高仿Android版手機QQ首頁側滑菜單源碼 [附件下載]》《分享java AMR音頻文件合併源碼,全網最全》《Android版高仿微信聊天界面源碼
  • React源碼之組件的實現與首次渲染
    源碼中的兩種數據結構 貫穿源碼,常見的兩種數據結構,有助於快速閱讀源碼。 ReactElement ] }, }; } } ReactDOM.render( { type: App, props: {}, }, document.getElementById("root") ); ReactDOM.render 先來看下 ReactDOM.render 源碼的執行過程
  • 俯瞰Dubbo全局,閱讀源碼前必須掌握這些!!
    為使更多童鞋受益,現給出開源框架地址:https://github.com/sunshinelyz/mykit-delay既然是要寫深度解析Dubbo源碼的系列專題,我們首先要做的就是搭建一套Dubbo的源碼環境,正所謂「工欲善其事,必先利其器」。
  • ——源碼太陽能路燈
    以上就是關於高效太陽能板的特點是什麼的講解,源碼太陽能路燈採用高效單晶A級組件,轉化率大於19%,通過了權威檢測認證機構的測試與認證:德國TUV 和中 國質量認證中心CQC,歐盟CE,ROHS等認證。
  • 直播系統源碼開發:關於安卓開發工具和obs直播推流
    隨著移動網際網路技術的不算發展,直播系統源碼不再局限於娛樂直播的範疇尤其對於今年來說,購物直播行業的迅速發展,對直播系統源碼開發的需求進一步擴大,同時對直播源碼開發技術也有了新的要求。
  • java任務調度之Timer定時器(案例和源碼分析)
    對此就有必要深入其源碼看看了。二、Timer源碼分析對於一個類的源碼分析,我一貫的思路就是先從參數開始,然後構造方法,最後就是常用方法。下面我們就按照這個思路開始今天的源碼分析,在這裡基於jdk1.8。
  • 直播帶貨app源碼用Java語言來開發有哪些好處?
    而直播帶貨APP源碼的開發十分的重要,且在目前來看,最常用的還是Java語言,那麼相比較於其他語言開發,Java語言的直播帶貨系統有什麼優勢呢?下面就由小編為大家介紹吧。 一、源碼獨立性 Java開發直播帶貨APP源碼可以給企業自主搭建的權利,無需通過第三方平臺交易,不再依賴第三方平臺的流量。
  • ACL 2020 復旦大學系列論文解讀開始了!
    AI 科技評論聯合復旦大學,推出「ACL 2020 - 復旦大學系列論文解讀」。 AI科技評論 x 復旦大學4月3日,備受矚目的 NLP 頂會 ACL 2020 公布了錄用論文情況。本屆 ACL 共收到 3429 篇投稿,相比於去年增加了500多篇,儘管接收論文數量還沒有統計,但相比於去年必然將只多不少。
  • SpringSecurity 默認表單登錄頁展示流程源碼
    :f520875f-ea2b-4b5d-9b0c-f30c0c17b90b登錄成功並且瀏覽器又會重定向到剛剛訪問的接口2.springSecurityFilterchain 過濾器鏈如果你看過我另一篇關於SpringSecurity初始化源碼的博客