文 / David Budden 與 Matteo Hessel
DeepMind 工程師通過構建工具、對算法進行拓展和創造具有挑戰性的虛擬和物理環境來訓練和測試人工智慧 (AI) 系統,加速我們的研究。作為這項工作的一部分,我們在持續評估機器學習新的庫和框架。
近來,我們發現由 Google Research 團隊開發的機器學習框架 JAX 為越來越多的項目提供良好支持。JAX 與我們的工程理念產生了很好的共鳴,並在去年被我們的研究社區廣泛使用。本文將分享我們的 JAX 使用經驗,來說明我們認為它有助於我們 AI 研究的原因,並概述我們正在為支持各地研究人員而建立的生態系統。
JAX 是為高性能數字計算(尤其是機器學習研究)而設計的 Python 庫。其用於數值計算的 API 基於 NumPy 這樣一個用於科學計算的函數庫所構建。得益於 Python 和 NumPy 較高的使用率和知名度,使得 JAX 簡潔靈活、易於使用。
除了其 NumPy API 之外,JAX 還具有一個用於可組合函數的轉換的擴展系統,在以下幾方面幫助機器學習研究:
自動微分我們發現,JAX 幫助新型算法和架構的研究進行快速實驗,為近期發表的多篇論文奠定了基礎。要了解詳情,請參考我們在 NeurIPS 虛擬大會上舉辦的 JAX 圓桌會議。
對前沿 AI 研究的支持意味著能在快速原型設計與快速迭代間保持平衡的同時,兼顧在傳統生產環境中成規模部署的能力。而這一切帶來挑戰的原因為研究領域發展十分迅速且難以預測。往往一項新的研究突破能在任意時刻改變整個領域發展的方向與需求。在這種瞬息萬變的環境中,我們工程團隊的核心使命便是確保在研究項目中可以有效復用現有的經驗與代碼。
一種成熟的方法是模塊化:我們將每個研究項目中開發的最重要和最關鍵的代碼塊提取至經過測試且高效的組件中。這使得研究人員能夠專注研究的同時受益於我們的核心庫所實現的算法部分的代碼重用、錯誤修復和性能提升。我們還發現,應該確保每個庫都有明確定義的範圍,並確保庫之間在能夠互相調用的同時保證相互獨立。增量更新,即使用版本特性時不會受制於其餘部分,對於為研究人員提供最大的靈活性並持續支持其選擇正確的工作工具至關重要。
JAX 生態系統開發中的其他考慮因素包括確保其與現有 TensorFlow 庫(如 Sonnet 和 TRFL)的設計(儘可能)保持一致。我們還構建了(在相關時)儘可能接近其基礎數學的組件,以實現自我描述,並最大程度地減少「從紙面到代碼」的思維跳轉。最後,我們選擇將我們的庫開源,以促進分享研究成果,並鼓勵更廣泛的社區探索 JAX 生態系統。
TensorFlow 庫
https://tensorflow.google.cn/guide
Sonnet
https://deepmind.com/blog/article/open-sourcing-sonnet
TRFL
https://deepmind.com/blog/article/trfl
最後,我們選擇將我們的庫開源,以促進分享研究成果,並鼓勵更廣泛的社區探索 JAX 生態系統。
可組合函數轉換的 JAX 編程模型可能會使對有狀態對象的處理複雜化,例如具有可訓練參數的神經網絡。Haiku 神經網絡庫允許用戶使用常見的面向對象的編程模型,同時利用強勁而便利的 JAX 純功能範式。
Haiku 的活躍用戶包括 DeepMind 和 Google 的數百名研究員,Haiku 也已在多個外部項目(如 Coax、DeepChem、NumPyro)中得到採用。它以 Sonnet 的 API 為基礎。Sonnet 是我們在 TensorFlow 中基於模塊的神經網絡編程模型,我們希望儘可能簡化從 Sonnet 到 Haiku 的移植。
在 GitHub 上了解更多信息。
Optax
梯度優化是 ML 的基礎。Optax 提供了梯度轉換庫以及允許在單行代碼中實現許多標準優化器(例如 RMSProp 或 Adam)的合成算子(例如鏈)。
Optax 的合成性質自然支持在自定義優化器中重組相同的基本成分。此外,它還提供了許多用於隨機梯度估算和二階優化的實用工具。
許多 Optax 用戶已經採用 Haiku,但根據我們的增量購買理念,任何以 JAX 樹結構表示參數的庫都可獲得支持(例如 Elegy、Flax 和 Stax)。請在此處查看關於這一豐富多樣的 JAX 庫生態系統的更多信息。
在 GitHub 上了解更多信息。
RLax
我們許多最成功的項目都位於深度學習與強化學習 (RL) 的交匯處,也就是深度強化學習。RLax 庫為構建 RL 代理提供了實用的構建塊。
RLax 中的組件涵蓋了廣泛的算法和概念:TD 學習、政策梯度、actor-critic、MAP、近端政策優化、非線性價值轉換、一般價值函數和許多探索方法。
雖然提供了一些介紹性的示例代理,但 RLax 並不是用於構建和部署完整 RL 代理系統的框架。Acme 是基於 RLax 組件構建的全功能代理框架示例。
在 GitHub 上了解更多信息。
Chex
測試對於軟體可靠性至關重要,研究代碼也不例外。只有保證研究代碼正確,才能從研究實驗中得出科學結論。Chex 測試實用工具集合可支持庫作者驗證通用構建塊是否正確耐用,還可支持最終用戶檢查其實驗代碼。
Chex 提供了多種實用工具,包括 JAX 感知單元測試、JAX 數據類型的屬性斷言、mock 和 fake 以及多設備測試環境。Chex 廣泛用於 DeepMind 的整個 JAX 生態系統以及 Coax 和 MineRL 等外部項目。
在 GitHub 上了解更多信息。
Jraph
圖神經網絡 (GNN) 是一個激動人心的研究領域,包括許多大有前途的應用。例如,我們最近在 Google 地圖中的交通預測工作和物理模擬方面的工作。Jraph(發音同「giraffe」)是一個輕量級庫,支持在 JAX 中使用 GNN。
Jraph 提供了標準化的圖數據結構,用於處理圖的一組實用程序,以及易於分叉和可擴展的圖神經網絡模型的「zoo」。包括其他關鍵特性:有效利用硬體加速器的 GraphTuples 批處理,通過填充和遮蔽對可變形圖的 JIT 編譯支持,以及在輸入分區上定義的損失。與 Optax 和我們的其他庫一樣,Jraph 對用戶的神經網絡庫選擇沒有任何限制。
從我們豐富的示例中詳細了解如何使用庫。
在 GitHub 上了解更多信息。
我們的 JAX 生態系統正在不斷發展,我們希望 ML 研究社區能夠探索我們的庫和 JAX 的潛力,從而加速自己的研究。
如果您發現 DeepMind JAX 生態系統有助於您的工作,請使用此引用(託管在 GitHub 上)。
點擊屏末 | 閱讀原文 | 探索 JAX 庫