擴展之Tensorflow2.0 | 19 TF2模型的存儲與載入

2021-02-16 機器學習煉丹術

擴展之Tensorflow2.0 | 18 TF2構建自定義模型

擴展之Tensorflow2.0 | 17 TFrec文件的創建與讀取

擴展之Tensorflow2.0 | 16 TF2讀取圖片的方法

擴展之Tensorflow2.0 | 15 TF2實現一個簡單的服裝分類任務

小白學PyTorch | 14 tensorboardX可視化教程

小白學PyTorch | 13 EfficientNet詳解及PyTorch實現

小白學PyTorch | 12 SENet詳解及PyTorch實現

小白學PyTorch | 11 MobileNet詳解及PyTorch實現

小白學PyTorch | 10 pytorch常見運算詳解

小白學PyTorch | 9 tensor數據結構與存儲結構

小白學PyTorch | 8 實戰之MNIST小試牛刀

小白學PyTorch | 7 最新版本torchvision.transforms常用API翻譯與講解

小白學PyTorch | 6 模型的構建訪問遍歷存儲(附代碼)

小白學PyTorch | 5 torchvision預訓練模型與數據集全覽

小白學PyTorch | 4 構建模型三要素與權重初始化

小白學PyTorch | 3 淺談Dataset和Dataloader

小白學PyTorch | 2 淺談訓練集驗證集和測試集

小白學PyTorch | 1 搭建一個超簡單的網絡

小白學PyTorch | 動態圖與靜態圖的淺顯理解


這個系列《小白學PyTorch》的所有代碼和數據集放在了公眾號【機器學習煉丹術】後臺,回復【pytorch】獲取(還在更新的呢):

機器學習煉丹術的粉絲的人工智慧交流群已經建立,目前有目標檢測、醫學圖像、時間序列等多個目標為技術學習的分群和水群嘮嗑的總群,歡迎大家加煉丹兄為好友,加入煉丹協會。微信:cyx645016617.

參考目錄:

1 模型的構建

2 結構參數的存儲與載入

3 參數的存儲與載入

4 結構的存儲與載入

本文主要講述TF2.0的模型文件的存儲和載入的多種方法。主要分成兩類型:模型結構和參數一起載入,模型的結構載入。

1 模型的構建
import tensorflow.keras as keras

class CBR(keras.layers.Layer):
    def __init__(self,output_dim):
        super(CBR,self).__init__()
        self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
        self.bn = keras.layers.BatchNormalization(axis=3)
        self.ReLU = keras.layers.ReLU()

    def call(self, inputs):
        inputs = self.conv(inputs)
        inputs = self.ReLU(self.bn(inputs))
        return inputs

class MyNet(keras.Model):
    def __init__ (self):
        super(MyNet,self).__init__()
        self.cbr1 = CBR(16)
        self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
        self.cbr2 = CBR(32)
        self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2))

    def call(self, inputs):
        inputs = self.maxpool1(self.cbr1(inputs))
        inputs = self.maxpool2(self.cbr2(inputs))
        return inputs

model = MyNet()

部分朋友可以發現,上面的代碼就是上一次課程所構建的一個自定義的網絡。

我們現在需要展示這個模型的框架:

model.build((16,224,224,3))
print(model.summary())

運行結果為:

這裡需要對網絡執行一個構建.build()函數,之後才能生成model.summary()這樣的模型的描述。 這是因為模型的參數量是需要知道輸入數據的通道數的,假如我們輸入的是單通道的圖片,那麼就是:

model.build((16,224,224,1))
print(model.summary())

輸出結果為:

2 結構參數的存儲與載入
model.save('save_model.h5')
new_model = keras.models.load_model('save_model.h5')

這裡並不能保存成功,出現這樣的錯誤:

大概的意思就是:因為你的模型不是官方的模型,是自定義的,所以並不能同時保存結構和參數。只有官方的模型可以時候上面的保存的方法,同時保存參數和權重;自定義的模型建議只保存參數

3 參數的存儲與載入
model.save_weights('model_weight')
new_model = MyNet()
new_model.load_weights('model_weight')

這樣子就可以保存自定義的模型了。在對應的目錄下會出現這幾個文件:

我們來看一下原來的模型和載入的模型對於同一個樣本給出的結果是否相同:

# 看一下原來的模型和載入的模型預測相同的樣本的輸出
test = tf.ones((1,8,8,3))
prediction = model.predict(test)
new_prediction = new_model.predict(test)
print(prediction,new_prediction)
>>> [[[[0.02559286]]]] [[[[0.02559286]]]]

結果相同,載入的沒有問題~

4 結構的存儲與載入

結構的存儲有兩種方法:

需要注意的是,上面的兩個方法和save的問題一樣,是不能用在自定義的模型中的,如果你在其中使用了自定義的Layer類,那麼只能!只能用save_weights的方式進行保存

