今年五月份康奈爾大學的 Andrew Gordon Wilson 和 Permutation Venture 的 Yunus Saatchi 提出了一個貝葉斯生成對抗網絡(Bayesian GAN),結合貝葉斯和對抗生成網絡,提出了一個實用的貝葉斯公式框架,用GAN來進行無監督學習和半監督式學習。論文《Bayesian GAN》也被2017年機器學習頂級會議 NIPS 接受,今天Andrew Gordon Wilson在Twitter上發布消息開源了這篇論文的TensorFlow實現,並且Google GAN之父 Ian Goodfellow 轉發這條推文,讓我們來看下。
摘要生成式對抗網絡(GANs)能在不知不覺中學習圖像、聲音和數據中的豐富分布。這些分布通常因為具有明確的相似性,所以很難去建模。在這篇論文中,我們提出了一個實用的貝葉斯公式,通過使用GAN來進行無監督學習和半監督式學習。在這一框架之下,使用動態的梯度漢密爾頓蒙特卡洛(Hamiltonian Monte Carlo)來將生成網絡和判別網絡中的權重最大化。提出的方法可以非常直接的獲得最後的結果,並且在不需要任何標準的幹預,比如特徵匹配或者mini-batch discrimination的情況下,都獲得了良好的表現。通過對生成器中的參數部署一個具有表達性的後驗機制。貝葉斯生成式對抗網絡能夠避免模式碰撞,產生可判斷的、多樣化的候選樣本,並且提供在既有的一些基準測試上,能夠提供最好的半監督學習量化結果,比如,SVHN, CelebA 和 CIFAR-10,其效果遠遠超過 DCGAN, Wasserstein GANs 和 DCGAN 等等。
TensorFlow實現的貝葉斯生成對抗網絡Contents簡介
python 依賴包
訓練參數
使用方法
安裝
合成數據
例子: MNIST, CIFAR10, CelebA, SVHN
自定義數據
簡介貝葉斯生成對抗網絡中我們提出了使用條件後驗分布來建模生成器和判別器的權重參數,隨後使用了動態的梯度漢密爾頓蒙特卡洛(Hamiltonian Monte Carlo)來將生成網絡和判別網絡中的權重最大化。貝葉斯方法用在生成對抗網絡主要有一下幾個特性:(1),能夠提供很好的半監督學習量化結果。(2),對效果的影響比較小。(3), 可以通過估計概率GAN的邊際相似性;(4),它不容易遭受模型失效(mode collapse)的風險;(5)一個包含針對數據互補的多生成和判別模型,可以形成一個概率集成(ensemble)。
我們展示了在生成器參數上的多模後驗。每種參數設定都和不同的數據生成假設相對應。上圖顯示了對應兩種不同手寫風格的參數設定而產生的樣本。這個貝葉斯生成對抗網絡保留了在參數上的全概率分布。相反,標準的生成對抗網絡使用點估計(類似於單個最大似然估計)來表示這個全概率分布,這樣會丟失一些潛在的並重要的數據解釋。
python 依賴包這個代碼包含以下依賴包 (版本號非常重要):
python 2.7
tensorflow==1.0.0
在Linux上安裝tensorflow 1.0.0可以參考官方指南 https://www.tensorflow.org/versions/r1.0/install/.
bayesian_gan_hmc.py 包含以下訓練選項.
--out_dir: 輸出目錄
--n_save: 每次保存的樣本和參數的數量 n_save 是迭代次數; 默認為 100
--z_dim: 生成器中 z 向量的維度 ;默認為100
--data_path: 數據目錄; 這個路徑是必須的
--dataset: 數據集可以是 mnist, cifar, svhn or celeb; 默認為 mnist
--gen_observed: 被生成器「觀察」到的數據 ; 這會影響到噪聲離散的尺度和先驗,默認為1000
--batch_size: 一次訓練的批量數 ;默認 64
--prior_std: 權重先驗的標準差;默認為1
--numz: 與論文中的J參數一樣; 參數 z 需要整合的樣本數; 默認 1
--num_mcmc: 與論文中的M參數一樣; 每個zde 蒙特卡洛 NN權重樣本; 默認是1
--lr: Adam 優化器的學習率; 默認 0.0002
--optimizer: 優化方法: adam (tf.train.AdamOptimizer) 或者 sgd (tf.train.MomentumOptimizer); 默認使用 adam
--semi_supervised: 進行半監督學習
--N: 進行半監督學習的標註樣本數
--train_iter: 訓練迭代次數; 默認 50000
--save_samples: 訓練中保存生成樣本
--save_weights: 訓練中保存生成權重
--random_seed: 隨機種子;注意如果使用了GPU,因為這個操作結果不能做到%100復現
你可以使用--wasserstein來運行WGANs 或者使用 --ml_ensemble <num_dcgans>來訓練多個 DCGANs 的集成. 此外你還可以使用-ml_ensemble 1來訓練DCGAN
使用方法安裝安裝要求的依賴集
克隆代碼倉庫
合成數據為了能再論文中提到的合成數據上運行你可以使用T bgan_synth 腳本. 比如,下面的命令訓練 貝葉斯生成對抗網絡(with D=100 and d=10)迭代 5000 詞並將結果保存在 <results_path>.
`./bgan_synth.py --x_dim 100 --z_dim 10 --numz 10 --out \<results_path\>
`在此數據集上運行 ML GAN可以運行
`./bgan_synth.py --x_dim 100 --z_dim 10 --numz 1 --out \<results_path\>
bgan_synth有--save_weights,--out_dir,--z_dim,--numz,--wasserstein,--train_iter以及--x_dim這些參數.x_dim控制觀測數據的維度 (也就是論文中的x` ).
如果你運行了以上兩條命令後你會看到每100次迭代的輸出結果 <results_path>. 舉例來說貝葉斯生成對抗網絡在第900次迭代的結果如下圖:
對比來說標準 GAN (對應於numz=1, 使用最大似然估計) 產生的結果如下:
上面的圖展示了標準GAN容易遇到模型失效(mode collapse)而我們提出的 Bayesian GAN則可以避免這種情況。
為了進一步探究合成的數據, 同時生成JS散度 ,你可以運行 synth.ipynb.
MNIST, CIFAR10, CelebA, SVHNbayesian_gan_hmc script allows to train the model on standard and custom datasets. Below we describe the usage of this script.
數據準備為了重現在 MNIST, CIFAR10, CelebA 和 SVHN 數據集上的實驗,你需要使用正確的--data_path來準備數據.
對於 MNIST你不需要預處理數據,可以指定任意的 --data_path;
對於 CIFAR10 你需要從https://www.cs.toronto.edu/kriz/cifar.htmlPython處理的數據please下載並解壓出適合 download ;
對於 SVHN數據, 從http://ufldl.stanford.edu/housenumbers/下載 train_32x32.mat 和 test_32x32.mat 文件
對於CelebA數據,你需要首先安裝 openCV. 可以從這個連結來下載數據http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. 首先創建一個包含「 Anno 和 img_align_celeba 子目錄的目錄celebA folder 」Anno 『 必須包含list_attr_celeba.txt ,而 img_align_celeba 必須包含 .jpg 文件. 你還需要使用 datasets/crop_faces.py 腳本來裁剪圖片, 其中包含參數 --data_path <path> 來指定』celebA『的目錄。
無監督訓練你可以通過運行不包含--semi 參數的bayesian_gan_hmc 腳本來訓練無監督版本的訓練,. 比如使用:
`./bayesian_gan_hmc.py --data_path \<data_path\> --dataset svhn --numz 1 --num_mcmc 10 --out_dir \<results_path\> --train_iter 75000 --save_samples --n_save 100
在SVHN 數據集上訓練模型. 這條命令將迭代75000次並且每100次迭代保存一次樣本。 這裡的必須指向結果產生的目錄.
半監督訓練你可以用腳本帶--semi 選項的bayesian_gan_hmc 腳本來訓練半監督版本的模型。 用 -N 參數來設定需要訓練的標註樣本數目。比如運行:
`./bayesian_gan_hmc.py --data_path \<data_path\> --dataset cifar --numz 1 --num_mcmc 10--out_dir \<results_path\> --train_iter 75000 --N 4000 --semi --lr 0.00005
在 CIFAR10 數據集上使用 4000 標註樣本來訓練模型. 這條命令將迭代75000次訓練模型,並將結果保存在` 文件夾中.
為了在MNIST數據集上使用200個標註樣本訓練模型你可以使用以下命令:
`./bayesian_gan_hmc.py --data_path \<data_path\>/ --dataset mnist --numz 5 --num_mcmc 5--out_dir \<results_path\> --train_iter 30000 -N 200 --semi --lr 0.001
`
自定義數據為了在自定義的數據集上訓練模型,你需要為每一個分類定義特定的接口。比如你想在 digits(http://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) 數據集上訓練模型.,這個數據集包含8x8的數字圖片。假設數據被分別存儲在x_tr.npy, y_tr.npy, x_te.npy and y_te.npy 文件中,我們認為 x_tr.npy and x_te.npy 的大小為 (?, 8, 8, 1). 隨後我們可以在bgan_util.py 中定義針對這個數據集類:
`class Digits:def __init__(self):self.imgs = np.load('x_tr.npy') self.test_imgs = np.load('x_te.npy')self.labels = np.load('y_tr.npy')self.test_labels = np.load('y_te.npy')self.labels = one_hot_encoded(self.labels, 10)self.test_labels = one_hot_encoded(self.test_labels, 10) self.x_dim = [8, 8, 1](#)self.num_classes = 10@staticmethoddef get_batch(batch_size, x, y): """Returns a batch from the given arrays."""idx = np.random.choice(range(x.shape[0](#)), size=(batch_size,), replace=False)return x[idx](#), y[idx](#)def next_batch(self, batch_size, class_id=None):return self.get_batch(batch_size, self.imgs, self.labels)def test_batch(self, batch_size):return self.get_batch(batch_size, self.test_imgs, self.test_labels)
這個類必須有next_batch和test_batch等函數, 同時要包含imgs,labels,test_imgs,test_labels,x_dim以及num_classes` 屬性.
這時候我們就可以引入 Digits 類到 bayesian_gan_hmc.py中了
`from bgan_util import Digits
同時可以在--dataset` 參數中添加如下行
`if args.dataset == "digits":dataset = Digits()
` 在準備工作結束後,我們可以用下面命令來訓練模型
`./bayesian_gan_hmc.py --data_path \<any_path\> --dataset digits --numz 1 --num_mcmc 10 --out_dir \<results path\> --train_iter 5000 --save_samples
聲明感謝Pavel Izmailov對代碼進行的壓力測試,並且寫出這份教程。
參考網址連結:代碼:https://github.com/andrewgordonwilson/bayesgan
論文:https://arxiv.org/abs/1705.09558
請關注專知公眾號(掃一掃最下面專知二維碼,或者點擊上方藍色專知),
請登錄專知,獲取GAN知識資料,請PC登錄www.zhuanzhi.ai或者點擊閱讀原文,頂端搜索「GAN」 主題,查看獲得對應主題專知薈萃全集知識等資料!如下圖所示~
歡迎轉發到你的微信群和朋友圈,分享專業AI知識!
更多專知薈萃知識資料全集獲取,請查看:
【專知薈萃01】深度學習知識資料大全集(入門/進階/論文/代碼/數據/綜述/領域專家等)(附pdf下載)
【專知薈萃02】自然語言處理NLP知識資料大全集(入門/進階/論文/Toolkit/數據/綜述/專家等)(附pdf下載)
【專知薈萃03】知識圖譜KG知識資料全集(入門/進階/論文/代碼/數據/綜述/專家等)(附pdf下載)
【專知薈萃04】自動問答QA知識資料全集(入門/進階/論文/代碼/數據/綜述/專家等)(附pdf下載)
【專知薈萃05】聊天機器人Chatbot知識資料全集(入門/進階/論文/軟體/數據/專家等)(附pdf下載)
【專知薈萃06】計算機視覺CV知識資料大全集(入門/進階/論文/課程/會議/專家等)(附pdf下載)
【專知薈萃07】自動文摘AS知識資料全集(入門/進階/代碼/數據/專家等)(附pdf下載)
【專知薈萃08】圖像描述生成Image Caption知識資料全集(入門/進階/論文/綜述/視頻/專家等)
【專知薈萃09】目標檢測知識資料全集(入門/進階/論文/綜述/視頻/代碼等)
【專知薈萃10】推薦系統RS知識資料全集(入門/進階/論文/綜述/視頻/代碼等)
【教程實戰】Google DeepMind David Silver《深度強化學習》公開課教程學習筆記以及實戰代碼完整版
【GAN貨】生成對抗網絡知識資料全集(論文/代碼/教程/視頻/文章等)
【乾貨】Google GAN之父Ian Goodfellow ICCV2017演講:解讀生成對抗網絡的原理與應用
【AlphaGoZero核心技術】深度強化學習知識資料全集(論文/代碼/教程/視頻/文章等)
請掃描小助手,加入專知人工智慧群,交流分享~
獲取更多關於機器學習以及人工智慧知識資料,請訪問www.zhuanzhi.ai, 或者點擊閱讀原文,即可得到!
-END-
歡迎使用專知
專知,一個新的認知方式!目前聚焦在人工智慧領域為AI從業者提供專業可信的知識分發服務, 包括主題定製、主題鏈路、搜索發現等服務,幫你又好又快找到所需知識。
使用方法>>訪問www.zhuanzhi.ai, 或點擊文章下方「閱讀原文」即可訪問專知
中國科學院自動化研究所專知團隊
@2017 專知