【注意力機制】transformers之轉換Tensorflow的Checkpoints

2021-02-11 磐創AI

來源 | Github

作者 | huggingface

編譯 | VK

【導讀】本節提供了一個命令行界面來轉換模型中的原始Bert/GPT/GPT-2/Transformer-XL/XLNet/XLM的Checkpoints,然後使用庫的from_pretrained方法加載該Checkpoints。

注意:從2.3.0版本開始,轉換腳本現在已成為 transformers CLI(transformers-cli)的一部分,在任何transformers)=2.3.0的都可用。以下文檔反映了transformers-cli convert命令格式。



你可以通過使用convert_tf_checkpoint_to_pytorch.py將任意的BERT的Tensorflow的Checkpoints轉換為PyTorch格式(特別是由Google發布的預訓練模型(https://github.com/google-research/bert#pre-trained-models))此CLI將TensorFlow checkpoints(三個以bert_model.ckpt開頭的文件)和關聯的配置文件(bert_config.json)作為輸入,並為此配置創建PyTorch模型,並加載在PyTorch模型中從TensorFlow checkpoints進行權重計算,然後將生成的模型保存到標準PyTorch格式文件中,該文件可以使用torch.load()導入(請參閱run_bert_extract_features.py, run_bert_classifier.py and run_bert_squad.py的示例)。你只需一次運行此轉換腳本即可獲得PyTorch模型。然後你可以忽略TensorFlow checkpoints(以bert_model.ckpt開頭的三個文件),但請確保保留配置文件(bert_config.json)和詞彙表文件(vocab.txt),因為PyTorch模型也需要這些。要運行此特定的轉換腳本,你將需要安裝TensorFlow和PyTorch(pip install tensorflow)。存儲庫的其餘部分僅需要PyTorch。這是一個預訓練的BERT-Base Uncased模型的轉換過程示例:
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12

transformers-cli convert --model_type bert \
--tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \
--config $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin

你可以在此處(https://github.com/google-research/bert#pre-trained-models)下載Google的預訓練模型。


這是一個預訓練OpenAI GPT模型轉換過程的示例,假設你的NumPy checkpoints保存的格式與OpenAI的預訓練模型相同(請參見此處(https://github.com/openai/finetune-transformer-lm))
export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights

transformers-cli convert --model_type gpt \
--tf_checkpoint $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
--pytorch_dump_output $PYTORCH_DUMP_OUTPUT \
[--config OPENAI_GPT_CONFIG] \
[--finetuning_task_name OPENAI_GPT_FINETUNED_TASK] \



這是預訓練OpenAI GPT-2模型轉換過程的示例(請參見此處(https://github.com/openai/gpt-2))
export OPENAI_GPT2_CHECKPOINT_PATH=/path/to/gpt2/pretrained/weights

transformers-cli convert --model_type gpt2 \
--tf_checkpoint $OPENAI_GPT2_CHECKPOINT_PATH \
--pytorch_dump_output $PYTORCH_DUMP_OUTPUT \
[--config OPENAI_GPT2_CONFIG] \
[--finetuning_task_name OPENAI_GPT2_FINETUNED_TASK]



這是預訓練Transformer-XL模型轉換過程的示例(請參見此處的(https://github.com/kimiyoung/transformer-xl/tree/master/tf#obtain-and-evaluate-pretrained-sota-models))

export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint

transformers-cli convert --model_type transfo_xl \
--tf_checkpoint $TRANSFO_XL_CHECKPOINT_FOLDER_PATH \
--pytorch_dump_output $PYTORCH_DUMP_OUTPUT \
[--config TRANSFO_XL_CONFIG] \
[--finetuning_task_name TRANSFO_XL_FINETUNED_TASK]



這是一個預訓練XLNet模型的轉換過程示例:

export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint
export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config

transformers-cli convert --model_type xlnet \
--tf_checkpoint $TRANSFO_XL_CHECKPOINT_PATH \
--config $TRANSFO_XL_CONFIG_PATH \
--pytorch_dump_output $PYTORCH_DUMP_OUTPUT \
[--finetuning_task_name XLNET_FINETUNED_TASK] \


這是一個預訓練XLM模型的轉換過程示例:

export XLM_CHECKPOINT_PATH=/path/to/xlm/checkpoint

transformers-cli convert --model_type xlm \
--tf_checkpoint $XLM_CHECKPOINT_PATH \
--pytorch_dump_output $PYTORCH_DUMP_OUTPUT
[--config XML_CONFIG] \
[--finetuning_task_name XML_FINETUNED_TASK]

原文連結:https://huggingface.co/transformers/converting_tensorflow_models.html

相關焦點

  • TensorFlow 攜手 NVIDIA,使用 TensorRT 優化 TensorFlow Serving...
    HTTP/REST API at:localhost:8501 …$ curl -o /tmp/resnet/resnet_client.py https://raw.githubusercontent.com/tensorflow/serving/master/tensorflow_serving/example/resnet_client.py
  • Tensorflow基礎教程15天之創建Tensor
    在將Tensor定義為Variable之後,Tensorflow才會將其傳入計算圖。如何操作我們將在這裡介紹創建Tensor的主要方法。序列TensorTensorflow允許我們定義數組Tensor。
  • 步履不停:TensorFlow 2.4新功能一覽!
    參數伺服器訓練教程           https://tensorflow.google.cn/tutorials/distribute/parameter_server_training    ClusterCoordinator           https://tensorflow.google.cn/api_docs/python
  • 直觀理解並使用Tensorflow實現Seq2Seq模型的注意機制
    import numpy as npimport pandas as pdfrom tensorflow.keras.preprocessing.text import Tokenizerfrom tensorflow.keras.preprocessing.sequence import pad_sequencesimport tensorflow as tffrom sklearn.model_selection
  • 分享TensorFlow Lite應用案例
    內存大小控制機制存在一定的問題,例如模型本身在計算時只有 20MB,但加載到內存之後的運行時峰值可能會飆升 40 到 70MB。   TF Lite 對於 CNN 類的應用支持較好,目前對於 RNN 的支持尚存在 op 支持不足的缺點。
  • Tensorflow 2.0 即將入場
    而就在即將到來的2019年,Tensorflow 2.0將正式入場,給暗流湧動的框架之爭再燃一把火。如果說兩代Tensorflow有什麼根本不同,那應該就是Tensorflow 2.0更注重使用的低門檻,旨在讓每個人都能應用機器學習技術。
  • TensorFlow 資源大全中文版
    (點擊上方藍字,快速關注我們)譯文:伯樂在線專欄作者 - Yalye英文:jtoy如有好文章投稿
  • TensorFlow極速入門
    最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、tensorflow是什麼?tensorflow 是 google 開源的機器學習工具,在2015年11月其實現正式開源,開源協議Apache 2.0。
  • 終於來了,TensorFlow 新增官方 Windows 支持
    選自Google Developers Blog機器之心編譯參與:李澤南昨日,Google Brain 工程師團隊宣布在 TensorFlow 0.12 中加入初步的 Windows 支持。TensorFlow 宣布開源剛剛過去一年。在谷歌的支持下,TensorFlow 已成為 GitHub 上最受歡迎的機器學習開源項目。
  • Tensorflow 2.0的這些新設計,你適應好了嗎?
    而就在即將到來的2019年,Tensorflow 2.0將正式入場,給暗流湧動的框架之爭再燃一把火。如果說兩代Tensorflow有什麼根本不同,那應該就是Tensorflow 2.0更注重使用的低門檻,旨在讓每個人都能應用機器學習技術。
  • 教程| 如何用TensorFlow在安卓設備上實現深度學習推斷
    從源安裝和配置 TensorFlow(https://www.tensorflow.org/install/install_sources)。3.在 TensorFlow 目錄下運行下列命令行:bazel build tensorflow/tools/graph_transforms:transform_graphbazel-bin/tensorflow/tools/graph_transforms/transform_graph \ --in_graph=/your/.pb/file \ --outputs="output_node_name
  • 資源| TensorFlow版本號升至1.0,正式版即將到來
    選自github機器之心編譯參與:吳攀2015 年 11 月份,谷歌宣布開源了深度學習框架 TensorFlow,一年之後,TensorFlow 就已經成長為了 GitHub 上最受歡迎的深度學習框架(參見機器之心文章《深度 | TensorFlow 開源一周年:這可能是一份最完整的盤點》),儘管那時候 TensorFlow 的版本號還是 v0.11。
  • 5個簡單的步驟掌握Tensorflow的Tensor
    在這篇文章中,我們將深入研究Tensorflow Tensor的細節。我們將在以下五個簡單步驟中介紹與Tensorflow的Tensor中相關的所有主題:第一步:張量的定義→什麼是張量?我們經常將NumPy與TensorFlow一起使用,因此我們還可以使用以下行導入NumPy:import tensorflow as tfimport numpy as np張量的創建:創建張量對象有幾種方法可以創建tf.Tensor對象。讓我們從幾個例子開始。
  • 谷歌開放GNMT教程:如何使用TensorFlow構建自己的神經機器翻譯系統
    選自谷歌機器之心編譯參與:機器之心編輯部近日,谷歌官方在 Github 開放了一份神經機器翻譯教程,該教程從基本概念實現開始,首先搭建了一個簡單的NMT模型,隨後更進一步引進注意力機制和多層 LSTM 加強系統的性能,最後谷歌根據 GNMT 提供了更進一步改進的技巧和細節,這些技巧能令該NMT系統達到極其高的精度。
  • 玩轉TensorFlow?你需要知道這30功能
    地址是:tensorflow.org/tfx/?https://www.tensorflow.org/tfx/data_validation/?hl=zh-cn4)TFX -TensorFlow 變換同樣地,你可能希望用於重新訓練的數據也能被自動進行預處理:對特定特性進行歸一化、將字符串轉換為數值等。Transform 不僅可以對單個樣本進行這些操作,還能批處理數據。
  • tensorflow初級必學算子
    在之前的文章中介紹過,tensorflow框架的核心是將各式各樣的神經網絡抽象為一個有向無環圖,圖是由tensor以及tensor變換構成;雖然現在有很多高階API可以讓開發者忽略這層抽象,但對於靈活度要求比較高的算法仍然需要開發者自定義網絡圖,所以建議開發者儘量先學習tf1.x
  • Tensorflow 全網最全學習資料匯總之Tensorflow 的入門與安裝【2】
    《TensorFlow學習筆記1:入門》連結:http://www.jeyzhang.com/tensorflow-learning-notes.html本文與上一篇的行文思路基本一致,首先概括了TensorFlow的特性,然後介紹了graph、session、variable 等基本概念的含義,以具體代碼的形式針對每個概念給出了進一步的解釋
  • TensorFlow 中文資源全集,學習路徑推薦
    https://gitee.com/fendouai/Awesome-TensorFlow-Chinese很多內容下面這個英文項目:Inspired by https://github.com/jtoy/awesome-tensorflow官方網站官網:https://www.tensorflow.org/中文:https://tensorflow.google.cn
  • TensorFlow 2.0開源工具書,30天「無痛」上手
    開源電子書地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/GitHub 項目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days為什麼一定要學
  • 谷歌正式發布TensorFlow 1.5,究竟提升了哪些功能?
    機器之心對這次更新的重大改變以及主要功能和提升進行了編譯介紹,原文請見文中連結。GitHub 地址:https://github.com/tensorflow/tensorflow/releases/tag/v1.5.0原始碼(zip):https://github.com/tensorflow/tensorflow/archive/v1.5.0.zip原始碼(tar.gz):https://github.com/tensorflow/tensorflow