WGAN最新進展:從weight clipping到gradient penalty,更加先進的Lipschitz限制手法

2021-02-08 CreateAMind

https://www.zhihu.com/question/52602529/answer/158727900

Wasserstein GAN最新進展:從weight clipping到gradient penalty,更加先進的Lipschitz限制手法

前段時間,Wasserstein  GAN以其精巧的理論分析、簡單至極的算法實現、出色的實驗效果,在GAN研究圈內掀起了一陣熱潮(對WGAN不熟悉的讀者,可以參考我之前寫的介紹文章:令人拍案叫絕的Wasserstein GAN - 知乎專欄)。但是很多人(包括我們實驗室的同學)到了上手跑實驗的時候,卻發現WGAN實際上沒那麼完美,反而存在著訓練困難、收斂速度慢等問題。其實,WGAN的作者Martin Arjovsky不久後就在reddit上表示他也意識到了這個問題,認為關鍵在於原設計中Lipschitz限制的施加方式不對,並在新論文中提出了相應的改進方案:

首先回顧一下WGAN的關鍵部分——Lipschitz限制是什麼。WGAN中,判別器D和生成器G的loss函數分別是:

(公式1)

 (公式2)

公式1表示判別器希望儘可能拉高真樣本的分數,拉低假樣本的分數,公式2表示生成器希望儘可能拉高假樣本的分數。

Lipschitz限制則體現為,在整個樣本空間  上,要求判別器函數D(x)梯度的Lp-norm不大於一個有限的常數K:

 (公式3)

直觀上解釋,就是當輸入的樣本稍微變化後,判別器給出的分數不能發生太過劇烈的變化。在原來的論文中,這個限制具體是通過weight clipping的方式實現的:每當更新完一次判別器的參數之後,就檢查判別器的所有參數的絕對值有沒有超過一個閾值,比如0.01,有的話就把這些參數clip回 [-0.01, 0.01] 範圍內。通過在訓練過程中保證判別器的所有參數有界,就保證了判別器不能對兩個略微不同的樣本給出天差地別的分數值,從而間接實現了Lipschitz限制。

然而weight clipping的實現方式存在兩個嚴重問題:

第一,如公式1所言,判別器loss希望儘可能拉大真假樣本的分數差,然而weight clipping獨立地限制每一個網絡參數的取值範圍,在這種情況下我們可以想像,最優的策略就是儘可能讓所有參數走極端,要麼取最大值(如0.01)要麼取最小值(如-0.01)!為了驗證這一點,作者統計了經過充分訓練的判別器中所有網絡參數的數值分布,發現真的集中在最大和最小兩個極端上:



這樣帶來的結果就是,判別器會非常傾向於學習一個簡單的映射函數(想想看,幾乎所有參數都是正負0.01,都已經可以直接視為一個二值神經網絡了,太簡單了)。而作為一個深層神經網絡來說,這實在是對自身強大擬合能力的巨大浪費!判別器沒能充分利用自身的模型能力,經過它回傳給生成器的梯度也會跟著變差。

在正式介紹gradient penalty之前,我們可以先看看在它的指導下,同樣充分訓練判別器之後,參數的數值分布就合理得多了,判別器也能夠充分利用自身模型的擬合能力:



第二個問題,weight clipping會導致很容易一不小心就梯度消失或者梯度爆炸。原因是判別器是一個多層網絡,如果我們把clipping threshold設得稍微小了一點,每經過一層網絡,梯度就變小一點點,多層之後就會指數衰減;反之,如果設得稍微大了一點,每經過一層網絡,梯度變大一點點,多層之後就會指數爆炸。只有設得不大不小,才能讓生成器獲得恰到好處的回傳梯度,然而在實際應用中這個平衡區域可能很狹窄,就會給調參工作帶來麻煩。相比之下,gradient penalty就可以讓梯度在後向傳播的過程中保持平穩。論文通過下圖體現了這一點,其中橫軸代表判別器從低到高第幾層,縱軸代表梯度回傳到這一層之後的尺度大小(注意縱軸是對數刻度),c是clipping threshold:



說了這麼多,gradient penalty到底是什麼?

