手推公式:LSTM單元梯度的詳細的數學推導

2020-12-11 deephub

長短期記憶是複雜和先進的神經網絡結構的重要組成部分。本文的主要思想是解釋其背後的數學原理,所以閱讀本文之前,建議首先對LSTM有一些了解。

介紹

上面是單個LSTM單元的圖表。我知道它看起來可怕,但我們會通過一個接一個的文章,希望它會很清楚。

解釋

基本上一個LSTM單元有4個不同的組件。忘記門、輸入門、輸出門和單元狀態。我們將首先簡要討論這些部分的使用,然後深入討論數學部分。

忘記門

顧名思義,這部分負責決定在最後一步中扔掉或保留哪些信息。這是由第一個s型層完成的。

根據ht-1(以前的隱藏狀態)和xt(時間步長t的當前輸入),它為單元格狀態C_t-1中的每個值確定一個介於0到1之間的值。

遺忘門和上一個狀態

如果為1,所有的信息保持原樣,如果為0,所有的信息都被丟棄,對於其他的值,它決定有多少來自前一個狀態的信息被帶入下一個狀態。

輸入

Christopher Olah博客的解釋在輸入門發生了什麼:

下一步是決定在單元格狀態中存儲什麼新信息。這包括兩部分。首先,一個稱為「輸入門層」的sigmoid層決定我們將更新哪些值。接下來,一個tanh層創建一個新的候選值的向量,C~t,可以添加到狀態中。在下一步中,我們將結合這兩者來創建對狀態的更新。

現在這兩個值i。e i_t和c~t結合決定什麼新的輸入是被輸入到狀態。

單元狀態

單元狀態充當LSTM的內存。這就是它們在處理較長的輸入序列時比普通RNN表現得更好的地方。在每一個時間步長,前一個單元狀態(Ct-1)與遺忘門結合,以決定什麼信息要被傳送,然後與輸入門(it和c~t)結合,形成新的單元狀態或單元的新存儲器。

狀態的計算公式

輸出門

最後,LSTM單元必須給出一些輸出。從上面得到的單元狀態通過一個叫做tanh的雙曲函數,因此單元狀態值在-1和1之間過濾。

LSTM單元的基本單元結構已經介紹完成,繼續推導在實現中使用的方程。

推導先決條件

推導方程的核心概念是基於反向傳播、成本函數和損失。除此以外還假設您對高中微積分(計算導數和規則)有基本的了解。

變量:對於每個門,我們有一組權重和偏差,表示為:

Wf,bf->遺忘門的權重和偏差Wi,bi->輸入門的權重和偏差Wc,bc->單元狀態的權重和偏差Wo,bo->輸出門的權重和偏差Wv ,bv -> 與Softmax層相關的權重和偏差ft, it,ctiledet, o_t -> 輸出使用的激活函數af, ai, ac, ao -> 激活函數的輸入J是成本函數,我們將根據它計算導數。注意(下劃線(_)後面的字符是下標)

前向傳播推導

門的計算公式

狀態的計算公式

以遺忘門為例說明導數的計算。我們需要遵循下圖中紅色箭頭的路徑。

我們畫出一條從f_t到代價函數J的路徑,也就是

ft→Ct→h_t→J。

反向傳播完全發生在相同的步驟中,但是是反向的

ft←Ct←h_t←J。

J對ht求導,ht對Ct求導,Ct對f_t求導。

所以如果我們在這裡觀察,J和ht是單元格的最後一步,如果我們計算dJ/dht,那麼它可以用於像dJ/dC_t這樣的計算,因為:

dJ/dCt = dJ/dht * dht/dCt(鏈式法則)

同樣,對第一點提到的所有變量的導數也要計算。

現在我們已經準備好了變量並且清楚了前向傳播的公式,現在是時候通過反向傳播來推導導數了。我們將從輸出方程開始因為我們看到在其他方程中也使用了同樣的導數。這時就要用到鏈式法則了。我們現在開始吧。

反向傳播推導

lstm的輸出有兩個值需要計算。

Softmax:對於交叉熵損失的導數,我們將直接使用最終的方程。

隱藏狀態是ht。ht是w.r的微分。根據鏈式法則,推導過程如下圖所示。

輸出門相關變量:ao和ot,微分的完整方程如下:

dJ/dVt * dVt/dht * dht/dO_t

dJ/dVt * dVt/dht可以寫成dJ/dht(我們從隱藏狀態得到這個值)。

