tensorflow的數據輸入

2021-02-21 LeadAI OpenLab

tensorflow有兩種數據輸入方法,比較簡單的一種是使用feed_dict,這種方法在畫graph的時候使用placeholder來站位,在真正run的時候通過feed字典把真實的輸入傳進去。比較簡單不再介紹。

比較惱火的是第二種方法,直接從文件中讀取數據(其實第一種也可以我們自己從文件中讀出來之後使用feed_dict傳進去,但方法二tf提供很完善的一套類和函數形成一個類似pipeline一樣的讀取線):

1.使用tf.train.string_input_producer函數把我們需要的全部文件打包為一個tf內部的queue類型,之後tf開文件就從這個queue中取目錄了,要注意一點的是這個函數的shuffle參數默認是True,也就是你傳給他文件順序是1234,但是到時候讀就不一定了,我一開始每次跑訓練第一次迭代的樣本都不一樣,還納悶了好久,就是這個原因。

files_in = ["./data/data_batch%d.bin" % i for i in range(1, 6)]

files = tf.train.string_input_producer(files_in)

2.搞一個reader,不同reader對應不同的文件結構,比如度bin文件tf.FixedLengthRecordReader就比較好,因為每次讀等長的一段數據。如果要讀什麼別的結構也有相應的reader。

reader = tf.FixedLengthRecordReader(record_bytes=1+32*32*3)

3.用reader的read方法,這個方法需要一個IO類型的參數,就是我們上邊string_input_producer輸出的那個queue了,reader從這個queue中取一個文件目錄,然後打開它經行一次讀取,reader的返回是一個tensor(這一點很重要,我們現在寫的這些讀取代碼並不是真的在讀數據,還是在畫graph,和定義神經網絡是一樣的,這時候的操作在run之前都不會執行,這個返回的tensor也沒有值,他僅僅代表graph中的一個結點)。

key, value = reader.read(files)

4.對這個tensor做些數據與處理,比如CIFAR1-10中label和image數據是糅在一起的,這裡用slice把他們切開,切成兩個tensor(注意這個兩個tensor是對應的,一個image對一個label,對叉了後便訓練就完了),然後對image的tensor做data augmentation。

data = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.slice(data, [0], [1]), tf.int64)
raw_image = tf.reshape(tf.slice(data, [1], [32*32*3]), [3, 32, 32])
image = tf.cast(tf.transpose(raw_image, [1, 2, 0]), tf.float32)

lr_image = tf.image.random_flip_left_right(image)
br_image = tf.image.random_brightness(lr_image, max_delta=63)
rc_image = tf.image.random_contrast(br_image, lower=0.2, upper=1.8)

std_image = tf.image.per_image_standardization(rc_image)

5.這時候可以發現,這個tensor代表的是一個樣本([高寬管道]),但是訓練網絡的時候的輸入一般都是一推樣本([樣本數高寬*管道]),我們就要用tf.train.batch或者tf.train.shuffle_batch這個函數把一個一個小樣本的tensor打包成一個高一維度的樣本batch,這些函數的輸入是單個樣本,輸出就是4D的樣本batch了,其內部原理似乎是創建了一個queue,然後不斷調用你的單樣本tensor獲得樣本,直到queue裡邊有足夠的樣本,然後一次返回一堆樣本,組成樣本batch。

images, labels = tf.train.batch([std_image, label],
                          batch_size=100,
                          num_threads=16,

                          capacity=int(50000* 0.4 + 3 * batch_size))

5.事實上一直到上一部的images這個tensor,都還沒有真實的數據在裡邊,我們必須用Session run一下這個4D的tensor,才會真的有數據出來。這個原理就和我們定義好的神經網絡run一下出結果一樣,你一run這個4D tensor,他就會順著自己的operator找自己依賴的其他tensor,一路最後找到最開始reader那裡。

除了上邊講的原理,其中還要注意幾點:

1.tf.train.start_queue_runners(sess=sess)這一步一定要運行,且其位置要在定義好讀取graph之後,在真正run之前,其作用是把queue裡邊的內容初始化,不跑這句一開始string_input_producer那裡就沒用,整個讀取流水線都沒用了。

training_images = tf.train.batch(XXXXXXXXXXXXXXX)
tf.train.start_queue_runners(sess=self.sess)

real_images = sess.run(training_images)

