【前沿】NIPS2017貝葉斯生成對抗網絡TensorFlow實現(附GAN資料下載)

2021-03-02 專知
導讀

今年五月份康奈爾大學的 Andrew Gordon Wilson 和 Permutation Venture 的 Yunus Saatchi 提出了一個貝葉斯生成對抗網絡(Bayesian GAN),結合貝葉斯和對抗生成網絡,提出了一個實用的貝葉斯公式框架,用GAN來進行無監督學習和半監督式學習。論文《Bayesian GAN》也被2017年機器學習頂級會議 NIPS 接受,今天Andrew Gordon Wilson在Twitter上發布消息開源了這篇論文的TensorFlow實現,並且Google GAN之父 Ian Goodfellow 轉發這條推文,讓我們來看下。

摘要

生成式對抗網絡(GANs)能在不知不覺中學習圖像、聲音和數據中的豐富分布。這些分布通常因為具有明確的相似性,所以很難去建模。在這篇論文中,我們提出了一個實用的貝葉斯公式,通過使用GAN來進行無監督學習和半監督式學習。在這一框架之下,使用動態的梯度漢密爾頓蒙特卡洛(Hamiltonian Monte Carlo)來將生成網絡和判別網絡中的權重最大化。提出的方法可以非常直接的獲得最後的結果,並且在不需要任何標準的幹預,比如特徵匹配或者mini-batch discrimination的情況下,都獲得了良好的表現。通過對生成器中的參數部署一個具有表達性的後驗機制。貝葉斯生成式對抗網絡能夠避免模式碰撞,產生可判斷的、多樣化的候選樣本,並且提供在既有的一些基準測試上,能夠提供最好的半監督學習量化結果,比如,SVHN, CelebA 和 CIFAR-10,其效果遠遠超過 DCGAN, Wasserstein GANs 和 DCGAN 等等。

TensorFlow實現的貝葉斯生成對抗網絡Contents

簡介

python 依賴包

訓練參數

使用方法

安裝

合成數據

例子: MNIST, CIFAR10, CelebA, SVHN

自定義數據

簡介

貝葉斯生成對抗網絡中我們提出了使用條件後驗分布來建模生成器和判別器的權重參數,隨後使用了動態的梯度漢密爾頓蒙特卡洛(Hamiltonian Monte Carlo)來將生成網絡和判別網絡中的權重最大化。貝葉斯方法用在生成對抗網絡主要有一下幾個特性:(1),能夠提供很好的半監督學習量化結果。(2),對效果的影響比較小。(3), 可以通過估計概率GAN的邊際相似性;(4),它不容易遭受模型失效(mode collapse)的風險;(5)一個包含針對數據互補的多生成和判別模型,可以形成一個概率集成(ensemble)。

我們展示了在生成器參數上的多模後驗。每種參數設定都和不同的數據生成假設相對應。上圖顯示了對應兩種不同手寫風格的參數設定而產生的樣本。這個貝葉斯生成對抗網絡保留了在參數上的全概率分布。相反,標準的生成對抗網絡使用點估計(類似於單個最大似然估計)來表示這個全概率分布,這樣會丟失一些潛在的並重要的數據解釋。

python 依賴包

這個代碼包含以下依賴包 (版本號非常重要):

python 2.7

tensorflow==1.0.0

在Linux上安裝tensorflow 1.0.0可以參考官方指南 https://www.tensorflow.org/versions/r1.0/install/.

bayesian_gan_hmc.py 包含以下訓練選項.

--out_dir: 輸出目錄

--n_save: 每次保存的樣本和參數的數量 n_save 是迭代次數; 默認為 100

--z_dim: 生成器中 z 向量的維度 ;默認為100

--data_path: 數據目錄; 這個路徑是必須的

--dataset: 數據集可以是 mnist, cifar, svhn or celeb; 默認為 mnist

--gen_observed: 被生成器「觀察」到的數據 ; 這會影響到噪聲離散的尺度和先驗,默認為1000

--batch_size: 一次訓練的批量數 ;默認 64

--prior_std: 權重先驗的標準差;默認為1

--numz: 與論文中的J參數一樣; 參數 z 需要整合的樣本數; 默認 1