ht的值= ot * tanh(ct) ->所以我們只需要對ht w.r求導。t o_t。其區別如下:

同樣,a_o和J之間的路徑也顯示出來。微分的完整方程如下:

dJ/dVt * dVt/dht * dt /da_o

dJ/dVt * dVt/dht * dht/dOt可以寫成dJ/dOt(我們從上面的o_t得到這個值)。

Ct是單元的單元狀態。除此之外,我們還處理候選單元格狀態ac和c~_t。

Ct的推導很簡單,因為從Ct到J的路徑很簡單。Ct→ht→Vt→j,因為我們已經有了dJ/dht,我們直接微分ht w.r。t Ct。

ht = ot * tanh(ct) ->所以我們只需要對ht w.r求導。t C_t。

微分的完整方程如下:

dJ/dht * dht/dCt * dCt/dc~_t

可以將dJ/dht * dht/dCt寫成dJ/dCt(我們在上面有這個值)。

Ct的值如圖9公式5所示(下圖第3行最後一個Ct缺少波浪號(~)符號->書寫錯誤)。所以我們只需要對C_t w.r求導。t c ~ _t。

ac:如下圖所示為ac到J的路徑。根據箭頭,微分的完整方程如下:

dJ/dht * dht/dCt * dCt/ da_c

dJ/dht * dht/dCt * dCt/dc_t可以寫成dJ/dc_t(我們在上面有這個值)。

所以我們只需要對c~t w.r求導。t ac。

輸入門相關變量:it和ai

微分的完整方程如下:

dt / dt * dt /dit

可以將dJ/dht * dht/dCt寫入為dJ/dCt(我們在單元格狀態中有這個值)。所以我們只需要對Ct w.r求導。t it。

a_i:微分的完整方程如下:

dJ/dht * dht/dCt * dt /da_i

dJ/dht * dht/dCt * dCt/dit可以寫成dJ/dit(我們在上面有這個值)。所以我們只需要對i_t w.r求導。t ai。

遺忘門相關變量:ft和af

微分的完整方程如下:

dJ/dht * dht/dCt * dCt/df_t

可以將dJ/dht * dht/dCt寫入為dJ/dCt(我們在單元格狀態中有這個值)。所以我們只需要對Ct w.r求導。t ft。

a_f:微分的完整方程如下:

dJ/dht * dht/dCt * dft/da_t

dJ/dht * dht/dCt * dCt/dft可以寫成dJ/dft(我們在上面有這個值)。所以我們只需要對ftw.r求導。t af。

Lstm的輸入

每個單元格i有兩個與輸入相關的變量。前一個單元格狀態C_t-1和前一個隱藏狀態與當前輸入連接,即

[ht-1,xt] > Z_t

C_t-1:這是Lstm單元的內存。圖5顯示了單元格狀態。c - t-1的推導很簡單因為只有c - t和c - t。

Zt:如下圖所示,Zt進入四個不同的路徑,af,ai,ao,ac。

Zt→af→ft→Ct→h_t→J。- >遺忘門

Zt→ai→it→Ct→h_t→J。- >輸入門

Zt→ac→c~t→Ct→h_t→J。->單元狀態

Zt→ao→ot→Ct→h_t→J。- >輸出門

權重和偏差

W和b的推導很簡單。下面的推導是針對Lstm的輸出門的。對於其餘的門,對權重和偏差也進行了類似的處理。

輸入和遺忘門的權重和偏差

輸出和輸出門的權重和偏差

J/dWf = dJ/daf。daf / dWf ->遺忘門

dJ/dWi = dJ/dai。dai / dWi ->輸入門

dJ/dWv = dJ/dVtdVt/ dWv ->輸出門

dJ/dWo = dJ/dao。dao / dWo ->輸出門

我們完成了所有的推導。但是有兩點需要強調

到目前為止,我們所做的只是一個時間步長。現在我們要讓它只進行一次迭代。

所以如果我們有總共T個時間步長,那麼每一個時間步長的梯度會在T個時間步長結束時相加,所以每次迭代結束時的累積梯度為:

每次迭代結束時的累積梯度用來更新權重

總結

LSTM是非常複雜的結構,但它們工作得非常好。具有這種特性的RNN主要有兩種類型:LSTM和GRU。

訓練LSTMs也是一項棘手的任務,因為有許多超參數,而正確地組合通常是一項困難的任務。

作者:Rahuljha

deephub翻譯組

