Self-Attention GAN 中的 self-attention 機制

2021-12-17 PaperWeekly

作者丨尹相楠

學校丨裡昂中央理工博士在讀

研究方向丨人臉識別、對抗生成網絡

Self Attention GAN 用到了很多新的技術。最大的亮點當然是 self-attention 機制,該機制是 Non-local Neural Networks [1] 這篇文章提出的。其作用是能夠更好地學習到全局特徵之間的依賴關係。因為傳統的 GAN 模型很容易學習到紋理特徵:如皮毛,天空,草地等,不容易學習到特定的結構和幾何特徵,例如狗有四條腿,既不能多也不能少。 

除此之外,文章還用到了 Spectral Normalization for GANs [2] 提出的譜歸一化。譜歸一化的解釋見本人這篇文章:詳解GAN的譜歸一化(Spectral Normalization)。

但是,該文代碼中的譜歸一化和原始的譜歸一化運用方式略有差別: 

1. 原始的譜歸一化基於 W-GAN 的理論,只用在 Discriminator 中,用以約束 Discriminator 函數為 1-Lipschitz 連續。而在 Self-Attention GAN 中,Spectral Normalization 同時出現在了 Discriminator 和 Generator 中,用於使梯度更穩定。除了生成器和判別器的最後一層外,每個卷積/反卷積單元都會上一個 SpectralNorm。 

2. 當把譜歸一化用在 Generator 上時,同時還保留了 BatchNorm。Discriminator 上則沒有 BatchNorm,只有 SpectralNorm。 

3. 譜歸一化用在 Discriminator 上時最後一層不加 Spectral Norm。 

最後,self-attention GAN 還用到了 cGANs With Projection Discriminator 提出的 conditional normalizationprojection in the discriminator。這兩個技術我還沒有來得及看,而且 PyTorch 版本的 self-attention GAN 代碼中也沒有實現,就先不管它們了。

本文主要說的是 self-attention 這部分內容。

 圖1. Self-Attention

Self-Attention

在卷積神經網絡中,每個卷積核的尺寸都是很有限的(基本上不會大於 5),因此每次卷積操作只能覆蓋像素點周圍很小一塊鄰域。

對於距離較遠的特徵,例如狗有四條腿這類特徵,就不容易捕獲到了(也不是完全捕獲不到,因為多層的卷積、池化操作會把 feature map 的高和寬變得越來越小,越靠後的層,其卷積核覆蓋的區域映射回原圖對應的面積越大。但總而言之,畢竟還得需要經過多層映射,不夠直接)。

Self-Attention 通過直接計算圖像中任意兩個像素點之間的關係,一步到位地獲取圖像的全局幾何特徵。 

論文中的公式不夠直觀,我們直接看文章的 PyTorch 的代碼,核心部分為 sagan_models.py:

class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) 
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) 
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) 
        energy =  torch.bmm(proj_query,proj_key) 
        attention = self.softmax(energy) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) 

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        return out,attention

構造函數中定義了三個 1 × 1 的卷積核,分別被命名為 query_conv , key_conv 和 value_conv 。

為啥命名為這三個名字呢?這和作者給它們賦予的含義有關。query 意為查詢,我們希望輸入一個像素點,查詢(計算)到 feature map 上所有像素點對這一點的影響。而 key 代表字典中的鍵,相當於所查詢的資料庫。query 和 key 都是輸入的 feature map,可以看成把 feature map 複製了兩份,一份作為 query 一份作為 key。 

需要用一個什麼樣的函數,才能針對 query 的 feature map 中的某一個位置,計算出 key 的 feature map 中所有位置對它的影響呢?作者認為這個函數應該是可以通過「學習」得到的。那麼,自然而然就想到要對這兩個 feature map 分別做卷積核為 1 × 1 的卷積了,因為卷積核的權重是可以學習得到的。 

至於 value_conv ,可以看成對原 feature map 多加了一層卷積映射,這樣可以學習到的參數就更多了,否則 query_conv 和 key_conv 的參數太少,按代碼中只有 in_dims × in_dims//8 個。 

接下來逐行研究 forward 函數:

proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1)

