魚羊 編輯整理
量子位 報導 | 公眾號 QbitAI
編者按:作為一個Java開發者,你是否曾為在PyTorch上部署模型而苦惱?這篇來自AWS軟體工程師的投稿,結合實例,詳細介紹了DJL這個為Java開發者設計的深度學習庫:5分鐘,你就能在PyTorch上,用Java實現目標檢測。
5分鐘,用Java實現目標檢測
文 / 知乎用戶@Lanking
PyTorch在深度學習領域中的應用日趨廣泛,得益於它獨到的設計。無論是數據的並行處理還是動態計算圖,一切都為Python做出了很多簡化。很多論文都選擇使用PyTorch去實現也證明了它在訓練方面的效率以及易用性。
在PyTorch領域,儘管部署一個模型有很多選擇,可為Java開發人員準備的選項卻屈指可數。
在過去,用戶可以用PyTorch C++ 寫JNI (Java Native Interface) 來實現這個過程。最近,PyTorch 1.4 也發布了試驗性的Java 前端。
可是這兩種解決方案都沒有辦法能讓Java開發者很好的使用:用戶需要從易於使用和易於維護中二選一。
針對於這個問題,亞馬遜雲服務 (AWS)開源了 Deep Java Library (DJL),一個為Java開發者設計的深度學習庫。它兼顧了易用性和可維護性,一切運行效率以及內存管理問題都得到了很好的處理。
DJL使用起來異常簡單。只需幾行代碼,用戶就可以輕鬆部署深度學習模型用作推理。那麼我們就開始上手用DJL部署一個PyTorch 模型吧。
前期準備
用戶可以輕鬆使用maven或者gradle等Java常用配置管理包來引用DJL。下面是一個示例:
plugins {
id 'java'
}
repositories {
jcenter()
}
dependencies {
implementation "ai.djl:api:0.4.0"
implementation "ai.djl:repository:0.4.0"
runtimeOnly "ai.djl.pytorch:pytorch-model-zoo:0.4.0"
runtimeOnly "ai.djl.pytorch:pytorch-native-auto:1.4.0"
}
然後只需gradle build,基本配置就大功告成了。
開始部署模型
我們用到的目標檢測模型來源於NVIDIA在torchhub發布的預訓練模型。我們用下面這張圖來推理幾個可以識別的物體(狗,自行車以及皮卡)。
可以通過下面的代碼來實現推理的過程:
public static void main(String[] args) throws IOException, ModelException, TranslateException {
String url = "https://github.com/awslabs/djl/raw/master/examples/src/test/resources/dog_bike_car.jpg";
BufferedImage img = BufferedImageUtils.fromUrl(url);
Criteria criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(BufferedImage.class, DetectedObjects.class)
.optFilter("backbone", "resnet50")
.optProgress(new ProgressBar())
.build();
try (ZooModel model = ModelZoo.loadModel(criteria)) {
try (Predictor predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
System.out.println(detection);
}
}
}
然後,就結束了。相比於其他解決方案動輒上百行的代碼,DJL把所有過程簡化到了不到30行完成。那麼我們看看輸出的結果:
[
class: "dog", probability: 0.96709, bounds: [x=0.165, y=0.348, width=0.249, height=0.539]
class: "bicycle", probability: 0.66796, bounds: [x=0.152, y=0.244, width=0.574, height=0.562]
class: "truck", probability: 0.64912, bounds: [x=0.609, y=0.132, width=0.284, height=0.166]
]
你也可以用我們目標檢測圖形化API來看一下實際的檢測效果:
你也許會說,這些代碼都包裝的過於厲害,真正的小白該如何上手呢?
讓我們仔細的看一下剛才的那段代碼:
// 讀取一張圖片
String url = "https://github.com/awslabs/djl/raw/master/examples/src/test/resources/dog_bike_car.jpg";
BufferedImage img = BufferedImageUtils.fromUrl(url);
// 創建一個模型的尋找標準
Criteria criteria =
Criteria.builder()
// 設置應用類型:目標檢測
.optApplication(Application.CV.OBJECT_DETECTION)
// 確定輸入輸出類型 (使用默認的圖片處理工具)
.setTypes(BufferedImage.class, DetectedObjects.class)
// 模型的過濾條件
.optFilter("backbone", "resnet50")
.optProgress(new ProgressBar())
.build();
// 創建一個模型對象
try (ZooModel model = ModelZoo.loadModel(criteria)) {
// 創建一個推理對象
try (Predictor predictor = model.newPredictor()) {
// 推理
DetectedObjects detection = predictor.predict(img);
System.out.println(detection);
}
}
這樣是不是清楚了很多?DJL建立了一個模型庫(ModelZoo)的概念,引入了來自於GluonCV, TorchHub, Keras 預訓練模型, huggingface自然語言處理模型等70多個模型。所有的模型都可以一鍵導入,用戶只需要使用默認或者自己寫的輸入輸出工具就可以實現輕鬆的推理。我們還在不斷的添加各種預訓練模型。
了解DJL
DJL是亞馬遜雲服務在2019年re:Invent大會推出的專為Java開發者量身定製的深度學習框架,現已運行在亞馬遜數以百萬的推理任務中。
如果要總結DJL的主要特色,那麼就是如下三點:
DJL不設限制於後端引擎:用戶可以輕鬆的使用 MXNet, PyTorch, TensorFlow和fastText來在Java上做模型訓練和推理。
DJL的算子設計無限趨近於numpy:它的使用體驗上和numpy基本是無縫的,切換引擎也不會造成結果改變。
DJL優秀的內存管理以及效率機制:DJL擁有自己的資源回收機制,100個小時連續推理也不會內存溢出。
James Gosling (Java 創始人) 在使用後給出了讚譽:
對於PyTorch的支持
DJL現已支持PyTorch 1.5。我們深度整合了PyTorch C++ API,開發了一套JNI提供Java的底層支持。DJL提供各類PyTorch原生算子算法,現在支持所有的 TorchScript模型。
現在可以在 Mac/Linux/Windows全平臺運行DJL PyTorch。DJL具有自檢測CUDA版本的功能,也會自動採用對應的CUDA版本包來運行gpu任務。
想了解更多,請參見下面幾個連結:
https://djl.ai
https://github.com/awslabs/djl
也歡迎加入我們slack論壇:
https://app.slack.com/client/TPX8YGQTW
— 完 —
本文系網易新聞•網易號特色內容激勵計劃籤約帳號【量子位】原創內容,未經帳號授權,禁止隨意轉載。
報名 | 智慧生活行業私享會
歡迎報名,與峰瑞資本、石頭科技、網易有道、思必馳、九號機器人、視感科技、雲丁科技等企業高管,共同探討如何借力資本市場、把握行業趨勢,打造全場景智慧生活:
量子位 QbitAI · 頭條號籤約作者
վ'ᴗ' ի 追蹤AI技術和產品新動態
原標題:《5分鐘!用Java實現目標檢測 | PyTorch》
閱讀原文