選自arXiv
作者:Will Grathwohl、David Duvenaud 等
參與:Panda、杜偉
今天要介紹的這篇論文來自多倫多大學、Vector Institute 和谷歌,該論文獲得了ICLR 2020 會議 6-8-8 的高分,提出了一種設計判別式分類器的新思路:把判別式分類器重新解讀為基於能量的模型。這種新思路有諸多優勢,比如在單個混合模型上同時實現了生成式學習和判別式學習的最優表現。
論文連結:https://arxiv.org/abs/1912.03263
引言
生成模型已經得到了數十年的研究,因為人們相信生成模型對很多下遊任務有利,比如半監督學習、缺失數據處理和不確定性校準。然而,對深度生成模型的大多數近期研究都忽視了這些問題,而將重心放在了定性樣本質量以及在留存驗證集上的對數似然上。 目前,在相關下遊任務的最強大生成建模方法以及為每個特定問題人工設計的解決方案之間還存在較大的性能差距。一個可能的原因是大多數下遊任務本質上是判別式的,而當前最佳的生成模型與當前最佳的判別架構之間的差異也還很大。因此,即使僅以分類器為目標而訓練時,生成模型的表現也遠遜於最佳判別模型的表現。因此,判別性能的下降就會遠遠超過來自模型中生成組件的好處。 近期一些研究試圖利用可逆架構來提升生成模型的判別表現,但這些方法的表現仍不及以生成模型為目標而聯合訓練的純判別式方法。本論文提出使用基於能量的模型(EBM/energy based models)來幫助生成模型在下遊的判別式問題上發揮自己的潛力。儘管 EBM 模型目前來說還難以使用,但它們相比於其它生成式模型能更自然地應用在判別式的框架中,並有助於使用現代的分類器架構。 這篇論文有以下貢獻: 1. 提出了一種用於聯合建模標籤和數據的全新且直觀的框架;2. 新提出的模型在生成式建模與判別式建模方面都顯著優於之前的最佳混合模型;3. 研究表明,整合生成模型能讓模型的校準性能更高,能實現對分布外數據的檢測,還能實現更好的對抗魯棒性,而且在每個任務上的表現也能與人工設計的方法媲美甚至更好。 基於能量的模型(EBM)是什麼? 基於能量的模型(EBM)最早由 LeCun et al. 於 2006 年提出,其基於以下觀察:對於 x∈R^D,任意概率密度 p(x) 均可表示為:
其中
是能量函數,
是分割函數。
EBM 的訓練必需依靠其它方法。研究者注意到,單個樣本 x 的對數似然相對於 θ 的導數可以表示為:
不幸的是,從 p_θ(x) 取出樣本是很困難的,因此必須藉助 MCMC 來使用梯度估計器。最早期的一些 EBM 就是使用這種方法訓練的。 儘管這樣的小進展已經積累了很長時間,但最近有些研究開始使用這種方法來在高維數據上訓練大規模 EBM,而且使用了深度神經網絡來對其進行參數化。近期的這些成功使用基於隨機梯度 Langevin 動態(SGLD)的採樣器,結果已經接近等式 (2) 的預期,其取出樣本的方式為:
新提出的基於聯合能量的模型(Joint Energy Based Model) 在現代機器學習中,有 K 個類別的分類問題通常是使用一個參數函數來解決,即 f_θ : R^D → R^K,其能將每個數據點 x ∈ R^D 映射成被稱為 logit 的實數值。使用所謂的 softmax 遷移函數,可將這些 logit 用於對類別分布執行參數化:
研究者在本文中給出了一個關鍵性的觀察,即也可以略微重新解讀從 f_θ 獲得的 logit 來定義 p(x,y) 和 p(x)。無需改變 f_θ,可通過下式復用這些 logit 來為數據點 x 和標籤 y 的聯合分布定義一個基於能量的模型:
通過將 y 邊緣化,也可為 x 獲得一個非歸一化的密度模型:
注意,現在任意分類器的 logit 都可被重新用於定義數據點 x 處的能量函數:
由此,研究者就找到了每個標準的判別模型中隱藏的生成模型!因為這種方法提出將分類器重新解讀為基於聯合能量的模型(Joint Energy based Model),所以他們將該方法稱為 JEM。 下圖 1 給出了該框架的概況,其中分類器的 logit 會被重新解讀,以定義數據點和標籤的聯合數據密度以及數據點單獨的數據密度。
圖 1:新方法 JEM 的可視化,其可從分類器架構定義一個聯合 EBM
優化 那麼,這種對分類器架構的新解讀方法能在保留模型強大判別能力的同時也獲得生成模型的優勢嗎?
因為 p(y|x) 的模型參數化是相對 y 進行歸一化的,因此最大化其似然是很簡單的,就如同在標準的分類器訓練中一樣。又因為 p(x) 和 p(x, y) 的模型未歸一化,因此最大化它們的似然並不容易。在這樣的模型下,以最大化數據的似然為目標來訓練 f_θ 的方法有很多。我們可以將等式 (2) 的梯度估計器應用於等式 (5) 的聯合分布下的似然。使用等式 (6) 和 (4),可將該似然分解為:
鑑於這項研究的目標是將 EBM 訓練整合進標準的分類設置中,所涉分布為 p(y|x)。因此,研究者提出使用等式 (8) 的因式分解來確保該分布的優化使用的目標是無偏差的。他們使用了標準的交叉熵來優化 p(y|x),使用了帶 SGLD 的等式 (2) 來優化 log p(x),其中梯度是根據
得到的。
應用 為了展示 JEM 相比於標準分類器的優勢,研究者進行了全面的實驗研究。首先,新方法的表現在判別式建模和生成式建模上都與當前最佳方法媲美。更有意思的是,他們還觀察到一些與判別式模型的實際應用相關的好處,包括不確定性量化的改善、對分布外數據的檢測、對對抗樣本的魯棒性。人們很久以前就預期生成模型能夠提供這些好處,但從來沒有在這樣的規模上展現這一點。 實驗中使用的所有架構都基於 Wide Residual Networks,其中移除了批歸一化以確保模型的輸出是輸入的確定性函數。這將 WRN-28-10 在 CIFAR-10 上的分類誤差從 4.2% 提升到了 6.4%,將其在 SVHN 上的分類誤差從 2.3% 提升到了 3.4%。 所有的模型都是用同樣的方法訓練的,它們的超參數也都一樣,都是在 CIFAR-10 上調節得到的。有趣的是,這裡找到的 SGLD 採樣器參數可以在各種數據集和模型架構上實現很好的泛化。此外,所有模型都在單個 GPU 上訓練完成,耗時大約 36 小時。 混合建模 首先,研究者表明給定的分類器架構可以作為 EBM 訓練,而且能同時實現與分類器和生成模型都相媲美的表現。他們在 CIFAR-10、SVHN 和 CIFAR-100 上訓練了 JEM,並與其它混合模型以及單獨的生成模型和判別模型進行了比較。結果發現 JEM 能在兩個任務上同時取得接近最佳表現的結果,優於其它混合模型(下表 1)。
表 1:CIFAR-10 混合建模的結果。
鑑於這種方法無法計算歸一化的似然,所以研究者提出使用 inception 分數(IS)和 Frechet Inception Distance(FID)來表示結果的質量。結果發現,JEM 能在這些指標上與當前最佳的生成模型相媲美。新提出的模型在 SVHN 和 CIFAR-100 上分別實現了 96.7% 和 72.2% 的準確度。下圖 2 和 3 展示了 JEM 的樣本。
圖 2:CIFAR-10 類-條件樣本。
圖 3:類-條件樣本。 JEM 的訓練目標是最大化等式 (8) 中的似然因式分解。這是為了確保不會把偏差加進 log p(y|x) 的估計中,這在新提出的設置中可以確切地計算出來。在控制變量研究中,為最大化這一目標而訓練的 JEM 的判別性能有顯著的下降(見表 1 第 4 行)。
校準
如果一個分類器的預測置信度 max_y p(y|x) 與其誤分類率是一致的,那麼就認為這個分類器是已校準的。因此,當一個經過校準的分類器以 0.9 的置信度預測標籤 y 時,它應該有 90% 的機率是正確的。對於要在真實世界場景中部署的模型而言,這是一個非常重要的特性,因為在現實場景中,不正確的決策輸出可能造成災難性的後果。在實際應用時,經過良好校準但不夠準確的分類器可能比更準確但校準差的模型更加有用。
研究者發現 JEM 能在顯著提升分類性能的同時維持較高的準確度。 研究者重點關注了在 CIFAR-100 上的表現,因為當前最佳的分類器的準確度大約為 80%。他們在這個數據集上訓練了 JEM,並將其與沒有 EBM 訓練的同樣架構的基準進行了比較。基準模型得到的準確度為 74.2%,JEM 得到的準確度為 72.2%(參考一下,ResNet-110 得到的準確度為 74.8%)。下圖 4 給出了結果。
圖 4:CIFAR-100 校準結果。ECE 是指預期校準誤差。
檢測分布外數據
通常而言,分布外(out-of-distribution,OOD)檢測是二元分類問題,模型的目標是得到一個分布 s_θ(x) ∈R,其中 x 是查詢,θ 是可學習參數的集合。有很多不同的 OOD 檢測方法都可以使用 JEM。
輸入密度如下表 2 第 2 列所示,JEM 為分布內數據分配的似然總是比 OOD 數據高。JEM 相比於 IGEBM 進一步提升的一個可能解釋是其有能力在訓練過程中整合有標註的信息,同時還能推導 p(x) 的一個原理模型。
表 2:OOD 檢測的直方圖。所有模型都是在 CIFAR-10 上訓練的。綠色對應於在分布內 CIFAR-10 數據上的分數,紅色對應在 OOD 數據集上的分數。
預測分布很多成功方法都為 OOD 檢測使用了分類器的預測分布。JEM 是一種很有競爭力的分類器,實驗發現其表現足以媲美優秀的基準分類器,並且顯著優於其它生成模型。下表 3 給出了結果(中行)。
表 3:OOD 檢測結果。所測模型是在 CIFAR-10 訓練的,結果是 AUROC 指標。
一種新分數:近似質量(Approximate Mass)對於在經典數據集之外的高似然數據點,研究者預期其周圍的密度會快速變化,因此其對數密度的梯度範數相比於經典數據集中的樣本會很大(否則它會處於高質量的區域)。基於這一數量,他們提出了一種新的 OOD 分數:
對於 EBM(JEM 和 IGEBM),研究者發現這種預測器的表現顯著優於我們自己的和其它的生成式模型的似然——見表 2 第 3 列。對於易處理的似然方法,他們發現這種預測器與模型的似然是反相關的(anti-correlated),它們對 OOD 檢測而言都不可靠。結果見表 3(底行)。 魯棒性 作者使用了一種基於梯度的優化流程來生成樣本,從而激活特定的高層面網絡激活,然後優化網絡的權重以最小化所生成的樣本對該激活的影響。圍繞數據,對抗訓練和網絡激活的梯度的正則化之間的進一步關聯已經被推導出來。 有了這些關聯,人們可能會疑惑從 EBM 推導出來的分類器是否比標準模型能更穩健地處理對抗樣本。類似地,作者發現 JEM 能在無損判別性能的前提下實現相當不錯的穩健性。
通過 EBM 訓練提升魯棒性在基於 CIFAR-10 訓練的模型上,研究者執行了大量強力的對抗攻擊。他們執行了一次白盒 PGD 攻擊,通過採樣流程向攻擊者提供了對梯度的訪問權。另外,研究者還執行了一些無梯度的黑盒攻擊、邊界攻擊和暴力式逐點攻擊。下圖 5 給出了 PGD 實驗的結果。所有的攻擊都是針對 L2 和 L∞ 範數進行的,他們測試了在輸入中執行 0、1、10 步採樣的 JEM。 實驗表明,新模型的魯棒性顯著優於使用標準分類器訓練得到的基準模型。在這兩個範數上,JEM 的表現與當前最佳的對抗訓練方法相當(但略差一些),也和 Salman et al. (2019) 提出的當前最佳的經過認證的魯棒性方法(圖 5 中的 RandAdvSmooth)相媲美。
圖 5:使用 PGD 攻擊的對抗穩健性結果。JEM 能帶來相當可觀的魯棒性提升。
魯棒性不強模型的另一種常見失敗模式是它們往往會以高置信度分類無意義的輸入。為了分析這一性質,研究者遵照 Schott et al. (2018) 的方法進行了測試。下圖 6 給出了結果。基準方法會有信心地分類非結構化的噪聲圖像。JEM 不能有信心地分類無意義的圖像,所以可以明顯看到圖中出現了汽車屬性和自然圖像屬性。
圖 6:遠端對抗(Distal Adversarials)結果。