手把手 | OpenAI開發可拓展元學習算法Reptile,能快速學習

2021-01-15 大數據文摘

大數據文摘作品

編譯:Zoe Zuo、丁慧、Aileen

本文來自OpenAI博客,介紹一種新的元學習算法Retile。

在OpenAI, 我們開發了一種簡易的元學習算法,稱為Reptile。它通過對任務進行重複採樣,利用隨機梯度下降法,並將初始參數更新為在該任務上學習的最終參數。

其性能可以和MAML(model-agnostic meta-learning,由伯克利AI研究所研發的一種應用廣泛的元學習算法)相媲美,操作簡便且計算效率更高。

MAML元學習算法:

http://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/

元學習是學習如何學習的過程。此算法接受大量各種的任務進行訓練,每項任務都是一個學習問題,然後產生一個快速的學習器,並且能夠通過少量的樣本進行泛化。

一個深入研究的元學習問題是小樣本分類(few-shot classification),其中每項任務都是一個分類問題,學習器在每個類別下只能看到1到5個輸入-輸出樣本(input-output examples),然後就要給新輸入的樣本進行分類。

下面是應用了Reptile算法的單樣本分類(1-shot classification)的互動演示,大家可以嘗試一下。

嘗試單擊「Edit All」按鈕,繪製三個不同的形狀或符號,然後在右側的輸入區中繪製其中一個,並查看Reptile如何對它進行分類。前三張圖是標記樣本,每圖定義一個類別。最後一張圖代表未知樣本,Reptile要輸出此圖屬於每個類別的概率。

Reptile的工作原理

像MAML一樣,Reptile試圖初始化神經網絡的參數,以便通過新任務產生的少量數據來對網絡進行微調。

但是,當MAML藉助梯度下降算法的計算圖來展開和區分時,Reptile只是以標準方法在每個任務中執行隨機梯度下降(stochastic gradient descent, SGD)算法,並不展開計算圖或者計算二階導數。這使得Reptile比MAML需要更少的計算和內存。示例代碼如下:

初始化Φ,初始參數向量對於迭代1,2,3……執行隨機抽樣任務T 在任務T上執行k>1步的SGD,輸入參數Φ,輸出參數w 更新:Φ←Φ+(wΦ)結束返回Φ

