文 | McGL
源 | 知乎
寫深度學習網絡代碼,最大的挑戰之一,尤其對新手來說,就是把所有的張量維度正確對齊。如果以前就有TensorSensor這個工具,相信我的頭髮一定比現在更濃密茂盛!
TensorSensor,碼痴教授 Terence Parr 出品,他也是著名 parser 工具 ANTLR 的作者。
在包含多個張量和張量運算的複雜表達式中,張量的維數很容易忘了。即使只是將數據輸入到預定義的 TensorFlow 網絡層,維度也要弄對。當你要求進行錯誤的計算時,通常會得到一些沒啥用的異常消息。為了幫助自己和其他程式設計師調試張量代碼,Terence Parr 寫了一個名叫 TensorSensor 的庫(pip install tensor-sensor 直接安裝) 。TensorSensor 通過增加消息和可視化 Python 代碼來展示張量變量的形狀,讓異常更清晰(見下圖)。它可以兼容 TensorFlow、PyTorch 和 Numpy以及 Keras 和 fastai 等高級庫。
在張量代碼中定位問題令人抓狂!即使是專家,執行張量操作的 Python 代碼行中發生異常,也很難快速定位原因。調試過程通常是在有問題的行前面添加一個 print 語句,以打出每個張量的形狀。這需要編輯代碼添加調試語句並重新運行訓練過程。或者,我們可以使用交互式調試器手動單擊或鍵入命令來請求所有張量形狀。(這在像 PyCharm 這樣的 IDE 中不太實用,因為在調試模式很慢。)下面將詳細對比展示看了讓人貧血的預設異常消息和 TensorSensor 提出的方法,而不用調試器或 print 大法。
調試一個簡單的線性層讓我們來看一個簡單的張量計算,來說明預設異常消息提供的信息不太理想。下面是一個包含張量維度錯誤的硬編碼單(線性)網絡層的簡單 NumPy 實現。
import numpy as np
n = 200 # number of instances
d = 764 # number of instance features
n_neurons = 100 # how many neurons in this layer?
W = np.random.rand(d,n_neurons) # Ooops! Should be (n_neurons,d)
b = np.random.rand(n_neurons,1)
X = np.random.rand(n,d) # fake input matrix with n rows of d-dimensions
Y = W @ X.T + b # pass all X instances through layer10 Y = W @ X.T + b
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)
執行該代碼會觸發一個異常,其重要元素如下:
...
---> 10 Y = W @ X.T + b
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)異常顯示了出錯的行以及是哪個操作(matmul: 矩陣乘法),但是如果給出完整的張量維數會更有用。此外,這個異常也無法區分在 Python 的一行中的多個矩陣乘法。
接下來,讓我們看看 TensorSensor 如何使調試語句更加容易的。如果我們使用 Python with 和tsensor 的 clarify()包裝語句,我們將得到一個可視化和增強的錯誤消息。
import tsensor
with tsensor.clarify():
Y = W @ X.T + b...
ValueError: matmul: Input operand ...
Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200)從可視化中可以清楚地看到,W 的維度應該翻轉為 n _ neurons x d; W 的列必須與 X.T 的行匹配。您還可以檢查一個完整的帶有和不帶闡明()的並排圖像,以查看它在筆記本中的樣子。下面是帶有和沒有 clarify() 的例子在notebook 中的比較。
clarify() 功能在沒有異常時不會增加正在執行的程序任何開銷。有異常時, clarify():
給出出錯操作所涉及的張量大小的可視化表示; 只突出顯示異常涉及的操作對象和運算符,而其他 Python 元素則不突出顯示。
TensorSensor 還區分了 PyTorch 和 TensorFlow 引發的與張量相關的異常。下面是等效的代碼片段和增強的異常錯誤消息(Cause: @ on tensor ...)以及 TensorSensor 的可視化:
PyTorch 消息沒有標識是哪個操作觸發了異常,但 TensorFlow 的消息指出了是矩陣乘法。兩者都顯示操作對象維度。
調試複雜的張量表達式預設消息缺乏具體細節,在包含大量操作符的更複雜的語句中,識別出有問題的子表達式很難。例如,下面是從一個門控循環單元(GRU)實現的內部提取的一個語句:
h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
這是什麼計算或者變量代表什麼不重要,它們只是張量變量。有兩個矩陣乘法,兩個向量加法,還有一個向量逐元素修改(r*h)。如果沒有增強的錯誤消息或可視化,我們就無法知道是哪個操作符或操作對象導致了異常。為了演示 TensorSensor 在這種情況下是如何分清異常的,我們需要給語句中使用的變量(為 h _ 賦值)一些偽定義,以得到可執行代碼:
nhidden = 256
Whh_ = torch.eye(nhidden, nhidden) # Identity matrix
Uxh_ = torch.randn(d, nhidden)
bh_ = torch.zeros(nhidden, 1)
h = torch.randn(nhidden, 1) # fake previous hidden state h
r = torch.randn(nhidden, 1) # fake this computation
X = torch.rand(n,d) # fake input
with tsensor.clarify():
h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)同樣,你可以忽略代碼執行的實際計算,將重點放在張量變量的形狀上。
對於我們大多數人來說,僅僅通過張量維數和張量代碼是不可能識別問題的。當然,默認的異常消息是有幫助的,但是我們中的大多數人仍然難以定位問題。以下是默認異常消息的關鍵部分(注意對 C++ 代碼的不太有用的引用) :
---> 10 h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
RuntimeError: size mismatch, m1: [764 x 256], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41我們需要知道的是哪個操作符和操作對象出錯了,然後我們可以通過維數來確定問題。以下是 TensorSensor 的可視化和增強的異常消息:
---> 10 h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
RuntimeError: size mismatch, m1: [764 x 256], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
Cause: @ on tensor operand Uxh_ w/shape [764, 256] and operand X.T w/shape [764, 200]人眼可以迅速鎖定在指示的算子和矩陣相乘的維度上。哎呀, Uxh 的列必須與 X.T的行匹配,Uxh_的維度翻轉了,應該為:
Uxh_ = torch.randn(nhidden, d)
現在,我們只在 with 代碼塊中使用我們自己直接指定的張量計算。那麼在張量庫的內置預建網絡層中觸發的異常又會如何呢?
理清預建層中觸發的異常TensorSensor 可視化進入你選擇的張量庫前的最後一段代碼。例如,讓我們使用標準的 PyTorch nn.Linear 線性層,但輸入一個 X 矩陣維度是 n x n,而不是正確的 n x d:
L = torch.nn.Linear(d, n_neurons)
X = torch.rand(n,n) # oops! Should be n x d
with tsensor.clarify():
Y = L(X)增強的異常信息
RuntimeError: size mismatch, m1: [200 x 200], m2: [764 x 100] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
Cause: L(X) tensor arg X w/shape [200, 200]TensorSensor 將張量庫的調用視為操作符,無論是對網絡層還是對 torch.dot(a,b) 之類的簡單操作的調用。在庫函數中觸發的異常會產生消息,消息標示了函數和任何張量參數的維數。
後臺回復關鍵詞【入群】
加入賣萌屋NLP/IR/Rec與求職討論群
後臺回復關鍵詞【頂會】
獲取ACL、CIKM等各大頂會論文集!
[1] https://explained.ai/tensor-sensor/index.html