前面提到,Lipschitz限制是要求判別器的梯度不超過K,那我們何不直接設置一個額外的loss項來體現這一點呢?比如說:

 (公式4)

不過,既然判別器希望儘可能拉大真假樣本的分數差距,那自然是希望梯度越大越好,變化幅度越大越好,所以判別器在充分訓練之後,其梯度norm其實就會是在K附近。知道了這一點,我們可以把上面的loss改成要求梯度norm離K越近越好,效果是類似的:

 (公式5)

究竟是公式4好還是公式5好,我看不出來,可能需要實驗驗證,反正論文作者選的是公式5。接著我們簡單地把K定為1,再跟WGAN原來的判別器loss加權合併,就得到新的判別器loss:

 (公式6)

這就是所謂的gradient penalty了嗎?還沒完。公式6有兩個問題,首先是loss函數中存在梯度項,那麼優化這個loss豈不是要算梯度的梯度?一些讀者可能對此存在疑惑,不過這屬於實現上的問題,放到後面說。

其次,3個loss項都是期望的形式,落到實現上肯定得變成採樣的形式。前面兩個期望的採樣我們都熟悉,第一個期望是從真樣本集裡面採,第二個期望是從生成器的噪聲輸入分布採樣後,再由生成器映射到樣本空間。可是第三個分布要求我們在整個樣本空間 上採樣,這完全不科學!由於所謂的維度災難問題,如果要通過採樣的方式在圖片或自然語言這樣的高維樣本空間中估計期望值,所需樣本量是指數級的,實際上沒法做到。

所以,論文作者就非常機智地提出,我們其實沒必要在整個樣本空間上施加Lipschitz限制,只要重點抓住生成樣本集中區域、真實樣本集中區域以及夾在它們中間的區域就行了。具體來說,我們先隨機採一對真假樣本,還有一個0-1的隨機數:

(公式7)

然後在 的連線上隨機插值採樣:

 (公式8)

把按照上述流程採樣得到的 所滿足的分布記為 ,就得到最終版本的判別器loss:

 (公式9)

這就是新論文所採用的gradient penalty方法,相應的新WGAN模型簡稱為WGAN-GP。我們可以做一個對比:

論文還講了一些使用gradient penalty時需要注意的配套事項,這裡只提一點:由於我們是對每個樣本獨立地施加梯度懲罰,所以判別器的模型架構中不能使用Batch Normalization,因為它會引入同個batch中不同樣本的相互依賴關係。如果需要的話,可以選擇其他normalization方法,如Layer Normalization、Weight Normalization和Instance Normalization,這些方法就不會引入樣本之間的依賴。論文推薦的是Layer Normalization。

實驗表明,gradient penalty能夠顯著提高訓練速度,解決了原始WGAN收斂緩慢的問題:



雖然還是比不過DCGAN,但是因為WGAN不存在平衡判別器與生成器的問題,所以會比DCGAN更穩定,還是很有優勢的。不過,作者憑什麼能這麼說?因為下面的實驗體現出,在各種不同的網絡架構下,其他GAN變種能不能訓練好,可以說是一件相當看人品的事情,但是WGAN-GP全都能夠訓練好,尤其是最下面一行所對應的101層殘差神經網絡:



剩下的實驗結果中,比較厲害的是第一次成功做到了「純粹的」的文本GAN訓練!我們知道在圖像上訓練GAN是不需要額外的有監督信息的,但是之前就沒有人能夠像訓練圖像GAN一樣訓練好一個文本GAN,要麼依賴於預訓練一個語言模型,要麼就是利用已有的有監督ground truth提供指導信息。而現在WGAN-GP終於在無需任何有監督信息的情況下,生成出下圖所示的英文字符序列:



它是怎麼做到的呢?我認為關鍵之處是對樣本形式的更改。以前我們一般會把文本這樣的離散序列樣本表示為sequence of index,但是它把文本表示成sequence of probability vector。對於生成樣本來說,我們可以取網絡softmax層輸出的詞典概率分布向量,作為序列中每一個位置的內容;而對於真實樣本來說,每個probability vector實際上就蛻化為我們熟悉的onehot vector。

