如何對機器學習做單元測試

2021-03-02 AI公園

在過去的一年裡,我把大部分的工作時間都花在了深度學習研究和實習上。那一年,我犯了很多大錯誤,這些錯誤不僅幫助我了解了ML,還幫助我了解了如何正確而穩健地設計這些系統。我在谷歌Brain學到的一個主要原則是,單元測試可以決定算法的成敗,可以為你節省數周的調試和訓練時間。

然而,在如何為神經網絡代碼編寫單元測試方面,似乎沒有一個可靠的在線教程。即使是像OpenAI這樣的地方,也只是通過盯著他們代碼的每一行,並試著思考為什麼它會導致bug來發現bug的。顯然,我們大多數人都沒有這樣的時間,所以希望本教程能夠幫助你開始理智地測試你的系統!

讓我們從一個簡單的例子開始。試著找出這段代碼中的錯誤。

def make_convnet(input_image):    net = slim.conv2d(input_image, 32, [11, 11], scope="conv1_11x11")    net = slim.conv2d(input_image, 64, [5, 5], scope="conv2_5x5")    net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')    net = slim.conv2d(input_image, 64, [5, 5], scope="conv3_5x5")    net = slim.conv2d(input_image, 128, [3, 3], scope="conv4_3x3")    net = slim.max_pool2d(net, [2, 2], scope='pool2')    net = slim.conv2d(input_image, 128, [3, 3], scope="conv5_3x3")    net = slim.max_pool2d(net, [2, 2], scope='pool3')    net = slim.conv2d(input_image, 32, [1, 1], scope="conv6_1x1")    return net

你看到了嗎?網絡實際上並沒有堆積起來。在編寫這段代碼時,我複製並粘貼了slim.conv2d(…)行,並且只修改了內核大小,而沒有修改實際的輸入。

我很不好意思地說,這件事在一周前就發生在我身上了……但這是很重要的一課!由於一些原因,這些bug很難捕獲。

幾個小時後,這些值就會收斂,但結果卻非常糟糕,讓你摸不著頭腦,不知道需要修復什麼。

當你唯一的反饋是最終的驗證錯誤時,你惟一需要搜索的地方就是你的整個網絡體系結構。不用說,你需要一個更好的系統。

那麼,在我們進行完整的多日訓練之前,我們如何真正抓住這個機會呢?關於這個最容易注意到的是層的值實際上不會到達函數外的任何其他張量。假設我們有某種類型的損失和一個優化器,這些張量永遠不會得到優化,所以它們總是有它們的默認值。

我們可以通過簡單的訓練步驟和前後對比來檢測它。