下面依然給出這兩種方法的代碼,對於簡單的、已經封裝好的一些網絡層構成的網絡,是可以使用這些的。我個人還是常用save_weights啦

# 第一種方法
config = model.get_config()
reinitialized_model = keras.Model.from_config(config)
# 第二種方法
json_config = model.to_json()
# 把json寫的文件中
with open('model_config.json', 'w') as json_file:
    json_file.write(json_config)
# 讀取本地json文件
with open('model_config.json') as json_file:
    json_config = json_file.read()
reinitialized_model = keras.models.model_from_json(json_config)

今天的內容就是這麼多,雖然提供了四種方法,但是對於自定義程度較高的模型,還是要使用save_weights哦~

- END -


小白學論文 | EfficientNet強在哪裡

小白學論文 | 神經網絡初始化Xavier

小白學論文 | 端側神經網絡GhostNet(2019)

小白學目標檢測 | RCNN, SPPNet, Fast, Faster

小白學圖像 | BatchNormalization詳解與比較

小白學圖像 | Group Normalization詳解+PyTorch代碼

小白學圖像 | 八篇經典CNN論文串講

圖像增強 | CLAHE 限制對比度自適應直方圖均衡化

小白學卷積 | 深入淺出卷積網絡的平移不變性

小白學卷積 | (反)卷積輸出尺寸計算

損失函數 | 焦點損失函數 FocalLoss 與 GHM

<<小白學機器學習>>

小白學ML | 隨機森林 全解 (全網最全)

小白學SVM | SVM優化推導 + 拉格朗日 + hingeLoss

小白學LGB | LightGBM = GOSS + histogram + EFB

小白學LGB | LightGBM的調參與並行

小白學XGB | XGBoost推導與牛頓法

評價指標 | 詳解F1-score與多分類F1

小白學ML | Adaboost及手推算法案例

小白學ML | GBDT梯度提升樹

小白學優化 | 最小二乘法與嶺回歸&Lasso回歸

小白學排序 | 十大經典排序算法(動圖)

雜談 | 正態分布為什麼如此常見

Adam優化器為什麼被人吐槽?

機器學習不得不知道的提升技巧:SWA與pseudo-label

<<小白面經>>

秋招總結 | 一個非Top學校的跨專業的算法應屆研究生的幾十場面試

【小白面經】快手 AI算法崗 附答案解析

【小白面經】 拼多多 AI算法崗 附帶解析

【小白面經】八種應對樣本不均衡的策略

【小白面經】之防止過擬合的所有方法

【小白面經】梯度消失爆炸及其解決方法

【小白面經】 判別模型&生成模型

<<小白健身>>

【小白健身】腹肌搓衣板化

【小白健身】8個動作練爆胸大肌

【小白健身 】背闊大作戰(下)

【小白健身】背闊大作戰(上)

【小白健身】徒手健身40個動作(gif)

【小白健身】彈力帶輕度健身gif動圖



