一行代碼切換TensorFlow與PyTorch,模型訓練也能用倆框架

2020-12-22 機器之心Pro

機器之心報導

參與:思源

你是否有時要用 PyTorch,有時又要跑 TensorFlow?這個項目就是你需要的,你可以在訓練中同時使用兩個框架,並端到端地轉換模型。也就是說 TensorFlow 寫的計算圖可以作為某個函數,直接應用到 Torch 的張量上,這操作也是很厲害了。

在早兩天開源的 TfPyTh 中,不論是 TensorFlow 還是 PyTorch 計算圖,它們都可以包裝成一個可微函數,並在另一個框架中高效完成前向與反向傳播。很顯然,這樣的框架交互,能節省很多重寫代碼的麻煩事。

github項目地址:BlackHC/TfPyTh

為什麼框架間的交互很重要

目前 GitHub 上有很多優質的開源模型,它們大部分都是用 PyTorch 和 TensorFlow 寫的。如果我們想要在自己的項目中調用某個開源模型,那麼它們最好都使用相同的框架,不同框架間的對接會帶來各種問題。當然要是不怕麻煩,也可以用不同的框架重寫一遍。

以前 TensorFlow 和 PyTorch 經常會用來對比,討論哪個才是更好的深度學習框架。但是它們之間就不能友好相處麼,模型在兩者之間的相互遷移應該能帶來更多的便利。

在此之前,Facebook 和微軟就嘗試過另一種方式,即神經網絡交換格式 ONNX。直觀而言,該工具定義了一種通用的計算圖,不同深度學習框架構建的計算圖都能轉化為它。雖然目前 ONNX 已經原生支持 MXNet、PyTorch 和 Caffe2 等大多數框架,但是像 TensorFlow 或 Keras 之類的只能通過第三方轉換器轉換為 ONNX 格式。

而且比較重要的一點是,現階段 ONNX 只支持推理,導入的模型都需要在原框架完成訓練。所以,想要加入其它框架的模型,還是得手動轉寫成相同框架,再執行訓練。

神奇的轉換庫 TfPyTh

既然 ONNX 無法解決訓練問題,那麼就輪到 TfPyTh 這類項目出場了,它無需改寫已有的代碼就能在框架間自由轉換。具體而言,TfPyTh 允許我們將 TensorFlow 計算圖包裝成一個可調用、可微分的簡單函數,然後 PyTorch 就能直接調用它完成計算。反過來也是同樣的,TensorFlow 也能直接調用轉換後的 PyTorch 計算圖。

因為轉換後的模塊是可微的,那么正向和反向傳播都沒什麼問題。不過項目作者也表示該項目還不太完美,開源 3 天以來會有一些小的問題。例如張量必須通過 CPU 進行複製與路由,直到 TensorFlow 支持__cuda_array_interface 相關功能才能解決。

目前 TfPyTh 主要支持三大方法:

torch_from_tensorflow:創建一個 PyTorch 可微函數,並給定 TensorFlow 佔位符輸入計算張量輸出;eager_tensorflow_from_torch:從 PyTorch 創建一個 Eager TensorFlow 函數;tensorflow_from_torch:從 PyTorch 創建一個 TensorFlow 運算子或張量。TfPyTh 示例

如下所示為 torch_from_tensorflow 的使用案例,我們會用 TensorFlow 創建一個簡單的靜態計算圖,然後傳入 PyTorch 張量進行計算。

import tensorflow as tfimport torch as thimport numpy as npimport tfpythsession = tf.Session()defget_torch_function(): a = tf.placeholder(tf.float32, name='a') b = tf.placeholder(tf.float32, name='b') c = 3 * a + 4 * b * b f = tfpyth.torch_from_tensorflow(session, [a, b], c).applyreturn ff = get_torch_function()a = th.tensor(1, dtype=th.float32, requires_grad=True)b = th.tensor(3, dtype=th.float32, requires_grad=True)x = f(a, b)assert x == 39.x.backward()assert np.allclose((a.grad, b.grad), (3., 24.))

我們可以發現,基本上 TensorFlow 完成的就是一般的運算,例如設置佔位符和建立計算流程等。TF 的靜態計算圖可以通過 session 傳遞到 TfPyTh 庫中,然後就產生了一個新的可微函數。後面我們可以將該函數用於模型的某個計算部分,再進行訓練也就沒什麼問題了。