這行代碼先對輸入的 feature map 卷積了一次,相當於對 query feature map 做了一次投影,所以叫做 proj_query。由於是 1 × 1 的卷積,所以不改變 feature map 的長和寬。feature map 的每個通道為如 (1) 所示的矩陣,矩陣共有 N 個元素(像素)。

然後重新改變了輸出的維度,變成:

 (m_batchsize,-1,width*height) 

batch size 保持不變,width 和 height 融合到一起,把如 (1) 所示二維的 feature map 每個 channel 拉成一個長度為 N 的向量。

因此,如果 m_batchsize 取 1,即單獨觀察一個樣本,該操作的結果是得到一個矩陣,矩陣的的行數為 query_conv 卷積輸出的 channel 的數目 C( in_dim//8 ),列數為 feature map 像素數 N。

然後作者又通過 .permute(0, 2, 1) 轉置了矩陣,矩陣的行數變成了 feature map 的像素數 N,列數變成了通道數 C。因此矩陣維度為 N × C 。該矩陣每行代表一個像素位置上所有通道的值,每列代表某個通道中所有的像素值。

 圖2. proj_query 的維度

proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height)

這行代碼和上一行類似,只不過取消了轉置操作。得到的矩陣行數為通道數 C,列數為像素數 N,即矩陣維度為 C × N。該矩陣每行代表一個通道中所有的像素值,每列代表一個像素位置上所有通道的值。

 圖3. proj_key的維度

energy =  torch.bmm(proj_query,proj_key)

這行代碼中, torch.bmm 的意思是 batch matrix multiplication。就是說把相同 batch size 的兩組 matrix 一一對應地做矩陣乘法,最後得到同樣 batchsize 的新矩陣。

若 batch size=1,就是普通的矩陣乘法。已知 proj_query 維度是 N × C, proj_key 的維度是 C × N,因此 energy 的維度是 N × N:

 圖4. energy的維度

energy 是 attention 的核心,其中第 i 行 j 列的元素,是由 proj_query 第 i 行,和 proj_key 第 j 列通過向量點乘得到的。而 proj_query 第 i 行表示的是 feature map 上第 i 個像素位置上所有通道的值,也就是第 i 個像素位置的所有信息,而 proj_key 第 j 列表示的是 feature map 上第 j 個像素位置上的所有通道值,也就是第 j 個像素位置的所有信息。

這倆相乘,可以看成是第 j 個像素對第 i 個像素的影響。即,energy 中第 i 行 j 列的元素值,表示第 j 個像素點對第 i 個像素點的影響。

attention = self.softmax(energy)

這裡 sofmax 是構造函數中定義的,為按「行」歸一化。這個操作之後的矩陣,各行元素之和為 1。這也比較好理解,因為 energy 中第 i 行元素,代表 feature map 中所有位置的像素對第 i 個像素的影響,而這個影響被解釋為權重,故加起來應該是 1,故應對其按行歸一化。attention 的維度也是 N × N。

proj_value = self.value_conv(x).view(m_batchsize,-1,width*height)

上面的代碼中,先對原 feature map 作一次卷積映射,然後把得到的新 feature map 改變形狀,維度變為 C × N ,其中 C 為通道數(注意和上面計算 proj_query   proj_key 的 C 不同,上面的 C 為 feature map 通道數的 1/8,這裡的 C 與 feature map 通道數相同),N 為 feature map 的像素數。

 圖5. proj_value的維度

out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)

然後,再把 proj_value (C × N)矩陣同  attention 矩陣的轉置(N × N)相乘,得到 out (C × N)。之所以轉置,是因為 attention 中每行的和為 1,其意義是權重,需要轉置後變為每列的和為 1,施加於 proj_value 的行上,作為該行的加權平均。 proj_value 第 i 行代表第 i 個通道所有的像素值, attention 第 j 列,代表所有像素施加到第 j 個像素的影響。

因此, out 中第 i 行包含了輸出的第 i 個通道中的所有像素,第 j 列表示所有像素中的第 j 個像素,合起來也就是: out 中的第 i 行第 j 列的元素,表示被 attention 加權之後的 feature map 的第 i 個通道的第 j 個像素的像素值。再改變一下形狀, out 就恢復了 channel×width×height 的結構。

 圖6. out的維度

