機器之心專欄
作者:張皓
RNN 在處理時序數據時十分成功。但是,對 RNN 及其變種 LSTM 和 GRU 結構的理解仍然是一個困難的任務。本文介紹一種理解 LSTM 和 GRU 的簡單通用的方法。通過對 LSTM 和 GRU 數學形式化的三次簡化,最後將數據流形式畫成一張圖,可以簡潔直觀地對其中的原理進行理解與分析。此外,本文介紹的三次簡化一張圖的分析方法具有普適性,可廣泛用於其他門控網絡的分析。
1. RNN、梯度爆炸與梯度消失
1.1 RNN
近些年,深度學習模型在處理有非常複雜內部結構的數據時十分有效。例如,圖像數據的像素之間的 2 維空間關係非常重要,CNN(convolution neural networks,卷積神經網絡)處理這種空間關係十分有效。而時序數據(sequential data)的變長輸入序列之間時序關係非常重要,RNN(recurrent neural networks,循環神經網絡,注意和 recursive neural networks,遞歸神經網絡的區別)處理這種時序關係十分有效。
我們使用下標 t 表示輸入時序序列的不同位置,用 h_t 表示在時刻 t 的系統隱層狀態向量,用 x_t 表示時刻 t 的輸入。t 時刻的隱層狀態向量 h_t 依賴於當前詞 x_t 和前一時刻的隱層狀態向量 h_(t-1):
其中 f 是一個非線性映射函數。一種通常的做法是計算 x_t 和 h_(t-1) 的線性變換後經過一個非線性激活函數,例如
其中 W_(xh) 和 W_(hh) 是可學習的參數矩陣,激活函數 tanh 獨立地應用到其輸入的每個元素。
為了對 RNN 的計算過程做一個可視化,我們可以畫出下圖:
圖中左邊是輸入 x_t 和 h_(t-1)、右邊是輸出 h_t。計算從左向右進行,整個運算包括三步:輸入 x_t 和 h_(t-1) 分別乘以 W_(xh) 和 W_(hh) 、相加、經過 tanh 非線性變換。
我們可以認為 h_t 儲存了網絡中的記憶(memory),RNN 學習的目標是使得 h_t 記錄了在 t 時刻之前(含)的輸入信息 x_1, x_2,..., x_t。在新詞 x_t 輸入到網絡之後,之前的隱狀態向量 h_(t-1) 就轉換為和當前輸入 x_t 有關的 h_t。
1.2 梯度爆炸與梯度消失
雖然理論上 RNN 可以捕獲長距離依賴,但實際應用中,RNN 將會面臨兩個挑戰:梯度爆炸(gradient explosion)和梯度消失(vanishing gradient)。
我們考慮一種簡單情況,即激活函數是恆等(identity)變換,此時
在進行誤差反向傳播(error backpropagation)時,當我們已知損失函數
對 t 時刻隱狀態向量 h_t 的偏導數
時,利用鏈式法則,我們計算損失函數
對 t 時刻隱狀態向量 h_0 的偏導數
我們可以利用 RNN 的依賴關係,沿時間維度展開,來計算
也就是說,在誤差反向傳播時我們需要反覆乘以參數矩陣 W_(hh)。我們對矩陣 W_(hh) 進行奇異值分解(SVD)
其中 r 是矩陣 W_(hh) 的秩(rank)。因此,
那麼我們最後要計算的目標
當 t 很大時,該偏導數取決於矩陣 W_(hh) 的最大的奇異值
是大於 1 還是小於 1,要麼結果太大,要麼結果太小:
(1). 梯度爆炸。當
> 1,
,那麼
此時偏導數
將會變得非常大,實際在訓練時將會遇到 NaN 錯誤,會影響訓練的收斂,甚至導致網絡不收斂。這好比要把本國的產品賣到別的國家,結果被加了層層關稅,等到了別國市場的時候,價格已經變得非常高,老百姓根本買不起。在 RNN 中,梯度(偏導數)就是價格,隨著向前推移,梯度越來越大。這種現象稱為梯度爆炸。
梯度爆炸相對比較好處理,可以用梯度裁剪(gradient clipping)來解決:
這好比是不管前面的關稅怎麼加,設置一個最高市場價格,通過這個最高市場價格保證老百姓是買的起的。在 RNN 中,不管梯度回傳的時候大到什麼程度,設置一個梯度的閾值,梯度最多是這麼大。
(2). 梯度消失。當
< 1,
,那麼
此時偏導數
將會變得十分接近 0,從而在梯度更新前後沒有什麼區別,這會使得網絡捕獲長距離依賴(long-term dependency)的能力下降。這好比打仗的時候往前線送糧食,送糧食的隊伍自己也得吃糧食。當補給點離前線太遠時,還沒等送到,糧食在半路上就已經被吃完了。在 RNN 中,梯度(偏導數)就是糧食,隨著向前推移,梯度逐漸被消耗殆盡。這種現象稱為梯度消失。
梯度消失現象解決起來困難很多,如何緩解梯度消失是 RNN 及幾乎其他所有深度學習方法研究的關鍵所在。LSTM 和 GRU 通過門(gate)機制控制 RNN 中的信息流動,用來緩解梯度消失問題。其核心思想是有選擇性的處理輸入。比如我們在看到一個商品的評論時
Amazing! This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!
我們會重點關注其中的一些詞,對它們進行處理
Amazing!This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!
LSTM 和 GRU 的關鍵是會選擇性地忽略其中一些詞,不讓其參與到隱層狀態向量的
更新中,最後只保留相關的信息進行預測。
2. LSTM
2.1 LSTM 的數學形式
LSTM(Long Short-Term Memory)由 Hochreiter 和 Schmidhuber 提出,其數學上的形式化表示如下:
其中
代表逐元素相乘,sigm 代表 sigmoid 函數
和 RNN 相比,LSTM 多了一個隱狀態變量 c_t,稱為細胞狀態(cell state),用來記錄信息。
這個公式看起來似乎十分複雜,為了更好的理解 LSTM 的機制,許多人用圖來描述 LSTM 的計算過程。比如下面這張圖:
似乎看完之後,對 LSTM 的理解仍然是一頭霧水?這是因為這些圖想把 LSTM 的所有細節一次性都展示出來,但是突然暴露這麼多的細節會使你眼花繚亂,從而無處下手。
2.2 三次簡化一張圖
因此,本文提出的方法旨在簡化門控機制中不重要的部分,從而更關注在 LSTM 的核心思想。整個過程是三次簡化一張圖,具體流程如下:
(1). 第一次簡化:忽略門控單元 i_t 、f_t 、o_t 的來源。3 個門控單元的計算方法完全相同,都是由輸入經過線性映射得到的,區別只是計算的參數不同:
使用相同計算方式的目的是它們都扮演了門控的角色,而使用不同參數的目的是為了誤差反向傳播時對三個門控單元獨立地進行更新。在理解 LSTM 運行機制的時候,為了對圖進行簡化,我們不在圖中標註三個門控單元的計算過程,並假定各門控單元是給定的。
(2). 第二次簡化:考慮一維門控單元 i_t 、 f_t 、 o_t。LSTM 中對各維是獨立進行門控的,所以為了表示和理解方便,我們只需要考慮一維情況,在理解 LSTM 原理之後,將一維推廣到多維是很直接的。經過這兩次簡化,LSTM 的數學形式只有下面三行
由於門控單元變成了一維,所以向量和向量的逐元素相乘符號
變成了數和向量相乘 · 。
(3). 第三次簡化:各門控單元二值輸出。門控單元 i_t 、f_t 、o_t 的由於經過了 sigmoid 激活函數,輸出是範圍是 [0, 1]。激活函數使用 sigmoid 的目的是為了近似 0/1 階躍函數,這樣 sigmoid 實數值輸出單調可微,可以基於誤差反向傳播進行更新。
既然 sigmoid 激活函數是為了近似 0/1 階躍函數,那麼,在進行 LSTM 理解分析的時候,為了理解方便,我們認為各門控單元 {0, 1} 二值輸出,即門控單元扮演了電路中開關的角色,用於控制信息的通斷。
(4). 一張圖。將三次簡化的結果用電路圖表述出來,左邊是輸入,右邊是輸出。在 LSTM 中,有一點需要特別注意,LSTM 中的細胞狀態 c_t 實質上起到了 RNN 中隱層單元 h_t 的作用,這點在其他文獻資料中不常被提到,所以整個圖的輸入是 x_t 和 c_{t-1},而不是 x_t 和 h_(t-1)。為了方便畫圖,我們需要將公式做最後的調整
最終結果如下:
和 RNN 相同的是,網絡接受兩個輸入,得到一個輸出。其中使用了兩個參數矩陣 W_(xc) 和 W_(hc),以及 tanh 激活函數。不同之處在於,LSTM 中通過 3 個門控單元 i_t 、f_t 、o_t 來對的信息交互進行控制。當 i_t=1(開關閉合)、f_t=0(開關打開)、o_t=1(開關閉合)時,LSTM 退化為標準的 RNN。
2.3 LSTM 各單元作用分析
根據這張圖,我們可以對 LSTM 中各單元作用進行分析:
輸出門 o_(t-1):輸出門的目的是從細胞狀態 c_(t-1) 產生隱層單元 h_(t-1)。並不是 c_(t-1) 中的全部信息都和隱層單元 h_(t-1) 有關,c_(t-1) 可能包含了很多對 h_(t-1) 無用的信息。因此,o_t 的作用就是判斷 c_(t-1) 中哪些部分是對 h_(t-1) 有用的,哪些部分是無用的。輸入門 i_t。i_t 控制當前詞 x_t 的信息融入細胞狀態 c_t。在理解一句話時,當前詞 x_t 可能對整句話的意思很重要,也可能並不重要。輸入門的目的就是判斷當前詞 x_t 對全局的重要性。當 i_t 開關打開的時候,網絡將不考慮當前輸入 x_t。遺忘門 f_t: f_t 控制上一時刻細胞狀態 c_(t-1) 的信息融入細胞狀態 c_t。在理解一句話時,當前詞 x_t 可能繼續延續上文的意思繼續描述,也可能從當前詞 x_t 開始描述新的內容,與上文無關。和輸入門 i_t 相反,f_t 不對當前詞 x_t 的重要性作判斷,而判斷的是上一時刻的細胞狀態c_(t-1)對計算當前細胞狀態 c_t 的重要性。當 f_t 開關打開的時候,網絡將不考慮上一時刻的細胞狀態 c_(t-1)。細胞狀態 c_t :c_t 綜合了當前詞 x_t 和前一時刻細胞狀態 c_(t-1) 的信息。這和 ResNet 中的殘差逼近思想十分相似,通過從 c_(t-1) 到 c_t 的「短路連接」,梯度得已有效地反向傳播。當 f_t 處於閉合狀態時,c_t 的梯度可以直接沿著最下面這條短路線傳遞到c_(t-1),不受參數 W_(xh) 和 W_(hh) 的影響,這是 LSTM 能有效地緩解梯度消失現象的關鍵所在。
3. GRU
3.1 GRU 的數學形式
GRU 是另一種十分主流的 RNN 衍生物。RNN 和 LSTM 都是在設計網絡結構用於緩解梯度消失問題,只不過是網絡結構有所不同。GRU 在數學上的形式化表示如下:
3.2 三次簡化一張圖
為了理解 GRU 的設計思想,我們再一次運用三次簡化一張圖的方法來進行分析:
(1). 第一次簡化:忽略門控單元 z_t 和 r_t 的來源。
(2). 考慮一維門控單元 z_t 和 r_t。經過這兩次簡化,GRU 的數學形式是以下兩行
(3). 第三次簡化:各門控單元二值輸出。這裡和 LSTM 略有不同的地方在於,當 z_t=1 時h_t = h_(t-1) ;而當 z_t = 0 時,h_t =
。因此,z_t 扮演的角色是一個個單刀雙擲開關。
(4). 一張圖。將三次簡化的結果用電路圖表述出來,左邊是輸入,右邊是輸出。
與 LSTM 相比,GRU 將輸入門 i_t 和遺忘門 f_t 融合成單一的更新門 z_t,並且融合了細胞狀態 c_t 和隱層單元 h_t。當 r_t=1(開關閉合)、 z_t=0(開關連通上面)GRU 退化為標準的 RNN。
3.3 GRU 各單元作用分析
根據這張圖, 我們可以對 GRU 的各單元作用進行分析:
重置門 r_t : r_t 用於控制前一時刻隱層單元 h_(t-1) 對當前詞 x_t 的影響。如果 h_(t-1) 對 x_t 不重要,即從當前詞 x_t 開始表述了新的意思,與上文無關。那麼開關 r_t 可以打開,使得 h_(t-1) 對 x_t 不產生影響。更新門 z_t : z_t 用於決定是否忽略當前詞 x_t。類似於 LSTM 中的輸入門 i_t,z_t 可以判斷當前詞 x_t 對整體意思的表達是否重要。當 z_t 開關接通下面的支路時,我們將忽略當前詞 x_t,同時構成了從 h_(t-1) 到 h_t 的短路連接,這使得梯度得已有效地反向傳播。和 LSTM 相同,這種短路機制有效地緩解了梯度消失現象,這個機制於 highway networks 十分相似。
4. 小結
儘管 RNN、LSTM、和 GRU 的網絡結構差別很大,但是他們的基本計算單元是一致的,都是對 x_t 和 h_t 做一個線性映射加 tanh 激活函數,見三個圖的紅色框部分。他們的區別在於如何設計額外的門控機制控制梯度信息傳播用以緩解梯度消失現象。LSTM 用了 3 個門、GRU 用了 2 個,那能不能再少呢?MGU(minimal gate unit)嘗試對這個問題做出回答,它只有一個門控單元。最後留個小練習,參考 LSTM 和 GRU 的例子,你能不能用三次簡化一張圖的方法來分析一下 MGU 呢?
參考文獻
Yoshua Bengio, Patrice Y. Simard, and Paolo Frasconi. Learning long-term dependencies with gradient descent is difficult. IEEE Transactions on Neural Networks 5(2): 157-166, 1994.Kyunghyun Cho, Bart van Merrienboer, aglar Gülehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using RNN encoder-decoder for statistical machine translation. In EMNLP, pages 1724-1734, 2014.Junyoung Chung, aglar Gülehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. In NIPS Workshop, pages 1-9, 2014.Felix Gers. Long short-term memory in recurrent neural networks. PhD Dissertation, Ecole Polytechnique Fédérale de Lausanne, 2001.Ian J. Goodfellow, Yoshua Bengio, and Aaron C. Courville. Deep learning. Adaptive Computation and Machine Learning, MIT Press, ISBN 978-0-262-03561-3, 2016.Alex Graves. Supervised sequence labelling with recurrent neural networks. Studies in Computational Intelligence 385, Springer, ISBN 978-3-642-24796-5, 2012.Klaus Greff, Rupesh Kumar Srivastava, Jan Koutník, Bas R. Steunebrink, and Jürgen Schmidhuber. LSTM: A search space odyssey. IEEE Transactions on Neural Networks and Learning Systems. 28(10): 2222-2232, 2017.Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, pages 770-778, 2016.Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings in deep residual networks. In ECCV, pages 630-645, 2016.Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural Computation 9(8): 1735-1780, 1997.Rafal Józefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In ICML, pages 2342-2350, 2015.Zachary Chase Lipton. A critical review of recurrent neural networks for sequence learning. CoRR abs/1506.00019, 2015.Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of training recurrent neural networks. In ICML, pages 1310-1318, 2013.Rupesh Kumar Srivastava, Klaus Greff, and Jürgen Schmidhuber. Highway networks. In ICML Workshop, pages 1-6, 2015.Guo-Bing Zhou, Jianxin Wu, Chen-Lin Zhang, and Zhi-Hua Zhou. Minimal gated unit for recurrent neural networks. International Journal of Automation and Computing, 13(3): 226-234, 2016.
本文為機器之心專欄,轉載請聯繫本公眾號獲得授權。