機器之心原創
作者:仵冀穎
編輯:H4O
2019 年,NeurIPS 接受與元學習相關的研究論文約有 20 餘篇。元學習(Meta-Learning)是近幾年的研究熱點,其目的是基於少量無標籤數據實現快速有效的學習。本文對本次接收的元學習論文進行了梳理和解讀。
2019 年 NeurIPS 將於 12 月 8 日至 14 日在加拿大溫哥華舉行。NeurIPS 今年共收到投稿 6743 篇,其中接受論文 1429 篇,接受率達到了 21.1%。作為人工智慧領域的年度盛會,每年人工智慧的專家學者以及工業企業界的研發人員都會積極參會,發布最新的理論研究結果以及應用實踐方面的成果。今年,國外的高校和研究機構一如既往的踴躍參加本屆 NeurIPS,其中 Google 共貢獻了 179 篇文章,斯坦福和卡耐基梅隆分別有 79 篇和 75 篇文章。國內的企業界騰訊上榜 18 篇、阿里巴巴上榜 10 篇,高校和研究機構中清華參與完成的文章共有 35 篇。
2019 年,NeurIPS 接受與元學習相關的研究論文約有 20 餘篇。元學習(Meta-Learning)是近幾年的研究熱點,其目的是基於少量無標籤數據實現快速有效的學習。元學習通過首先學習與相似任務匹配的內部表示,為機器提供了一種使用少量樣本快速適應新任務的方法。學習這種表示的方法主要有基於模型的(model-based meta-learning)和模型不可知的(model-agnostic meta-learning,MAML)兩類。基於模型的元學習方法利用少量樣本的任務標記(task identity)來調整模型參數,使用模型完成新任務,這種方法最大的問題是設計適用於未知任務的元學習策略非常困難。模型不可知的方法首先由 Chelsea Finn 研究組提出,通過初始化模型參數,執行少量的梯度更新步驟就能夠成功完成新的任務。
本文從 NeurIPS 2019 的文章中選擇了四篇來看看元學習的最新的研究方向和取得的成果。Chelsea Finn 以及他的老師 Pieter Abbeel 在元學習領域一直非常活躍,他們的研究團隊在這個方向已經貢獻了大量的優秀成果,推動了元學習在不同任務中的有效應用。在本次 NeurIPS 中,他們的研究團隊針對基於梯度(或優化)的元學習提出了一種只依賴於內部級別優化的解決方案,從而有效地將元梯度計算與內部循環優化器的選擇分離開來。另外,針對強化學習問題,提出了一種元強化學習算法,通過有監督的仿真學習有效的強化學習過程,大大加快了強化學習程序和先驗知識的獲取。我們將在這篇提前看中深入分析和理解這些工作。
Chelsea Finn 是史丹福大學計算機科學和電子工程的助理教授,同時也擔任 Google Brain 的研究科學家。Chelsea Finn 在她的博士論文《Learning to Learn with Gradients》中提出的 MAML 是目前元學習的三大方法之一,Chelsea Finn 證明了 MAML 的理論基礎,並在元學習領域中將其發揚光大,在少樣本模仿學習、元強化學習、少樣本目標推斷等中都獲得了很好的應用。
本文還選擇另外兩篇關於元學習的文章進行討論,其中一篇是 Facebook 的工作,提出了一種元序列到序列(Meta seq2seq)的方法,通過學習成分概括,利用域的代數結構來幫助理解新的語句。另外一篇提出了一個多模態 MAML(Multimodal MAML)框架,該框架能夠根據所識別的模式調整其元學習先驗參數,從而實現更高效的快速自適應。
論文清單:
Meta-Learning with Implicit GradientsGuided Meta-Policy SearchCompositional generalization through meta sequence-to-sequence learningMultimodal Model-Agnostic Meta-Learning via Task-Aware Modulation1、Aravind Rajeswaran,Chelsea Finn,Sham Kakade,Sergey Levine,Meta-Learning with Implicit Gradients ,https://papers.nips.cc/paper/8306-meta-learning-with-implicit-gradients.pdf
基於優化的元學習方法主要有兩種途徑,一是直接訓練元學習目標模型,即將元學習過程表示為神經網絡參數學習任務。另一種是將元學習看做一個雙層優化的過程,其中「內部」優化實現對給定任務的適應,「外部」優化的目標函數是元學習目標模型。本文是對後一種方法的研究和改進。元學習過程需要計算高階導數,因此給計算和記憶帶來不小的負擔,另外,元學習還面臨優化過程中的梯度消失問題。這些問題使得基於(雙層)優化的元學習方法很難擴展到涉及大中型數據集的任務,或者是需要執行許多內環優化步驟的任務中。
本文提出了一種隱式梯度元學習方法(implicit model-agnostic meta-learning,iMAML),利用隱式微分,推導出元梯度解析表達式,該表達式僅依賴於內部優化的解,而不是內部優化算法的優化路徑,這就將元梯度計算和內部優化兩個任務解耦。具體見圖 1 中所示,其中經典的任務不可知的元學習(model-agnostic meta-learning,MAML)方法沿綠色的路徑計算元梯度,一階 MAML 則利用一階倒數計算元梯度,本文提出的 iMAML 方法通過估計局部曲率,在不區分優化路徑的情況下,推導出精確的元梯度的解析表達式。
圖 1. 不同方法元梯度計算圖示
針對元學習任務 {Ti},i=1,...,M,分別對應數據集 Di,其中每個數據集包含兩個集(set):訓練集 D^tr 和測試集 D^test,每個集中的數據結構均為數據對,以訓練集為例:
元學習任務 Ti 的目標是,通過優化損失函數 L,基於訓練集學習任務相關的參數φ _i,從而實現測試集中的損失值最小。雙層優化的元學習任務為:
其中,本文重點關注 Alg 部分的顯示或隱式計算。經典 MAML 中,Alg 對應一步或幾步梯度下降處理:
在數值計算過程中,為避免過擬合問題、梯度消失問題以及優化路徑參數帶來的計算和內存壓力問題,採用顯示正則化優化方法:
由此雙層元學習優化任務為:
其中
採用顯示迭代優化算法計算 Alg*存在下列問題:1、依賴於顯示優化路徑,參數計算和存儲存在很大負擔;2、三階優化計算比較困難;3、該方法無法處理非可微分的操作。因此,本文考慮隱式計算 Alg*。具體算法如下:
考慮內部優化問題的近似解,它可以用迭代優化算法(如梯度下降)來獲得,即:
對 Alg*的優化可以通過雅克比向量積近似逼近:
其中,φ_ i = Alg_i( θ)。觀察到 g_i 可以作為優化問題的近似解獲得:
共軛梯度算法(Conjugate Gradient, CG)由於其迭代複雜度和僅滿足 Hessian 矢量積的要求而特別適合於求解此問題。不同方法的計算複雜度和內存消耗見表 1。用 k 來表示由 g_i 引起的內部問題的條件數,即內部優化計算問題的計算難度。Mem() 表示計算一個導數的內存負載。
表 1:不同方法的內存及計算負載
為了證明本文方法的有效性,作者給出了三個實驗:
一是,通過實驗驗證 iMAML 是否能夠準確計算元梯度。圖 2(a)考慮了一個人工模擬的回歸示例,其中的預測參數是線性的。iMAML 和 MAML 都能夠漸近匹配精確的元梯度,但 iMAML 在有限迭代中能夠計算出更好的近似。
圖 2. 準確度、計算複雜度和內存負載對比。其中 MAML 為經典方法,iMAML 為本文提出的方法,FOMAML 為一階 MAML 方法
二是,通過實驗驗證在有限迭代下 iMAML 是否能夠比 MAML 更精確地逼近元梯度。圖 2(b) 中實驗可知,iMAML 的內存是基於 Hessian 向量積的,與內部循環中梯度下降步數無關。內存使用也與 CG 迭代次數無關,因為中間計算不需要存儲在內存中。MAML 和 FOMAML 不通過優化過程反向傳播,因此計算成本僅為執行梯度下降的損耗。值得注意的是,FOMAML 儘管具有較小的計算複雜度和內存負載,但是由於它忽略了 Jacobian,因此 FOMAML 不能夠計算精確的元梯度。
三是,對比與 MAML 相比的計算複雜度和內存負載,以及通過實驗驗證 iMAML 是否能在現實的元學習問題中產生更好的結果,本文使用了 Omniglot 和 Mini ImageNet 的常見少數鏡頭圖像識別任務(few-shot)進行驗證。在現實元學習實驗中,選擇了 MAML、FOMAML (First order MAML) 和 Reptile 作為對比方法。在 Omniglot 域上,作者發現 iMAML 的梯度下降(GD)版本與全 MAML 算法相比具有競爭力,並且在亞空間上優於其近似值(即 FOMAML 和 Reptile),特別是對於較難的 20 路(20-way)任務。此外,實驗還表明無 Hessian 優化的 iMAML 比其他方法有更好的性能,這表明內部循環中強大的優化器可以改進元學習的效果。在 Mini-ImageNet 域中,iMAML 的效果也優於 MAML 和 FOMAML。
表 2. Omniglot 實驗結果
表 3. Mini ImageNet 實驗結果
2、Russell Mendonca,Abhishek Gupta,Rosen Kralev,Pieter Abbeel,Sergey Levine,Chelsea Finn,Guided Meta-Policy Search,https://papers.nips.cc/paper/9160-guided-meta-policy-search.pdf
元學習的目的是利用完成不同任務的歷史經驗幫助學習完成新任務的技能,元強化學習通過與環境的少量交互通過嘗試和改正錯誤來解決這一問題。元強化學習的關鍵是使得 agent 具有適應性,能夠以新的方式操作新對象,而不必為每個新對象和目標從頭學習。目前元強化學習在優化穩定性、解決樣本複雜度等方面還存在困難,因此主要在簡單的任務領域中應用,例如低維連續控制任務、離散動作指令導航等。
本文的研究思路是:元強化學習是為了獲得快速有效的強化學習過程,這些過程本身不需要通過強化學習直接獲得,相反,可以使用一個更加穩定和高效的算法來提供元級(meta-level)監控,例如引入監督模仿學習。本文首次提出了在元學習環境中將模仿(imitation)和強化學習(RL)相結合。在執行元學習的過程中,首先由本地學習者單獨解決任務,然後將它們合併為一個中心元學習者。但是,與目標是學習能夠解決所有任務的單一策略的引導式策略搜索(guided policy search)不同,本文提出的方法旨在元學習到能夠適應訓練任務分布的單一學習者,通過概括和歸納以適應訓練期間未知的新任務。
圖 3. 引導式元策略搜索算法綜述
圖 3 給出本文提出的引導式元策略搜索算法的總體結構。通過在內部循環優化過程中使用增強學習以及在元優化過程引入監督學習,學習能夠快速適應新任務的策略π_θ。該方法將元學習問題明確分解為兩個階段:任務學習階段和元學習階段。此分解使得可以有效利用以前學習的策略或人工提供的演示輔助元學習。
現有的元強化學習算法一般使用同步策略方法(on-policy)從頭開始進行元學習。在元訓練期間,這通常需要大量樣本。本文的目標是使用以前學到的技能來指導元學習過程。雖然仍然需要用於內部循環採樣的同步策略數據,但所需要的數據比不使用先前經驗的情況下要少得多。經典 MAML 的目標函數如下:
應用於元強化學習中,每個數據集表示為如下軌跡形式:s_1,a_1,...,a_H-1,,s_H。內部和外部循環的損失函數為:
將元訓練任務的最優或接近最優的策略標記為 {(π_i)^*},其中每個政策定義為「專家」。元學習階段的優化目標 L_RL(φ_i,D_i) 與 MAML 相同,其中φ_i 表示策略參數,D_i 為數據集。
內部策略優化過程利用第一階段學習到的策略優化元目標函數,特別的,把外部目標建立在專家行為的監督模仿或行為克隆(Behavior Cloning,BC)上。BC 損失函數為:
監督學習的梯度方差較小,因此比強化學習的梯度更加穩定。第二階段的任務是:首先利用每個策略 (π_i)^*,為每個元訓練任務 Ti 收集專家軌跡 (Di)^*的數據集。使用此初始數據集,根據以下元目標更新策略:
由此得到一些能夠適用於不同任務的列初始策略參數θ從而生成φ_i。在單任務模擬學習環境中,進一步的,可以繼續通過從學習到的策略中收集額外的數據 (擴展數據集 D*),然後用專家策略中的最優操作標記訪問狀態。具體步驟為:(1)利用策略參數θ生成 {φ_i};(2)針對每個任務,利用當前策略 {π_(φ_i)} 生成狀態 {{s_t}_i};(3)利用專家生成監督數據 D={{s_t,π_i(s_t))}_i};(4)使用現有監督數據聚合該數據。
引導式元策略搜索算法(Guided Meta-policy Search, GMPS)如下:
本文使用 Sawyer 機器人控制任務和四足步行機任務驗證 GMPS 的有效性。所選擇的對比算法包括:基於異步策略方法的 PEARL、策略梯度版本的 MAML(內部循環使用 REINFORCE,外部循環使用 TRPO)、RL2、針對所有元訓練任務的單一政策方法 MultiTask、附加結構化噪聲的模型不可知算法 (MAESN)。圖 4 給出完成全狀態推送任務和密集獎勵運動的元訓練效率。所有方法都達到了相似的漸近性能,但 GMPS 需要的樣本數量明顯較少。與 PEARL 相比,GMPS 給出了相近的漸進性能性能。與 MAML 相比,GMPS 完成 Sawyer 物體推送任務的性能提高了 4 倍,完成四足步行機任務的性能提高了約 12 倍。GMPS 的下述處理方式:(1)採用了用於獲取每個任務專家的異步策略增強學習算法和(2)能夠執行多個異步策略監督梯度步驟的組合,例如外部循環中的專家,使得 GMPs 與基於策略的元增強學習算法(如 MAML)相比,獲得了顯著的總體樣本效率增益,同時也顯示出比 PEARL 等數據效率高的上下文方法更好的適應性。
圖 4. Sawyer 機器人任務效果對比
圖 5. 稀疏獎勵開門動作(左)、稀疏獎勵螞蟻移動(中)和視覺推手動作(右)的元訓練比較
對於涉及稀疏獎勵和圖像觀察的具有挑戰性的任務,有效利用人工提供的演示可以極大地改進強化學習的效果,圖 5 中給出了相關的實驗。與其他傳統方法相比,GMPS 能夠更加有效且容易的利用演示信息。在圖 5 所有的實驗中,關於目標位置的位置信息都不作為輸入,而元學習算法必須能夠發現一種從獎勵中推斷目標的策略。對於基於視覺的任務,GMPS 能夠有效地利用演示快速、穩定地學習適應。此外,圖 5 也表明,GMPS 能夠在稀疏的獎勵設置中成功地找到一個好的解決方案,並學會探索。GMPS 和 MAML 都能在所有訓練任務中獲得比單一策略訓練的強化學習更好的性能。
3、Brenden M. Lake,Compositional generalization through meta sequence-to-sequence learning,https://papers.nips.cc/paper/9172-compositional-generalization-through-meta-sequence-to-sequence-learning.pdf
由於人具有創作學習的能力,他們可以學習新單詞並立即能夠以多種方式使用它們。一旦一個人學會了動詞「to Facebook」的意思,他或她就能理解如何「慢慢地 Facebook」、「急切地 Facebook」或「邊走邊 Facebook」。這就是創造性的能力,或是通過結合熟悉的原語來理解和產生新穎話語的代數能力。作為一種機器學習方法,神經網絡長期以來一直因缺乏創造性而受到批評,導致批評者認為神經網絡不適合建模語言和思維。最近的研究通過對現代神經結構的研究,重新審視了這些經典的評論,特別是成功的將序列到序列(seq2seq)模型應用於機器翻譯和其他自然語言處理任務中。這些研究也表明,在創造性的概括方面,seq2seq 仍存在很大困難,尤其是需要把一個新的概念(「到 Facebook」)和以前的概念(「慢慢地」或「急切地」)結合起來時。也就是說,當訓練集與測試集相同時,seq2seq 等遞歸神經網絡能夠獲得較好的效果,但是當訓練集與測試集不同,即需要發揮「創造性」時,seq2seq 無法成功完成任務。
這篇文章中展示了如何訓練記憶增強神經網絡,從而通過「元-序列到序列學習」方法(meta seq2seq)實現創造性的概括。與標準的元學習方法類似,在「元訓練」的過程中,訓練是基於分布在一系列稱為「集(episode)」的小數據集上完成的,而不是基於單個靜態數據集。在「元 seq2seq 學習」過程中,每一集(episode)都是一個新的 seq2seq 問題,它為序列對(輸入和輸出)和「查詢」序列(僅輸入)提供「支持」。該方法的網絡支持將序列對加載到外部內存中,以提供為每個查詢序列生成正確輸出序列所需的上下文。將網絡的輸出序列與目標任務進行比較,從而獲得由支持項目到查詢項目的創造性概括能力。元 seq2seq 網絡對需要進行創造性組合泛化的多個 seq2seq 問題進行元訓練,目的是獲得解決新問題所需的組合技能。新的 seq2seq 問題完全使用網絡的激活動力學和外部存儲器來解決;元訓練階段結束後,不會進行權重更新。通過其獨特的結構選擇和訓練過程,網絡可以隱式地學習操作變量的規則。
圖 6. 元 seq2seq 學習
圖 6 給出了一個元 seq2seq 學習的示例,其任務是根據支撐數據集處理查詢指令「跳兩次」,支撐集包括「跑兩次」、「走兩次」、「看兩次」和「跳」。利用一個遞歸神經網絡(Recurrent Neural Network,RNN)編碼器(圖 6 中右側下部的紅色 RNN)和一個 RNN 解碼器(圖 6 中右側上部綠色 RNN)理解輸入語句生成輸出語句。這個結構與標準 seq2seq 不同,它利用了支撐數據集、外部存儲和訓練過程。當消息從查詢編碼器傳遞到查詢解碼器時,它們受到了由外部存儲提供的逐步上下文信息 C 影響。
下面將詳細描述體系結構的內部工作流程:
1、輸入編碼器
輸入編碼器 f_ie(圖 6 中紅色部分)對輸入查詢指令以及支撐數據集中的輸入指令進行編碼,生成輸入嵌入特徵 w_t,利用 RNN 轉化為隱層嵌入特徵 h_t:
對於查詢序列,在每個步驟 t 時的嵌入特徵 h_t 通過外部存儲器,傳遞到解碼器。對於每個支撐序列,只需要最後一步隱藏嵌入特徵,表示為 K_i。這些向量 K_i 作為外部鍵值存儲器中的鍵使用。本文使用的是雙向長短時記憶編碼(bidirectional long short-term memory encorders)方法。
2、輸出編碼器
輸出編碼器 f_oe(圖 6 中藍色部分)用於每個支撐數據集中的項目和其對應的輸出序列。首先,編碼器使用嵌入層嵌入輸出符號序列(例如動作)。第二,使用與 f_ie 相同的處理過程計算數列的嵌入特徵。最後一層 RNN 的狀態作為支撐項目的特徵向量存儲 V_i。仍然使用 biLSTM。
3、外部存儲器
該架構使用類似於存儲器網絡的軟鍵值存儲器,鍵值存儲器使用的注意函數為:
每個查詢指令從 RNN 編碼器生成 T 個嵌入,每個查詢符號對應一個,填充查詢矩陣 Q 的行。編碼的支撐項目分別為輸入和輸出序列的 K 行和 V 行。注意權重 A 表示對於每個查詢步驟,哪些內存單元處於活動狀態。存儲器的輸出是矩陣 M=AV,其中每一行是值向量的加權組合,表明查詢輸入步驟中每一步的存儲器輸出。最後,通過將查詢輸入嵌入項 h_t 和分步內存輸出項 M_t 與連接層 C_t=tanh(Wc1 [h_t;M_t])結合來計算分步上下文,從而生成分步上下文矩陣 C。
4、輸出解碼器
輸出解碼器將逐步上下文 C 轉換為輸出序列(圖 6 中綠色部分)。解碼器將先前的輸出符號嵌入為向量 o_j-1,該向量 o_j-1 與先前的隱藏狀態 g_j-1 一起啊輸入到 RNN(LSTM)以獲得下一個隱藏狀態,
初始隱藏狀態 g_0 被設置為最後一步的上下文 C_T。使用 Luong 式注意計算解碼器上下文 u_j,使得 u_j=Attension(g_j,C,C)。這個上下文通過另一個連接層 g_j=tanh(Wc2 [g_j;u_j]),然後映射到 softmax 輸出層以產生輸出符號。此過程重複,直到產生所有輸出符號,RNN 通過產生序列結束符號來終止響應。
5、元訓練
元訓練通過一系列訓練集優化網絡,每個訓練集都是一個帶有 n_s 支撐項目和 n_q 查詢項目的新 seq2seq 問題。模型的詞彙表是事件(episode)詞彙表的組合,損失函數是查詢的預測輸出序列的負對數似然。
本文方法的 PyTorch 代碼已公開發布:https://github.com/brendenlake/meta_seq2seq
本文給出了不同的實驗驗證元 seq2seq 方法的有效性。通過置換元訓練增加一個新的原語的實驗,評估了元 seq2seq 學習方法在添加新原語的 SCAN 任務中的效果。通過將原始 SCAN 任務分解為一系列相關的 seq2seq 子任務,訓練模型進行創造性的概括。目標是學習一個新的基本指令,並將其組合使用。例如模型學習一個新的原始「跳躍」,並將其與其他指令結合使用,類似於本文前面介紹的「to Facebook」示例。實驗結果見表 4 結果中間列。其中,標準 seq2seq 方法完全失敗,正確率僅為 0.03%。元 seq2seq 方法能夠成功完成學習複合技能的任務,表中所示達到了平均 99.95% 的正確率。
表 4. 在不同訓練模式下測試 SCAN「添加跳躍」任務的準確性
通過增強元訓練增加一個新的原語的實驗目的是表明元 seq2seq 方法可以「學習如何學習」原語的含義並將其組合使用。文章只考慮了四個輸入原語和四個意義的非常簡單的實驗,目前的研究情況下,作者認為尚不能確定元 seq2seq 學習是否適用於更複雜的任務領域。實驗結果見表 4 的最右側列。元 seq2seq 方法能夠完成獲得指令「跳」並正確使用的任務,正確率達到了 98.71%。標準 seq2seq 得益於增強訓練的處理得到了 12.26% 的正確率。
關於利用元訓練合成類似概念的任務,實驗結果見表 5 左側結果列。元 seq2seq 學習方法能夠近乎完美的完成這個任務(正確率 99.96%),能夠根據其組成部分推斷「around right」的含義。而標準 seq2seq 則完全無法完成這個任務(0.0% 正確率),syntactic attention 方法完成這個任務的正確率為 28.9%。最後一個實驗驗證了元 seq2seq 方法是否能夠學習推廣到更長的序列,即測試序列比元訓練期間的任何經驗語句序列都長。實驗結果見表 5 最右側列。可以看到,所有方法在這種情況下表現都不佳,元 seq2seq 方法僅有 16.64% 的正確率。儘管元 seq2seq 方法在合成任務上較為成功,但它缺乏對較長序列進行外推所需的真正系統化的概括能力。
表 5. 測試 SCAN「左右」和「長度」任務的準確性
元 seq2seq 學習對於理解人們如何從基本成分元素創造性的概括推廣到其它概念有著重要的意義。
人們是在動態環境中學習的,目的是解決一系列不斷變化的學習問題。在經歷過一次像「to Facebook」這樣的新動詞之後,人們能夠系統地概括這種學習或激勵方式是如何完成的。這篇文章的作者認為,元學習是研究學習和其他難以捉摸的認知能力的一個強大的新工具,儘管,在目前的研究條件下還需要更多的工作來理解它對認知科學的影響。
本文所研究的模型只是利用了網絡動態參數和外部存儲器就實現了在測試階段學到如何賦予單詞新的意義。雖然功能強大,但這個工作仍然是一個有限的「變量」概念,因為它需要熟悉元訓練期間所有可能的輸入和輸出分配。這是目前所有神經網絡體系架構所共有的問題。作者在文末提到,在未來的工作中,打算探索在現有網絡結構中添加更多的象徵性組織(symbolic machinery),以處理真正的新符號,同時解決推廣到更長輸出序列的挑戰。
4、Risto Vuorio,Shao-Hua Sun,Hexiang Hu,Joseph J. Lim,Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation,https://papers.nips.cc/paper/8296-multimodal-model-agnostic-meta-learning-via-task-aware-modulation.pdf
經典的模型不可知的元學習方法(MAML)需要找到在整個任務分布中共享的公共初始化參數。但是,當任務比較複雜時,針對任務採樣需要能夠找到實質性不同的參數。本文的研究目標是,基於 MAML,找到能夠獲得特定模式的先驗參數的元學習者,快速適應從多模式任務分布中抽取的給定任務。本文提出了一個多模態模型不可知元學習框架(Multimodal Model-Agnostic Meta-Learning,MMAML),該框架同時利用基於模型的元學習方法和模型不可知的元學習方法,能夠根據識別的模式調整其元學習先驗參數,從而實現更高效的快速自適應。圖 7 給出了 MMAML 整體框架。MMAML 的重點是利用兩種神經網絡實現快速適應新任務。首先,稱為調製網絡(Modulation Network)的網絡預測任務模式的標識。然後將預測出的模式標識作為任務網絡 (Task Network)的輸入,該網絡通過基於梯度的優化進一步適應任務。具體算法如下:
圖 7. MMAML 框架
調製網絡負責識別採樣任務的模式,並生成一組特定於該任務的參數。首先將給定的 K 個數據及其標籤 {x_k,y_k}_k=1,…,K 輸入到任務編碼器 f 中,並生成一個嵌入向量 v,該向量對任務的特徵進行編碼:
然後基於編碼後的任務嵌入向量 v 計算任務特定參數 τ,進而對任務網絡的元學習先驗參數進行調製。任務網絡可以是任意參數化的函數,例如深卷積網絡、多層遞歸網絡等。為了調整任務網絡中每個塊的參數作為解決目標任務的初始化參數,使用塊級轉換來縮放和行動網路中每個隱藏單元的輸出激活。具體地,調製網絡為每個塊 i 產生調製向量,表示為:
其中 N 是任務網絡中的塊數。上述過程表示
其中θ_i 為初始化參數,Φ_i 是任務網絡的調製先驗參數。本文選用了特徵線性調製方法(feature-wise linear modula-
tion,FiLM)作為調製運算方法。
使用調製網絡生成的任務特定參數τ={τ_i | i=1,···,N} 來調製任務網絡的每個塊的參數,該參數可以在參數空間 f(x;θ,τ)中生成模式感知初始化。在調製步驟之後,對任務網絡的元學習先驗參數進行幾步梯度下降以進一步優化任務τ_i 的目標函數。在元訓練和元測試時,採用了相同的調製和梯度優化方法。
作者表示,詳細的網絡結構和訓練超參數會因應用領域的不同而有所不同。本文在多模態任務分布下,評估了 MMAML 和基線極限方法在不同任務中的效果,包括回歸、圖像分類和強化學習等。基線對比方法包括使用多任務網絡的 MAML 和 Multi-MAML。
表 6. 回歸實驗結果
表 6 給出了 2、3 和 5 模式下多模態五次回歸的均方誤差(MSE)。應用μ=0 和σ=0.3 的高斯噪聲。Multi-MAML 方法使用基本事實的任務模式來選擇對應的 MAML 模型。本文提出的方法(使用 FiLM 調製)比其他方法效果稍好。
表 7. 圖像分類實驗結果
表 7 給出了 2、3、5 模式多模式少鏡頭圖像分類準確度測試結果,結果證明了本文提出的方法與 MAML 比有較好的效果,並且與 Multi-MAML 的性能相當。
表 8. 元強化學習實驗結果
表 8 給出在 3 個隨機種子上報告的 2、4 和 6 個模式的多模態強化學習問題中,每集(episode)累積獎勵的平均值和標準差。元強化學習的目標是在有限的任務經驗基礎上適應新的任務。本文使用 ProMP 算法優化策略和調製網絡,同時使用 ProMP 算法作為實驗對比基線,Multi-ProMP 是一個人工基線,用於顯示使用 ProMP 為每個模式訓練一個策略的性能。表 8 所示的實驗結果表明,MMAML 始終優於未經調製的 ProMP。只考慮單一模式的 Multi-ProMP 所展示出的良好性能表明,在該實驗環境下,不同方法面臨的適應性困難主要來自於多種模式。
圖 8. 從隨機抽樣的任務生成的任務嵌入的 tSNE 圖;標記顏色表示任務分布的不同模式
最後,圖 8 給出了上述各個實驗從隨機抽樣的任務生成的任務嵌入的 tSNE 圖,其中標記顏色表示任務分布的不同模式。圖(b)和圖(d)顯示了根據不同任務模式的清晰聚類,這表明 MMAML 能夠從少量樣本中識別任務並產生有意義的嵌入量。(a)回歸:模式之間的距離與函數相似性的情況一致(例如,二次函數有時可以類似於正弦函數或線性函數,而正弦函數通常不同於線性函數)(b)少鏡頭圖像分類:每個數據集(即模式)形成自己的簇。(c)-(d)強化學習:聚類數字代表不同的任務分配模式。不同模式的任務在嵌入空間中能夠清晰地聚集在一起。
作者介紹:仵冀穎,工學博士,畢業於北京交通大學,曾分別於香港中文大學和香港科技大學擔任助理研究員和研究助理,現從事電子政務領域信息化新技術研究工作。主要研究方向為模式識別、計算機視覺,愛好科研,希望能保持學習、不斷進步。