轉自 | GiantPandaCV
作者 | zzk
【導讀】GiantPandaCV導語:這篇文章為大家介紹了一下Transformer模型,Transformer模型原本是NLP中的一個Idea,後來也被引入到計算機視覺中,例如前面介紹過的DETR就是將目標檢測算法和Transformer進行結合,另外基於Transformer的魔改工作最近也層出不窮,感興趣的同學可以了解一下。前言Google於2017年提出了《Attention is all you need》,拋棄了傳統的RNN結構,「設計了一種Attention機制,通過堆疊Encoder-Decoder結構」,得到了一個Transformer模型,在機器翻譯任務中「取得了BLEU值的新高」。在後續很多模型也基於Transformer進行改進,也得到了很多表現不錯的NLP模型,前段時間,相關工作也引申到了CV中的目標檢測,可參考FAIR的DETR模型
引入問題常見的時間序列任務採用的模型通常都是RNN系列,然而RNN系列模型的順序計算方式帶來了兩個問題
某個時間狀態因此我們設計了一個全新的結構Transformer,通過Attention注意力機制,來對時間序列更好的建模。同時我們不需要像RNN那樣順序計算,從而能讓模型更能充分發揮並行計算性能。
模型架構TransFormer模型架構一覽上圖展示的就是Transformer的結構,左邊是編碼器Encoder,右邊是解碼器Decoder。通過多次堆疊,形成了Transformer。下面我們分別看下Encoder和Decoder的具體結構
Encoder編碼器架構Encoder結構如上,它由以下sublayer構成
Multi-Head Attention 多頭注意力Self AttentionMulti-Head Attention多頭注意力層是由多個self attention來組成的,因此我們先講解下模型的自注意力機制。
在一句話中,如果給每個詞都分配相同的權重,那麼會很難讓模型去學習詞與詞對應的關係。舉個例子
The animal didn't cross the street because it was too tired我們需要讓模型去推斷 it 所指代的東西,當我們給模型加了注意力機制,它的表現如下
注意力機制效果我們通過注意力機制,讓模型能看到輸入的各個單詞,然後它會更加關注於 The animal,從而更好的進行編碼。
論文裡將attention模塊記為「Scaled Dot-Product Attention」,計算如下
自注意力機制一覽其中 Q, K, V(向量長度為64)是由輸入X經過三個不同的權重矩陣(shape=512x64)計算得來,
經過Embedding的向量X,與右邊三個權重矩陣相乘,分別得到Query,Key,Value三個向量下面我們看一個具體例子
注意力機制運算過程以Thinking這個單詞為例,我們需要計算整個句子所有詞與它的score。
X1是Thinking對應的Embedding向量。然後我們與Key向量進行相乘,來計算相關性,這裡記作Score。「這個過程可以看作是當前詞的搜索q1,與其他詞的key去匹配」。當相關性越高,說明我們需要放更多注意力在上面。Softmax後的結果與Value向量相乘,得到最終結果MultiHead-Attention理解了自注意力機制後,我們可以很好的理解多頭注意力機制。簡單來說,多頭注意力其實就是合併了多個自注意力機制的結果
多頭注意力機制概覽,將多個自注意力機制並在一起我們以原文的8個注意力頭為例子,多頭注意力的操作如下
分別計算出每個自注意力模塊的結果Z0, Z1, Z2.....Z7
經過一層全連接層,得到最終的輸出最後多頭注意力的表現類似如下
多頭注意力機制效果Feed Forward Neural Network這個FFN模塊比較簡單,本質上全是兩層全連接層加一個Relu激活
Positional Encoding摒棄了CNN和RNN結構,我們無法很好的利用序列的順序信息,因此我們採用了額外的一個位置編碼來進行緩解
然後與輸入相加,通過引入位置編碼,給詞向量中賦予了單詞的位置信息
位置編碼下圖是總Encoder的架構
Encoder的整體結構DecoderDecoder的結構與Encoder的結構很相似
Decoder結構「只不過額外引入了當前翻譯和編碼特徵向量的注意力」,這裡就不展開了。
代碼這裡參考的是TensorFlow的官方實現notebook transformer.ipynb
位置編碼def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)這裡就是根據公式,生成位置編碼
Scaled-Dot Attentiondef scaled_dot_product_attention(q, k, v, mask):
"""Calculate the attention weights.
q, k, v must have matching leading dimensions.
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
The mask has different shapes depending on its type(padding or look ahead)
but it must be broadcastable for addition.
Args:
q: query shape == (..., seq_len_q, depth)
k: key shape == (..., seq_len_k, depth)
v: value shape == (..., seq_len_v, depth_v)
mask: Float tensor with shape broadcastable
to (..., seq_len_q, seq_len_k). Defaults to None.
Returns:
output, attention_weights
"""
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights輸入的是Q, K, V矩陣和一個mask掩碼向量根據公式進行矩陣相乘,得到最終的輸出,以及注意力權重
MultiheadAttention這裡的代碼就是將多個注意力結果組合在一起
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
FFNdef point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
])有了這三個模塊,就可以組合成Encoder和Decoder了,這裡限於篇幅就不展開,有興趣的可以看下官方notebook
總結Transformer這個模型設計還是很有特點的,雖然本質上還是全連接層的各個組合,但是通過不同的權重矩陣,對序列進行注意力機制建模。並且根據模型無法利用序列順序信息的缺陷,設計了一套位置編碼機制,賦予詞向量位置信息。近年來對Transformer的魔改也有很多,相信這個模型還有很大的潛力去挖掘。
相關資料參考Tensorflow官方notebook transformer.ipynb: ('https://github.com/tensorflow/docs/blob/master/site/en/tutorials/text/transformer.ipynb')illustrated-transformer: (http://jalammar.github.io/illustrated-transformer/) 該作者的圖示很明晰,相對容易理解✄---看到這裡,說明你喜歡這篇文章,請點擊「在看」或順手「轉發」「點讚」。
歡迎微信搜索「panchuangxx」,添加小編磐小小仙微信,每日朋友圈更新一篇高質量推文(無廣告),為您提供更多精彩內容。
▼ ▼ 掃描二維碼添加小編 ▼ ▼