最後一步中,我們可以將ΦW作為梯度,並將其插入像這篇論文裡(https://arxiv.org/abs/1412.6980)Adam這樣更為先進的優化器中作為替代方案。

首先令人驚訝的是,這種方法完全有效。如果k=1,這個算法就相當於 「聯合訓練」(joint training)——對多項任務的混合體執行SGD。雖然在某些情況下,聯合訓練可以學習到有用的初始化,但當零樣本學習(zero-shot learning)不可能實現時(比如,當輸出標籤是隨機排列時),聯合訓練就幾乎無法學習得到結果。

Reptile要求k>1,也就是說,參數更新要依賴於損失函數的高階導數實現,此時算法的表現和k=1(聯合訓練)時是完全不同的。

為了分析Reptile的工作原理,我們使用泰勒級數(Taylor series)來逼近參數更新。Reptile的更新將同一任務中不同小批量的梯度內積(inner product)最大化,從而提高了的泛化能力。

這一發現可能超出了元學習領域的指導意義,比如可以用來解釋SGD的泛化性質。進一步分析表明,Reptile和MAML的更新過程很相近,都包括兩個不同權重的項。

泰勒級數:

https://en.wikipedia.org/wiki/Taylor_series

在我們的實驗中,展示了Reptile和MAML在Omniglot和Mini-ImageNet基準測試中對少量樣本分類時產生相似的性能,由於更新具有較小的方差,因此Reptile也可以更快的收斂到解決方案。

Omniglot:

https://github.com/brendenlake/omniglot

Mini-ImageNet:

https://arxiv.org/abs/1606.04080

我們對Reptile的分析表明,通過不同的SGD梯度組合,可以獲得大量不同的算法。在下圖中,假設針對每一任務中不同小批量執行k步SGD,得出的梯度分別為g1,g2,…,gk。

下圖顯示了在 Omniglot 上由梯度之和作為元梯度而繪製出的學習曲線。g2對應一階MAML,也就是原先MAML論文中提出的算法。由於方差縮減,納入更多梯度明顯會加速學習過程。需要注意的是,僅僅使用g1(對應k=1)並不會給這個任務帶來改進,因為零樣本學習的性能無法得到改善。

X坐標:外循環迭代次數

Y坐標:Omniglot對比5種方式的

5次分類的準確度

算法實現

我們在GitHub上提供了Reptile的算法實現,它使用TensorFlow來完成相關計算,並包含用於在Omniglot和Mini-ImageNet上小樣本分類實驗的代碼。我們還發布了一個較小的JavaScript實現,對TensorFlow預先訓練好的模型進行了微調。文章開頭的互動演示也是藉助JavaScript完成的。

GitHub:

https://github.com/openai/supervised-reptile

較小的JavaScript實現:

https://github.com/openai/supervised-reptile/tree/master/web

最後,展示一個小樣本回歸(few-shot regression)的簡單示例,用以預測10(x,y)對的隨機正弦波。該示例基於PyTorch實現,代碼如下:

import numpy as npimport torchfrom torch import nn, autograd as agimport matplotlib.pyplot as pltfrom copy import deepcopyseed = 0plot = Trueinnerstepsize = 0.02 # stepsize in inner SGDinnerepochs = 1 # number of epochs of each inner SGDouterstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimizationniterations = 30000 # number of outer updates; each iteration we sample one task and update on itrng = np.random.RandomState(seed)torch.manual_seed(seed)# Define task distributionx_all = np.linspace(-5, 5, 50)[:,None] # All of the x pointsntrain = 10 # Size of training minibatchesdef gen_task():"Generate classification problem" phase = rng.uniform(low=0, high=2*np.pi) ampl = rng.uniform(0.1, 5) f_randomsine = lambda x : np.sin(x + phase) * ampl return f_randomsine# Define model. Reptile paper uses ReLU, but Tanh gives slightly better resultsmodel = nn.Sequential( nn.Linear(1, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, 1),)def totorch(x): return ag.Variable(torch.Tensor(x))def train_on_batch(x, y): x = totorch(x) y = totorch(y) model.zero_grad() ypred = model(x) loss = (ypred - y).pow(2).mean() loss.backward() for param in model.parameters(): param.data -= innerstepsize * param.grad.datadef predict(x): x = totorch(x) return model(x).data.numpy()# Choose a fixed task and minibatch for visualizationf_plot = gen_task()xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]# Reptile training loopfor iteration in range(niterations): weights_before = deepcopy(model.state_dict()) # Generate task f = gen_task() y_all = f(x_all) # Do SGD on this task inds = rng.permutation(len(x_all)) for _ in range(innerepochs): for start in range(0, len(x_all), ntrain): mbinds = inds[start:start+ntrain] train_on_batch(x_all[mbinds], y_all[mbinds]) # Interpolate between current weights and trained weights from this task # I.e. (weights_before - weights_after) is the meta-gradient weights_after = model.state_dict() outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule model.load_state_dict({name : weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize for name in weights_before}) # Periodically plot the results on a particular task and minibatch if plot and iteration==0 or (iteration+1) % 1000 == 0: plt.cla() f = f_plot weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1)) for inneriter in range(32): train_on_batch(xtrain_plot, f(xtrain_plot)) if (inneriter+1) % 8 == 0: frac = (inneriter+1) / 32 plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac)) plt.plot(x_all, f(x_all), label="true", color=(0,1,0)) lossval = np.square(predict(x_all) - f(x_all)).mean() plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k") plt.ylim(-4,4) plt.legend(loc="lower right") plt.pause(0.01) model.load_state_dict(weights_before) # restore from snapshot print(f"----") print(f"iteration {iteration+1}") print(f"loss on plotted curve {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity

論文連結:

https://arxiv.org/abs/1803.02999

代碼連結:

https://github.com/openai/supervised-reptile

原文連結:

https://blog.openai.com/reptile/

相關焦點

  • OpenAI提出Reptile:可擴展的元學習算法
    近日,OpenAI發布了簡單元學習算法Reptile,該算法對一項任務進行重複採樣,執行隨機梯度下降,更新初始參數直到習得最終參數。該方法的性能可與MAML(一種廣泛應用的元學習算法)媲美,且比後者更易實現,計算效率更高。
  • OpenAI發布強化學習環境Gym Retro:支持千種遊戲
    Gym 是 OpenAI 發布的用於開發和比較強化學習算法的工具包。使用它我們可以讓 AI 智能體做很多事情,比如行走、跑動,以及進行多種遊戲。目前,它運行在支持 Python 3.5 和 3.6 的 Linux、macOS 和 Windows 系統上。
  • 讓機器像人類一樣學習? 伯克利 AI 研究院提出新的元強化學習算法!
    針對這類既需要採取行動又需要積累過往經驗的智能體來說,元強化學習可以幫助其快速適應新的場景。但美中不足的是,雖然訓練後的策略可以幫助智能體快速適應新的任務,但元訓練過程需要用到來自一系列訓練任務的大量數據,這就加劇了困惱強化學習算法的樣本效率低下問題。因此,現有的元強化學習算法在很大程度上只能在模擬環境中正常運行。
  • OpenAI開源機器人仿真軟體Roboschool:已整合OpenAI Gym
    Roboschool 不再有這一限制,使得每個人皆可進行研究而無需擔心預算。Roboschool 基於 Bullet Physics Engine(一個開源、受到廣泛許可的物理庫),該庫已用於其他仿真軟體,比如 Gazebo 和 V-REP。
  • 深度學習預測比特幣價格;基於神經網絡的自動化前端開發 | Github...
    該項目用深度學習的方法預測比特幣的走勢,Siraj Raval 小哥也在視頻裡說了,這套模型還可以用來預測任何 Altcoin。看來現在的礦工們都應該學點機器學習了。基於神經網絡的自動化前端開發 —— Screenshot-to-code-in-KerasGithub 地址:https://github.com/emilwallner/Screenshot-to-code-in-Keras目前,自動化前端開發最大的障礙是計算能力,不過,我們可以用深度學習算法和數據訓練來探索自動化前端設計。
  • 從星際2深度學習環境到神經機器翻譯,上手機器學習這些開源項目必...
    機器學習是用數據來學習、概括、預測的研究。近幾年,隨著數據的開發、算法的改進以及硬體計算能力的提升,機器學習技術得以快速發展,不斷延伸至新的領域。由於我們需要這些數據來訓練機器學習算法,所以獲取高質量的數據集是如今機器學習領域的最大挑戰之一。算法:如何處理和分析數據機器學習算法可利用數據執行特定的任務,最常見的機器學習算法有如下幾種:1.監督學習。
  • 與模型無關的元學習,UC Berkeley提出一種可推廣到各類任務的元...
    人們正在開發多種技術來解決此類問題,我將在本文中對其進行概述,同時也將介紹我們實驗室開發的最新技術「與模型無關的元學習」(model-agnostic meta-learning)。元學習方法的運行機制首先元學習系統會在大量任務中進行訓練,然後測試其學習新任務的能力。例如每一個類別給出幾個樣本,那么元學習是否能在將新的圖片正確分類,或者在僅提供一條穿過迷宮的通道時,模型能否學會快速穿過新的迷宮。該方法包括在單個任務上訓練和在留出樣本上測試,與很多標準機器學習技術不同。
  • ICLR 2018最佳論文:基於梯度的元學習算法
    於 4 月 30 日開幕的 ICLR 2018 最近公布了三篇最佳論文,分別關注於最優化方法、卷積神經網絡和元學習算法。不出所料的是,這三篇最佳論文在 2017 年 11 月公布的評審結果中,都有很高的得分。機器之心以前已經介紹過關於修正 Adam 與球面 CNN 的最佳論文,本文將重點介紹第三篇關於元學習的最佳論文。
  • 專題| 深度強化學習綜述:從AlphaGo背後的力量到學習資源分享(附...
    該概述的大綱如下:第二節,深度學習及強化學習的背景知識及對測試平臺的介紹;第三節,對深度 Q 網絡及其拓展的介紹;第四節,異步放法的介紹;第五節,策略優化;第六節,獎勵;第七節,規劃;第八節,注意和記憶機制,特別是對可微分神經計算機(DNC)的介紹;第九節,非監督學習;第十節;學習去學習(learning to learn);第十一節,遊戲/博弈,包括棋類遊戲、視頻遊戲及非完美信息博弈
  • ...深度強化學習綜述:從AlphaGo背後的力量到學習資源分享(附論文)
    該概述的大綱如下:第二節,深度學習及強化學習的背景知識及對測試平臺的介紹;第三節,對深度 Q 網絡及其拓展的介紹;第四節,異步放法的介紹;第五節,策略優化;第六節,獎勵;第七節,規劃;第八節,注意和記憶機制,特別是對可微分神經計算機(DNC)的介紹;第九節,非監督學習;第十節;學習去學習(learning to learn);第十一節,遊戲/博弈,包括棋類遊戲、視頻遊戲及非完美信息博弈
  • 機器學習算法集錦:從貝葉斯到深度學習及各自優缺點
    Learning Algorithms)圖模型(Graphical Models)正則化算法(Regularization Algorithms)它是另一種方法(通常是回歸方法)的拓展,這種方法會基於模型複雜性對其進行懲罰,它喜歡相對簡單能夠更好的泛化的模型。
  • AI學會「以牙還牙」,OpenAI發布多智能體深度強化學習新算法LOLA
    OpenAI和牛津大學等研究人員合作,提出了一種新的算法LOLA,讓深度強化學習智能體在更新自己策略的同時,考慮到他人的學習過程,甚至實現雙贏。每個LOLA智能體都調整自己的策略,以便用有利的方式塑造其他智能體的學習過程。初步試驗結果表明,兩個LOLA 智能體相遇後會出現「以牙還牙/投桃報李」(tit-for-tat)策略,最終在無限重複囚徒困境中出現合作行為。
  • 前沿| 利用遺傳算法優化神經網絡:Uber提出深度學習訓練新方式
    開發包括神經進化在內的各種有力的學習方法將幫助 Uber 發展更安全、更可靠的運輸方案。遊戲;而且,它能在許多遊戲中比現代深度強化學習(RL)算法(例如 DQN 和 A3C)或進化策略(ES)表現得更好,同時由於更好的並行化能達到更快的速度。
  • 實踐入門NLP:基於深度學習的自然語言處理
    【課程亮點】三大模塊,五大應用,手把手快速入門NLP算法+實踐,搭配典型行業應用海外博士講師,豐富項目經驗專業學習社群,隨到隨學【適合人群】本次課程主要適合具備一定編程基礎的開發人員,以及對自然語言處理和深度學習有興趣的踐行者。
  • 谷歌開發出的深度學習算法模型,可用於預測DNA鏈等亞細胞結構的變化
    不同於以往,這裡的研究成果不斷,背後的首要功臣是谷歌研究團隊開發的3D細胞結構模型的算法。繼Alpha Go之後,谷歌研究團隊又一「黑科技」秒殺人類。而隨著科技的發展,深度學習成為圖像處理領域的最佳利器,故而許多研究人員開發出了算法,以用於處理活細胞等微生物螢光圖像:如當科學家希望利用深度學習來分析基因組中的基因突變,他們先將DNA鏈中的鹼基轉換為計算機可以識別的圖像,然後將已知的DNA突變片段信息與基因組信息一起用於訓練神經網絡系統,隨後用機器學習進行預測和數據分析。
  • 聯邦學習算法綜述
    對於每一個用戶來說,人們希望通過他的特徵X,學習一個模型來預測他的標籤Y。在現實中,不同的參與方可能是不同的公司或者機構,人們不希望自己的數據被別人知道,但是人們希望可以聯合訓練一個更強大的模型來預測標籤Y。根據聯邦學習的數據特點(即不同參與方之間的數據重疊程度),聯邦學習可被分為橫向聯邦學習、縱向聯邦學習、遷移聯邦學習。
  • 元學習幫你解決
    這樣,模型就可以跨任務學習準確地解決一個新的、不可見的少鏡頭分類任務。 標準的學習分類算法學習映射圖像→標籤,元學習算法通常學習映射支持集→c(.),其中c是映射查詢→標籤。  元學習算法 既然我們知道了算法元訓練的含義,那麼還有一個謎團:元學習模型是如何解決一個少鏡頭分類任務的?當然,解決方案不止一個。在這裡,我們將關注最受歡迎的方案。
  • 聽說你了解深度學習最常用的學習算法:Adam優化算法?
    By蔣思源2017年7月12日  深度學習常常需要大量的時間和機算資源進行訓練,這也是困擾深度學習算法開發的重大原因。雖然我們可以採用分布式並行訓練加速模型的學習,但所需的計算資源並沒有絲毫減少。而唯有需要資源更少、令模型收斂更快的最優化算法,才能從根本上加速機器的學習速度和效果,Adam算法正為此而生!
  • Python學習步驟
    Python10大特點:易於學習:Python有相對較少的關鍵字,結構簡單,和一個明確定義的語法,學習起來更加簡單。易於閱讀:Python代碼定義的更清晰。可移植:基於其開放原始碼的特性,Python已經被移植(也就是使其工作)到許多平臺。可擴展:如果你需要一段運行很快的關鍵代碼,或者是想要編寫一些不願開放的算法,你可以使用C或C++完成那部分程序,然後從你的Python程序中調用。
  • 成都學習Python開發哪家好
    企業內推:就業老師幫學員簡歷直通HR,可快速面試 簡歷置頂:各大招聘網站享有置頂權,優先面試 了解服務詳情 成都python培訓基礎課程簡介