TensorFlow筆記:模型保存、加載和Fine-tune

2021-03-02 跳動的數據
前言

嘗試過遷移學習的同學們都知道,Tensorflow的模型保存加載有不同格式,使用方法也不一樣,新手會覺得亂七八糟,所以本文做一個梳理。從模型的保存到加載,再到使用,力求理清這個流程。

1. 保存

Tensorflow的保存分為三種:

1. checkpoint模式;

2. pb模式;

3. saved_model模式。


1.1 先假設有這麼個模型

首先假定我們已經有了這樣一個簡單的線性回歸網絡結構:

import tensorflow as tfsize = 10X = tf.placeholder(name="input", shape=[None, size], dtype=tf.float32)y = tf.placeholder(name="label", shape=[None, 1], dtype=tf.float32)beta = tf.get_variable(name="beta", shape=[size, 1], initializer=tf.glorot_normal_initializer())bias = tf.get_variable(name="bias", shape=[1], initializer=tf.glorot_normal_initializer())pred = tf.add(tf.matmul(X, beta), bias, name="output")loss = tf.losses.mean_squared_error(y, pred)train_op = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss)


我們來簡單初始化,然後跑一下:

feed_X = np.ones((8,size)).astype(np.float32)feed_y = np.ones((8,1)).astype(np.float32)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print(sess.run(pred, feed_dict={X:feed_X}))    sess.run(train_op, feed_dict={X:feed_X, y:feed_y})    print(sess.run(pred, feed_dict={X:feed_X}))

可以看到初始化的輸出y值,以及訓練1個step之後的模型輸出發生了變化。

1.2 checkpoint模式

checkpoint模式將網絡和變量數據分開保存,保存好的模型長這個樣子:

|--checkpoint_dir|    |--checkpoint|    |--test-model-550.meta|    |--test-model-550.data-00000-of-00001|    |--test-model-550.index


checkpoint_dir就是保存時候指定的路徑,路徑下會生成4個文件。其中.meta文件(其實就是pb格式文件)用來保存模型結構,.data和.index文件用來保存模型中的各種變量,而checkpoint文件裡面記錄了最新的checkpoint文件以及其它checkpoint文件列表,在inference時可以通過修改這個文件,指定使用哪個model。那麼要如何保存呢?

checkpoint_dir = "./model_ckpt/"saver = tf.train.Saver(max_to_keep=1)    with tf.Session() as sess:    saver.save(sess, checkpoint_dir + "test-model",global_step=i, write_meta_graph=True)

實際就兩步。執行之後就可以在checkpoint_dir下面看到前面提到的4個文件了。(這裡的max_to_keep是指本次訓練在checkpoint_dir這個路徑下最多保存多少個模型文件,新模型會覆蓋舊模型以節省空間)。

1.3 pb模式

pb模式保存的模型,只有在目標路徑pb_dir = "./model_pb/"下孤孤單單的一個文件"test-model.pb",這也是它相比於其他幾種方式的優勢,簡單明了。假設還是前面的網絡結構,如果想保存成pb模式該怎麼做呢?

pb_dir = "./model_pb/"with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    graph_def = tf.get_default_graph().as_graph_def()        var_list = ["input", "label", "beta", "bias", "output"]       constant_graph = tf.graph_util.convert_variables_to_constants(sess, graph_def, var_list)    with tf.gfile.FastGFile(pb_dir + "test-model.pb", mode='wb') as f:        f.write(constant_graph.SerializeToString())

其實pb模式本質上就是把變量先凍結成常數,然後保存到圖結構中。這樣就可以直接加載圖結構和「參數」了。

1.4 saved_model模式


雖然saved_model也支持模型加載,並進行遷移學習。可是不得不說saved_model幾乎就是為了部署而生的,因為依靠tf.Serving部署模型時要求模型格式必須是saved_model格式。除此以外saved_model還有另外一個優點就是可以跨語言讀取,所以本文也介紹一下這種模式的保存於加載。本文樣例的保存在參數設置上會考慮到方便部署。保存好的saved_model結構長這個樣子:

|--saved_model_dir|    |--1|        |--saved_model.pb|        |--variables|            |--variables.data-00000-of-00001|            |--variables.index

保存時需要將保存路徑精確到"saved_model_dir/1/ ",會在下面生成一個pb文件,以及一個variables文件夾。其中「1」文件夾是表示版本的文件夾,應該是一個整數。人為設定這個「版本文件夾」的原因是,在模型部署的時候需要將模型位置精確到saved_model_dir,tf.Serving會在saved_model_dir下搜索版本號最大的路徑下的模型進行服務。模型保存的方法是

