深度學習綜合案例:手寫數字識別

2021-12-26 圖靈人工智慧

在前文了解了卷積神經網絡的原理,接下看如何用卷積神經網絡處理MNIST數據集,實現對手寫數字識別。

4.4.1 MNIST數據集初識

MNIST是一個初級的計算機視覺數據集,數據中每一個樣本都是0到9的手寫數字,其數據包括四個部分,包括訓練數據集,文件名為train-images.idx3-ubyte,其中包括了50000張訓練圖片;訓練標籤集,文件名為train-labels.idx1-ubyte,其中包括了50000個標籤;測試數據集,文件名為t10k-images.idx3-ubyte,其中包括了10000張測試圖片;測試標籤集,文件名t10k-labels.idx1-ubyte,包含了10000個測試標籤。

其中每張圖片的像素都是28*28大小,為了方便存儲好下載,官方對該數據集的圖片都進行了處理,每一張圖片被拉伸為(1,784)的向量,因此每一張圖片就是1行784列的數據,括號中每一個值代表一個像素。圖片數據可以以可視化的方式呈現的,如以下代碼。

1.   #MNIST_data代表存放MNIST數據的文件夾

2.   mnist=input_data.read_data_sets("MNIST_data",one_hot=True)

3.   #獲取第十張圖片

4.   image=mnist.train.images[9,:]

5.   #將圖像數據還原成28*28的解析度

6.   image=image.reshape(28,28)

7.   #列印對應的標籤

8.   print(mnist.train.labels[9])

9.   #列印圖片的數量、大小,標籤大小

10.  print(mnist.train.images.shape)

11.  print(mnist.train.labels.shape)

12.  plt.figure()

13.  plt.imshow(image)

14.  plt.show()

運行代碼後,獲取的是第十張照片,可以看到如圖4-29所示。

其另一個列印結果如圖4-30所示,為[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.],而訓練圖片共55000張,表示圖片的向量長度為784,標籤長度為10,但是讀者可能會有疑問,怎麼多出了500張,這是因為在input_data中,人為地增加了訓練集,統共合計5500張圖片。

 

4.4.2 手寫數字識別模型構建和訓練(1)下載並讀取MNIST數據集

MNIST數據集可以到官網http://yann.lecun.com/exdb/mnist/下載,由四個文件組成。其數據集也可以通過TensorFlow對MNIST數據集進行讀取和格式轉換,然後從MNIST引入input_data這個類,然後調用read_data_sets()方法。

1.   fromtensorflow.examples.tutorials.mnistimportinput_data

2.   mnist=input_data.read_data_sets('MNIST_data',one_hot=True)

以上代碼會下載MNIST_data保存在本地,如果所在目錄不存在MNIST_data文件夾,在該目錄下會自動生成這個目錄,然後把下載的文件存放在該文件夾下,需要注意的是下載的文件不需要解壓,因為input_data.read_data_sets讀取的是壓縮包。

(2)構建模型

在下載完數據集以後,要開始訓練模型。本節將會用卷積神經網絡來構建模型,為了更加容易理解以及更好地實踐,這裡選擇的卷積網絡層數為6層,其結構如圖4-31所示,構建模型的按順序來分別為輸入層、卷積層1、卷積層2、全連接層1、全連接層2、輸出層。

以下為代碼,其中x代表輸入,y_代表輸出,其中placeholder()函數是在神經網絡構建計算流圖時在模型中的佔位,換而言之就是分配一定的空間,這時還沒有把數據傳輸到模型中去,這樣可以在後續訓練時動態分配。建立了session後,運行模型時會通過feed_dict()函數將數據放入佔位符。

 

1.   #輸入輸出數據的placeholder

2.   x=tf.placeholder("float",[None,784])

3.   y_=tf.placeholder("float",[None,10])

4.    

5.   #對數據進行重新排列,形成圖像,適用於CNN的特徵提取

6.   x_image=tf.reshape(x,[-1,28,28,1])

以下代碼定義幾個函數,前兩個函數,使用truncated_normal()產生隨機數,使用Variable()定義學習參數。一個是定義隨機生成權重的函數,另一個是定義生成偏置的函數。後兩個函數,一個是定義卷積層,另一個定義池化層。為後面創建神經網絡模型做鋪墊。

1.   defweight_variable(shape):

2.   initial=tf.truncated_normal(shape,stddev=0.1)

3.   returntf.Variable(initial)

4.    

5.   defbias_variable(shape):

6.   initial=tf.constant(0.1,shape=shape)

7.   returntf.Variable(initial)

8.    

9.   defconv2d(x,W):

10.  returntf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')

11.   

12.  defmax_pool_2x2(x):