--num_mcmc: 與論文中的M參數一樣; 每個zde 蒙特卡洛 NN權重樣本; 默認是1

--lr: Adam 優化器的學習率; 默認 0.0002

--optimizer: 優化方法: adam (tf.train.AdamOptimizer) 或者 sgd (tf.train.MomentumOptimizer); 默認使用 adam

--semi_supervised: 進行半監督學習

--N: 進行半監督學習的標註樣本數

--train_iter: 訓練迭代次數; 默認 50000

--save_samples: 訓練中保存生成樣本

--save_weights: 訓練中保存生成權重

--random_seed: 隨機種子;注意如果使用了GPU,因為這個操作結果不能做到%100復現

你可以使用--wasserstein來運行WGANs 或者使用 --ml_ensemble <num_dcgans>來訓練多個 DCGANs 的集成. 此外你還可以使用-ml_ensemble 1來訓練DCGAN

使用方法安裝

安裝要求的依賴集

克隆代碼倉庫

合成數據

為了能再論文中提到的合成數據上運行你可以使用T bgan_synth 腳本. 比如,下面的命令訓練 貝葉斯生成對抗網絡(with D=100 and d=10)迭代 5000 詞並將結果保存在 <results_path>.

`./bgan_synth.py --x_dim 100 --z_dim 10 --numz 10 --out \<results_path\>

`在此數據集上運行 ML GAN可以運行

`./bgan_synth.py --x_dim 100 --z_dim 10 --numz 1 --out \<results_path\>

bgan_synth有--save_weights,--out_dir,--z_dim,--numz,--wasserstein,--train_iter以及--x_dim這些參數.x_dim控制觀測數據的維度 (也就是論文中的x` ).

如果你運行了以上兩條命令後你會看到每100次迭代的輸出結果 <results_path>. 舉例來說貝葉斯生成對抗網絡在第900次迭代的結果如下圖:

對比來說標準 GAN (對應於numz=1, 使用最大似然估計) 產生的結果如下:

上面的圖展示了標準GAN容易遇到模型失效(mode collapse)而我們提出的 Bayesian GAN則可以避免這種情況。

為了進一步探究合成的數據, 同時生成JS散度 ,你可以運行 synth.ipynb.

MNIST, CIFAR10, CelebA, SVHN

bayesian_gan_hmc script allows to train the model on standard and custom datasets. Below we describe the usage of this script.

數據準備

為了重現在 MNIST, CIFAR10, CelebA 和 SVHN 數據集上的實驗,你需要使用正確的--data_path來準備數據.

對於 MNIST你不需要預處理數據,可以指定任意的 --data_path;

對於 CIFAR10 你需要從https://www.cs.toronto.edu/kriz/cifar.htmlPython處理的數據please下載並解壓出適合 download ;

對於 SVHN數據, 從http://ufldl.stanford.edu/housenumbers/下載 train_32x32.mat 和 test_32x32.mat 文件

對於CelebA數據,你需要首先安裝 openCV. 可以從這個連結來下載數據http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. 首先創建一個包含「 Anno 和 img_align_celeba 子目錄的目錄celebA folder 」Anno 『 必須包含list_attr_celeba.txt ,而 img_align_celeba 必須包含 .jpg 文件. 你還需要使用 datasets/crop_faces.py 腳本來裁剪圖片, 其中包含參數 --data_path <path> 來指定』celebA『的目錄。

無監督訓練

你可以通過運行不包含--semi 參數的bayesian_gan_hmc 腳本來訓練無監督版本的訓練,. 比如使用:

`./bayesian_gan_hmc.py --data_path \<data_path\> --dataset svhn --numz 1 --num_mcmc 10 --out_dir \<results_path\> --train_iter 75000 --save_samples --n_save 100

在SVHN 數據集上訓練模型. 這條命令將迭代75000次並且每100次迭代保存一次樣本。 這裡的必須指向結果產生的目錄.

半監督訓練

你可以用腳本帶--semi 選項的bayesian_gan_hmc 腳本來訓練半監督版本的模型。 用 -N 參數來設定需要訓練的標註樣本數目。比如運行:

`./bayesian_gan_hmc.py --data_path \<data_path\> --dataset cifar --numz 1 --num_mcmc 10--out_dir \<results_path\> --train_iter 75000 --N 4000 --semi --lr 0.00005