version = "1/"saved_model_dir = "./saved_model/test-model-dir/"builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir + version)
signature = tf.saved_model.signature_def_utils.build_signature_def( inputs={"input": tf.saved_model.utils.build_tensor_info(X)}, outputs={"output": tf.saved_model.utils.build_tensor_info(pred)}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME )
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) builder.add_meta_graph_and_variables(sess, tags=[tf.saved_model.tag_constants.SERVING], signature_def_map={"serving_default": signature}, ) builder.save()

因為涉及到部署,比較複雜,這裡不得不說明一下。

在保存之前需要構建一個signature,用來構造signature的build_signature_def函數有三個參數:inputs、outputs、method_name。其中inputs和outputs分別用來獲取輸入輸出向量的信息,在部署服務後來的數據會餵到inputs中,服務吐的結果會以outputs的形式返回;而method_name如果用來部署模型的話需要設置為

"tensorflow/serving/predict",

"tensorflow/serving/classify", 

"tensorflow/serving/regress" 

中的一個。如果不是用來服務,就可以寫一個其他的。

在保存的時候,除了剛剛構建的signature,還需要提供一個tags 參數,如果用來部署的話需要填[tf.saved_model.tag_constants.SERVING],否則可以填其他。另外如果用來部署模型的話,signature_def_map的key必須是"serving_default"。

2. 加載

下面說如何加載,checkpoint和pb兩種模式的加載方法也不一樣。下面分別說


2.1 checkpoint加載(略煩)

checkpoint模式的網絡結構和變量是分來保存的,加載的時候也需要分別加載。而網絡結構部分你有兩種選擇:

1. 加載.meta文件中的結構,

2. 手動重新寫一遍原樣結構。

我們先說後一個,如果你不光有模型文件,還有源碼,可以把源碼構建模型那部分複製過來,然後只加載變量就好,這是手動重新搭建網絡結構:

import tensorflow as tfsize = 10X = tf.placeholder(name="input", shape=[None, size], dtype=tf.float32)y = tf.placeholder(name="label", shape=[None, 1], dtype=tf.float32)beta = tf.get_variable(name="beta", shape=[size, 1], initializer=tf.glorot_normal_initializer())bias = tf.get_variable(name="bias", shape=[1], initializer=tf.glorot_normal_initializer())pred = tf.sigmoid(tf.matmul(X, beta) + bias, name="output")

然後加載變量:

feed_X = np.ones((8,size)).astype(np.float32)feed_y = np.ones((8,1)).astype(np.float32)saver = tf.train.Saver()with tf.Session() as sess:    saver.restore(sess, tf.train.latest_checkpoint('./model_ckpt/'))            print(sess.run(pred, feed_dict={X:feed_X}))

所以手動構建網絡結構後,只需要saver.restore一下,就可以加載模型中的參數。

另外,如果將上面的sess.run(tf.global_variables_initializer())注釋掉,

那每次運行的結果都一樣,可見此時模型中的變量確實是加載進來的變量。如果取消注釋這一句,每次跑出來的結果都不同,因為加載進來的變量又被初始化函數覆蓋了,所以每次都不一樣。這也說明了:通過checkpoint這種模式加載進來的變量,依然是變量,而且是trainable=True的。

print(tf.trainable_variables())

結果為:[<tf.Variable 'beta:0' shape=(10, 1) dtype=float32_ref>, <tf.Variable 'bias:0' shape=(1,) dtype=float32_ref>]

那如果我懶,活著沒有源碼,無法手動構建網絡呢?就需要從.meta文件裡導入網絡結構了。

import numpy as npimport tensorflow as tfsize = 10saver=tf.train.import_meta_graph('./model_ckpt/test-model-0.meta')

什麼?這就完了?網絡結構在哪呢?先別急,這種方法就是這樣,網絡結構已經加載進來了,那怎麼用呢?

feed_X = np.ones((8,size)).astype(np.float32)feed_y = np.ones((8,1)).astype(np.float32)with tf.Session() as sess:    saver.restore(sess, tf.train.latest_checkpoint('./model_ckpt/'))      graph = tf.get_default_graph()    X = graph.get_tensor_by_name("input:0")            pred = graph.get_tensor_by_name("output:0")            print(sess.run(pred, feed_dict={X:feed_X}))

其實前面把網絡結構加載進來之後,如果需要對某tensor進行操作的話(run、feed、concat等等)需要通過tensor的name獲取成變量。同樣通過sess.run(tf.global_variables_initializer())可以看出,加載進來的變量,還是變量。