def test_convnet():  image = tf.placeholder(tf.float32, (None, 100, 100, 3)  model = Model(image)  sess = tf.Session()  sess.run(tf.global_variables_initializer())  before = sess.run(tf.trainable_variables())  _ = sess.run(model.train, feed_dict={               image: np.ones((1, 100, 100, 3)),               })  after = sess.run(tf.trainable_variables())  for b, a, n in zip(before, after):      # Make sure something changed.      assert (b != a).any()

在不到15行代碼中,我們現在驗證了至少我們創建的所有變量都得到了訓練。

這個測試超級簡單,超級有用。假設我們修復了前面的問題,現在我們要開始添加一些批歸一化。看看你能否發現這個bug。

  def make_convnet(image_input):        # Try to normalize the input before convoluting        net = slim.batch_norm(image_input)        net = slim.conv2d(net, 32, [11, 11], scope="conv1_11x11")        net = slim.conv2d(net, 64, [5, 5], scope="conv2_5x5")        net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')        net = slim.conv2d(net, 64, [5, 5], scope="conv3_5x5")        net = slim.conv2d(net, 128, [3, 3], scope="conv4_3x3")        net = slim.max_pool2d(net, [2, 2], scope='pool2')        net = slim.conv2d(net, 128, [3, 3], scope="conv5_3x3")        net = slim.max_pool2d(net, [2, 2], scope='pool3')        net = slim.conv2d(net, 32, [1, 1], scope="conv6_1x1")        return net

你看到了嗎?這個非常微妙。您可以看到,在tensorflow batch_norm中,is_training的默認值是False,所以添加這行代碼並不能使你在訓練期間的輸入正常化!值得慶幸的是,我們編寫的最後一個單元測試將立即發現這個問題!(我知道,因為這是三天前發生在我身上的事。)

再看一個例子。這實際上來自我一天看到的一篇文章(https://www.reddit.com/r/MachineLearning/comments/6qyvvg/p_tensorflow_response_is_making_no_sense/)。我不會講太多細節,但是基本上這個人想要創建一個輸出範圍為(0,1)的分類器。

class Model:  def __init__(self, input, labels):    """Classifier model    Args:      input: Input tensor of size (None, input_dims)      label: Label tensor of size (None, 1).         Should be of type tf.int32.    """    prediction = self.make_network(input)    # Prediction size is (None, 1).    self.loss = tf.nn.softmax_cross_entropy_with_logits(        logits=prediction, labels=labels)    self.train_op = tf.train.AdamOptimizer().minimize(self.loss)

注意到這個錯誤嗎?這是真的很難提前發現,並可能導致超級混亂的結果。基本上,這裡發生的是預測只有一個輸出,當你將softmax交叉熵應用到它上時,它的損失總是0。

一個簡單的測試方法是確保損失不為0。

def test_loss():  in_tensor = tf.placeholder(tf.float32, (None, 3))  labels = tf.placeholder(tf.int32, None, 1))  model = Model(in_tensor, labels)  sess = tf.Session()  loss = sess.run(model.loss, feed_dict={    in_tensor:np.ones(1, 3),    labels:[[1]]  })  assert loss != 0

另一個很好的測試與我們的第一個測試類似,但是是反向的。你可以確保只有你想訓練的變量得到了訓練。以GAN為例。出現的一個常見錯誤是在進行優化時不小心忘記設置要訓練的變量。這樣的代碼經常發生。

class GAN:  def __init__(self, z_vector, true_images):    # Pretend these are implemented.    with tf.variable_scope("gen"):      self.make_geneator(z_vector)    with tf.variable_scope("des"):      self.make_descriminator(true_images)    opt = tf.AdamOptimizer()    train_descrim = opt.minimize(self.descrim_loss)    train_gen = opt.minimize(self.gen_loss)

這裡最大的問題是優化器有一個默認設置來優化所有變量。在像GANs這樣的高級架構中,這是對你所有訓練時間的死刑判決。但是,你可以通過編寫這樣的測試來輕鬆地發現這些錯誤:

def test_gen_training():  model = Model  sess = tf.Session()  gen_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='gen')  des_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='des')  before_gen = sess.run(gen_vars)  before_des = sess.run(des_vars)  # Train the generator.  sess.run(model.train_gen)  after_gen = sess.run(gen_vars)  after_des = sess.run(des_vars)  # Make sure the generator variables changed.  for b,a in zip(before_gen, after_gen):    assert (a != b).any()  # Make sure descriminator did NOT change.  for b,a in zip(before_des, after_des):    assert (a == b).all()

可以為鑑別器編寫一個非常類似的測試。同樣的測試也可以用於許多強化學習算法。許多行為-批評模型有單獨的網絡,需要根據不同的損失進行優化。

下面是一些我推薦你進行測試的模式。

讓測試具有確定性。如果一個測試以一種奇怪的方式失敗,卻永遠無法重現這個錯誤,那就太糟糕了。如果你真的想要隨機輸入,確保使用種子隨機數,這樣你就可以輕鬆地重新運行測試。

保持測試簡短。不要使用單元測試來訓練收斂性並檢查驗證集。這樣做是在浪費自己的時間。

總之,這些黑箱算法仍然有很多方法需要測試!花一個小時寫一個測試可以節省你幾天的重新運行訓練模型,並可以大大提高你的研究效率。因為我們的實現有缺陷而不得不放棄完美的想法,這不是很糟糕嗎?

這個列表顯然不全面,但它是一個堅實的開始!