在 CIFAR10 數據集上使用 4000 標註樣本來訓練模型. 這條命令將迭代75000次訓練模型,並將結果保存在` 文件夾中.

為了在MNIST數據集上使用200個標註樣本訓練模型你可以使用以下命令:

`./bayesian_gan_hmc.py --data_path \<data_path\>/ --dataset mnist --numz 5 --num_mcmc 5--out_dir \<results_path\> --train_iter 30000 -N 200 --semi --lr 0.001

自定義數據

為了在自定義的數據集上訓練模型,你需要為每一個分類定義特定的接口。比如你想在 digits(http://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) 數據集上訓練模型.,這個數據集包含8x8的數字圖片。假設數據被分別存儲在x_tr.npy, y_tr.npy, x_te.npy and y_te.npy 文件中,我們認為 x_tr.npy and x_te.npy 的大小為 (?, 8, 8, 1). 隨後我們可以在bgan_util.py 中定義針對這個數據集類:

`class Digits:def __init__(self):self.imgs = np.load('x_tr.npy') self.test_imgs = np.load('x_te.npy')self.labels = np.load('y_tr.npy')self.test_labels = np.load('y_te.npy')self.labels = one_hot_encoded(self.labels, 10)self.test_labels = one_hot_encoded(self.test_labels, 10) self.x_dim = [8, 8, 1](#)self.num_classes = 10@staticmethoddef get_batch(batch_size, x, y): """Returns a batch from the given arrays."""idx = np.random.choice(range(x.shape[0](#)), size=(batch_size,), replace=False)return x[idx](#), y[idx](#)def next_batch(self, batch_size, class_id=None):return self.get_batch(batch_size, self.imgs, self.labels)def test_batch(self, batch_size):return self.get_batch(batch_size, self.test_imgs, self.test_labels)

這個類必須有next_batch和test_batch等函數, 同時要包含imgs,labels,test_imgs,test_labels,x_dim以及num_classes` 屬性.

這時候我們就可以引入 Digits 類到 bayesian_gan_hmc.py中了

`from bgan_util import Digits

同時可以在--dataset` 參數中添加如下行

`if args.dataset == "digits":dataset = Digits()

` 在準備工作結束後,我們可以用下面命令來訓練模型

`./bayesian_gan_hmc.py --data_path \<any_path\> --dataset digits --numz 1 --num_mcmc 10 --out_dir \<results path\> --train_iter 5000 --save_samples

聲明

感謝Pavel Izmailov對代碼進行的壓力測試,並且寫出這份教程。

參考網址連結:

代碼:https://github.com/andrewgordonwilson/bayesgan

論文:https://arxiv.org/abs/1705.09558

請關注專知公眾號(掃一掃最下面專知二維碼,或者點擊上方藍色專知),

請登錄專知,獲取GAN知識資料,請PC登錄www.zhuanzhi.ai或者點擊閱讀原文,頂端搜索「GAN」 主題,查看獲得對應主題專知薈萃全集知識等資料!如下圖所示~

歡迎轉發到你的微信群和朋友圈,分享專業AI知識!


更多專知薈萃知識資料全集獲取,請查看:

【專知薈萃01】深度學習知識資料大全集(入門/進階/論文/代碼/數據/綜述/領域專家等)(附pdf下載)

【專知薈萃02】自然語言處理NLP知識資料大全集(入門/進階/論文/Toolkit/數據/綜述/專家等)(附pdf下載)

【專知薈萃03】知識圖譜KG知識資料全集(入門/進階/論文/代碼/數據/綜述/專家等)(附pdf下載)

【專知薈萃04】自動問答QA知識資料全集(入門/進階/論文/代碼/數據/綜述/專家等)(附pdf下載)

【專知薈萃05】聊天機器人Chatbot知識資料全集(入門/進階/論文/軟體/數據/專家等)(附pdf下載)

【專知薈萃06】計算機視覺CV知識資料大全集(入門/進階/論文/課程/會議/專家等)(附pdf下載)

【專知薈萃07】自動文摘AS知識資料全集(入門/進階/代碼/數據/專家等)(附pdf下載)

【專知薈萃08】圖像描述生成Image Caption知識資料全集(入門/進階/論文/綜述/視頻/專家等)

【專知薈萃09】目標檢測知識資料全集(入門/進階/論文/綜述/視頻/代碼等)

【專知薈萃10】推薦系統RS知識資料全集(入門/進階/論文/綜述/視頻/代碼等)

【教程實戰】Google DeepMind David Silver《深度強化學習》公開課教程學習筆記以及實戰代碼完整版

【GAN貨】生成對抗網絡知識資料全集(論文/代碼/教程/視頻/文章等)

【乾貨】Google GAN之父Ian Goodfellow ICCV2017演講:解讀生成對抗網絡的原理與應用

【AlphaGoZero核心技術】深度強化學習知識資料全集(論文/代碼/教程/視頻/文章等)

請掃描小助手,加入專知人工智慧群,交流分享~

獲取更多關於機器學習以及人工智慧知識資料,請訪問www.zhuanzhi.ai,  或者點擊閱讀原文,即可得到!


-END-

歡迎使用專知

專知,一個新的認知方式!目前聚焦在人工智慧領域為AI從業者提供專業可信的知識分發服務, 包括主題定製、主題鏈路、搜索發現等服務,幫你又好又快找到所需知識。

使用方法>>訪問www.zhuanzhi.ai, 或點擊文章下方「閱讀原文」即可訪問專知

中國科學院自動化研究所專知團隊

@2017 專知

相關焦點

  • 生成對抗網絡GANs學習路線
    導讀MLmindset作者發布了一篇生成對抗網絡合集文章,整合了各類關於GAN的資源,如GANs文章、模型、代碼、應用、課程、書籍
  • Tensorflow入門教程(二十四)——生成對抗網絡(GAN)
    目前在深度學習領域中,生成對抗網絡是非常熱門,能給我們帶來不可思議的一個領域方向。
  • 【GAN貨】生成對抗網絡知識資料全集(論文/代碼/教程/視頻/文章等)
    Inference with Inverse  Autoregressive Flow )2016https://papers.nips.cc/paper/6581-improving-variational-autoencoders-with-inverse-autoregressive-flow.pdf深度學習系統對抗樣本黑盒攻擊(Practical     Black-Box Attacks
  • 資源 NIPS 2016上22篇論文的實現匯集
    如何訓練生成對抗網絡(How to Train a GAN)6. Phased LSTM:為長的或基於事件的序列加速循環網絡訓練(Phased LSTM: Accelerating Recurrent Network Training for Long or Event-based Sequences)7.
  • 生成對抗網絡詳解與代碼演示
    生成對抗網絡(GAN)生成對抗網絡(Generative Adversarial Nets)在圖像生成、音樂與文本生成方面都有著很多神奇效果,生成對抗網絡產生受到都來自博弈論與對戰遊戲的啟發,生成對抗網絡,需要三個輸入輸入數據– 一組樣本數據x-P(data).生成器G – 隨機初始化生成數P(g),終極目標是生成跟樣本數據分布一致的數據.
  • 如何應用TFGAN快速實踐生成對抗網絡?
    AI 前線導讀:生成對抗網絡(Generative Adversarial Nets ,GAN)目前已廣泛應用於圖像生成、超解析度圖片生成、圖像壓縮、圖像風格轉換、數據增強、文本生成等場景。越來越多的研發人員從事 GAN 網絡的研究,提出了各種 GAN 模型的變種,包括 CGAN、InfoGAN、WGAN、CycleGAN 等。
  • 乾貨 | 請收下這份機器學習清單(附下載連結)
    (GANs) 什麼是對抗式生成網絡模型?-8-bit-pixel-art-e45d9b96cee7 對抗式生成網絡入門(TensorFlow)(aylien.com)http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/ 《對抗式生成網絡》(小學一年級~上冊
  • Python深度學習之生成對抗神經網絡(GAN)簡介及實現
    生成對抗網絡是現在人工智慧領域的當紅技術之一,網絡是近兩年深度學習領域的新秀,火的不行,本文旨在淺顯理解GAN,分享學習心得。GAN原理生成式對抗網絡GAN (Generative adversarial networks) 是 Goodfellow 等在 2014 年提出的一種生成式模型(原文arxiv:https://arxiv.org/abs/1406.2661)。
  • NIPS2018深度學習(18)|亮點: 貝葉斯深度學習;圖卷積(論文及代碼)
    在貝葉斯邏輯回歸算法中幾種方法在多個數據集上的效果對比如下EF全稱為Empirical Fisher在貝葉斯神經網絡算法中幾種方法在多個數據集上的效果對比如下ICML, 2017.代碼地址https://github.com/davidBelanger/SPEN本文所提出的網絡模型圖示如下
  • 基於tensorRT實現TensorFlow模型的高效推理
    自2017年9月26日英偉達創始人正式宣布TensorRT3神經網絡推理加速器以來,經歷了TensorRT4和TensorRT5的迭代,目前最新版本為TensorRT5.1.5。2.2 優化方式[6]訓練時圖定義和網絡的權重文件是兩個獨立的文件,這樣在部署生產環境時不方便,所以有了凍結的概念。此外,為方便TF(TensorFlow)模型的發布,Google發布了用於生產環境的TensorFlow Serving系統,該系統基於saved_model.pb模型文件發布。3.1 ckpt文件  訓練過程中會生成如下所示文件。
  • 可能是史上最全的Tensorflow學習資源匯總
    2)從Tensorflow基礎知識到有趣的項目應用:https://github.com/pkmital/tensorflow_tutorials同樣是適合新手的教程,從安裝到項目實戰,教你搭建一個屬於自己的神經網絡。
  • 資源 | GitHub萬星:適用於初學者的TensorFlow代碼資源集
    除了傳統的「原始」TensorFlow 實現之外,你還可以找到最新的 TensorFlow API 實踐(如層、估計器、數據集等)。連結:https://github.com/aymericdamien/TensorFlow-Examples最近一次更新(2017.08.27):本教程推薦使用 TensorFlow v1.3。
  • 在Android中藉助TensorFlow使用機器學習
    Build so文件和jar文件首先要clone TensorFlow的代碼:git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git注意:—recurse-submodules的目的是為了pull submodules下載NDK:下載地址
  • 乾貨 | tensorflow模型導出與OpenCV DNN中使用
    ,支持基於深度學習模塊前饋網絡運行、實現圖像與視頻場景中的其模型導入與加載的相關API支持以下深度學習框架tensorflow - readNetFromTensorflowcaffe - readNetFromCaffepytorch - readNetFromTorchdarknet - readNetFromDarknet
  • tensorflow(6)利用tensorflow/serving實現模型部署及預測
    在文章tensorflow(5)將ckpt轉化為pb文件並利用tensorflow/serving實現模型部署及預測中,筆者以一個簡單的例子,來介紹如何在tensorflow中將ckpt轉化為pb文件,並利用tensorflow/serving來實現模型部署及預測。本文將會介紹如何使用tensorflow/serving來實現單模型部署、多模型部署、模型版本控制以及模型預測。
  • OpenCV+Tensorflow實現實時人臉識別演示
    FaceNet是谷歌提出的人臉識別模型,它跟其他人臉識別模型最大的一個不同就是它不是一個中間層輸出,而是直接在歐幾裡德低維空間嵌入生成人臉特徵,這個對以後的各種識別、分類、相似度比較都非常方便。FaceNet網絡設計目標任務有如下1.驗證-驗證是否為同一張臉2.識別-識別是否為同一個人3.聚類-發現人臉具有相同之處的人關於什麼是神經網絡嵌入,這個解釋比較複雜,簡單的說神經網絡的嵌入學習可以幫助我們把離散變量表示為連續的向量,在低維空間找到最近鄰,tensorflow中的word2vec就是用了嵌入。
  • 手把手教你實現Tensorflow Lite動態庫編譯(適用於Windows端模型部署)
    按照官方的說法,TensorFlow Lite 是一組工具,可幫助開發者在行動裝置、嵌入式設備和 loT 設備上運行模型,以便實現設備端機器學習。所以在設計之初,Tensorflow Lite沒有打算在Windows端進行部署的,但是最近它提供了CMakeLists.txt編譯腳本,因而可以將其編譯為動態庫以在Windows端調用。