相關焦點

  • 模型秒變API只需一行代碼,支持TensorFlow等框架
    選自GitHub機器之心編譯參與:一鳴、杜偉還在為機器學習模型打包成 API 發愁?這個工具能讓你一行代碼直接打包。專注於機器學習應用的人們知道,從訓練好的模型到實際的工業生產工具還有一定的距離。其中工作量很大的地方在於將模型打包,預留 API 接口,並和現有的生產系統相結合。近日,GitHub 上有了這樣一個項目,能夠讓用戶一行代碼將任意模型打包為 API。這一工具無疑能夠幫助開發者在實際的生產應用中快速部署模型。
  • TensorFlow與PyTorch之爭,哪個框架最適合深度學習
    訓練後的模型可以用在不同的應用中,比如目標檢測、圖像語義分割等等。儘管神經網絡架構可以基於任何框架實現,但結果卻並不一樣。訓練過程有大量參數都與框架息息相關。舉個例子,如果你在 PyTorch 上訓練一個數據集,那麼你可以使用 GPU 來增強其訓練過程,因為它們運行在 CUDA(一種 C++ 後端)上。
  • 初學AI神經網絡應該選擇Keras或是Pytorch框架?
    對於許多開發者來說,TensorFlow是他們接觸的第一個機器學習框架。TensorFlow框架儘管意義非凡,引起極大關注和神經網絡學習風潮,但對一般開發者用戶太不友好。軟體開發者畢竟不是科學家,很多時候簡單易學易用是程式設計師選擇的第一要素。
  • 一行代碼即可調用18款主流模型!PyTorch Hub輕鬆解決論文可復現性
    ,加載ResNet、BERT、GPT、VGG、PGAN還是MobileNet等經典模型只需一行代碼。圖靈獎得主Yann LeCun發推表示,只需要一行代碼就可以調用所有倉庫裡的模型,通過一個pull請求來發布你自己的模型。同時,PyTorch Hub整合了Google Colab,併集成了論文代碼結合網站Papers With Code,可以直接找到論文的代碼。PyTorch Hub怎麼用?
  • 如何在PyTorch和TensorFlow中訓練圖像分類模型
    這將是你的起點,然後你可以選擇自己喜歡的任何框架,也可以開始構建其他計算機視覺模型。PyTorch為我們提供了一個框架,可以隨時隨地構建計算圖,甚至在運行時進行更改。特別是,對於我們不知道創建神經網絡需要多少內存的情況,這很有用。你可以使用PyTorch應對各種深度學習挑戰。
  • 寫給純小白的深度學習環境搭建寶典:pytorch+tensorflow
    每天給小編五分鐘,小編用自己的代碼,讓你輕鬆學習人工智慧。本文將手把手帶你快速搭建你自己的深度學習環境,然後實現自己的第一個深度學習程序。野蠻智能,小白也能看懂的人工智慧。Anaconda+pytorch環境準備如果你的電腦帶有GPU,可以先安裝Nvidia驅動 + cuda + cudnn,然後再搭建環境,這樣可以達到更高的運行速度。如果不想使用GPU,學習階段也可以使用cpu版本,對於簡單的程序用CPU和GPU其實沒差別。
  • Transformers2.0讓你三行代碼調用語言模型,兼容TF2.0和PyTorch
    更新後的 Transformers 2.0 汲取了 PyTorch 的易用性和 Tensorflow 的工業級生態系統。藉助於更新後的 Transformers 庫,科學家和實踐者可以更方便地在開發同一語言模型的訓練、評估和製作階段選擇不同的框架。那麼更新後的 Transformers 2.0 具有哪些顯著的特徵呢?
  • TensorFlow發布JavaScript開發者的機器學習框架TensorFlow.js
    發布新的 TensorFlow 官方博客(http://blog.tensorflow.org/)與 TensorFlow YouTube 頻道;2. 面向 JavaScript 開發者的全新機器學習框架 TensorFlow.js;3.
  • 上線倆月,TensorFlow 2.0被吐槽太難用,網友:看看人家PyTorch
    我有個想法,我想要在訓練過程中逐漸改變損失函數的『形狀』;2. 我搜索『tensorflow 在訓練中改變損失函數』;3. 最高搜索結果是一個 Medium 的文章,我們去看看吧;4. 這個 Medium 文章介紹的是均方誤差(MSE)損失函數,以及你怎樣在 TensorFlow 中用它訓練一個深度神經網絡;5. 我只好用腦袋砸鍵盤了。
  • 正確debug的TensorFlow的姿勢
    當討論在tensorflow上編寫代碼時,總是將其與PyTorch進行比較,討論框架有多複雜,以及為什麼要使用tf.contrib的某些部分,做得太爛了。此外,我認識很多數據科學家,他們只用Github上已有的repo來用tensorflow。對這個框架持這種態度的原因是非常不同的,但是今天讓我們關注更實際的問題:調試用tensorflow編寫的代碼並理解它的主要特性。
  • 對比PyTorch和TensorFlow的自動差異和動態子類化模型
    使用自定義模型類從頭開始訓練線性回歸,比較PyTorch 1.x和TensorFlow 2.x之間的自動差異和動態模型子類化方法,這篇簡短的文章重點介紹如何在PyTorch 1.x和TensorFlow 2.x中分別使用帶有模塊/模型API的動態子類化模型,以及這些框架在訓練循環中如何使用AutoDiff獲得損失的梯度並從頭開始實現
  • TensorFlow.js 進行模型訓練
    因為我們正在訓練模型來預測連續數字,所以此任務有時被稱為回歸任務。我們將通過展示輸入的許多示例以及正確的輸出來訓練模型。這被稱為監督學習。你將建立什麼您將創建一個使用TensorFlow.js在瀏覽器中訓練模型的網頁。
  • 基於tensorflow框架對手寫字體MNIST數據集的識別
    本文我們利用python語言,通過tensorflow框架對手寫字體MNIST資料庫進行識別。學習每一門語言都有一個「Hello World」程序,而對數字手寫體資料庫MNIST的識別就是深度學習的「Hello World」代碼。下面我們給出詳細的步驟。tensorflow概述tensorflow是用C++語言實現的一個深度學習模塊。
  • tensorflow安裝教程
    tensorflow as tf,有警告但是沒有報錯,說明安裝成功。python測試中有個很好用的工具jupyter notebook,有了這個工具我們可以在瀏覽器上輸入代碼,並查看結果,使用靈活,比使用命令行和編輯.py代碼文件方便,可以極大提高工作效率。
  • 漫談分布式計算框架
    如果問 spark 與 tensorflow 呢,就可能有點迷糊,這倆關注的領域不太一樣啊。但是再問 spark 與 MPI 呢?這個就更遠了。雖然這樣問多少有些不嚴謹,但是它們都有共同的一部分,這就是我們今天談論的一個話題,一個比較大的話題:分布式計算框架。
  • TensorFlow驚現大bug?網友:這是逼著我們用PyTorch啊
    在事情發酵後,TensorFlow 團隊終於回復了,表示已經在改,但對應的功能將在 2.4 版本中才能用。谷歌團隊 2015 年發布的 TensorFlow 框架是目前機器學習領域最流行的框架之一。雖然後起之秀 PyTorch 奮起直追,但 TensorFlow 框架的使用者仍然眾多。
  • TensorFlow極速入門
    最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、tensorflow是什麼?tensorflow 是 google 開源的機器學習工具,在2015年11月其實現正式開源,開源協議Apache 2.0。
  • tensorflow極速入門
    最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、 tensorflow是什麼?tensorflow 是 google 開源的機器學習工具,在2015年11月其實現正式開源,開源協議Apache 2.0。下圖是 query 詞頻時序圖,從中可以看出 tensorflow 的火爆程度。
  • TensorFlow vs PyTorch:哪個是深度學習網絡編程的最佳框架呢?
    import tensorflow as tf而對於PyTorch來說,需要的兩個庫為:import torchimport torchvisionb)加載和預處理數據使用TensorFlow加載和準備數據可以通過以下兩行代碼完成:在PyTorch中是這樣的:我們可以使用matplotlib.pyplot庫驗證這兩段代碼是否已經正確加載了相同數據
  • 玩轉TensorFlow?你需要知道這30功能
    3)TFX 數據驗證如何自動確保用於重新訓練模型的數據與最初用於訓練模型的數據具有相同的格式、源、命名約定等。對於線上訓練來說,這是一個量很大的工作!https://www.tensorflow.org/tfx/data_validation/?hl=zh-cn