2.image和label一定要一起run,要記清楚我們的image和label是在一張graph裡邊的,跑一次那個graph,這兩個tensor都會出結果,且同一次跑出來的image和label才是對應的,如果你run兩次,第一次為了拿image第二次為了拿label,那整個就叉了,因為第一次跑出來第0到100號image和0到100號label,第二次跑出來第100到200的image和第100到200的label,你拿到了0100的image和100200的label,整個樣本分類全不對,最後網絡肯定跑不出結果。

training_images, training_labels = read_image()
tf.train.start_queue_runners(sess=self.sess)

real_images = sess.run(training_images) # 讀出來是真的圖片,但是和label對不上

real_labels = sess.run(training_labels) # 讀出來是真的label,但是和image對不上

# 正確調用方法,通過跑一次graph,將成套的label和image讀出來

real_images, real_labels = sess.run([training_images, training_labels])

因為不懂這個道理的up主跑了一下午正確率還是10%。。。。(10類別分類10%正確率不就是亂猜嗎)

原文:【tensorflow的數據輸入】(https://goo.gl/Ls2N7s)

原文連結:https://www.jianshu.com/p/7e537cd96c6f


查閱更為簡潔方便的分類文章以及最新的課程、產品信息,請移步至全新呈現的「LeadAI學院官網」:

www.leadai.org


請關注人工智慧LeadAI公眾號,查看更多專業文章

大家都在看

相關焦點

  • tensorflow之並行讀入數據
    num_epochs表示列表遍歷的次數,主要是由於有時候訓練模型需要反覆的遍歷數據集便於更新模型參數,默認情況下是None(循環遍歷)。shuffle表示是否隨機遍歷,默認情況下是true,表示數據會隨機輸入隊列,當想順序讀入數據時shuffle設置為false。至於其他的capacity表示列表的容量,shared_name表示共享時的名字。
  • TensorFlow極速入門
    首先是一些基礎概念,包括計算圖,graph 與 session,基礎數據結構,Variable,placeholder 與 feed_dict 以及使用它們時需要注意的點。最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、tensorflow是什麼?te
  • tensorflow極速入門
    首先是一些基礎概念,包括計算圖,graph 與 session,基礎數據結構,Variable,placeholder 與 feed_dict 以及使用它們時需要注意的點。最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、 tensorflow是什麼?
  • TensorFlow應用實戰 | TensorFlow基礎知識
    節點的輸入與輸出都是Tensor張量。邊和節點共同構成了Graph 也就是數據流圖。數據流圖會被放進session會話中進行運行。會話可以在不同的設備上去運行,比如cpu和GPU。數據流圖:Tensor (張量) 邊裡流動的數據Operation(操作)
  • TensorFlow 安裝詳解
    那什麼是數據流圖(Data Flow Graphs)呢?數據流圖用「結點」(nodes)和「線」(edges)的有向圖來描述數學計算。「節點」 一般用來表示施加的數學操作,但也可以表示數據輸入(feed in)的起點/輸出(push out)的終點,或者是讀取/寫入持久變量(persistent variable)的終點。
  • 數據載入過慢?這裡有一份TensorFlow加速指南
    根據以往經驗,在TensorFlow中,feed-dict函數可能是最慢的一種數據載入方法,儘量少用。把數據輸入到模型的最佳方法是使用輸入流水線(input pipeline),來確保GPU無須等待新數據輸入。幸好,TensorFlow有一個內置接口,叫做Dataset。這個接口是為了更容易地實現數據輸入,在1.3版本已被提出。
  • 乾貨| TensorFlow的數據流圖
    TensorFlow是一個開源軟體庫,用於使用數據流圖進行數值計算。圖中的節點表示數學運算,而圖邊表示在它們之間傳遞的多維數據數組(張量,tensor)。該庫包括各種功能,使你能夠實現和探索用於圖像和文本處理的前沿卷積神經網絡(CNN)和循環神經網絡(RNN)架構。
  • TensorFlow學習
    ,學習內容為tensorflow!3.Placeholder 傳入值寫在前面'''Tensorflow 如果想要從外部傳入data, 那就需要用到 tf.placeholder(),然後以這種形式傳輸數據stat.run(***,feed_dict(key:value,key1:value))'''定義兩個placeholder
  • tensorflow安裝教程
    tensorflow是谷歌開源的人工智慧庫,有最完善的生態支持。是進行人工智慧領域開發和科研的必備工具。本文在windows10下,藉助anaconda建立成功進入tf2.0環境conda activate tf2.0安裝tensorflow2.0 pip install tensorflow==2.0.0-beta1下載的東西挺多,多等一會,最後成功如下命令行運行python,執行import
  • Windows配置tensorflow開發環境
    通過這篇文章,希望能夠幫助大家更加順利地配置tensorflow的開發環境。4、測試tensorflow是否可用可通過兩種方式測試tensorflow:(1)通過Anaconda Prompt窗口:首先,激活tensorflow環境。
  • TensorFlow安裝與卷積模型
    使用pip安裝1)下載安裝Python 2)打開windows的命令行窗口,安裝CPU版本pip installtensorflow安裝GPU版本Pip install tensorflow-gpu之後驗證是否安裝了 TensorFlow 可以嘗試一下代碼>>> importtensorflow
  • TensorFlow 的簡單例子 | Linux 中國
    你往其中輸入一組數據樣本用以訓練,接著給出另一組數據樣本基於訓練的數據而預測結果。這就是人工智慧了!◈ 支持 GPU 。你可以使用 GPU(圖像處理單元)替代 CPU 以更快的運算。TensorFlow 有兩個版本: CPU 版本和 GPU 版本。開始寫例子前,需要了解一些基本知識。什麼是張量?
  • 令人困惑的TensorFlow!
    如果你仔細閱讀,你甚至可能已經發現了這個頁面(https://www.tensorflow.org/programmers_guide/graphs),該頁面涵蓋了我將以更準確和技術化的方式去解釋的內容。本節是一篇高級攻略,把握重要的直覺概念,同時忽略一些技術細節。那麼:什麼是計算圖?它本質上是一個全局數據結構:是一個有向圖,用於捕獲有關如何計算的指令。
  • 【官方教程】TensorFlow在圖像識別中的應用
    tensorflow::ops::Div( tensorflow::ops::Sub( resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()), tensorflow::ops::Const({input_std}, b.opts()), b.opts().WithName(
  • TensorFlow發布JavaScript開發者的機器學習框架TensorFlow.js
    發布新的 TensorFlow 官方博客(http://blog.tensorflow.org/)與 TensorFlow YouTube 頻道;2. 面向 JavaScript 開發者的全新機器學習框架 TensorFlow.js;3.
  • tensorflow機器學習模型的跨平臺上線
    tensorflow模型的跨平臺上線的備選方案tensorflow模型的跨平臺上線的備選方案一般有三種:即PMML方式,tensorflow serving方式,以及跨語言API方式。PMML方式的主要思路在上一篇以及講過。
  • 玩轉TensorFlow?你需要知道這30功能
    3)TFX 數據驗證如何自動確保用於重新訓練模型的數據與最初用於訓練模型的數據具有相同的格式、源、命名約定等。Transform 不僅可以對單個樣本進行這些操作,還能批處理數據。網址是:https://www.tensorflow.org/tfx/transform/?hl=zh-cn
  • TensorFlow圖像分類教程
    有一個工具將隨機抓取一批圖像,使用模型猜測每種花的類型,測試猜測的準確性,重複執行,直到使用了大部分訓練數據為止。最後一批未被使用的圖像用於計算該訓練模型的準確性。分類:在新的圖像上使用模型。例如,輸入:IMG207.JPG,輸出:雛菊。這個步驟快速簡單,且衡量的代價小。
  • 談談Tensorflow的dropout
    許多文獻都對dropout有過描述,但解釋的含糊不清,這裡呢,我也不打算解釋清楚,只是通過tensorflow來看一看dropout的運行機理。文章分兩部分,第一部分介紹tensorflow中的dropout函數,第二部分是我的思考。
  • TensorFlow通過TFRecord高效讀寫數據
    利用tensorflow提供的tfrecord數據存儲格式工具,我們可以將我們已經進行過處理的數據保存起來,以便我們下次更高效地讀取,略過數據處理的過程,提高效率。具體的步驟大概分為以下幾步:將數據轉化為tf.train.Feature,然後存於字典;接著,將其轉化為tf.train.example,然後進行序列化,寫入tf.python_io.TFRecordWriter,到這裡就完成了寫入的操作;