TensorFlow Lite是TensorFlow在移動和嵌入式設備上的輕量級解決方案,目前只能用於預測,還不能進行訓練。TensorFLow Lite針對移動和嵌入設備開發,具有如下三個特點:
目前TensorFlow Lite已經支持Android、iOS、Raspberry等設備,本章會基於Android設備上的部署方法進行講解,內容包括模型保存、轉換和部署。
2、模型保存我們以keras模型訓練和保存為例進行講解,如下是keras官方的mnist模型訓練樣例。
'''Trains a simple convnet on the MNIST dataset.
Gets to 99.25% test accuracy after 12 epochs
(there is still a lot of margin for parameter tuning).
16 seconds per epoch on a GRID K520 GPU.
'''
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
batch_size = 128
num_classes = 10
epochs = 12
img_rows, img_cols = 28, 28
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])創建mnist_cnn.py文件,將以上內容拷貝進去,並在最後加上如下一行代碼:
model.save('mnist_cnn.h5')在終端中執行mnist_cnn.py文件,如下:
python mnist_cnn.py註:該過程需要連接網絡獲取mnist.npz文件(https://s3.amazonaws.com/img-datasets/mnist.npz),會被保存到$HOME/.keras/datasets/。如果網絡連接存在問題,可以通過其他方式獲取mnist.npz後,直接保存到該目錄即可。
執行過程會比較久,執行結束後,會產生在當前目錄產生mnist_cnn.h5文件(HDF5格式),就是keras訓練後模型,其中已經包含了訓練後的模型結構和權重等信息。
該模型可以在伺服器端,可以直接通過keras.models.load_model("mnist_cnn.h5")加載,然後進行推測;在行動裝置需要將HDF5模型文件轉換為TensorFlow Lite的格式,然後提供相應平臺提供的Interpreter加載,然後進行推測。
3、模型轉換不能直接在移動端部署,因為模型大小和運行效率比較低,最終需要通過工具轉化為Flat Buffer格式的模型。
谷歌提供了多種轉換方式:
tflight_convert跟tensorflow是一起下載的,筆者通過brew安裝的python,pip安裝tf-nightly後tflight_convert路徑如下:
/usr/local/opt/python/Frameworks/Python.framework/Versions/3.6/bin實際上,應該是/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/bin,但是軟連結到了如上路徑。如果命令行不能執行到tflight_convert,則在~/.bash_profile(macOS)或~/.bashrc(Linux)添加如下環境變量:
export PATH="/usr/local/opt/python/Frameworks/Python.framework/Versions/3.6/bin:$PATH" 然後執行
source ~/.bash_profile或
source ~/.bashrc在命令執行
tflight_convert -h輸出結果如下,則證明安裝配置成功。
usage: tflite_convert [-h] --output_fileOUTPUT_FILE
(--graph_def_file GRAPH_DEF_FILE | --saved_model_dirSAVED_MODEL_DIR | --keras_model_fileKERAS_MODEL_FILE)
[--output_format {TFLITE,GRAPHVIZ_DOT}]
[--inference_type {FLOAT,QUANTIZED_UINT8}]
[--inference_input_type {FLOAT,QUANTIZED_UINT8}]
[--input_arrays INPUT_ARRAYS]下面我們開始轉換模型,具體命令如下:
tflite_convert --keras_model_file=./mnist_cnn.h5 --output_file=./mnist_cnn.tflite到此,我們已經得到一個可以運行的TensorFlow Lite模型了,即mnist_cnn.tflite。
註:這裡只介紹了keras HDF5格式模型的轉換,其他模型轉換建議參考:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_examples.md
4、Android部署現在開始在Android環境部署,對於國內的讀者,需要先給Android Studio配置proxy,因為gradle編譯環境需要獲取相應的資源,請大家自行解決,這裡不再贅述。
4.1 配置app/build.gradle新建一個Android Project,打開app/build.gradle添加如下信息
android{
aaptOptions{
noCompress"tflite"
}
}
repositories{
maven{
url'https://google.bintray.com/tensorflow'
}
}
dependencies{
implementation'org.tensorflow:tensorflow-lite:1.10.0'
}其中,
1、aaptOptions設置tflite文件不壓縮,確保後面tflite文件可以被Interpreter正確加載。
2、org.tensorflow:tensorflow-lite的最新版本號,可以在這裡查詢https://bintray.com/google/tensorflow/tensorflow-lite,目前最新的是1.10.0版本。
設置好後,sync和build整個工程,如果build成功說明,配置成功。
4.2 添加tflite文件到assets文件夾在app目錄先新建assets目錄,並將mnist_cnn.tflite文件保存到assets目錄。重新編譯apk,檢查新編譯出來的apk的assets文件夾是否有mnist_cnn.tflite文件。
使用apk analyzer查看新編譯出來的apk,存在如下目錄即編譯打包成功。
assets
|__mnist_cnn.tflite4.3 加載模型使用如下函數將mnist_cnn.tflite文件加載到memory-map中,作為Interpreter實例化的輸入。
private static final String MODEL_PATH = "mnist_cnn.tflite";
private MappedByteBuffer loadModelFile(Activityactivity) throws IOException{
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
returnf ileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}實例化Interpreter,其中this為當前acitivity
tflite = new Interpreter(loadModelFile(this));4.4 運行輸入我們使用mnist test測試集中的某張圖片作為輸入,mnist圖像大小28*28,單像素。這樣我們輸入的數據需要設置成如下格式。
private ByteBufferimgData = null;
private static final int DIM_BATCH_SIZE = 1;
private static final int DIM_PIXEL_SIZE = 1;
private static final int DIM_IMG_WIDTH = 28;
private static final int DIM_IMG_HEIGHT=28;
protected void onCreate() {
imgData = ByteBuffer.allocateDirect(
4 * DIM_BATCH_SIZE * DIM_IMG_WIDTH * DIM_IMG_HEIGHT * DIM_PIXEL_SIZE);
imgData.order(ByteOrder.nativeOrder());
}將mnist圖片轉化成ByteBuffer,並保持到imgData中。
private int[] intValues = new int[DIM_IMG_WIDTH * DIM_IMG_HEIGHT];
private void convertBitmapToByteBuffer(Bitmapbitmap) {
if (imgData == null) {
return;
}
imgData.rewind();
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
int pixel = 0;
for (int i = 0; i < DIM_IMG_WIDTH; ++i) {
for (int j = 0; j < DIM_IMG_HEIGHT; ++j) {
final int val = intValues[pixel++];
imgData.putFloat(val);
}
}
}convertBitmapToByteBuffer的輸出即為模型運行的輸入。
4.5 運行輸出定義一個1*10的多維數組,因為我們只有1個batch和10個label(TODO:need double check),具體代碼如下:
private float[][] labelProbArray = new float[1][10];運行結束後,每個二級元素都是一個label的概率。
4.6 運行及結果處理開始運行模型,具體代碼如下:
tflite.run(imgData, labelProbArray);針對某個圖片,運行後labelProbArray的內容如下,也就是各個label識別的概率。
index 0 prob is 0.0
index 1 prob is 0.0
index 2 prob is 0.0
index 3 prob is1.0
index 4 prob is 0.0
index 5 prob is 0.0
index 6 prob is 0.0
index 7 prob is 0.0
index 8 prob is 0.0
index 9 prob is 0.0接下來,我們要做的就是根據對這些概率進行排序,找出Top的label並界面呈現給用戶
5、總結至此,整個TensorFlow Lite的部署就完成了,包含四個階段:
模型保存:我們使用的是keras Squential類的save函數
模型轉換:我們使用的tflite_convert工具
Android部署:配置build.gradle和assets,通過memory-map加載圖片並轉化為ByteBuffer作為輸入和固定維數的float數組作為輸出,最後調用Interpreter.run()
處理和顯示運行結果
6、附錄TF Lite Command-line tools
TF Lite Android App
Google TF Lite Codelab
TensorFlow Lite Example
What I know about TensorFlow Lite
TensorFlow Lite for mobile developers (Google I/O '18)