相關焦點

  • InstanceNorm 梯度公式推導
    InstanceNorm 梯度公式推導【GiantPandaCV導語】本文主內容是推導 InstanceNorm 關於輸入和參數的梯度公式,同時還會結合 Pytorch 和 MXNet 裡面 InstanceNorm 的代碼來分析。
  • 六年級數學上冊 第5單元 圓的面積公式推導及簡單應用(微課·課文·課件·教案)
    1.使學生理解圓的面積的含義,理解圓面積計算公式的推導過程,掌握圓面積的計算公式。2.培養學生動手操作、抽象概括的能力,運用所學知識解決簡單實際問題。3.滲透轉化的數學思想。重點:圓的面積計算公式的推導與應用。難點:推導圓的面積計算公式。
  • 人人都能看懂的LSTM介紹及反向傳播算法推導(非常詳細)
    圖8 轉化後的窺視孔LSTM前向傳播:在t時刻的前向傳播公式為:反向傳播:對反向傳播算法了解不夠透徹的,請參考陳楠:反向傳播算法推導過程(非常詳細),這裡有詳細的推導過程,本文將直接使用其結論。已知:
  • 淺析最美數學公式——歐拉公式之推導歸納
    本文是基於作者在高等數學和複變函數這兩門課程教學過程中的一些思考, 整理並總結了有關於大家熟知的歐拉公式在不同數學分支裡的詳細推導方法和推導過程, 以便為相關學者提供參考和借鑑。學習過高等數學的的人都學過歐拉公式, 還知道歐拉公式是指以歐拉命名的諸多公式之一。
  • 線性回歸的求解:矩陣方程和梯度下降、數學推導及NumPy實現
    求解過程會用到一些簡單的微積分,因此複習一下微積分中偏導數部分,有助於理解機器學習的數學原理。另外,複習一下矩陣和求導等知識有助於我們理解深度學習的一些數學原理。梯度下降法求解損失函數最小問題,或者說求解使損失函數最小的最優化問題時,經常使用搜索的方法。具體而言,選擇一個初始點作為起點,然後開始不斷搜索,損失函數逐漸變小,當到達搜索迭代的結束條件時,該位置為搜索算法的最終結果。
  • RNN系列教程之四 | 利用LSTM或GRU緩解梯度消失問題
    今天,我們將詳細介紹LSTM(長短時記憶)神經網絡和GRU(門控循環單元)。LSTM於1997年由Sepp Hochreiter 和Jürgen Schmidhuber首次提出,是當前應用最廣的NLP深度學習模型之一。GRU於2014年首次被提出,是LSTM的簡單變體,兩者有諸多共性。先來看看LSTM,隨後再探究LSTM與GRU的差異。
  • 自己動手做聊天機器人 三十三-兩套代碼詳解LSTM-RNN——有記憶的神經網絡
    請看下面的推導你就明白了這用來存儲第二層(也就是輸出層)的殘差,對於輸出層,殘差計算公式推導如下(公式可以在http://deeplearning.stanford.edu/wiki/index.php/%E5%8F%8D%E5%90%91%E4%
  • 【重溫序列模型】再回首DeepLearning遇見了LSTM和GRU
    上面就是LSTM一個宏觀工作原理的體現, 當然還有一些細節,比如這個記憶是怎麼進行選擇的, 這個記憶是怎麼在時間步中傳遞的, 又是怎麼保持的等, 下面從數學的角度詳細的說說:首先, 是那條記憶線到底在單元裡面長什麼樣子:
  • 線性模型篇之Logistic Regression數學公式推導
    本系列文章會介紹四種線性模型函數的推導和優化過程。參數學習LR回歸採用交叉熵作為損失函數,並使用梯度下降法對參數進行優化。採用梯度下降算法,Logistic的回歸訓練過程為:初始化w_0 為0,然後通過下式來更新迭代參數(公式-8)。
  • 等額資金終值公式的詳細推導計算式
    【提問】能否將等額資金終值公式的詳細推導計算式提供下,以便理解和記憶。年的終值相加:F=A(1+i)n+A(1+i)n-1+A(1+i)n-2……+A(1+i)+A這個計算式子是個等比數列求和的計算式,S=a1(1-qn)/(1-q)  F=A(1+i)n+A(1+i)n-1+A(1+i)n-2……+A(1+i)+A式子中a1=A,q=(1+i)  所以F=A[1-(1+i)n]/[1-(1+i)]=A[(1+i)n-1]/i  這個公式的推導就是等比數列求和的計算公式
  • 一起推導自然數平方、立方甚至更高次方的前n項和公式
    我們已知,自然數數列是一種等差數列其前n項和很容易求出自然數平方的前n項和那麼通項公式如下形式的自然數平方的數列,其前n項和如何求解呢?假設其前n項和為Tn觀察下列等式將等式中的x替換成自然數1~n,可得到一系列等式我們將這n個等式相加,可得那麼我們就可得到自然數平方數列的前n項和公式我們通過一系列等式相加
  • 機器學習之多元線性回歸模型梯度下降公式與代碼實現(篇二)
    上一篇我們介紹了線性回歸的概述和最小二乘的介紹,對簡單的一元線性方程模型手推了公式和python代碼的實現。機器學習之線性回歸模型詳細手推公式與代碼實現(篇一)今天這一篇來介紹多元線性回歸模型多元線性回歸模型介紹在回歸分析中,如果有兩個或兩個以上的自變量,就稱為多元回歸
  • 中考數學考點:三角函數公式推導過程
    中考數學考點:三角函數公式推導過程 萬能公式推導 sin2α=2sinαcosα=2sinαcosα/(cos^2(α)+sin^2(α))...... 同理可推導餘弦的萬能公式。正切的萬能公式可通過正弦比餘弦得到。
  • 高等數學入門——基本導數公式的推導
    文章中的例題大多為紮實基礎的常規性題目和幫助加深理解的概念辨析題,並適當選取了一些考研數學試題。所選題目難度各異,對於一些難度較大或對理解所學知識有幫助的「經典好題」,我們會詳細講解。閱讀更多「高等數學入門」系列文章,歡迎關注數學若只如初見!上一節中介紹了導數和導函數的基本概念,對於求導運算有一套行之有效的方法,後面幾節會逐步介紹。
  • 淺談「兩角差的餘弦公式」之推導
    「兩角差的餘弦公式」在推導過程中具有重要的教育價值,蘊涵著換一個角度看問題的轉換思想,是數學家創造發明的法寶,也是我們進行再發現、再創造活動的探索方式。本文針對 「兩角差的餘弦公式的推導」章節進行學習,分析並推導兩角差的餘弦公式,實踐檢驗。筆者在近年來的各省數學高考試卷中發現,經常會出現考查數學教材中相關公式或定理的證明試題,比如證明兩角和的餘弦公式及餘弦定理等等。
  • 高二數學公式:三倍角公式及其推導
    高二數學公式:三倍角公式及其推導 2013-09-12 18:56 來源:網際網路 作者:新東方網整理
  • 【探究作業】三角形面積公式推導,看看孩子們的探究!
    三角形的面積公式推導,非常重要。任何一個多邊形都可以分割成若干個三角形,也就是可以轉化成求三角形的面積。那三角形面積的公式如何推導呢?雖然前面已經有了平行四邊形推導的轉化經驗,但是如果只用倍拼法加以解決,又顯得很單一,也不利於後續梯形面積公式的推導。三角形面積公式的推導方法一般可概括為「倍拼法」和「割補法」兩大類型。
  • 每天五分鐘自然語言理解NLP:RNN為什麼會有梯度消失和梯度爆炸?
    如上所示,是上一個章節推導的RNN的反向傳播的公式,這裡是計算的第三個時刻損失關於h(1)的偏導數,如果我們將其進行擴展,擴展為T時刻的損失對h(1)的偏導數,那麼此時可以表示為:     這個矩陣我們稱之為雅可比矩陣,將其代入到反向傳播公式,我們可以看到這個矩陣會連乘。
  • 很詳細的弧長公式推導過程,一看就明白
    為了消除△x的誤差,我們用微元(數學中的數學名詞),微元很小很小,要多小有多小,此時這個三角形就是直角三角形(為了方便理解上圖和下圖都對弧長畫大了,其實很小),我們知道直角三角形的計算公式,從而得出ds=dx+dy.這裡有人要問△x和dx究竟怎麼看,我個人理解這裡的dx是一個微元更加精確。
  • 初中數學知識點大全:三角函數公式推導過程
    初中數學知識點大全:三角函數公式推導過程 萬能公式推導 sin2α=2sinαcosα=2sinαcosα/(cos^2(α)+sin^2(α))...... 同理可推導餘弦的萬能公式。正切的萬能公式可通過正弦比餘弦得到。