機器之心報導
參與:魔王、陳萍
誕生五年的 TensorFlow 出現大 bug,使用對應訓練方式得到的模型甚至論文結果可能受到波及,然而相關 issue 提交 24 天后依然沒有 TensorFlow 開發團隊的處理。用戶表示很失望,「怒而轉用 PyTorch」。在事情發酵後,TensorFlow 團隊終於回復了,表示已經在改,但對應的功能將在 2.4 版本中才能用。
谷歌團隊 2015 年發布的 TensorFlow 框架是目前機器學習領域最流行的框架之一。雖然後起之秀 PyTorch 奮起直追,但 TensorFlow 框架的使用者仍然眾多。
TensorFlow 經常被吐槽難用、新版本也常常收到差評,但不管怎樣,已經誕生五年之久的 TensorFlow 應該不會有什麼太大的 bug 吧?然而,事實似乎並非如此。
最近,機器學習工程師 Santosh Gupta 在使用 TensorFlow 時發現了一個問題:使用 Keras 功能 API 創建的模型自定義層中的權重無法進行梯度更新。
issue 詳情:https://github.com/tensorflow/tensorflow/issues/40638
這個帖子在 reddit 上引起了熱議,網友紛紛表示:「這是在逼我用 PyTorch!」
到底是什麼驚天大 bug?
那麼這個令人震驚的 bug 到底是什麼呢?
Santosh Gupta 對此的描述是:由於 Tensorflow 的缺陷,阻止了 Keras 功能 API 創建模型的自定義層中權重的梯度更新,從而使這些權重基本上保持無法更新狀態。
而我們都知道,梯度更新對於訓練神經網絡來說相當重要,它是保證模型正常訓練的前提。
對於使用自定義圖層功能性 API 的研究人員來說,他們往往會運行下列程序:
fori, varinenumerate(model.trainable_variables): print(model.trainable_variables[i].name)
這個程序會保存你的訓練權重。而 Tensorflow 中出現的這個 bug,導致使用者在功能性 API 中使用自定義圖層時 trainable_variables 缺少權重。同樣地,這些權重在 non_trainable_variables 也會消失。
但是,如果這些權重不在可訓練變量中,則必須凍結這些權重,因為只有這些權重才會接收梯度更新,如下面的 Keras 模型訓練代碼所示:
gradients = tape.gradient(loss, trainable_variables) # Whether to aggregate gradients outside of optimizer. This requires support # of the optimizer and doesn't work with ParameterServerStrategy and # CentralStroageStrategy. aggregate_grads_outside_optimizer = ( optimizer._HAS_AGGREGATE_GRAD and # pylint: disable=protected-access not isinstance(strategy.extended, parameter_server_strategy.ParameterServerStrategyExtended)) if aggregate_grads_outside_optimizer: # We aggregate gradients before unscaling them, in case a subclass of # LossScaleOptimizer all-reduces in fp16. All-reducing in fp16 can only be # done on scaled gradients, not unscaled gradients, for numeric stability. gradients = optimizer._aggregate_gradients(zip(gradients, # pylint: disable=protected-access trainable_variables)) if isinstance(optimizer, lso.LossScaleOptimizer): gradients = optimizer.get_unscaled_gradients(gradients) gradients = optimizer._clip_gradients(gradients) # pylint: disable=protected-access if trainable_variables: if aggregate_grads_outside_optimizer: optimizer.apply_gradients( zip(gradients, trainable_variables), experimental_aggregate_gradients=False) else: optimizer.apply_gradients(zip(gradients, trainable_variables))
通過 Colab gist [1],你可以看到此 bug。
針對上述 bug,也有研究者提出了解決方案。
一種解決方法是改用 Keras 子類創建模型。模型子類化導致所有權重出現在 trainable_variables 中。為了確保功能性 API 和子類模型完全相同,研究人員在每個筆記本底部使用相同的輸入對它們進行推論。模型的輸出完全相同。但是使用功能性 API 模型進行訓練會將許多權重視為凍結。
針對此帖,Keras 之父、谷歌軟體工程師 Francois Chollet 也不淡定了。
他表示,「如果第三方寫的代碼有 bug,且涉及到了 Keras 模型,這並不意味著『Keras 就有 bug』。」
此外,他認為:跟蹤自定義圖層中訓練參數的效果非常好,只需要 7 行代碼就可以進行測試。
最新動向:引發熱議後,谷歌回復
在 Francois Chollet 發推一小時後,谷歌工程師、TensorFlow 貢獻者 Tomer Kaftan 在 GitHub 上回復了該 issue:
目前,TensorFlow 的情況是這樣的:如果第一個參數中的所有輸入來自其他 Keras 層,則當前層進入「functional api construction」模式。但是,你的第一個位置參數輸入中包含 None,因此,無法觸發「functional api construction」模式。這導致該層與外部功能模型產生內聯(inlined),而不是正確地被納入外部模型。你可以更改層 API,排除掉輸入中的 Nones,這樣就可以解決該問題。
功能 API 的主要 cleanup/refactoring 已經大部分完成,以使功能 API 觸發機制更加清晰(即使輸入中出現任意符號值),並解決其他的一些 issue。但是,該功能將在 TensorFlow 2.4 版本中出現。
對此,issue 發起者 Santosh Gupta 表示同意:
網友:震驚,這是逼我用 PyTorch!
在這篇帖子的評論中,有網友復現了這個 bug,並表示震驚:「這個 bug 到底存在多久了?!這是不是意味著用這種方式訓練的每一個模型都失效了,基於這些模型的每一篇研究論文的結果也會被拖累。」
此外,該網友對 TensorFlow 開發者的維護效率也表示質疑:
Git issue 顯示 23 天前就有 TensorFlow 開發者承認了這個 bug 的存在,並將該 issue 指定給另一位開發者,而被指定者並沒有查看這個 issue。這就像一家食品公司 23 天就發現自己的產品中存在大腸桿菌,但是這麼多天過去了他們啥都沒幹。我見過很多對 TensorFlow 的抱怨,但是之前從未聽到過這樣的事情。
這件事也引發了開發者們對 TensorFlow 甚至谷歌產品的吐槽:
作為谷歌曾經的擁躉,現在我對它的所有產品感到厭倦。所有事情都半途而廢,看不到完成的可能性,也看不到對用戶的關注。TensorFlow 真是糟糕透了。開發團隊意識到 PyTorch 正在搶奪他們的用戶,但他們仍和以往一樣半途而廢,沒有將資源或 Keras 置於優先級較高的位置,因為他們內部並不使用。文檔也很糟糕,是因為任何有自尊心的工程師都不想為寫優秀的文檔費心嗎?
然而,競爭對手 PyTorch 的文檔可讀性就很強,PyTorch 官方甚至還提供了限時免費的權威官方教程書籍。或許有一天谷歌也會出現一位像薩提亞 · 納德拉那樣的人物,改變谷歌的內部文化,更加關注用戶和產品。而現在,谷歌只是停留在廣告業務帶來的收益上吃老底,這使得他們忽略了自己在幾乎其他所有業務上的無能。
即便在事情引發熱議後 TensorFlow 團隊進行了回復,但這個 bug 仍有可能對 TensorFlow 造成影響。
下面這句評論或許最能反映廣大開發者的心態:
「這將破壞用戶對 TensorFlow 的信任,可能有更多的開發者轉用 PyTorch。」
參考連結:
https://colab.research.google.com/gist/Santosh-Gupta/40c54e5b76e3f522fa78da6a248b6826/missingtrainablevarsinference_var.ipynb#scrollTo=28bP9FYpILJ9
https://www.reddit.com/r/MachineLearning/comments/hrawam/d_theres_a_flawbug_in_tensorflow_thats_preventing/