【導讀】深度學習框架TensorFlow不僅在學術界得到了普及,在工業界也有非常廣泛的應用。日常我們接觸到的TensorFlow的用法大多為基於Python的實驗用法,並不能直接用於工業界的線上產品。本文介紹一種簡單的發布TensorFlow模型的方法。
在工業產品中使用TensorFlow模型的方法
在工業產品中TensorFlow大概有下面幾種使用方法:
用TensorFlow的C++/Java/Nodejs API直接使用保存的TensorFlow模型:類似Caffe,適合做桌面軟體。
直接將使用TensorFlow的Python代碼放到Flask等Web程序中,提供Restful接口:實現和調試方便,但效率不太高,不大適合高負荷場景,且沒有版本管理、模型熱更新等功能。
將TensorFlow模型託管到TensorFlow Serving中,提供RPC或Restful服務:實現方便,高效,自帶版本管理、模型熱更新等,很適合大規模線上業務。
本文介紹的是方法3,如何用最簡單的方法將TensorFlow發布到TensorFlow Serving中。
一句代碼保存TensorFlow模型
# coding=utf-8
import tensorflow as tf
# 模型版本號
model_version = 1
# 定義模型
x = tf.placeholder(tf.float32, shape=[None, 4], name="x")
y = tf.layers.dense(x, 10, activation=tf.nn.softmax)
with tf.Session() as sess:
# 初始化變量
sess.run(tf.global_variables_initializer())
# 模型訓練過程,省略
# .
# 保存訓練好的模型到"model/版本號"中
tf.saved_model.simple_save(
session=sess,
export_dir="model/{}".format(model_version),
inputs={"x": x},
outputs={"y": y}
)
代碼中除了最後一句,其它部分都是常規的TensorFlow代碼,模型定義、進入Session、模型訓練等。代碼的最後用tf.saved_model.simple_save將模型保存為SavedModel。注意,這裡將模型保存在了"model/版本號"文件夾中,而不是直接保存在了"model"文件夾中,這是因為TensorFlow Serving要求在模型目錄下加一層版本目錄,來進行版本維護、熱更新等:
安裝TensorFlow Serving
方法一:用apt-get安裝
對於Ubuntu或Debian(Bash on Windows10也可以),可以使用apt-get安裝Tensorflow Serving。先用下面的命令添加軟體源:
echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
添加成功後可直接用apt-get進行安裝:
apt-get update && apt-get install tensorflow-model-server
方法二:用Docker安裝
TensorFlow Serving官方提供了Docker容器,可以一鍵安裝:
docker pull tensorflow/serving
將模型發布到TensorFlow Serving中
下面的方法基於在本機使用apt-get安裝TensorFlow Serving的方法。對於Docker用戶,需要將模型掛載或複製到Docker中,按照Docker中的路徑來執行下面的教程。
用下面這行命令,就可以啟動TensorFlow Serving,並將剛才保存的模型發布到TensorFlow Serving中。注意,這裡的模型所在路徑是剛才"model"目錄的路徑,而不是"model/版本號"目錄的路徑,因為TensorFlow Serving認為用戶的模型所在路徑中包含了多個版本的模型。
tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=模型名 --model_base_path=模型所在路徑
客戶端可以用GRPC和Restful兩種方式來調用TensorFlow Serving,這裡我們介紹基於Restful的方法,可以看到,命令中指定的Restful服務埠為8501,我們可以用curl命令來查看服務的狀態:
curl http://localhost:8501/v1/models/model
執行結果:
{
"model_version_status": [
{
"version": "1",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
}
下面我們用curl向TensorFlow Serving發送一個輸入x=[1.1, 1.2, 0.8, 1.3],來獲取預測的輸出信息y:
curl -d '{"instances": [[1.1,1.2,0.8,1.3]]}' -X POST http://localhost:8501/v1/models/模型名:predict
伺服器返回的結果如下:
{
"predictions": [[0.0649088, 0.0974758, 0.0456831, 0.297224, 0.152209, 0.0177431, 0.104193, 0.0450511, 0.13074, 0.044771]]
}
我們的模型成功地輸出了y=[0.0649088, 0.0974758, 0.0456831, 0.297224, 0.152209, 0.0177431, 0.104193, 0.0450511, 0.13074, 0.044771]
這裡我們使用的是curl命令,在實際工程中,使用requests(Python)、OkHttp(Java)等Http請求庫可以用類似的方法方便地請求TensorFlow Serving來獲取模型的預測結果。
版本維護和模型熱更新
剛才我們將模型保存在了"model/1"中,其中1是模型的版本號。如果我們的算法工程師研發出了更好的模型,此時我們並不需要將TensorFlow Serving重啟,只需要將新模型發布在"model/新版本號"中,如"model/2"。TensorFlow Serving就會自動發布新版本的模型,客戶端也可以請求新版本對應的API了。
PC登錄www.zhuanzhi.ai或者點擊閱讀原文,可以獲取更多AI知識資料!
加入專知主題群(請備註主題類型:AI、NLP、CV、 KG等)可以其他同行一起交流~ 請加專知小助手微信(掃一掃如下二維碼添加),
AI 項目技術 & 商務合作:bd@zhuanzhi.ai, 或掃描上面二維碼聯繫!請關注專知公眾號,獲取人工智慧的專業知識!
點擊「閱讀原文」,使用專知