如何使用TensorFlow Hub的ESRGAN模型來在安卓app中生成超分圖片

2020-12-11 電子發燒友

從一張低解析度的圖片生成一張對應的高解析度圖片的任務通常被稱為單圖超分(Single Image Super Resolution - SISR)。儘管可以使用傳統的插值方法(如雙線性插值和雙三次插值)來完成這個任務,但是產生的圖片質量卻經常差強人意。深度學習,尤其是對抗生成網絡 GAN,已經被成功應用在超分任務上,比如 SRGAN 和 ESRGAN 都可以生成比較真實的超分圖片。那麼在本文裡,我們將介紹一下如何使用TensorFlow Hub上的一個預訓練的 ESRGAN 模型來在一個安卓 app 中生成超分圖片。最終的 app 效果如下圖,我們也已經將完整代碼開源給大家參考。

SRGAN
https://arxiv.org/abs/1609.04802

ESRGAN
https://arxiv.org/abs/1809.00219

完整代碼
https://github.com/tensorflow/examples/tree/master/lite/examples/super_resolution

首先,我們可以很方便的從 TFHub 上加載 ESRGAN 模型,然後很容易的將其轉化為一個 TFLite 模型。注意在這裡我們使用了動態範圍量化(dynamic range quantization),並將輸入圖片的尺寸固定在50x50像素(我們已經將轉化後的模型上傳到 TFHub 上了):

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1") concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] concrete_func.inputs[0].set_shape([1, 50, 50, 3]) converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() # Save the TF Lite model. with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f: f.write(tflite_model) esrgan_model_path = './ESRGAN.tflite'

TFHub
https://hub.tensorflow.google.cn/

TFHub(轉化後模型)
https://hub.tensorflow.google.cn/captain-pool/lite-model/esrgan-tf2/1

現在 TFLite 已經支持動態大小的輸入,所以你也可以在模型轉化的時候不指定輸入圖片的大小,而在運行的時候動態指定。如果你想使用動態輸入大小,請參考這個例子。

例子
https://github.com/tensorflow/tensorflow/blob/c58c88b23122576fc99ecde988aab6041593809b/tensorflow/lite/python/lite_test.py#L529-L560

模型轉化完之後,我們可以很快驗證 ESRGAN 生成的超分圖片質量確實比雙三次插值要好很多。如果你想更多的了解 ESRGAN 模型,我們還有另外一個教程可供參考:

lr = cv2.imread(test_img_path) lr = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB) lr = tf.expand_dims(lr, axis=0) lr = tf.cast(lr, tf.float32) # Load TFLite model and allocate tensors. interpreter = tf.lite.Interpreter(model_path=esrgan_model_path) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # Run the model interpreter.set_tensor(input_details[0]['index'], lr) interpreter.invoke() # Extract the output and postprocess it output_data = interpreter.get_tensor(output_details[0]['index']) sr = tf.squeeze(output_data, axis=0) sr = tf.clip_by_value(sr, 0, 255) sr = tf.round(sr) sr = tf.cast(sr, tf.uint8)

教程
https://tensorflow.google.cn/hub/tutorials/image_enhancing

LR: 輸入的低解析度圖片,該圖從 DIV2K 數據集中的一張蝴蝶圖片中切割出來. ESRGAN (x4): ESRGAN 模型生成的超分圖片,單邊解析度提升4倍. Bicubic: 雙三次插值生成圖片. 在這裡大家可以很容易看出來,雙三次插值生成的圖片要比 ESRGAN 模型生成的超分圖片模糊很多

你可能已經知道,TensorFlow Lite 是 TensorFlow 用於在端側運行的官方框架,目前全球已有超過40億臺設備在運行 TFLite,它可以運行在安卓,iOS,基於 Linux 的 IoT 設備以及微處理器上。你可以使用 Java, C/C++ 或其他程式語言來運行 TFLite。在這篇文章中,我們將使用 TFLite C API,因為有許多的開發者表示希望我們能提供這樣一個範例。