13.  returntf.nn.max_pool(x,ksize=[1,2,2,1],

14.  strides=[1,2,2,1],padding='SAME')

定義完函數之後,創建卷積層和池化層的結構。這裡的卷積核為5*5,然後給學習參數賦值,接著創建卷積層,進行卷積操作,並使用ReLU激活函數進行激活,最後進行池化操作。

1.   #卷積層1

2.   #卷積核為5*5

3.   filter1=[5,5,1,32]

4.   W_conv1=weight_variable(filter1)

5.   b_conv1=bias_variable([32])

6.   #進行ReLU操作,輸出大小為28*28*32

7.   h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)

8.   #池化操作,輸出大小為14*14*32

9.   h_pool1=max_pool_2x2(h_conv1)

10.   

11.  #卷積層2

12.  #卷積核為5*5

13.  filter2=[5,5,32,64]

14.  W_conv2=weight_variable(filter)

15.  b_conv2=bias_variable([64])

16.  #進行ReLU操作,輸出大小為14*14*64

17.  h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)

18.  #Pooling操作,輸出大小為7*7*64

19.  h_pool2=max_pool_2x2(h_conv2)

然後是定義全連接層的結構,如以下代碼所示。

1.   #全連接層1

2.   W_fc1=weight_variable([7*7*64,1024])

3.   b_fc1=bias_variable([1024])

4.   #輸入數據變換

5.   h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])

6.   #進行全連接操作

7.   h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)

8.    

9.   #dropout的比例

10.  keep_prob=tf.placeholder("float")

11.  #防止過擬合,dropout

12.  h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)

13.   

14.  #softmax

15.  #全連接層2

16.  W_fc2=weight_variable([1024,10])

17.  b_fc2=bias_variable([10])

(3)訓練模型

最後開始訓練模型,訓練測試後,還要列印卷積神經網絡識別測試集的準確率。

1.   #預測

2.   y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)

3.    

4.   #計算Loss

5.   cross_entropy=-tf.reduce_sum(y_*tf.log(y_conv))

6.   #神經網絡訓練

7.   train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

8.   correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))

9.   accuracy=tf.reduce_mean(tf.cast(correct_prediction,"double"))

10.  sess.run(tf.initialize_all_variables())

11.  foriinrange(20000):

12.  batch=mnist.train.next_batch(50)

13.  ifi%100==0:

14.  train_accuracy=accuracy.eval(feed_dict={

15.  x:batch[0],y_:batch[1],keep_prob:1.0})

16.  print("step%d,trainingaccuracy%f"%(i,train_accuracy))

17.  train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})

18.   