總結一下:手動構建網絡結構的話,缺點是麻煩!優點是你想用什麼變量直接用就行;而通過.meta文件來加載網絡結構,優點是省事,缺點是如果想用某個變量,必須通過name獲取變量。

2.2 pb模式加載


相比之下,pb模式的加載舊沒那麼複雜,因為他的網絡結構和數據是存在一起的。

import numpy as npimport tensorflow as tf
pb_dir = "./model_pb/"with tf.gfile.FastGFile(pb_dir + "test-model.pb", "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) X, pred = tf.import_graph_def(graph_def, return_elements=["input:0", "output:0"])

現在我們就已經有了X和pred,下面來跑一個pred吧

feed_X = np.ones((8,size)).astype(np.float32)feed_y = np.ones((8,1)).astype(np.float32)with tf.Session() as sess:        print(sess.run(pred, feed_dict={X:feed_X}))

就這麼簡單!從pb中獲取進來的「變量」就可以直接用。為什麼我要給變量兩個字打上引號呢?因為在pb模型裡保存的其實是常量了,取消注釋sess.run(tf.global_variables_initializer())後,多次運行的結果還是一樣的。此時的「beta:0」和"bias:0"已經不再是variable,而是constant。這帶來一個好處:讀取模型中的tensor可以在Session外進行。相比之下checkpoint只能在Session內讀取模型,對Fine-tune來說就比較麻煩。



2.3 saved_model模式加載

前兩種加載方法想要獲取tensor,要麼需要手動搭建網絡,要麼需要知道tensor的name,如果用模型和訓模型的不是同一個人,那在沒有源碼的情況下,就不方便獲取每個tensor的name。好在saved_model可以通過前面提到的signature_def_map的方法獲取tensor。先看一下直接通過tensor的name獲取變量的加載方式:

size = 10feed_X = np.ones((8,size)).astype(np.float32)feed_y = np.ones((8,1)).astype(np.float32)
saved_model_dir = "./saved_model/1/"with tf.Session() as sess: meta_graph_def = tf.saved_model.loader.load(sess, tags=["serve"], export_dir=saved_model_dir) graph = tf.get_default_graph() X = graph.get_tensor_by_name("input:0") pred = graph.get_tensor_by_name("output:0") print(sess.run(pred, feed_dict={X:feed_X}))

這裡和checkpoint的加載過程很相似,先一個load過程,然後get_tensor_by_name。這需要我們事先知道tensor的name。如果有了signature的信息就不一樣了:

size = 10feed_X = np.ones((8,size)).astype(np.float32)feed_y = np.ones((8,1)).astype(np.float32)
saved_model_dir = "./saved_model/1/"with tf.Session() as sess: meta_graph_def = tf.saved_model.loader.load(sess, tags=["serve"], export_dir=saved_model_dir) signature = meta_graph_def.signature_def X = signature["serving_default"].inputs["input"].name pred = signature["serving_default"].outputs["output"].name print(sess.run(pred, feed_dict={X:feed_X}))

這時即使我們沒有源碼,也可以通過print(signature)獲知關於tensor的信息,如上就展示了沒有源碼時,通過signature獲取tensor的name,並獲取tensor的過程。這裡輸出的signature長這樣:

print(signature)
"""INFO:tensorflow:Restoring parameters from ./saved_model/1/variables/variables{'serving_default': inputs { key: "input" value { name: "input:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 10 } } }}outputs { key: "output" value { name: "output:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 1 } } }}method_name: "tensorflow/serving/predict"}"""



3. Fine-tune

最後不管保存還是加載模型,多數情況都是為了能夠進行遷移學習。其實大部分無非就是將模型加載進來之後,使用某一個節點的值,作為我們後續模型的輸入唄。比如我要用前面的模型結果作為特徵通過一元羅輯回歸去預測z,這樣新的網絡結構就是這樣:

import numpy as npimport tensorflow as tf
pb_dir = "./model_pb/"with tf.gfile.FastGFile(pb_dir + "test-model.pb", "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) X, pred = tf.import_graph_def(graph_def, return_elements=["input:0", "output:0"])
z = tf.placeholder(name="new_label", shape=[None, 1], dtype=tf.float32)new_beta = tf.get_variable(name="new_beta", shape=[1], initializer=tf.glorot_normal_initializer())new_bias = tf.get_variable(name="new_bias", shape=[1], initializer=tf.glorot_normal_initializer())new_pred = tf.sigmoid(new_beta * pred + new_beta)
new_loss = tf.reduce_mean(tf.losses.log_loss(predictions=new_pred, labels=z))train_op = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(new_loss)

就是這樣,把保存好的模型看作一個黑盒,餵進去X吐出來pred,然後我們直接用pred就好了。

但是這裡存在一個問題,就是只能通過name獲取節點。比如這裡的new_pred就沒有name,那我想要基於這個新模型再次進行Fine-tune的時候,就不能獲取這個new_pred,就無法進行Fine-tune。所以大家還是要養成一個好習慣,多給變量起名字,尤其是placeholder!要是連placeholder都沒名字,別人就沒法用你的模型啦。如果保存的是saved_model,建議一定要設置signature。

下面來實驗一下這個Fine-tune的模型吧:

feed_X = np.ones((8,size)).astype(np.float32)feed_z = np.array([[1],[1],[0],[0],[1],[1],[0],[0]]).astype(np.float32)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print(sess.run(new_pred, feed_dict={X:feed_X}))    sess.run(train_op,  feed_dict={X:feed_X, z:feed_z})    print(sess.run(new_pred, feed_dict={X:feed_X}))

這裡補充一下:通過pb模式導入進來的參數其實是constants,所以在Fine-tune的時候不會變化,而通過checkpoint模式導入進來的參數是variables,在後續Fine-tune的時候是會發生變化的。具體讓不讓他trainable就看你的實際需要了。

4. 其他補充

在2.2中,加載pb模型的時候,並不需要把所有的tensor都獲取到,只要「一頭一尾」即可。因為頭("input:0")是需要進行feed操作的,而尾("output:0")是需要輸出,或者在遷移學習中要進行其他操作。至於中間哪些其他不需要進行操作的tensor,可以不獲取。

因為只有pb模式在加載的時候,可以在Session外進行加載,方便Fine-tune。所以個人建議,如果要進行遷移學習,先將模型轉化為pb模式。

其他的想起來在寫

相關焦點

  • 詳解Tensorflow模型量化(Quantization)原理及其實現方法
    因此,為了解決此類問題模型量化應運而生,本篇我們將探討模型量化的概念原理、優缺點及tensorflow模型量化的實現方法。五、tensorflow訓練後量化(Post-training quantization)介紹及其實現方法tensorflow訓練後量化是針對已訓練好的模型來說的,針對大部分我們已訓練未做任何處理的模型來說均可用此方法進行模型量化,而tensorflow提供了一整套完整的模型量化工具,如TensorFlow
  • Keras訓練的h5文件轉pb文件並用Tensorflow加載
    ,而pb格式的文件一般比較適合部署,pb模型文件的大小要比h5文件小一點,同時pb文件也適用於在TensorFlow Serving,所以需要把Keras保存的h5模型文件轉成TensorFlow加載的pb格式來使用。
  • TensorFlow Object Detection API 實踐
    然而構建準確率高的、能定位和識別單張圖片裡多種物體的模型仍然是計算機視覺領域一大挑戰。TF Object Detection API 【1】是一個構建在 TensorFlow 之上的可以簡化構建、訓練、部署目標檢測模型的開源框架。TF Object Detection API 安裝步驟參考【2】。
  • 【TensorFlow實戰筆記】 遷移學習實戰--卷積神經網絡CNN-Inception-v3模型
    模型 : https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip3.flower_photos: http://download.tensorflow.org/example_images/flower_photos.tgz
  • tensorflow使用object detection實現目標檢測超詳細全流程(視頻+圖像集檢測)
    使用tensorflow object detection進行訓練檢測。參考原始代碼:https://github.com/tensorflow/models/tree/master/research本文以mobilenet-ssd-v2為例進行處理,通過換模型即可實現faster RCNN等的訓練檢測。
  • 所有的Tensorflow模型都可以嵌入到行動裝置
    AI 前線導讀:我們通常將 TensorFlow 模型保存為可以被 TesnorFlow Serving 和 Simple TensorFlow Serving 調用的本地文件。然後行動裝置可以通過在線訪問這些服務實現移動端的一些應用。
  • PyTorch 深度剖析:如何保存和加載PyTorch模型?
    目錄1 需要掌握3個重要的函數2 state_dict2.1 state_dict 介紹2.2 保存和加載 state_dict (已經訓練完,無需繼續訓練)2.3 保存和加載整個模型 (已經訓練完,無需繼續訓練)2.4 保存和加載 state_dict (沒有訓練完,還會繼續訓練
  • 【NLP預訓練模型】你finetune BERT的姿勢可能不對哦?
    預訓練模型BERT是NLP領域如今最大的網紅,BERT的預訓練過程學習了大量的自然語言中詞、句法以及常識等泛領域的知識。因此,在運用BERT到實際的NLP任務中,通常的做法都是基於特定領域內的少量數據(幾千到幾萬)集,再進行Finetune,以適用於當前的任務和領域。通常來說,基於BERT進行Finetune效果都會還不錯。
  • tensorflow(8)將h5文件轉化為pb文件並利用tensorflow/serving實現模型部署
    在文章NLP(三十四)使用keras-bert實現序列標註任務中,我們使用Keras和Keras-bert進行模型訓練、模型評估和模型預測。我們對人民日報實體數據集進行模型訓練,保存後的模型文件為example.h5,h5是Keras保存模型的一種文件格式。
  • 社區分享 | Spark 玩轉 TensorFlow 2.0
    本文來自社區投稿與徵集,作者梁雲,轉自:https://github.com/lyhue1991/eat_tensorflow2_in_30_days本篇文章介紹在 Spark 中調用訓練好的 TensorFlow 模型進行預測的方法。本文內容的學習需要一定的 Spark 和 Scala 基礎。
  • TensorFlow Serving入門
    、驗證和預測,但模型完善之後的生產上線流程,就變得五花八門了。TF Serving的工作流程主要分為以下幾個步驟:Source會針對需要進行加載的模型創建一個Loader,Loader中會包含要加載模型的全部信息;Source通知Manager有新的模型需要進行加載
  • 深度學習入門篇——手把手教你用 TensorFlow 訓練模型
    Tensorflow在更新1.0版本之後多了很多新功能,其中放出了很多用tf框架寫的深度網絡結構,大大降低了開發難度,利用現成的網絡結構,無論fine-tuning
  • 教程 | 從零開始:TensorFlow機器學習模型快速部署指南
    更典型的 ML 用例通常基於數百個圖像,這種情況我推薦大家對現有模型進行微調。例如,https://www.tensorflow.org/tutorials/image_retraining 頁面上有如何微調 ImageNet 模型對花樣本數據集(3647 張圖像,5 個類別)進行分類的教程。
  • 使用OpenCV加載TensorFlow2模型
    Suaro希望使用OpenCV來實現模型加載與推演,但是沒有成功,因此開了issue尋求我的幫助。首先,我們先解決OpenCV加載模型的問題。使用OpenCV加載模型OpenCV在3.0的版本時引入了一個dnn模塊,實現了一些基本的神經網絡模型layer。在最新的4.5版本中,dnn模塊使用函數 readNet 實現模型加載。
  • Nvidia Jetson Nano:使用Tensorflow和OpenCV從頭開始自定義對象檢測
    本文作者:轉載自:https://medium.com/swlh/nvidia-jetson-nano-custom-object-detection-from-scratch-using-tensorflow-and-opencv
  • 老闆讓我用少量樣本 finetune 模型,我還有救嗎?急急急,在線等!
    傳統的 RNN 在這個樣本大小下很難被訓練好,自然地,我們會想到使用預訓練模型,在其基礎上進行 finetune。具體來講,就是將預訓練模型作為模型的底層,在上面添加與當前任務特點相關的網絡結構。這樣就引入了預訓練的知識,對當前任務能產生很大的幫助。
  • Tensorflow如何導出與使用預測圖
    tf.train.Saver API說明保存於恢復變量,對定義好完成訓練或者完成部分訓練的計算圖所有OP操作的中間變量進行保存,保存為檢查點文件(checkpoint file),檢查點文件通過restore方法完成恢復,實現從變量到張量值(tensor value)得映射加載,可以進行調用或者繼續訓練。
  • 還在 Fine-tune 大規模預訓練模型? 該了解下最新玩法 Prompt-tuning 啦!
    我們先來總結下 fine-tune 存在的一些問題:(以下 'PLM' 代表 Pre-trained Language Model,即預訓練模型)· PLM 規模不斷增大,對其進行 fine-tune 的硬體要求和數據需求
  • 用TensorFlow實現文本分析模型,做個聊天機器人
    blogId=121我最近每天都會學一點,拿出解讀來和大家分享一下。本文結構:聊天機器人的架構簡圖用 TensorFlow 實現 Chatbot 的模型如何準備 chatbot 的訓練數據Chatbot 源碼解讀1.
  • Tensorflow的C語言接口部署DeeplabV3+語義分割模型
    tensorflow框架一般都是基於Python調用,但是有些時候跟應用場景,我們希望調用tensorflow C語言的接口,在C++的應用開發中使用它。要這麼幹,首先需要下載tensorflow源碼,完成編譯,然後調用相關的API函數實現C語言版本的調用,完成模型的加載、前向推理預測與解析。