最後一行代碼,借鑑了殘差神經網絡(residual neural networks)的操作, gamma 是一個參數,表示整體施加了 attention 之後的 feature map 的權重,需要通過反向傳播更新。而 x 就是輸入的 feature map。

在初始階段, gamma 為 0,該 attention 模塊直接返回輸入的 feature map,之後隨著學習,該 attention 模塊逐漸學習到了將 attention 加權過的 feature map 加在原始的 feature map 上,從而強調了需要施加注意力的部分 feature map。

總結

可以把 self attention 看成是 feature map 和它自身的轉置相乘,讓任意兩個位置的像素直接發生關係,這樣就可以學習到任意兩個像素之間的依賴關係,從而得到全局特徵了。看論文時會被它複雜的符號迷惑,但是一看代碼就發現其實是很 naive 的操作。

參考文獻

[1] Xiaolong Wang, Ross Girshick, Abhinav Gupta, Kaiming He, Non-local Neural Networks, CVPR 2018.

[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida, Spectral Normalization for Generative Adversarial Networks, ICLR 2018.

相關焦點

  • 乾貨 | NLP中的self-attention【自-注意力】機制
    本人就這篇論文中的self-attention以及一些相關工作進行了學習總結(其中也參考借鑑了張俊林博士的博客"深度學習中的注意力機制(2017版)"和蘇劍林的"《Attention is All You Need》淺讀(簡介+代碼)"),和大家一起分享。
  • 乾貨|理解attention機制本質及self-attention
    的知識,這篇接上篇,更加深入的理解attention的有關思想和最新的self-attention機制一、Attention機制的本質思想如果把Attention機制從上文講述例子中的Encoder-Decoder框架中剝離,並進一步做抽象,可以更容易看懂Attention機制的本質思想。
  • Self-attention Mask詳解
    如果你要對transformer進行改動,就必然要去重新構建你的mask,那麼本文會對transformer中的mask機制進行詳細解讀,教你如何根據輸入的改變來修改你的mask。figure 6其中,每一行都為softmax之後歸一化得到的權重向量,代表q在K中的每一個k中的信息佔比。最後,與V矩陣進行相乘得到self-attention的結果,如下圖:
  • 動手推導Self-Attention
    基於Transformer的架構主要用於對自然語言理解任務進行建模,避免使用神經網絡中的遞歸神經網絡,而是完全依賴Self-Attention機制來繪製輸入和輸出之間的全局依存關係。但是,這背後的數學原理是什麼?這就是我們今天要發掘的問題。這篇文章的主要內容是引導您完成Self-Attention模塊中涉及的數學運算。在本文結尾處,您應該能夠從頭開始編寫或編寫Self-Attention模塊。
  • 【Self-Attention】幾篇較新的計算機視覺Self-Attention
    中,然後再上擴 channel 數,與原輸入 feature map X 殘差一下,完整的 bottleneck嵌入在 action recognition 框架中的attention map 可視化效果:
  • 從原始碼剖析Self-Attention知識點
    不考慮多頭的原因,self-attention中詞向量不乘QKV參數矩陣,會怎麼樣?對於 Attention 機制,都可以用統一的 query/key/value 模式去解釋,而對於  self-attention,一般會說它的 q=k=v,這裡的相等實際上是指它們來自同一個基礎向量,而在實際計算時,它們是不一樣的,因為這三者都是乘了 QKV 參數矩陣的。
  • Self-Attention與Transformer
    所以,在NMT(Neural Machine Translation,神經機器翻譯)任務上,還添加了attention的機制。1.3 注意力機制的本質思想關於注意力機制入門文章請看我之前的一篇文章:深度學習中的注意力機制,Microstrong,地址:https://mp.weixin.qq.com/s/3911D_FkTWrtKwBo30vENg熟悉了注意力機制的原理後,我們來探究一下注意力機制的本質思想。
  • Transformer + self-attention
    部分預熱1.1 計算順序首先了解NLP中self-attention計算順序:1.2 計算公式詳解有些突兀,不著急,接下來我們看看self-attention的公式長什麼樣子:公式1此公式在論文《attention is all your need》中出現,拋開Q、K、V與dk不看,則最開始的self-attention注意力計算公式為:
  • 從Seq2seq到Attention模型到Self Attention(二)
    假設我們在計算第一個字」Are」的self-attention,我們可能會將輸入句中的每個文字」Are」, 」you」, 『very』, 『big』分別和」Are」去做比較,這個分數決定了我們在encode某個特定位置的文字時,應該給予多少注意力(attention)。
  • 人人都能看得懂的Self-Attention詳解
    QKV是qkv的矩陣形式, 所以本質是搞清楚qkv是幹嘛的;在self-attention中,q、k、v都是輸入參數矩陣變換而來的(增加可學習性)其中q和k是算相似度的得權重的,v是用來跟權重做加權求和的
  • Transformer+self-attention超詳解(亦個人心得)
    部分預熱1.1 計算順序首先了解NLP中self-attention計算順序:1.2 計算公式詳解有些突兀,不著急,接下來我們看看self-attention的公式長什麼樣子:公式1此公式在論文《attention is all your need》中出現,拋開Q、K、V與dk不看,則最開始的self-attention注意力計算公式為:
  • nlp中的Attention注意力機制+Transformer詳解
    目錄一、Attention機制剖析1、為什麼要引入Attention機制?2、Attention機制有哪些?(怎麼分類?)3、Attention機制的計算流程是怎樣的?4、Attention機制的變種有哪些?5、一種強大的Attention機制:為什麼自注意力模型(self-Attention model)在長距離序列中如此強大?
  • 自然語言處理中的自注意力機制(Self-Attention Mechanism)
    提出了多頭注意力(Multi-headed Attention)機制方法,在編碼器和解碼器中大量的使用了多頭自注意力機制(Multi-headed self-attention)。3. 在 WMT2014 語料中的英德和英法任務上取得了先進結果,並且訓練速度比主流模型更快。
  • 超詳細圖解Self-Attention的那些事兒
    我們再想,Attention機制的核心是什麼?那麼權重從何而來呢?就是這些歸一化之後的數字。當我們關注"早"這個字的時候,我們應當分配0.4的注意力給它本身,剩下0.4關注"上",0.2關注"好"。當然具體到我們的Transformer,就是對應向量的運算了,這是後話。行文至此,我們對這個東西是不是有點熟悉?Python中的熱力圖Heatmap,其中的矩陣是不是也保存了相似度的結果?
  • CV中的Attention機制:簡單而有效的CBAM模塊
    什麼是注意力機制?注意力機制(Attention Mechanism)是機器學習中的一種數據處理方法,廣泛應用在自然語言處理、圖像識別及語音識別等各種不同類型的機器學習任務中。通俗來講:注意力機制就是希望網絡能夠自動學出來圖片或者文字序列中的需要注意的地方。
  • 一文讀懂Attention機制
    : self-Attention 在上文Attention的常見類型下,這裡著重介紹self-attention。Self-attention In Transform首先,還是來講一下Transformer中的self-attention機制。
  • 注意力機制Attention
    注意力機制(attention mechanism)Attention實質上是一種分配機制,其核心思想是突出對象的某些重要特徵。根據Attention對象的重要程度,重新分配資源,即權重,實現核心思想是基於原有的數據找到其之間的關聯性,然後突出其某些重要特徵。注意力可以被描述為將一個查詢和一組鍵值對映射到一個輸出,其中查詢、鍵、值和輸出都是向量。
  • 機器翻譯的Attention機制
    __init__() self.batch_sz = batch_sz self.dec_units = dec_units self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim) self.gru = tf.keras.layers.GRU(self.dec_units,
  • 推薦算法AutoInt模型:基於multi-head self-attention的特徵高階交叉
    因此,能否通過Transformer模型的核心結構multi-head self-attention模塊,實現對特徵的顯式的高階交叉呢?作為Transformer模型的核心模塊,multi-head self-attention在這裡主要是用來對特徵信息進行提取和交叉計算。
  • 關於attention機制的一些細節的思考
    之前看過的一些attention機制,除了self attention之外,基於rnn或者cnn的attention在處理文本問題的時候基本上是embedding之後經過了rnn或者cnn結構的映射之後得到了映射後的向量V,然後attention是針對於V進行注意力weights的計算,問題來了,能不能直接在embedding上進行score的計算?