機器之心專欄
作者:陳相寧
可微網絡架構搜索能夠大幅縮短搜索時間,但是穩定性不足。為此,UCLA 基於隨機平滑(random smoothing)和對抗訓練(adversarial training),提出新型 NAS 算法。
可微網絡架構搜索(DARTS)能夠大幅縮短搜索時間,但是其穩定性受到質疑。隨著搜索進行,DARTS 生成的網絡架構性能會逐漸變差。最終生成的結構甚至全是跳過連接(skip connection),沒有任何卷積操作。在 ICML 2020 中,UCLA 基於隨機平滑(random smoothing)和對抗訓練(adversarial training),提出了兩種正則化方法,大幅提升了可微架構搜索算法的魯棒性。
論文:https://arxiv.org/abs/2002.05283
代碼:https://github.com/xiangning-chen/SmoothDARTS
近期,可微架構搜索算法將 NAS 搜索時間縮短至數天,因而備受關注。然而,其穩定生成高性能神經網絡的能力受到廣泛質疑。許多研究者發現隨著搜索進行,DARTS 生成的網絡架構反而越來越差,最終甚至會完全變為跳過連接(skip connection)。為了支持梯度下降,DARTS 對於搜索空間做了連續化近似,並始終在優化一組連續可微的框架權重 A。但是在生成最終框架時,需要將這個權重離散化。
本研究作者觀察到這組連續框架權重 A 在驗證集上的損失函數非常不平滑,DARTS 總是會收斂到一個非常尖銳的區域。因此對於 A 輕微的擾動都會讓驗證集性能大幅下降,更不用說最終的離散化過程了。這樣尖銳的損失函數還會損害搜索算法在架構空間中的探索能力。
於是,本文作者提出了新型 NAS 框架 SmoothDARTS(SDARTS),使得 A 在驗證集上的損失函數變得十分平滑。
該工作的主要貢獻包括:
提出 SDARTS,大幅提升了可微架構搜索算法的魯棒性和泛化性。SDARTS 在搜索時優化 A 整個鄰域的網絡權重,而不僅僅像傳統可微 NAS 那樣只基於當前這一組參數。第一種方法優化鄰域內損失函數的期望,沒有提升搜索時間卻非常有效。第二種方法基於整個鄰域內的最差損失函數(worst-case loss),取得了更強的穩定性和搜索性能。
在數學上,尖銳的損失函數意味著其 Hessian 矩陣範數非常大。作者發現隨著搜索進行,這一範數極速擴大,導致了 DARTS 的不穩定性。而本文提出的兩種框架都有數學保障可以一直降低 Hessian 範數,這也在理論上解釋了其有效性。
最後,本文提出的方法可以廣泛應用於各種可微架構算法。在各種數據集和搜索空間上,作者發現 SDARTS 可以一貫地取得性能提升。
具體方法
傳統 DARTS 使用一組連續的框架權重 A,但是 A 最終卻要被投射到離散空間以獲得最終架構。這一步離散化會導致網絡性能大幅下降,一個高性能的連續框架並不意味著能生成一個高性能的離散框架。因此,儘管 DARTS 可以始終減少連續框架在驗證集上的損失函數,投射後的損失函數通常非常不穩定,甚至會突變得非常大。
因此作者希望最終獲得的連續框架在大幅擾動,例如離散化的情況下,仍然能保持高性能。這也意味了損失函數需要儘可能平滑,並保持很小的 Hessian 範數。因此本文提出在搜索過程中即對 A 進行擾動,這便會讓搜索算法關注在平滑區域。
SDARTS-RS 基本隨機平滑(random smoothing),優化 A 鄰域內損失函數的期望。該研究在均勻分布中採樣了隨機噪聲,並在對網絡權重 w 進行優化前加到連續框架權重 A 之上。
這一方法非常簡單,只增加了一行代碼並且不增加計算量,可作者發現其有效地平滑了在驗證集上的損失函數。
SDARTS-ADV 基於對抗訓練(adversarial training),優化鄰域內最差的損失函數,這一方法希望最終搜索到連續框架權重 A 可以抵禦最強的攻擊,包括生成最終架構的離散化過程。在這裡,我們使用 PGD (projected gradient descent)迭代獲得當前最強擾動。
整個優化過程遵循可微 NAS 的通用範式,交替優化框架權重 A 和網絡權重 w。
理論分析
對 SDARTS-RS 的目標函數進行泰勒展開,作者發現這在搜索過程中,Hessian 矩陣的 trace norm 也在被一直減小。如果 Hessian 矩陣近似 PSD,那麼近似於一直在減小 Hessian 的正特徵值。相似地,在通常的範數選擇下(2 範數和無窮範數),SDARTS-ADV 目標函數中第二項近似於被 Hessian 範數 bound 住。因此它也可以隨著搜索降低範數。
這些理論分析進一步解釋了為何 SDARTS 可以獲得平滑的損失函數,在擾動下保持魯棒性與泛化性。
實驗結果
NAS-Benchmark-1Shot1 實驗
這個 benchmark 含有 3 個不同大小的搜索空間,並且可以直接獲得架構的性能,不需要任何訓練過程。這也使本文可以跟蹤搜索算法任意時刻得到架構的精確度,並比較他們的穩定性。
如圖 4 所示,DARTS 隨著搜索進行生成的框架不斷變差,甚至在最後的性能直接突變得很差。近期提出的一些新的改進算法,例如 NASP 與 PC-DARTS 也難以始終保持高穩定性。與之相比,SDARTS-RS 與 SDARTS-ADV 大幅提升了搜索穩定性。得益於平滑的損失函數,該研究提出的兩種方法還具有更強的探索能力,甚至在搜索迭代了 80 輪之後仍能持續發現精度更高的架構。
另外,作者還在圖 5 中跟蹤了 Hessian 範數的變化情況,所有 baseline 方法的範數都擴大了 10 倍之多,而本文提出的方法一直在降低該範數,這與上文的理論分析一致。
CIFAR-10 實驗
作者在通用的基於 cell 的空間上進行搜索,這裡需要對獲得架構進行 retrain 以獲得其精度。值得注意的是,除了 DARTS,本文提出的方法可以普遍適用於可微 NAS 下的許多方法,例如 PC-DARTS 和 P-DARTS。如表 1 所示,作者將原本 DARTS 的 test error 從 3.00% 減少至 2.61%,將 PC-DARTS 從 2.57% 減少至 2.49%,將 P-DARTS 從 2.50% 減少至 2.48%。搜索結果的方差也由於穩定性的提升而減小。
ImageNet 實驗
為了測試在大數據集上的性能,作者將搜索的架構遷移到 ImageNet 上。在表 2 中,作者獲得了 24.2% 的 top1 test error,超過了所有相比較的方法。
與其他正則項方法比較
作者還在另外 4 個搜索空間 S1-S4 和 3 個數據集上做實驗。這四個空間與 CIFAR-10 上的搜索空間類似,只是包含了更少的操作,例如 S2 只包含 3x3 卷積和跳過連接,S4 只包括 3x3 卷積和噪聲。在這些簡化的空間上實驗能進一步驗證 SDARTS 的有效性。
如表 4 所示,SDARTS 在這 12 個任務中的 9 個中包攬了前兩名,SDARTS-ADV 分別平均超過 DARTS、R-DARTS (L2)、DARTS-ES、R-DARTS (DP) 和 PC-DARTS 31.1%、11.5%、11.4%、10.9% 和 5.3%。