但是如果按照傳統GAN的思路來分析,這不是作死嗎?一邊是hard onehot vector,另一邊是soft probability vector,判別器一下子就能夠區分它們,生成器還怎麼學習?沒關係,對於WGAN來說,真假樣本好不好區分並不是問題,WGAN只是拉近兩個分布之間的Wasserstein距離,就算是一邊是hard onehot另一邊是soft probability也可以拉近,在訓練過程中,概率向量中的有些項可能會慢慢變成0.8、0.9到接近1,整個向量也會接近onehot,最後我們要真正輸出sequence of index形式的樣本時,只需要對這些概率向量取argmax得到最大概率的index就行了。

新的樣本表示形式+WGAN的分布拉近能力是一個「黃金組合」,但除此之外,還有其他因素幫助論文作者跑出上圖的效果,包括:

上面第三點非常有趣,因為它讓我聯想到前段時間挺火的語言學科幻電影《降臨》:



裡面的外星人「七肢怪」所使用的語言跟人類不同,人類使用的是線性的、串行的語言,而「七肢怪」使用的是非線性的、並行的語言。「七肢怪」在跟主角交流的時候,都是一次性同時給出所有的語義單元的,所以說它們其實是一些多層反卷積網絡進化出來的人工智慧生命嗎?




開完腦洞,我們回過頭看,不得不承認這個實驗的setup實在過於簡化了,能否擴展到更加實際的複雜場景,也會是一個問題。但是不管怎樣,生成出來的結果仍然是突破性的。

最後說回gradient penalty的實現問題。loss中本身包含梯度,優化loss就需要求梯度的梯度,這個功能並不是現在所有深度學習框架的標配功能,不過好在Tensorflow就有提供這個接口——tf.gradients。開頭連結的GitHub源碼中就是這麼寫的:

# interpolates就是隨機插值採樣得到的圖像,gradients就是loss中的梯度懲罰項gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]

對於我這樣的PyTorch黨就非常不幸了,高階梯度的功能還在開發,感興趣的PyTorch黨可以訂閱這個GitHub的pull request:Autograd refactor,如果它被merged了話就可以在最新版中使用高階梯度的功能實現gradient penalty了。但是除了等待我們就沒有別的辦法了嗎?其實可能是有的,我想到了一種近似方法來實現gradient penalty,只需要把微分換成差分:

 (公式10)

也就是說,我們仍然是在分布 上隨機採樣,但是一次採兩個,然後要求它們的連線斜率要接近1,這樣理論上也可以起到跟公式9一樣的效果,我自己在MNIST+MLP上簡單驗證過有作用,PyTorch黨甚至Tensorflow黨都可以嘗試用一下。

