圖像分類入門,輕鬆拿下90%準確率|教你用Keras搞Fashion-MNIST

2020-12-13 量子位

原作 Margaret Maynard-Reid

王小新 編譯自 TensorFlow的Medium

量子位 出品 | 公眾號 QbitAI

這篇教程會介紹如何用TensorFlow裡的tf.keras函數,對Fashion-MNIST數據集進行圖像分類。

只需幾行代碼,就可以定義和訓練模型,甚至不需要太多優化,在該數據集上的分類準確率能輕鬆超過90%。

在進入正題之前,我們先介紹一下上面提到的兩個名詞:

Fashion-MNIST,是去年8月底德國研究機構Zalando Research發布的一個數據集,其中訓練集包含60000個樣本,測試集包含10000個樣本,分為10類。樣本都來自日常穿著的衣褲鞋包,每一個都是28×28的灰度圖。

這個數據集致力於成為手寫數字數據集MNIST的替代品,可用作機器學習算法的基準測試,也同樣適合新手入門。

想深入了解這個數據集,推薦閱讀量子位之前的報導:

連LeCun都推薦的Fashion-MNIST數據集,是這位華人博士的成果

或者去GitHub:

https://github.com/zalandoresearch/fashion-mnist

tf.keras是用來在TensorFlow中導入Keras的函數。Keras是個容易上手且深受歡迎的深度學習高級庫,是一個獨立開源項目。在TensorFlow中,可以使用tf.keras函數來編寫Keras程序,這樣就能充分利用動態圖機制eager execution和tf.data函數。

下面可能還會遇到其他深度學習名詞,我們就不提前介紹啦。進入正題,教你用tf.keras完成Fashion-MNIST數據集的圖像分類~

運行環境

無需設置,只要使用Colab直接打開這個Jupyter Notebook連結,就能找到所有代碼。

https://colab.research.google.com/github/margaretmz/deep-learning/blob/master/fashion_mnist_keras.ipynb

數據處理

Fashion-MNIST數據集中有十類樣本,標籤分別是:

T恤 0褲子 1套頭衫 2裙子 3外套 4涼鞋 5襯衫 6運動鞋 7包 8踝靴 9數據集導入

下面是數據集導入,為後面的訓練、驗證和測試做準備。

只需一行代碼,就能用keras.datasets接口來加載fashion_mnist數據,再用另一行代碼來載入訓練集和測試集。

數據可視化

我喜歡用Jupyter Notebook來可視化,你也可以用matplotlib庫中imshow函數來可視化訓練集中的圖像。要注意,每個圖片都是大小為28x28的灰度圖。

數據歸一化

接著,進行數據歸一化,使得樣本值都處於0到1之間。

數據集劃分

這個數據集一共包含60000個訓練樣本和10000個測試樣本,我們會把訓練樣本進一步劃分為訓練集和驗證集。下面是深度學習中三種數據的作用:

訓練數據,用來訓練模型;驗證數據,用來調整超參數和評估模型;測試數據,用來衡量最優模型的性能。模型構建

下面是定義和訓練模型。

模型結構

在Keras中,有兩種模型定義方法,分別是序貫模型和功能函數。

在本教程中,我們使用序貫模型構建一個簡單CNN模型,用了兩個卷積層、兩個池化層和一個Dropout層。

要注意,第一層要定義輸入數據維度。最後一層為分類層,使用Softmax函數來分類這10種數據。

模型編譯

在訓練模型前,我們用model.compile函數來配置學習過程。在這裡,要選擇損失函數、優化器和訓練測試時的評估指標。

模型訓練

訓練模型時,Batch Size設為64,Epoch設為10。

測試性能

訓練得到的模型在測試集上的準確率超過了90%。

預測可視化

我們通過datasetmodel.predict(x_test)函數,用訓練好的模型對測試集進行預測並可視化預測結果。當標籤為紅色,則說明預測錯誤;當標籤為綠色,則說明預測正確。下圖為15個測試樣本的預測結果。

相關連結

最後,在這篇普通的入門教程基礎上,還有一些提升之路:

如果想深入了解本文使用的Google Colab,可以看這份官方介紹:

https://medium.com/tensorflow/colab-an-easy-way-to-learn-and-use-tensorflow-d74d1686e309

如果你是深度學習初學者,MNIST也應該了解一下。之前TensorFlow有一篇MNIST教程,可以拿來和本文比較一下,你就會發現,深度學習現在已經變得簡單了很多:

https://www.tensorflow.org/versions/r1.1/get_started/mnist/beginners

本文用到的是Keras裡的序貫模型,如果對功能函數感興趣,可查看這篇用Keras功能函數和TensorFlow來預測葡萄酒價格的博文:

https://medium.com/tensorflow/predicting-the-price-of-wine-with-the-keras-functional-api-and-tensorflow-a95d1c2c1b03

— 完—

相關焦點

  • 從小白到入門:用Keras進行圖像基礎分類
    【IT168 資訊】在這篇文章中,將解釋一些在keras中經常需要的常見操作。首先,如何保存模型並使用它們進行預測,從數據集中顯示圖像並從加載系統中圖像並預測其類別。    如果你還沒有這樣做,可以啟動你的IDE跟著文章,一起來操作。
  • Wandb用起來,一行Python代碼實現Keras模型可視化
    (註:Keras使得構建神經網絡變得簡單明了,這一點深得人心)這樣好用的包如何下載呢?只需運行「pip install wandb」,就可以輕鬆地安裝wandb,然後所有的Keras示例就都可以運行了。
  • 評測| CNTK在Keras上表現如何?能實現比TensorFlow更好的深度學習嗎?
    在深度學習成為主流之前,優秀的機器學習模型在測試集上達到大約 88% 的分類準確率。第一個模型方法(imdb_bidirectional_lstm.py)使用了雙向 LSTM(Bidirectional LSTM),它通過詞序列對模型進行加權,同時採用向前(forward)傳播和向後(backward)傳播的方法。
  • Keras官方出調參工具了,然而Francois說先別急著用
    然而目前發布的版本還不成熟,Keras 作者 Franois Chollet 表示:大家先別用,API 還不穩定。Keras Tuner GitHub 地址:https://github.com/keras-team/keras-tuner早在上個月舉辦的谷歌 I/O 大會上,谷歌即展示了 Keras Tuner 的功能。Keras 作者 Franois Chollet 也發推介紹了該工具。
  • Keras R語言接口正式發布,同時公開20個完整示例
    這意味著Keras 本質上適合用於構建任意深度學習模型(從記憶網絡到神經圖靈機)兼容多種運行後端,例如TensorFlow、CNTK和 Theano如果你已經很熟悉Keras了,並且想要立刻體驗最新發布的R語言接口,請點擊如下網址:https://keras.rstudio.com,這裡有超過20個完整示例,相信有你需要的東西。
  • 5分鐘入門GANS:原理解釋和keras代碼實現
    它是一種可替代的自適應變分編碼器(VAEs)學習圖像的潛在空間,以生成合成圖像。它的目的是創造逼真的人工圖像,幾乎無法與真實的圖像區分。GAN的直觀解釋生成器和鑑別器網絡:生成器網絡的目的是將隨機圖像初始化並解碼成一個合成圖像。
  • 如何在PyTorch和TensorFlow中訓練圖像分類模型
    介紹圖像分類是計算機視覺的最重要應用之一。它的應用範圍包括從自動駕駛汽車中的物體分類到醫療行業中的血細胞識別,從製造業中的缺陷物品識別到建立可以對戴口罩與否的人進行分類的系統。在所有這些行業中,圖像分類都以一種或另一種方式使用。他們是如何做到的呢?他們使用哪個框架?
  • TensorFlow(Keras)中的正則化技術及其實現(附代碼)
    import tensorflow as tffrom tensorflow import keras我們將使用的數據集是瑣碎的fashion-MNIST數據集。fashion-MNIST數據集包含70,000件服裝圖像。更具體地說,它包括60,000個訓練示例和10,000個測試示例,它們都是尺寸為28 x 28的灰度圖像,分為十類。數據集的準備工作包括通過將每個像素值除以255.0來歸一化訓練圖像和測試圖像。這會將像素值置於0到1的範圍內。
  • 圖像分類比賽中,你可以用如下方案舉一反三
    為了對使用該資料庫得到的分類結果進行標準化評估,組織者提供了基於 F1 值的對比基準,你可以通過如下連結獲得這個數據集:https://vision.eng.au.dk/plant-seedlings-dataset/。
  • 小白學CNN以及Keras的速成
    一、為何要用Keras如今在深度學習大火的時候,第三方工具也層出不窮,比較出名的有Tensorflow,Caffe,Theano,MXNet,在如此多的第三方框架中頻繁的更換無疑是很低效的,只要你能夠好好掌握其中一個框架,熟悉其原理,那麼之後因為各種要求你想要更換框架也是很容易的。那麼sherlock用的是哪個框架呢?
  • 用Keras和「直方圖均衡」為深度學習實現「圖像擴充」
    這樣做不僅能夠獲得更多的訓練數據,還能讓我們的分類器應對光照和色彩更加複雜的環境,從而使我們的分類器功能越來越強大。以下是來自imgaug的不同的圖像擴充例子:你還可以用 keras.preprocessing 函數將擴充的圖像導出到一個文件夾,以便建立一個更龐大的擴充圖像數據集。在本文中,我們將看一些更直觀、有趣的擴充圖像。你可以在Keras文件中查看所有的ImageDataGenerator參數,以及keras.preprocessing中的其他方法。
  • Python小白深度學習教程:Keras 精講(上)
    ,比如data = boston_housingdata = cifar10data = cifar100data = imbddata = reutersdata = mnistdata = fashion_mnist或者直接寫 keras.dataset
  • 人工智慧入門:用python教你實現手寫數字識別!
    今天我給大家帶來一個用機器學習的方法來實現手寫數字識別的教程,就像C語言中輸出的那一行「Hellow World」一樣,這個教程也是入門圖像識別中需要學會的第一個技能,我們將會使用tensorflow深度學習框架來實現手寫數字識別,在觀看此教程之前你需具備以下基礎:python基本語法
  • 掌握深度學習,數據不足也能進行圖像分類!
    ——吳恩達圖像分類即根據固定類別對輸入的圖像設置標籤。儘管計算機視覺過於簡單,但是它在實際中仍有廣泛的應用,而圖像分類就是其中的核心問題之一。在本文中,小芯將示範如何在數據不足的情況下應用深度學習。現已創建特製汽車和巴士分類器兩個數據集,每個數據集包含100個圖像。其中,訓練集有70個圖像,驗證集有30個。挑戰1.
  • 手把手教你用Python庫Keras做預測(附代碼)
    本文將教你如何使用Keras這個Python庫完成深度學習模型的分類與回歸預測。當你在Keras中選擇好最合適的深度學習模型,就可以用它在新的數據實例上做預測了。但是很多初學者不知道該怎樣做好這一點,我經常能看到下面這樣的問題: 「我應該如何用Keras對我的模型作出預測?」
  • 用Keras+TensorFlow,實現ImageNet數據集日常對象的識別
    博客DeepLearningSandbox作者GregChu打算通過一篇文章,教你用Keras和TensorFlow,實現對ImageNet數據集中日常物體的識別。量子位翻譯了這篇文章:你想識別什麼?看看ILSVRC競賽中包含的物體對象。如果你要研究的物體對象是該列表1001個對象中的一個,運氣真好,可以獲得大量該類別圖像數據!
  • Colab超火的Keras/TPU深度學習實戰,會點Python就能看懂的課程
    這些可能都是阻礙你搭建第一個神經網絡的原因。谷歌開發者博客的Codelabs項目上面給出了一份教程(課程連結在文末),不只是教你搭建神經網絡,還給出四個實驗案例,手把手教你如何使用keras、TPU、Colab。
  • 使用TF2與Keras實現經典GNN的開源庫——Spektral
    機器之心機器之心報導參與:Racoon這裡有一個簡單但又不失靈活性的開源 GNN 庫推薦給你。我們可以使用 Spektral 來進行網絡節點分類、預測分子特性、使用 GAN 生成新的拓撲圖、節點聚類、預測連結以及其他任意數據是使用拓撲圖來描述的任務。