DIV2K 
https://data.vision.ee.ethz.ch/cvl/DIV2K/

Java, C/C++ 
https://tensorflow.google.cn/lite/guide/android

TFLite C API
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/c_api.h

我們在預先編譯好的 AAR 文件中包含了 TFLite C API需要的頭文件和庫 (包括核心庫和 GPU 庫)。為了正確的設置好 Android 項目,我們首先需要下載兩個 JAR 文件並將相應的頭文件和庫抽取出來。我們可以在一個 download.gradle 文件中定義這些任務,然後將這些任務導入 build.gradle。下面我們先定義下載 TFLite JAR 文件的兩個任務:

task downloadTFLiteAARFile() { download { src "https://bintray.com/google/tensorflow/download_file?file_path=org%2Ftensorflow%2Ftensorflow-lite%2F2.3.0%2Ftensorflow-lite-2.3.0.aar" dest "${project.rootDir}/libraries/tensorflow-lite-2.3.0.aar" overwrite false retries 5 } } task downloadTFLiteGPUDelegateAARFile() { download { src "https://bintray.com/google/tensorflow/download_file?file_path=org%2Ftensorflow%2Ftensorflow-lite-gpu%2F2.3.0%2Ftensorflow-lite-gpu-2.3.0.aar" dest "${project.rootDir}/libraries/tensorflow-lite-gpu-2.3.0.aar" overwrite false retries 5 } }

AAR 文件
https://tensorflow.google.cn/lite/guide/android#use_tflite_c_api

然後我們定義另一個任務來講頭文件和庫解壓然後放到正確的位置:

task fetchTFLiteLibs() { copy { from zipTree("${project.rootDir}/libraries/tensorflow-lite-2.3.0.aar") into "${project.rootDir}/libraries/tensorflowlite/" include "headers/tensorflow/lite/c/*h" include "headers/tensorflow/lite/*h" include "jni/**/libtensorflowlite_jni.so" } copy { from zipTree("${project.rootDir}/libraries/tensorflow-lite-gpu-2.3.0.aar") into "${project.rootDir}/libraries/tensorflowlite-gpu/" include "headers/tensorflow/lite/delegates/gpu/*h" include "jni/**/libtensorflowlite_gpu_jni.so" }

因為我們是用安卓 NDK 來編譯這個 app,我們需要讓 Android Studio 知道如何處理對應的原生文件。我們在 CMakeList.txt 文件中這樣寫:

set(TFLITE_LIBPATH "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libraries/tensorflowlite/jni") set(TFLITE_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libraries/tensorflowlite/headers") set(TFLITE_GPU_LIBPATH "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libraries/tensorflowlite-gpu/jni") set(TFLITE_GPU_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libraries/tensorflowlite-gpu/headers") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=gnu++14") set(CMAKE_CXX_STANDARD 14) add_library(SuperResolution SHARED SuperResolution_jni.cpp SuperResolution.cpp) add_library(lib_tensorflowlite SHARED IMPORTED) set_target_properties(lib_tensorflowlite PROPERTIES IMPORTED_LOCATION ${TFLITE_LIBPATH}/${ANDROID_ABI}/libtensorflowlite_jni.so) add_library(lib_tensorflowlite_gpu SHARED IMPORTED) set_target_properties(lib_tensorflowlite_gpu PROPERTIES IMPORTED_LOCATION ${TFLITE_GPU_LIBPATH}/${ANDROID_ABI}/libtensorflowlite_gpu_jni.so) find_library(log-lib log) include_directories(${TFLITE_INCLUDE}) target_include_directories(SuperResolution PRIVATE ${TFLITE_INCLUDE}) include_directories(${TFLITE_GPU_INCLUDE}) target_include_directories(SuperResolution PRIVATE ${TFLITE_GPU_INCLUDE}) target_link_libraries(SuperResolution android lib_tensorflowlite lib_tensorflowlite_gpu ${log-lib})

我們在 app 裡包含了3個示例圖片,這樣用戶可能會運行同一個模型多次,這意味著為了提高運行效率,我們需要將 TFLite 解釋執行器進行緩存。這一點我們可以在解釋執行器成功建立後通過將其指針從 C++ 傳回到 Java 來實現:

extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_examples_superresolution_MainActivity_initWithByteBufferFromJNI(JNIEnv *env, jobject thiz, jobject model_buffer, jboolean use_gpu) { const void *model_data = static_cast(env->GetDirectBufferAddress(model_buffer)); jlong model_size_bytes = env->GetDirectBufferCapacity(model_buffer); SuperResolution *super_resolution = new SuperResolution(model_data, static_cast(model_size_bytes), use_gpu); if (super_resolution->IsInterpreterCreated()) { LOGI("Interpreter is created successfully"); return reinterpret_cast(super_resolution); } else { delete super_resolution; return 0; } }

解釋執行器建立之後,運行模型實際上就非常簡單了,我們只需要按照 TFLite C API 來就好。不過我們需要注意的是如何從每個像素中抽取 RGB 值:

// Extract RGB values from each pixel float input_buffer[kNumberOfInputPixels * kImageChannels]; for (int i = 0, j = 0; i < kNumberOfInputPixels; i++) { // Alpha is ignored input_buffer[j++] = static_cast((lr_img_rgb[i] >> 16) & 0xff); input_buffer[j++] = static_cast((lr_img_rgb[i] >> 8) & 0xff); input_buffer[j++] = static_cast((lr_img_rgb[i]) & 0xff); }

TFLite C API 
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/c_api.h

運行完模型後我們需要再將 RGB 值再打包進每個像素:

// Postprocess the output from TFLite int clipped_output[kImageChannels]; auto rgb_colors = std::make_unique(kNumberOfOutputPixels); for (int i = 0; i < kNumberOfOutputPixels; i++) { for (int j = 0; j < kImageChannels; j++) { clipped_output[j] = std::max(0, std::min(255, output_buffer[i * kImageChannels + j])); } // When we have RGB values, we pack them into a single pixel. // Alpha is set to 255. rgb_colors[i] = (255u & 0xff) << 24 | (clipped_output[0] & 0xff) << 16 | (clipped_output[1] & 0xff) << 8 | (clipped_output[2] & 0xff); }

那麼到這裡我們就完成了這個 app 的關鍵步驟,我們可以用這個 app 來生成超分圖片。您可以在對應的代碼庫中看到更多信息。我們希望這個範例能作為一個好的參考來幫助剛剛起步的開發者更快的掌握如何使用 TFLite C/C++ API 來搭建自己的機器學習 app。

對應的代碼庫中
https://github.com/tensorflow/examples/tree/master/lite/examples/super_resolution

致謝

作者十分感謝 @captain__pool 將他實現的 ESRGAN 模型上傳到 TFHub, 以及 TFLite 團隊的 Tian Lin 和 Jared Duke 提供十分有幫助的反饋。

— 參考 —

[1] Christian Ledig, Lucas Theis, Ferenc Huszar, Jose Caballero, Andrew Cunningham, Alejandro Acosta, Andrew Aitken, Alykhan Tejani, Johannes Totz, Zehan Wang, Wenzhe Shi. 2016. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network.

[2] Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Chen Change Loy, Yu Qiao, Xiaoou Tang. 2018. ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

[3] Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

https://github.com/krasserm/super-resolution

[4] @captain__pool 的 ESGRAN 代碼實現

https://github.com/captain-pool/GSOC

[5] Eirikur Agustsson, Radu Timofte. 2017. NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study.

責任編輯:xj

原文標題:學習教程 | 使用 TensorFlow Lite 在 Android App 中生成超分圖片

文章出處:【微信公眾號:TensorFlow】歡迎添加關注!文章轉載請註明出處。

打開APP閱讀更多精彩內容

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發燒友網立場。文章及其配圖僅供工程師學習之用,如有內容圖片侵權或者其他問題,請聯繫本站作侵刪。 侵權投訴

相關焦點

  • 在TensorFlow中使用模型剪枝將機器學習模型變得更小
    學習如何通過剪枝來使你的模型變得更小剪枝是一種模型優化技術,這種技術可以消除權重張量中不必要的值。這將會得到更小的模型,並且模型精度非常接近標準模型。在本文中,我們將通過一個例子來觀察剪枝技術對最終模型大小和預測誤差的影響。
  • 教程| 如何用TensorFlow在安卓設備上實現深度學習推斷
    邊緣計算(Edge computing)是一種在物理上靠近數據生成的位置從而對數據進行處理和分析的方法,為解決這些問題提供了方案。以「Ok Google」這個功能為例:用一名用戶的聲音來訓練「Ok Google」,他的手機在接收到這個關鍵詞的時候就會被喚醒。
  • MobileNet教程(2):用TensorFlow搭建安卓手機上的圖像分類App
    MobileNet是為移動端量身打造的,因此這次我們準備把之前做的辨別道路的模型應用到一個Android App中,看看它在行動裝置上效果如何。建立數據集在前一篇推送中,我們為了辨認「道路/非道路」,從多個來源拉取了圖片作為訓練素材。現在我們再來思考一下這樣做是否有必要。如果你記得的話,這個項目的目標是為了保護用戶隱私,當車上的攝像頭打開的時候,如果它看見的不是道路,就應該自動關掉。
  • 使用Python+Tensorflow的CNN技術快速識別驗證碼
    一開始學習tensorflow是盲目的,不知如何下手,網上的資料都比較單一,為了回報社會,讓大家少走彎路,我將詳細介紹整個過程。本教程所需要的完整材料,我都會放在這裡。限於個人水平,如有錯誤請指出!接下來我將介紹如何使用Python+Tensorflow的CNN技術快速識別驗證碼。在此之前,介紹我們用到的工具:1.
  • tensorflow機器學習模型的跨平臺上線
    作者:劉建平編輯:黃俊嘉在用PMML實現機器學習模型的跨平臺上線中,我們討論了使用PMML文件來實現跨平臺模型上線的方法
  • 模型秒變API只需一行代碼,支持TensorFlow等框架
    每個模型都載入到一個 Docker 容器中,包括相關的 Python 包和處理請求的代碼。模型通過網絡服務,如 Elastic Load Balancing (ELB)、Flask、TensorFlow Serving 和 ONNX Runtime 公開 API 給用戶使用。
  • 如何從Tensorflow中創建CNN,並在GPU上運行該模型(附代碼)
    在本教程中,您將學習卷積神經網絡(CNN)的架構,如何在Tensorflow中創建CNN,並為圖像標籤提供預測。
  • 玩轉TensorFlow?你需要知道這30功能
    1)TensorFlow 擴展(TFX)大家都知道我特別喜歡用 TFX 以及它的全套工具來把機器學習模型部署到生產環境中。如果你關心如何使模型保持最新並監控它們,那麼你可以了解一下這個產品、看看它的論文。
  • 如何使用 TensorFlow mobile 將 PyTorch 和 Keras 模型部署到行動裝置
    截止到今年,已經有超過 20 億活躍的安卓設備。安卓手機的迅速普及很大程度上是因為各式各樣的智能 app,從地圖到圖片編輯器應有盡有。
  • TensorFlow開發者證書 中文手冊
    該等級的證書考試主考查開發人員將機器學習集成到工具和應用程式中的基本知識。要獲得證書需要理解如何使用計算機視覺、卷積神經網絡、自然語言處理以及真實世界的圖像數據和策略來構建TensorFlow模型。搭建自然語言處理系統準備模型使用的文本數據使用二分類搭建模型識別文本片段使用多分類搭建模型識別文本片段在你的模型中使用詞向量
  • 基於TensorFlow的深度學習實戰
    --upgrade six(tensorflow)$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.7.1-cp27-none-any.whl你可訪問官方文檔來確認所安裝的版本。
  • TensorFlow安裝與卷積模型
    1:Tensorflow 安裝:首先目前已學習的方法中有兩種方法可用於安裝TensorFlow:一是pip,二是Anaconda。另外 TensorFlow還有cpu和Gpu兩個版本。使用pip安裝1)下載安裝Python 2)打開windows的命令行窗口,安裝CPU版本pip installtensorflow安裝GPU版本Pip install tensorflow-gpu之後驗證是否安裝了 TensorFlow 可以嘗試一下代碼>>> importtensorflow
  • 【官方教程】TensorFlow在圖像識別中的應用
    人類在ImageNet挑戰賽上的表現如何呢?Andrej Karpathy寫了一篇博文來測試他自己的表現。他的top-5 錯誤率是5.1%。這篇教程將會教你如何使用Inception-v3。你將學會如何用Python或者C++把圖像分為1000個類別。我們也會討論如何從模型中提取高層次的特徵,在今後其它視覺任務中可能會用到。
  • 玩轉TensorFlow Lite:有道雲筆記實操案例分享
    OCR 的結果中生成帶有格式的筆記。依照官方文檔,bazel 編譯的 target 是 "//tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo",這樣得到的是一個 demo app。
  • TensorFlow極速入門
    最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、tensorflow是什麼?tensorflow 是 google 開源的機器學習工具,在2015年11月其實現正式開源,開源協議Apache 2.0。
  • tensorflow極速入門
    首先是一些基礎概念,包括計算圖,graph 與 session,基礎數據結構,Variable,placeholder 與 feed_dict 以及使用它們時需要注意的點。最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、 tensorflow是什麼?
  • 機器如何閱讀圖片?能看破並說破一切的TensorFlow
    但機器不是這樣,因而它需要目標檢測,這是在圖片中定位目標實例的計算機視覺問題。目標檢測是深度學習和計算機視覺領域最有趣的概念之一,構建一個能瀏覽圖片並告知圖片中有什麼對象的模型,多麼奇妙的感覺!好消息是,開發目標檢測應用程式比以往更加容易了。如今的方法專注於端到端管道,極大地提高了性能,有助於開發實時用例。
  • 用TensorFlow構建一個中文分詞模型需要幾個步驟
    注意,所謂半監督學習(Semi-supervised learning),其實是一大類算法、方法的統稱,這裡使用的方法只是某種非常簡單的半監督學習方法的應用。模型在模型上,我們選擇使用Albert-small版本的模型,這個版本的模型大小不到30MB,適合比較輕量級的任務,我們可以先嘗試實現一個最簡單的序列標註模型。
  • 使用pix2pix-tensorflow 的交互式圖象到圖象翻譯的演示
    原文地址:連結最近,我將Isola等人做的pix2pix移植到了Tensorflow 平臺。Tensorflow 平臺包含在Tensorflow的圖像到圖像翻譯(Image-to-ImageTranslation in Tensorflow)論文中。我採用了一些預訓練的模型,並製作了一個網絡互動的程序可以直接嘗試玩玩。
  • TensorFlow 資源大全中文版
    TensorFlow 是一個採用數據流圖(data flow graphs),用於數值計算的開源軟體庫。節點(Nodes)在圖中表示數學操作,圖中的線(edges)則表示在節點間相互聯繫的多維數據數組,即張量(tensor)。它靈活的架構讓你可以在多種平臺上展開計算,例如臺式計算機中的一個或多個CPU(或GPU)、伺服器、行動裝置等等。