相關焦點

  • tensorflow安裝教程
    Anaconda安裝和使用,AkShare入門,安裝tensorflow2.0。首先打開anaconda,執行conda create --name tf2.0 python=3.7建立一個名為tf2.0的虛擬環境。細節不說了,參考我之前的文章,就是一直選yes,安裝就行了。如果報HTTPSConnectionPool字樣的錯誤,是網速慢的原因,多試幾次就好了。
  • 運行tensorflow2.0出錯
    今天在調試tf2.0的代碼的時候,Console丟了一個錯誤出來:AttributeError: module 'tensorflow'
  • 帶你入門機器學習與TensorFlow2.x
    在後續的文章中將深入講解用Tensorflow2.x訓練各種模型,以及利用模型完成相關的工作。pip install tensorflow如果要安裝Tensorflow1.x,那麼需要按前面的步驟創建一個名為tf1的Python虛擬環境,然後使用下面的命令安裝tensorflow的特定版本。
  • 安裝TensorFlow 2.0 preview進行深度學習(附Jupyter Notebook)
    本文介紹安裝TensorFlow 2.0 preview的方法,並介紹一個Github項目tf2_course,它包含了一些TensorFlow 2的練習和解決方案,以Jupyter Notebook的形式展現。TensorFlow是最流行的深度學習框架之一,大家期待已久的TensorFlow 2.0現在出了Preview版本,並且可以直接通過pip安裝。
  • 《30天吃掉那隻 TensorFlow2.0 》全新TF2.0教程收穫1000 Star
    作者 | lyhue1991來源 | GitHub轉自 | Python與算法之美【導讀】本文對書籍《30天吃掉那隻 TensorFlow2.0 》作簡單的內容介紹。TensorFlow2.0 還是 Pytorch?
  • 6 種方法部署 TensorFlow2 機器學習模型,簡單 + 快速 + 跨平臺!
    知識點Keras 導入預訓練模型預訓練模型的使用方法保存模型為 HDF5 格式保存模型為 SavedModel 格式3. 環境配置目前,TensorFlow 2 已正式發布,你需要通過 pip install -U tensorflow 進行升級安裝。
  • tensorflow機器學習模型的跨平臺上線
    ,這個方法當然也適用於tensorflow生成的模型,但是由於tensorflow模型往往較大,使用無法優化的PMML文件大多數時候很笨拙,因此本文我們專門討論下tensorflow機器學習模型的跨平臺上線的方法。
  • 【小白學PyTorch】18.TF2構建自定義模型
    擴展之Tensorflow2.0 | 17 TFrec文件的創建與讀取擴展之Tensorflow2.0 | 16 TF2讀取圖片的方法擴展之Tensorflow2.0 | 15 TF2實現一個簡單的服裝分類任務
  • TensorFlow 2.0 概述
    接下來先來看一段演示代碼:# 將通過清華鏡像下載的tensorflow包導入import tensorflow as tfa = tf.constant([[1.0,-2],[-3,4]])print(a)控制臺輸出結果如下:tf.Tensor
  • 數據載入過慢?這裡有一份TensorFlow加速指南
    根據以往經驗,在TensorFlow中,feed-dict函數可能是最慢的一種數據載入方法,儘量少用。把數據輸入到模型的最佳方法是使用輸入流水線(input pipeline),來確保GPU無須等待新數據輸入。幸好,TensorFlow有一個內置接口,叫做Dataset。這個接口是為了更容易地實現數據輸入,在1.3版本已被提出。
  • 深度學習基礎(十):TensorFlow 2.x模型的驗證、正則化和回調
    此外,對模型進行泛化是一項實際需求。否則,如果您無法使用機器學習模型來訓練模型所依據的數據,而無法成功地進行預測,那麼該方法又有什麼用呢?通過模型選擇和正則化來實現泛化和避免過度擬合。Tensorflow 2.x提供了 回調 函數的功能,工程師可以通過該 功能在培訓進行時根據保留的驗證集監視模型的性能。
  • TensorFlow極速入門
    最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、tensorflow是什麼?tensorflow 是 google 開源的機器學習工具,在2015年11月其實現正式開源,開源協議Apache 2.0。
  • tensorflow極速入門
    最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、 tensorflow是什麼?tensorflow 是 google 開源的機器學習工具,在2015年11月其實現正式開源,開源協議Apache 2.0。下圖是 query 詞頻時序圖,從中可以看出 tensorflow 的火爆程度。
  • 擴展之Tensorflow2.0 | 21 Keras的API詳解(上)卷積、激活、初始化、正則
    2 Keras參數初始化 把之前提到的簡單的例子,增加卷積核和偏置的初始化:import tensorflow as tfinput_shape = (4, 28, 28, 3)initializer = tf.keras.initializers.RandomNormal
  • TensorFlow 模型優化工具包:模型大小減半,精度幾乎不變!
    在 IEEE 754-2008 標準中,16 位 base-2 格式稱為 binary16。它用於在高精度對於執行算術計算不是必需的應用中存儲浮點值,並且 IEEE 754 標準將 binary16 指定為具有以下格式:
  • 模型秒變API只需一行代碼,支持TensorFlow等框架
    每個模型都載入到一個 Docker 容器中,包括相關的 Python 包和處理請求的代碼。模型通過網絡服務,如 Elastic Load Balancing (ELB)、Flask、TensorFlow Serving 和 ONNX Runtime 公開 API 給用戶使用。
  • TensorFlow 2.0 中文手寫字識別(漢字OCR)
    總之一句話,模型太簡單和太複雜都不好,甚至會發散!(想親身體驗模型訓練發散抓狂的可以來嘗試一下!)。但是,挑戰這個任務也有很多好處:本項目實現了基於CNN的中文手寫字識別,並且採用標準的tensorflow 2.0 api 來構建!
  • Tensorflow 2.0 到底好在哪裡?
    在以前的文章中,我曾評測過 :TensorFlow r0.10(2016):https://www.infoworld.com/article/3127397/review-tensorflow-shines-a-light-on-deep-learning.html  TensorFlow 1.5(2018):https://www.infoworld.com
  • TensorFlow安裝與卷積模型
    使用pip安裝1)下載安裝Python 2)打開windows的命令行窗口,安裝CPU版本pip installtensorflow安裝GPU版本Pip install tensorflow-gpu之後驗證是否安裝了 TensorFlow 可以嘗試一下代碼>>> importtensorflow
  • tf2+cnn+中文文本分類優化系列(2)
    import tensorflowas tffrom tensorflow.keras import layersfrom tensorflow.keras.callbacks import ReduceLROnPlateauimport numpy as npimport collectionsimport matplotlib.pyplot as pltimport codecs