相關焦點

  • 高手如何給 Spring MVC 做單元測試?
    這裡僅僅做一個入門,對返回視圖和返回 Json 數據的方法進行測試演示,不會把所有的方法都介紹到,具體文檔詳見連結:Mock Test,本章節主要講解以下兩部分內容:1、Mock 測試簡介2、測試用例演示二、Mock 測試簡介1
  • 機器如何學習?一個老師機器人負責數千名學生,測試100萬道題
    還是說科學家真的創造了類似人腦的組件並令其在機器內部運行呢?機器會佔領世界嗎?別擔心,還沒到那個時候。本文就將為你祛魅機器學習,科學家是如何教計算機完成人類的任務,甚至在某些情況下比人類做得還好。筆者將用最直白的話給大家解釋這個問題,努力讓任何領域的人都可以理解。為什麼需要學習?
  • 如何寫好單元測試?
    二方三方接口可能存在日常沒法用,只能上預發/正式的情況,上預發測低效如何處理?本文分享三個單元測試神器及相關經驗總結。文末福利:《Linux運維學習路線》技術公開課。Q1:好代碼應具備可讀性,可測試性,可擴展性等等,那麼如何寫出好代碼?
  • 如何構建基於機器學習的入侵檢測系統
    因此機器學習在信息安全也正扮演著日益重要的角色。通過本文 peerlyst 的讀者不但能探索機器學習技術的基本原理還能進一步親手體驗如何從零開始,通過前沿技術,代碼庫和能公開獲取的數據集來構建一個真實的入侵檢測系統。我們也提供了有用的附件來幫助你更方便的選擇機器學習的算法。
  • R做機器學習簡易教程
    本文介紹如何利用R語言做一個完整的機器學習項目。一個小項目,端到端,遵循機器學習的實施的工作流,系統地完成項目。
  • 軟體測試學習教程:單元測試之UnitTest測試框架
    單元測試的概念單元測試(unit testing),是指對軟體中的最小可測試單元進行檢查和驗證。對於單元測試中單元的含義,要根據實際情況去判定其具體含義。一個單元可能是功能模塊、類、方法(函數)等。單元測試工具不同的程式語言都有比較成熟的單元測試框架,語法規則有些差別,其核心思想都是相通的。
  • 從「機器兒童」到「機器學習」:「學習」的概念是如何變化的?
    人類智能的一個重要特徵便是具備學習能力。人類大腦令人驚嘆的學習能力使得牙牙學語的嬰兒能夠成長為學識淵博、談吐自如的成年人。對於人類而言,學習是與生俱來的能力,這種能力的普遍存在使得我們忽略了它的奇異與珍貴之處。而對於人工智慧研究而言,如何令機器具備這種人類世界中最普遍的能力,是非常具備挑戰性的研究方向。
  • 系統測試:單元測試相關知識筆記
    一、單元測試概念單元測試也成為模塊測試,在模塊編寫完成且五編譯錯誤後就可以進行。單元測試側重模塊中的內部處理邏輯和數據結構。如果採用機器測試,一般用白盒測試法。二、單元測試檢查模塊特徵1、模塊接口模塊接口保證了測試模塊數據流可以正確地流入、流出。主要檢查一下要點:測試模塊輸入參數和形式參數在個數、屬性、單位是否一致。
  • 聊聊單元測試
    還有一種情況是,寫代碼的時候並沒有考慮這代碼要怎麼測,因此寫完了以後發現寫單元測試很難,沒有現成的測試入口。這時候項目交付的deadline又快到了,唉,要不先放著改天再寫吧。當然我們都知道,這個改天大概率再也不會做。我們有一萬個理由可以不做單元測試。但是這就好比,組裝一架飛機不用測試各個零件的運作是否符合預期,直接讓它飛起來再看有哪些問題。
  • 如何編寫屬於你的第一個 Android 單元測試?
    本文主要面向單元測試新手,首先簡單介紹了什麼是單元測試,為什麼要寫單元測試,討論了一下 Android 項目中哪些代碼適合做單元測試
  • 單元測試可測試程式設計師代碼編寫的正確性,如何使用VS2019測試項目
    單元測試是指編寫代碼來驗證開發者編寫代碼的正確性,一般單元測試也是由開發者完成的,自已開發單元測試代碼來檢查自己編寫代碼的通過性。定義:單元測試是開發人員編寫的、用於檢測在特定條件下目標代碼正確性的代碼,單元測試是代碼級別的測試。
  • Android單元測試——初探
    紅灰李的博客地址:http://www.jianshu.com/users/6ca395ea7e3e這篇文章主要是總結一下我自己在學習Android單元測試過程中的收穫及感悟,同時也希望可以幫助到正在學習Android單元測試的小夥伴們.由於時間及經驗有限,文中可能存在錯誤與不足,歡迎大家指出,我會在第一時間對文章進行修改糾正.
  • Android單元測試實踐
    為什麼要引入單元測試  一般來說我們都不會寫單元測試,為什麼呢?因為要寫多餘的代碼,而且還要進行一些學習,入門有些門檻,所以一般在工程中都不會寫單元測試。那麼為什麼我決定要寫單元測試。  這篇文章看完並不會讓你完全掌握單元測試,但是會給你在單元測試的開始有一個好的指引  大大提高工作效率  單元的概念比較模糊,可以是一個方法,可以是一個時機,但是不是一整套環節,一整套環節那就是集成測試了。為什麼說大大提高了工作效率。
  • 單元測試的藝術
    單元測試(Unit Testing):單元測試是一段代碼,這段代碼調用一個工作單元,使用偽對象擺脫對時間、網絡等的真實依賴而實現對工作單元的完全控制,並檢驗該工作單元的一個具體的最終結果。如果關於這個最終結果的期望與實際不一致,單元測試就失敗了。一個單元測試的範圍,可以小到一個方法,大到多個類。3.
  • PEP三年級下冊英語全套單元測試,很不錯的學習資料,可以收藏學
    PEP三年級下冊英語全套單元測試,很不錯的學習資料,可以收藏學如何教孩子們學習英語?小學三年級英語是小學階段英語學習的一個轉折點。而這一年齡段的學生對英語仍保持著新鮮感,大部分學生對英語學習饒有興趣。所以,在這個過程中,我們教孩子學英語更輕鬆。接下來,小西老師就分享一份PEP三年級下冊英語全套單元測試,建議家長幫同學們收藏一份,讓孩子在家也能輕鬆學習英語。文中資料「完整版」如何獲取呢?
  • 三年級數學上冊第一單元測試,輔導班學習效果如何,答完卷子便知
    很多家長為孩子暑假學習效果如何而擔心著。下面以三年級數學上冊第一單元測試卷為例,來分析本單元學習重難點:第一大題,看誰算得又對又快,屬於最基礎的口算題,只要三年級小學生遵循混合運算法則:在一個混合算式裡,既有乘除法、又有加減法、還有小括號時,要先算小括號裡面的,再算乘除法,最後再算加減法。
  • 四年級語文單元測試卷,附答案!做個小測試,檢驗孩子的學習成果
    孩子學習語文,在於積累,我們如果想知道孩子的學習情況,最簡單、直接的方法就是給孩子做個小測試。從孩子的答題試卷上來看,就能知道孩子在哪些知識方面比較薄弱,哪些知識沒有掌握到位。老師才能根據孩子的薄弱方面進行因材施教,重點攻克孩子的短板,提高孩子的學習效率、。
  • .NET 項目中的單元測試
    .NET 項目中的單元測試Intro「不會寫單元測試的程式設計師不是合格的程式設計師,不寫單元測試的程式設計師不是優秀的工程師。」—— 一隻想要成為一個優秀程式設計師的渣逼程序猿。那麼問題來了,什麼是單元測試,如何做單元測試。
  • 單元測試常用的方法
    概述   工廠在組裝一臺電視機之前,會對每個元件都進行測試,這,就是單元測試。其實我們每天都在做單元測試。你寫了一個函數,除了極簡單的外,總是要執行一下,看看功能是否正常,有時還要想辦法輸出些數據,如彈出信息窗口什麼的,這,也是單元測試,我們把這種單元測試稱為臨時單元測試。
  • 當Espresso遇見Android單元測試
    如果依賴Android環境,但是沒有UI相關或者UI比較簡單(如點擊按鈕)的單元測試可以使用開源庫Robolectric解決依賴問題,使測試運行在JVM上,而非模擬器上,大大提高測試運行效率。但是如果測試UI相關比較複雜的代碼,又可以如何進行測試呢?