相關焦點

  • 掀起熱潮的Wasserstein GAN,在近段時間又有哪些研究進展?
    GAN最新進展:從weight clipping到gradient penalty,更加先進的Lipschitz限制手法雷鋒網按:本文作者為中山大學鄭華濱,他在知乎的提問《生成式對抗網絡GAN有哪些最新的發展,可以實際應用到哪些場景中?》
  • 生成對抗網絡的最新研究進展
    計算機視覺和人工智慧的愛好者 Bharath Raj 近日發布以一篇博文,總結了生成對抗網絡的原理、缺點和為了克服這些缺點所做的研究的最新進展。雷鋒網 AI 科技評論編譯整理如下:當然,許多研究人員已經提出了很好的解決方案,以減輕 GAN 網絡訓練中所涉及到的一些問題。然而,這一領域的研究進展速度如此之快,以至於人們來不及去追蹤很多有趣的想法。這個博客列出了一些常用的使 GAN 訓練表現穩定的技術。
  • 資源 | 生成對抗網絡新進展與論文全集
    以下是兩篇原文的連結:GAN 理論&實踐的新進展首先我們看看 Liping Liu 在 github.io 上發布的這篇介紹了 GAN 理論和實踐上的新進展的文章。這篇文章對兩篇 GAN 相關的論文進行了探討;其中第一篇是 Arora et al.
  • 訓練GAN,你應該知道的二三事
    maps 拼接到原來的 feature maps 裡,一個簡單的 tensorflow 實現如下:9.GAN loss除了第二節提到的原始 GANs 中提出的兩種 loss,還可以選擇 wgan loss [12]、hinge loss、lsgan loss [13]等。
  • 簡單易用 TensorFlow 代碼集,GAN通用框架、函數
    Optionpadding='SAME'pad_typesnRaloss_funcganlsganhingewganwgan-gpdragan權重(Weight)weight_init= tf.truncated_normal_initializer(mean=0.0, stddev=0.02)weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)
  • 谷歌工程師:聊一聊深度學習的weight initialization
    背景深度學習模型訓練的過程本質是對weight(即參數 W)進行更新,這需要每個參數有相應的初始值。有人可能會說:「參數初始化有什麼難點?直接將所有weight初始化為0或者初始化為隨機數!」對一些簡單的機器學習模型,或當optimization function是convex function時,這些簡單的方法確實有效。
  • weight normalization 原理和實現
    (注意,替換前 v.v 表示向量點乘,替換後用 w'w 代表向量點乘,和論文保持一致,方便得到後面的矩陣表示。)得到:(9) 最後一行,中間括號部分為一個投影矩陣,將空間中任意一個向量投影到向量 w 的補空間(垂直於 w 的超平面)。
  • GAN(生成對抗網絡)的最新應用狀況
    值得一提的是,作者提出 G 需為保距映射的限制,這使得整個過程的大部分操作可以轉換為求解優化問題,整個修改過程近乎實時。細節比較多,這裡不再展開,請參考文獻 [6],代碼請參考文末的 iGAN。文獻 [7] 設計了一種算法 pix2pix,將 GAN 應用到 image to image translation 上。
  • [PyTorch 學習筆記] 6.1 weight decay 和 dropout
    = net_normal(train_x), net_weight_decay(train_x)    loss_normal, loss_wdecay = loss_func(pred_normal, train_y), loss_func(pred_wdecay, train_y)    optim_normal.zero_grad()    optim_wdecay.zero_grad
  • 還記得Wasserstein GAN嗎?不僅有Facebook參與,也果然被 ICML 接收...
    不過在討論中,還是有人反映 WGAN 存在訓練困難、收斂速度慢等問題,WGAN 論文一作 Martin Arjovsky 也在 reddit 上表示自己意識到了,然後對 WGAN 做了進一步的改進。改進後的論文為「Improved Training of Wasserstein GANs」。
  • Sparse Deep Neural Networks Through L_{1,\infty}-Weight...
    報告人:中國科學技術大學楊周旺教授報告時間:10月20日 上午9:30-10:30報告地點:數學樓一樓第一報告廳報告題目:Sparse Deep Neural Networks Through L_{1,\infty}-Weight Normalization報告摘要:We study L_{1,\infty}-weight
  • Was Sweden robbed of a penalty kick against Germany?
    Sweden players and fans could not believe they were not awarded a penalty for an apparent foul as the referee didn't check VAR and ex-pros were clearly in agreement
  • 深度 | 從修正Adam到理解泛化:概覽2017年深度學習優化算法的最新研究進展
    在這篇博客中,我將介紹深度學習優化算法中幾個最有意義的進展以及最有潛力的方向。這篇博客假定讀者熟悉 SGD 和適應性學習率方法,如 Adam。如果想快速入門,可以查看 Ruder 以前的博客以概覽當前已有的梯度下降優化算法:深度解讀最流行的優化算法:梯度下降,或參見更加基礎與入門的文章:目標函數的經典優化算法介紹。
  • 何愷明團隊:stop gradient是孿生網絡對比學習成功的關鍵
    如果輸出「崩潰」到了常數向量,那麼其每個通道的標準差應當是0,見上圖middle。在無Stop-gradient時,其分類精度僅有0.1%,而添加Stop-gradient後最終分類精度可達67.7%。上述實驗表明:「崩潰」確實存在。但「崩潰」的存在不足以說明所提方法可以避免「崩潰」,儘管上述對比中僅有「stop-gradient」的區別。