本教程介紹如何使用 tf.Keras 時序 API 從頭開始訓練模型,將 tf.Keras 模型轉換為 tflite 格式,並在 Android 上運行該模型。我將以 MNIST 數據為例介紹圖像分類,並分享一些你可能會面臨的常見問題。本教程著重於端到端的體驗,我不會深入探討各種 tf.Keras API 或 Android 開發。
下載我的示例代碼並執行以下操作:
1.訓練自定義分類器
加載數據
我們將使用作為tf.keras框架一部分的mnst數據。
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
預處理數據
接下來,我們將輸入圖像從 28x28 變為 28x28x1 的形狀,將其標準化,並對標籤進行 one-hot 編碼。
定義模型體系結構
然後我們將用 cnn 定義網絡架構。
def create_model():
# Define the model architecture
model = keras.models.Sequential([
# Must define the input shape in the first layer of the neural network
keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(28,28,1)),
keras.layers.MaxPooling2D(pool_size=2),
keras.layers.Dropout(0.3),
keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'),
keras.layers.MaxPooling2D(pool_size=2),
keras.layers.Dropout(0.3),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(10, activation='softmax')
])
# Compile the model
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])
return model
訓練模型
然後我們使用 model.fit()來訓練模型。
model.fit(x_train,
y_train,
batch_size=64,
epochs=3,
validation_data=(x_test, y_test))
2.模型保存和轉換
訓練結束後,我們將保存一個 Keras 模型並將其轉換為 TFLite 格式。
保存一個 Keras 模型
下面是保存 Keras 模型的方法-
# Save tf.keras model in HDF5 format
keras_model = "mnist_keras_model.h5"
keras.models.save_model(model, keras_model)
將keras模型轉換為tflite
當使用 TFLite 轉換器將 Keras 模型轉換為 TFLite 格式時,有兩個選擇- 1)從命令行轉換,或 2)直接在 python 代碼中轉換,這個更加推薦。
1)通過命令行轉換
$ tflite_convert \
$ --output_file=mymodel.tflite \
$ --keras_model_file=mymodel.h5
2)通過 python 代碼轉換
如果你可以訪問模型訓練代碼,則這是轉換的首選方法。
# Convert the model
flite_model = converter.convert()
# Create the tflite model file
tflite_model_name = "mymodel.tflite"
open(tflite_model_name, "wb").write(tflite_model)
你可以將轉換器的訓練後量化設置為 true。
# Set quantize to true
converter.post_training_quantize=True
驗證轉換的模型
將 Keras 模型轉換為 TFLite 格式後,驗證它是否能夠與原始 Keras 模型一樣正常運行是很重要的。請參閱下面關於如何使用 TFLite 模型運行推斷的 python 代碼片段。示例輸入是隨機輸入數據,你需要根據自己的數據更新它。
# Load TFLite model and allocate tensors. interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors
input_details = interpreter.get_input_details() output_details = interpreter.get_output_details()
# Test model on random input data
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape),
dtype=np.float32) interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
ps:確保在轉換後和將 TFLite 模型放到 Android 上面之前始終測試它。否則,當它在你的 Android 應用程式上不能工作時,你無法分清是你的 android 代碼有問題還是 ML 模型有問題。
3.在 Android 上實現 tflite 模型
現在我們準備在 Android 上實現 TFLite 模型。創建一個新的 Android 項目並遵循以下步驟
將 mnist.tflite 模型放在 assets 文件夾下
更新 build.gradle 以包含 tflite 依賴項
為用戶創建自定義視圖
創建一個進行數字分類的分類器
從自定義視圖輸入圖像
圖像預處理
用模型對圖像進行分類
後處理
在用戶界面中顯示結果
Classifier 類是大多數 ML 魔術發生的地方。確保在類中設置的維度與模型預期的維度匹配:
28x28x1 的圖像
10 位數字的 10 個類:0、1、2、3…9
要對圖像進行分類,請執行以下步驟:
預處理輸入圖像。將位圖轉換為 bytebuffer 並將像素轉換為灰度,因為 MNIST 數據集是灰度的。
使用由內存映射到 assets 文件夾下的模型文件創建的解釋器運行推斷。
後處理輸出結果以在 UI 中顯示。我們得到的結果有 10 種可能,我們將選擇在 UI 中顯示概率最高的數字。
過程中的挑戰
以下是你可能遇到的挑戰:
如果 Android 應用程式崩潰,請查看 logcat 中的 stacktrace 錯誤:
aaptOptions {
noCompress "tflite"
}
總體來說,用 tf.Keras 訓練一個簡單的圖像分類器是輕而易舉的,保存 Keras 模型並將其轉換為 TFLite 也相當容易。目前,我們在 Android 上實現 TFLite 模型的方法仍然有點單調,希望將來能有所改進。
via:https://medium.com/@margaretmz/e2e-tfkeras-tflite-android-273acde6588
雷鋒網雷鋒網雷鋒網(公眾號:雷鋒網)
雷鋒網版權文章,未經授權禁止轉載。詳情見轉載須知。