19.  print("testaccuracy%f"%accuracy.eval(feed_dict={

20.  x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

運行代碼,得到的運行結果如圖4-32所示。

(4)保存模型

首先需要建立一個saver和一個保存路徑,然後通過調用save方法,自動將session中的參數保存起來。

1.   saver=tf.train.Saver()#定義saver

上面代碼是用來建立saver的,該代碼需要在session運行之前。接著調用save方法保存所訓練的模型,如以下代碼所示。

1.   saver.save(sess,'G:/**/**/model.ckpt')#模型儲存位置

保存模型後會在對應路徑生成四個文件,如圖4-33所示。

(5)測試模型

模型存儲後,接下來進行測試,以下代碼為使用保存的模型來測試識別手寫字體,restore方法中的路徑為保存模型時所使用的路徑。

1.  saver.restore(sess,"G:/**/**/model.ckpt")#使用模型,參數和之前的代碼保持一致

2.    

3.   prediction=tf.argmax(y_conv,1)

4.   predint=prediction.eval(feed_dict={x:[result],keep_prob:1.0},session=sess)

5.   print('識別結果:')

6.   print(predint[0])

這裡讀取的圖,是作者通過AdobePhotoshopCS6製作的,當然讀者也可以從MNIST數據集中讀取,其中讀取的圖結果如圖4-34所示。

運行相關的代碼,得到如圖4-35所示,得到識別的結果為7。

內容選自國防科技大學鄧勁生教授的《實戰深度學習—原理、框架及應用》

掃描如下二維碼購買:

鄧勁生 國防科技大學前沿交叉學科學院教授,當前主要從事大數據、人工智慧、情報科學等方面的研究。主持國家和省部級科研項目十餘項,獲得省部級科技進步獎、教學成果獎多項,著譯圖書十餘本,發表論文數十篇。

1. 知識覆蓋全面。系統介紹深度學習的主要原理、方法和實踐,涵蓋深度學習領域的基礎知識、主要框架和常見場景,從宏觀到微觀全方位把握知識體系。

2. 案例即學即用。從主流工具的使用、神經網絡的搭建,到常用的深度學習模型的演練,再到目標檢測、圖像分割、人臉識別、文本自動生成等熱門應用。

3. 教學脈絡清晰。針對新興學科的認知特點和導入策略,學習曲線呈循序漸進由淺入深,材料組織上通俗易懂實戰性強,以任務驅動教學法牽引學習過程。

4. 配套資源豐富。配有課程課件和上機操作說明,將深度學習原理融入具體的實際開發工程,指導切實掌握要領,並培養獨立探索、開拓進取的能力。

往期精彩必讀文章(單擊就可查看):

1.三位深度學習之父共獲2019年圖靈獎,學術人生令人讚嘆!!!

2.人工智慧的現狀與未來

3.國防科技大學教授:殷建平——計算機科學理論的過去、現在與未來

4.圖靈獎得主Hamming的22年前經典演講:如何做研究,才能不被歷史遺忘

5.當這位70歲的Hinton老人還在努力推翻自己積累了30年的學術成果時,我才知道什麼叫做生命力(附Capsule最全解析)

6.科學正在證明,科學並不科學

7.沉痛!中國半導體 」芯酸「史!

8.數學的深淵

9. 計算的極限(續)

10.計算的極限

11.高考大數據:哪個省才是高考地獄模式?結論和想像不太一樣

12.統計了最近10年的高考分數線,大數據分析告訴你哪些大學最難考?誰是京滬之後的教育第三城?

相關焦點

  • 構建深度學習3-手寫數字識別
    我們堅信,只有動手創建一個模型,才能真正掌握一個模型,本系列全是用Python代碼創建深度學習,不直接用現成的框架。
  • 深度學習筆記13:Tensorflow實戰之手寫mnist手寫數字識別
    作者:魯偉一個數據科學踐行者的學習日記。
  • TF2.0深度學習實戰(一):分類問題之手寫數字識別
    本著學習的心,希望和大家相互交流,一起進步!手寫數字識別是一個非常經典的圖像分類任務,經常被作為深度學習入門的第一個指導案例。相當於我們學程式語言時,編寫的第一個程序「Hello World !」。不一樣的是,入門深度學習,需要有一定量的理論基礎。手寫數字識別是基於MNIST數據集的一個圖像分類任務,目的是通過搭建深度神經網絡,實現對手寫數字的識別(分類)。
  • 菜鳥深度學習教程:識別手寫數字
    引言歡迎來到快速學習"深度學習"系列教程第一篇,其中既包括基本原理,又包括複雜細節.MNIST(手寫數字分類)和CIFAR-10(飛機,汽車,鳥,貓,鹿,狗,蛙,馬,船,卡車十種不同小圖片分類)是已制定的機器學習標準,本文會儘量達到這些標準.
  • 深度學習系列:PaddlePaddle之手寫數字識別
    上周在搜索關於深度學習分布式運行方式的資料時,無意間搜到了paddlepaddle,發現這個框架的分布式訓練方案做的還挺不錯的,想跟大家分享一下。不過呢,這塊內容太複雜了,所以就簡單的介紹一下paddlepaddle的第一個「hello word」程序----mnist手寫數字識別。下一次再介紹用PaddlePaddle做分布式訓練的方案。
  • 35丨深度學習(下):如何用Keras搭建深度學習網絡做手寫數字識別?
    通過上節課的講解,我們已經對神經網絡和深度學習有了基本的了解。這節課我就用Keras這個深度學習框架做一個識別手寫數字的練習。
  • PyTorch深度學習框架入門——使用PyTorch實現手寫數字識別
    3、如何優化和訓練我們搭建好的模型 註:本案例使用的PyTorch為0.4版本簡介Pytorch是目前非常流行的深度學習框架,因為它具備了Python的特性所以極易上手和使用,同時又兼具了NumPy的特性,因此在性能上也並不遜於任何一款深度學習框架。
  • 手把手教你學Python之手寫數字識別
    問題描述:手寫數字識別是指給定一系列的手寫數字圖片以及對應的數字標籤,構建模型進行學習,目標是對於一張新的手寫數字圖片能夠自動識別出對應的數字
  • Python大數據綜合應用 :零基礎入門機器學習、深度學習算法原理與案例
    機器學習、深度學習算法原理與案例實現暨Python大數據綜合應用高級研修班一、課程簡介
  • 機器學習、深度學習算法原理與案例實踐暨Python大數據綜合應用...
    原標題:機器學習、深度學習算法原理與案例實踐暨Python大數據綜合應用高級研修班通信和信息技術創新人才培養工程項目辦公室 通人辦〔2018〕 第5號 機器學習、深度學習算法原理與案例實踐暨Python
  • 使用AI算法進行手寫數字識別
    因此,人工智慧、機器學習、深度學習的關係如下圖所示。至今已有數種深度學習模型,如深度神經網絡、卷積神經網絡和深度置信網絡和遞歸神經網絡已被應用在計算機視覺、語音識別、自然語言處理、音頻識別與生物信息學等領域並獲取了極好的效果。目前,業內也已經產生了多種優秀的深度學習框架,例如TensorFlow、PyTorch、Caffe、Mxnet等等。
  • 手寫公式識別 :基於深度學習的端到端方法
    該論文是2017年發表在ICDAR上的文章[1]的升級版,主要解決了在線手寫數學公式的識別問題。該論文中介紹的方法獲得了國際最大在線手寫數學公式比賽CROHME2019的冠軍,且是在未使用額外數據的情況下超過了有大量額外數據的國際企業參賽隊伍,如MyScript,Wiris,MathType等,突出了該算法較傳統數學公式識別算法的優勢。
  • 附完整代碼:【AI實戰】訓練第一個AI模型「MNIST手寫數字識別模型」
    MNIST是一個經典的手寫數字數據集,來自美國國家標準與技術研究所,由不同人手寫的0至9的數字構成,由60000個訓練樣本集和10000個測試樣本集構成,每個樣本的尺寸為28x28,以二進位格式存儲,如下圖所示:
  • 手寫數字識別
    這是一個數字識別任務。因此,預測結果是將圖片中的手寫數字識別為有10個數字(0到9)。使用預測準確度來報告結果,優異的結果能夠達到99%以上的預測準確度。目前,大型卷積神經網絡可以實現約0.2%的預測誤差的高準確度。
  • Python人工智慧 | 七.TensorFlow實現分類學習及MNIST手寫體識別案例
    本篇文章將通過TensorFlow實現分類學習,以MNIST數字圖片為例進行講解。本文主要結合作者之前的博客、AI經驗和"莫煩大神"的視頻介紹,後面隨著深入會講解更多的Python人工智慧案例及應用。基礎性文章,希望對您有所幫助,如果文章中存在錯誤或不足之處,還請海涵。作者作為人工智慧的菜鳥,希望大家能與我在這一筆一划的博客中成長起來,共勉。
  • 「人工智慧師資班」(Python機器學習,圖像識別與深度學習,深度學習與NLP,知識圖譜,強化學習)
    本次培訓分為Python機器學習,圖像識別與深度學習,深度學習與NLP,知識圖譜和強化學習五大專題。本次培訓由權威專家主講,提供實驗環境及實驗數據,並提供配套資料,通過剖析工程案例展現機器學習、深度學習落地全過程。培訓暫定2021年1月5日開始,每個專題6天左右,一共28天,直播集訓。本次培訓由淺入深,面向0基礎、不懂機器學習、不具備任何Python基礎的老師和同學。
  • python+flask搭建CNN在線識別手寫中文網站
    ,並經過圖片裁剪處理之後傳入CNN手寫中文識別的模型中進行識別,最後通過PIL將識別結果生成圖片,最後異步回傳給web端進行識別結果展示。這裡主要對常見的3755個漢字進行識別。代碼獲取:關注微信公眾號 datayx  然後回復 手寫識別 即可獲取。
  • 小白學機器學習|如何識別5000多個手寫數字
    今天我接著來分享一篇好玩的機器學習例子,我們如何識別手寫數字。怎麼玩呢:了解這個5000多個手寫數子清洗數據並用機器學習算法訓練讓機器來識別數字超參數調整提高準確率1.介紹一下這個數據集這個數據集也是非常有名的,是入門的經典數據集,而且時間也蠻久的了!
  • CRNN+CTCLoss中文手寫漢字識別
    而OCR則還停留在對列印字體的識別上。為什麼不能把手寫輸入法的算法用在OCR上呢。手寫識別和OCR是有一定區別的。1. 手寫識別通常包涵更多的信息(這裡指的是在線識別,我接觸到的高識別率手寫識別都是在線識別), 如筆畫順序, 連筆等。 這些細節看似簡單, 卻在無形之中給識別提供了不少的特徵, 有助於識別率的提升。2. 手寫識別的樣本預處理比較容易。
  • 如何實踐AI深度學習的十大驚豔案例
    你可能已經聽說過深度學習並認為它是駭人的數據科學裡的一個領域。怎麼可能讓機器像人類一樣學習呢?再者,對於某些人而言,更為駭人的是,我們為什麼要讓機器展現出類人的行為?這裡,請看深度學習在實際應用中的十大案例,以便將其潛能視覺化。  What is deep learning?  深度學習是什麼?