雷鋒網(公眾號:雷鋒網)(公眾號:雷鋒網)按:本文為雷鋒字幕組編譯的技術博客,原標題 Smart way to serialize/deserialize classes to/from Tensorflow graph ,作者為 Francesco Zuppichini 。
翻譯 |王褘 整理 | MY
將類中的欄位和 graph 中的 tensorflow 變量進行自動綁定,並且在不需要手動將變量從 graph 中取出的情況下進行重存,聽起來有沒有很炫酷?
可以點擊這裡找到本文所涉及的代碼。Jupyter-notebook 的版本點擊這裡。
假設你有一個 Model 類。
一般來說,首先需要構建模型,然後對模型進行訓練。之後無需再次從頭重新構建訓練模型,而是從已經保存的 graph 中獲取舊變量來進行使用。
假設我們已經訓練好了模型,現在我們想要把它保存下來。通常的模式是:
接下來你會通過加載已保存的 graph 來執行 inference,也就是把變量取出的操作。在下面的例子中,我們將變量命名為 variable 。
現在我們可以從 graph 中取出變量 variable 。
假如我們想要再次使用 model 類要怎麼辦?如果我們嘗試去調用 model.variable,得到的結果會是 None。
一個解決方案是重新構建整個模型,然後重新保存一個 graph 。
可以想見,這個過程肯定非常耗費時間。我們可以通過直接將 model.variable 綁定到相應的 graph 節點上來實現,如下:
假設我們有一個非常大的模型,且內含嵌套變量。
為了能夠將變量指針正確的重存進模型,你需要
如果可以通過在 Model 類中將變量設置為欄位的方式來實現自動檢索,這聽起來就很酷,有沒有?
TFGraphConvertible
我創建了一個 TFGraphConvertible 類,你可以用這個 TFGraphConvertible 類來自動進行類的序列化和反序列化。
讓我們來重新創建我們的模型。
它會暴露兩個方法: to_graph 和 from_graph 方法。
序列化 — to_graph
你可以通過調用 to_graph 方法來進行類的序列化,這個方法會創建一個以欄位為 key , tensorflow 變量名為值的字典。
你想要序列化哪些欄位來構建這個字典,那麼你需要將這些欄位作為 fields 參數傳入。
在下例中,我們傳入所有這些欄位。
這會創建全量字典,以欄位作為關鍵字,以每個欄位對應的 tensorflow 變量名作為值。
反序列化 — from_graph
你可以通過調用 from_graph 方法來進行類的反序列化,這個方法通過我們在上文中構建的字典內容,將類中的欄位綁定到對應的 tensorflow 變量上。
現在你恢復了 model 。
完整的例子
來看一個更有趣的例子!我們接下來要用 MNIST 數據集來訓練/恢復一個模型。
首先,獲取數據集。
現在我們用這個數據集來進行訓練
完美!接下來我們將這個序列化後的模型存到內存中。
接著我們重置 graph,並且重建模型。
顯而易見,變量並沒有在 mnist_model 中。
我們通過調用 from_graph 方法來重建它們
現在 mnist_model 已經可以使用了,我們來看一下在測試集上的精確度如何吧。
結論
通過這次的教程,我們了解了如何進行類的序列化,以及如何在 tensorflow graph 中將類中的欄位反綁到對應的變量上。
並且可以將 serialized_model 保存成 .json 格式,然後從任意位置直接加載它。
通過這種方式,你可以通過面向對象編程的方式來直接創建模型,且無需重新構建就可以索引到所有的變量。
感謝您的閱讀。
原文連結:https://towardsdatascience.com/smart-way-to-srialize-deserialise-class-to-from-tensorflow-graph-1b131db50c7d
雷鋒網雷鋒網
雷鋒網原創文章,未經授權禁止轉載。詳情見轉載須知。