大家上午好!
今天向各位分享TensorFlow中張量Tensor的轉置函數tf.transpose()的用法,重點介紹其參數perm及原理。
Tensor 的階在開始介紹轉置函數之前,我們先來看一下Tensor的階
當張量Tensor為一個標量時,即不帶方向的純量,其階為0;
x0 = tf.constant(1)
print(x0) # 輸出 tf.Tensor(1, shape=(), dtype=int32)
當Tensor為一個向量時,如[1, 2, 3]時,其階為1;
x1 = tf.constant([1, 2, 3])
print(x1) # 輸出 tf.Tensor([1 2 3], shape=(3,), dtype=int32)
當Tensor為矩陣時,其階為2,如
x2 = tf.constant([[1, 2], [3, 4]])
print(x2) # 輸出 tf.Tensor([[1 2] [3 4]], shape=(2, 2), dtype=int32)
而3階Tensor可以被認為是一個立方體的數字集合,由多個小立方體組成,每個小立方體上存放了一個數字,如下圖所示:
x3 = tf.constant([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
print(x3) # 輸出 tf.Tensor([[[ 1 2 3] [ 4 5 6]] [[ 7 8 9] [10 11 12]]], shape=(2, 2, 3), dtype=int32)
接下來我們對Tensor的轉置進行討論
0階,1階Tensor的轉置,可以說沒有意義;2階Tensor轉置就相當於矩陣轉置,比如
的轉置就為
屬於大學線性代數部分,也無需過多介紹;
我們重點來討論3階Tensor的轉置,這時就需要用到tf.transpose()函數了
tf.transpose()函數的官方文檔中,介紹了該函數存在一個參數perm,通過指定perm的值,來完成的Tensor的轉置。
perm表示張量階的指定變化。假設Tensor是2階的,且其shape=(x, y),此狀態下默認perm = [0, 1]。當對2階Tensor進行轉置時,如果指定tf.transpose(perm=[1, 0]),就直接完成了矩陣的轉置,此時Tensor的shape=(y, x).
x2_ = tf.transpose(x2)
print(x2_) # 輸出 tf.Tensor([[1 3] [2 4]], shape=(2, 2), dtype=int32)
而處理對象為3階Tensor時,在下方例子中,官方文檔中給出了這麼一句話:
(https://tensorflow.google.cn/api_docs/python/tf/transpose)
# 'perm' is more useful for n-dimensional tensors, for n > 2
於是問題來了,為什麼要設置perm=[0, 2, 1]?當參數perm=[0, 2, 1]設置完成後,為什麼會得到這樣的轉置結果呢?
tf.transpose()函數及perm參數詳解這就要和原Tensor本身的shape有關了。
首先看Tensor x3是如何組成的。該Tensor中,最外層1個中括號包含了2個中括號,這兩個中括號又分別包含了2個中括號,這兩個中括號又包含了3個int型數值,所以其shape值為(2, 2, 3)。當我們將這個3維Tensor畫成立體圖時,如下圖所示。
x3 = tf.constant([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
print(x3) # 輸出 tf.Tensor([[[ 1 2 3] [ 4 5 6]] [[ 7 8 9] [10 11 12]]], shape=(2, 2, 3), dtype=int32)
關鍵來了,這裡我們可以將perm理解為切割該立方體的切割順序。我們已知Tensor x3的shape是(2, 2, 3),它對應著原perm的切割順序。這個順序就是,先豎著與側邊平行切一次,再橫著切一次,再豎著平行於橫邊切一次,如下圖所示,就得到了Tensor原本的形狀。
我們將這種切割順序依次定義為0,1,2,於是perm=[0, 1, 2],如下圖所示:
在搞懂這個對應關係後。再來看如果不通過代碼結果,我們如何確定轉置後的Tensor形狀。
當我們對這個3維Tensor x3進行轉置,並設定perm值為[0, 2, 1]時,則此時對應的shape形狀就會轉化為(2, 3, 2)。為什麼呢?
perm=[0, 2, 1]就意味著,對立方體要按照如下順序進行切割:先豎著與側邊平行切一次,再豎著平行於橫邊切一次,再橫著切一次,如下圖所示,就得到了轉置後Tensor的形狀。
這時,我們使用函數語句 tf.transpose(x3, perm = [0, 2, 1]) 進行驗證,轉置結果與推演結果一致。也就是說,shape=(2, 2, 3) 的Tensor經過perm=[0, 2, 1]轉置後,變為shape=(2, 3, 2)的Tensor。
x3_ = tf.transpose(x3, perm = [0, 2, 1])
print(x3_) # 輸出 tf.Tensor([[[1 4] [2 5] [3 6]] [[7 10] [8 11] [9 12]]], shape=(2, 3, 2), dtype=int32)
這也是為什麼在TensorFlow2.0官網教程中,官方推薦在Tensor維度大於2時,使用perm參數進行轉置操作,會更方便的達到效果。當然前提是你要明確原Tensor shape及你想要的變形後的Tensor shape,根據後續需求確定參數perm的值。
希望這篇文章對大家理解張量Tensor有幫助!畫圖排版不易,歡迎【在看】和【打賞】!