長短期記憶是複雜和先進的神經網絡結構的重要組成部分。本文的主要思想是解釋其背後的數學原理,所以閱讀本文之前,建議首先對LSTM有一些了解。
介紹
上面是單個LSTM單元的圖表。我知道它看起來可怕,但我們會通過一個接一個的文章,希望它會很清楚。
解釋
基本上一個LSTM單元有4個不同的組件。忘記門、輸入門、輸出門和單元狀態。我們將首先簡要討論這些部分的使用,然後深入討論數學部分。
忘記門
顧名思義,這部分負責決定在最後一步中扔掉或保留哪些信息。這是由第一個s型層完成的。
根據ht-1(以前的隱藏狀態)和xt(時間步長t的當前輸入),它為單元格狀態C_t-1中的每個值確定一個介於0到1之間的值。
遺忘門和上一個狀態
如果為1,所有的信息保持原樣,如果為0,所有的信息都被丟棄,對於其他的值,它決定有多少來自前一個狀態的信息被帶入下一個狀態。
輸入門
Christopher Olah博客的解釋在輸入門發生了什麼:
下一步是決定在單元格狀態中存儲什麼新信息。這包括兩部分。首先,一個稱為「輸入門層」的sigmoid層決定我們將更新哪些值。接下來,一個tanh層創建一個新的候選值的向量,C~t,可以添加到狀態中。在下一步中,我們將結合這兩者來創建對狀態的更新。
現在這兩個值i。e i_t和c~t結合決定什麼新的輸入是被輸入到狀態。
單元狀態
單元狀態充當LSTM的內存。這就是它們在處理較長的輸入序列時比普通RNN表現得更好的地方。在每一個時間步長,前一個單元狀態(Ct-1)與遺忘門結合,以決定什麼信息要被傳送,然後與輸入門(it和c~t)結合,形成新的單元狀態或單元的新存儲器。
狀態的計算公式
輸出門
最後,LSTM單元必須給出一些輸出。從上面得到的單元狀態通過一個叫做tanh的雙曲函數,因此單元狀態值在-1和1之間過濾。
LSTM單元的基本單元結構已經介紹完成,繼續推導在實現中使用的方程。
推導先決條件
推導方程的核心概念是基於反向傳播、成本函數和損失。除此以外還假設您對高中微積分(計算導數和規則)有基本的了解。
變量:對於每個門,我們有一組權重和偏差,表示為:
Wf,bf->遺忘門的權重和偏差Wi,bi->輸入門的權重和偏差Wc,bc->單元狀態的權重和偏差Wo,bo->輸出門的權重和偏差Wv ,bv -> 與Softmax層相關的權重和偏差ft, it,ctiledet, o_t -> 輸出使用的激活函數af, ai, ac, ao -> 激活函數的輸入J是成本函數,我們將根據它計算導數。注意(下劃線(_)後面的字符是下標)
前向傳播推導
門的計算公式
狀態的計算公式
以遺忘門為例說明導數的計算。我們需要遵循下圖中紅色箭頭的路徑。
我們畫出一條從f_t到代價函數J的路徑,也就是
ft→Ct→h_t→J。
反向傳播完全發生在相同的步驟中,但是是反向的
ft←Ct←h_t←J。
J對ht求導,ht對Ct求導,Ct對f_t求導。
所以如果我們在這裡觀察,J和ht是單元格的最後一步,如果我們計算dJ/dht,那麼它可以用於像dJ/dC_t這樣的計算,因為:
dJ/dCt = dJ/dht * dht/dCt(鏈式法則)
同樣,對第一點提到的所有變量的導數也要計算。
現在我們已經準備好了變量並且清楚了前向傳播的公式,現在是時候通過反向傳播來推導導數了。我們將從輸出方程開始因為我們看到在其他方程中也使用了同樣的導數。這時就要用到鏈式法則了。我們現在開始吧。
反向傳播推導
lstm的輸出有兩個值需要計算。
Softmax:對於交叉熵損失的導數,我們將直接使用最終的方程。
隱藏狀態是ht。ht是w.r的微分。根據鏈式法則,推導過程如下圖所示。
輸出門相關變量:ao和ot,微分的完整方程如下:
dJ/dVt * dVt/dht * dht/dO_t
dJ/dVt * dVt/dht可以寫成dJ/dht(我們從隱藏狀態得到這個值)。
ht的值= ot * tanh(ct) ->所以我們只需要對ht w.r求導。t o_t。其區別如下:
同樣,a_o和J之間的路徑也顯示出來。微分的完整方程如下:
dJ/dVt * dVt/dht * dt /da_o
dJ/dVt * dVt/dht * dht/dOt可以寫成dJ/dOt(我們從上面的o_t得到這個值)。
Ct是單元的單元狀態。除此之外,我們還處理候選單元格狀態ac和c~_t。
Ct的推導很簡單,因為從Ct到J的路徑很簡單。Ct→ht→Vt→j,因為我們已經有了dJ/dht,我們直接微分ht w.r。t Ct。
ht = ot * tanh(ct) ->所以我們只需要對ht w.r求導。t C_t。
微分的完整方程如下:
dJ/dht * dht/dCt * dCt/dc~_t
可以將dJ/dht * dht/dCt寫成dJ/dCt(我們在上面有這個值)。
Ct的值如圖9公式5所示(下圖第3行最後一個Ct缺少波浪號(~)符號->書寫錯誤)。所以我們只需要對C_t w.r求導。t c ~ _t。
ac:如下圖所示為ac到J的路徑。根據箭頭,微分的完整方程如下:
dJ/dht * dht/dCt * dCt/ da_c
dJ/dht * dht/dCt * dCt/dc_t可以寫成dJ/dc_t(我們在上面有這個值)。
所以我們只需要對c~t w.r求導。t ac。
輸入門相關變量:it和ai
微分的完整方程如下:
dt / dt * dt /dit
可以將dJ/dht * dht/dCt寫入為dJ/dCt(我們在單元格狀態中有這個值)。所以我們只需要對Ct w.r求導。t it。
a_i:微分的完整方程如下:
dJ/dht * dht/dCt * dt /da_i
dJ/dht * dht/dCt * dCt/dit可以寫成dJ/dit(我們在上面有這個值)。所以我們只需要對i_t w.r求導。t ai。
遺忘門相關變量:ft和af
微分的完整方程如下:
dJ/dht * dht/dCt * dCt/df_t
可以將dJ/dht * dht/dCt寫入為dJ/dCt(我們在單元格狀態中有這個值)。所以我們只需要對Ct w.r求導。t ft。
a_f:微分的完整方程如下:
dJ/dht * dht/dCt * dft/da_t
dJ/dht * dht/dCt * dCt/dft可以寫成dJ/dft(我們在上面有這個值)。所以我們只需要對ftw.r求導。t af。
Lstm的輸入
每個單元格i有兩個與輸入相關的變量。前一個單元格狀態C_t-1和前一個隱藏狀態與當前輸入連接,即
[ht-1,xt] > Z_t
C_t-1:這是Lstm單元的內存。圖5顯示了單元格狀態。c - t-1的推導很簡單因為只有c - t和c - t。
Zt:如下圖所示,Zt進入四個不同的路徑,af,ai,ao,ac。
Zt→af→ft→Ct→h_t→J。- >遺忘門
Zt→ai→it→Ct→h_t→J。- >輸入門
Zt→ac→c~t→Ct→h_t→J。->單元狀態
Zt→ao→ot→Ct→h_t→J。- >輸出門
權重和偏差
W和b的推導很簡單。下面的推導是針對Lstm的輸出門的。對於其餘的門,對權重和偏差也進行了類似的處理。
輸入和遺忘門的權重和偏差
輸出和輸出門的權重和偏差
J/dWf = dJ/daf。daf / dWf ->遺忘門
dJ/dWi = dJ/dai。dai / dWi ->輸入門
dJ/dWv = dJ/dVtdVt/ dWv ->輸出門
dJ/dWo = dJ/dao。dao / dWo ->輸出門
我們完成了所有的推導。但是有兩點需要強調
到目前為止,我們所做的只是一個時間步長。現在我們要讓它只進行一次迭代。
所以如果我們有總共T個時間步長,那麼每一個時間步長的梯度會在T個時間步長結束時相加,所以每次迭代結束時的累積梯度為:
每次迭代結束時的累積梯度用來更新權重
總結
LSTM是非常複雜的結構,但它們工作得非常好。具有這種特性的RNN主要有兩種類型:LSTM和GRU。
訓練LSTMs也是一項棘手的任務,因為有許多超參數,而正確地組合通常是一項困難的任務。
作者:Rahuljha
deephub翻譯組