編碼的世界 / 優質文選 / 文明

RNN以及LSTM的介紹和公式梳理


2022年8月04日
-   

前言
好久沒用正兒八經地寫博客了,csdn居然也有了markdown的編輯器了,最近花了不少時間看RNN以及LSTM的論文,在組內『夜校』分享過了,再在這裏總結一下發出來吧,按照我講解的思路,理解RNN以及LSTM的算法流程並推導一遍應該是沒有問題的。
RNN最近做出了很多非常漂亮的成果,比如Alex Graves的手寫文字生成、名聲大振的『根據圖片生成描述文字』、輸出類似訓練語料的文字等應用,都讓人感到非常神奇。這裏就不細說這些應用了,我其實也沒看過他們的paper,就知道用到了RNN和LSTM而已O(∩_∩)O
本文就假設你對傳統的NN很熟悉了,不會的話參考http://ufldl.stanford.edu/wiki/index.php/UFLDL_Tutorial和我之前的文章http://blog.csdn.net/dark_scope/article/details/9421061學習一下~~
RNN(Recurrent Neural Network)
今天我這裏講到的RNN主要是上圖這種結構的,即是Hidden Layer會有連向下一時間Hidden Layer的邊,還有一種結構是Bidirectional Networks,也就是說會有來自下一時間的Hidden Layer傳回來的邊,但這不在我們今天的討論範圍內,講完LSTM,如果你想推導一下Bidirectional Network,應該也是順理成章的。為了方便推導和描述,我們後面都將左邊簡化為右邊這樣一個結構。
RNN和傳統的多層感知機不同的就是跟時間沾上邊了,下一時間(理解為step)會受本時間的影響,為了更好地說明這個東西,我們可以將網絡按照時間進行展開:
BPTT(Back Propagation Through Time)算法
將RNN展開之後,似乎一切都很明了了,前向傳播(Forward Propagation)就是依次按照時間的順序計算一次就好了,反向傳播(Back Propagation)就是從最後一個時間將累積的殘差傳遞回來即可,跟普通的神經網絡訓練並沒有本質上的不同。
前向傳播
直接上公式啦: 本文用到的公式基本來自Alex的論文,其中a表示匯集計算的值,b表示經過激活函數計算的值,w是不同節點之間連接的參數(具體睡誰連誰看下標),帶下標k的是輸出層,帶下標h的是隱藏層相關的,除此之外你看到所有帶括號的的函數都是激活函數,
ϵ
epsilon 和
δ
delta 的定義看公式,

mathcal L 是最後的Loss function,這裏沒有給出具體的計算方法,因為這和NN是一樣的,可以看到輸出層和普通的NN是完全一樣的,接收隱藏層傳入的數據並乘以參數求和,只是每一個計算出來的值都有個時間上標t,表示它是t時刻的那個節點。
而隱藏層的計算就是和NN不同的地方,從之前的拓撲圖也看到了,隱藏層會接受來自上一時間隱藏層傳入的數據,在公式裏也體現出來了:第一個求和是和NN一致的,接收來自輸入層的數據,第二個是接收來自上一隱藏層的數據。
後向傳播
這裏主要給出的是計算隱藏層的累積殘差的公式,因為輸出層和經典的NN是一回事,可以看到第一個公式括號中的兩個部分,一個是接收當前時間輸出層傳回的殘差,第二個是接收下一時間隱藏層傳回的殘差,看著上面的圖其實非常好理解。
LSTM(Long-Short Term Memory)
原生的RNN會遇到一個很大的問題,叫做 The vanishing gradient problem for RNNs,也就是後面時間的節點對於前面時間的節點感知力下降,也就是忘事兒,這也是NN在很長一段時間內不得志的原因,網絡一深就沒法訓練了,深度學習那一套東西暫且不表,RNN解決這個問題用到的就叫LSTM,簡單來說就是你不是忘事兒嗎?我給你拿個小本子把事記上,好記性不如爛筆頭嘛,所以LSTM引入一個核心元素就是Cell。
與其說LSTM是一種RNN結構,倒不如說LSTM是RNN的一個魔改組件,把上面看到的網絡中的小圓圈換成LSTM的block,就是所謂的LSTM了。那它的block長什麼樣子呢? 怎麼這麼複雜……不要怕,下文慢慢幫你縷清楚。理解LSTM最方便的就是結合上面這個圖,先簡單介紹下裏面有幾個東西:
  • Cell,就是我們的小本子,有個叫做state的參數東西來記事兒的
  • Input Gate,Output Gate,在參數輸入輸出的時候起點作用,算一算東西
  • Forget Gate:不是要記東西嗎,咋還要Forget呢。這個沒找到為啥就要加入這樣一個東西,因為原始的LSTM在這個位置就是一個值1,是連接到下一時間的那個參數,估計是以前的事情記太牢了,最近的就不住就不好了,所以要選擇性遺忘一些東西。(沒找到解釋設置這個東西的動機,還望指正)

  • 在閱讀下面公式說明的時候時刻記得這個block上面有一個輸出節點,下面有一個輸入節點,block只是中間的隱層小圓圈~~~
    前向傳播
    一大波公式正在路上。。。。。公式均來自Alex的論文 我們按照一般算法的計算順序來給出每個部分的公式:
    Input Gate

    帶下標L的就是跟Input Gate相關的,回去看上面那個圖,看都有誰連向了Input Gate:外面的輸入,來自Cell的那個虛線(虛線叫做peephole連接),這在公式立體現在4.2的第一項和第三項,計算就是普通的累積求和。那中間那個是個什麼鬼? 帶H的是一個泛指,因為LSTM的一個重要特點是其靈活性,cell之間可以互聯,hidden units之間可以互聯,至於連不連都看你(所以你可能在不同地方看到的LSTM公式結構都不一樣)所以這個H就是泛指這些連進來的東西,可以看成是從外面連進了的三條邊的一部分。 至於4.3就是簡單的激活函數計算而已
    Forget Gate
    再回去看那個圖,連到Forget Gate都有哪些:輸入層的輸入、泛指的輸入、來自cell的虛線,這個和Input Gate就是一回事嘛
    Cells
    還是老樣子,回去看都有啥連到了Cell(這裏的cell不是指中間那個Cell,而是最下面那個小圓圈,中間的Cell表示的其實是那個狀態值S[c][t]):輸入層的輸入,泛指的輸入。(這體現在4.6式中) 再看看中間的那個Cell狀態值都有誰連過去了:這次好像不大一樣,連過去的都是經過一個小黑點匯合的,從公式也能體現出來,分別是:ForgetGate*上一時間的狀態 + InputGate*Cell激活後的值
    Output Gate
    老樣子,看誰連到了Output Gate:跟其他幾個Gate好像完全一樣嘛~咦,4.8那個S[c][t]為啥是t,以前都是t-1啊。 這裏我也沒找到相關的原因,可以理解為在計算OG的時候,S[c][t]已經被計算出來了,所以就不用使用上一時間的狀態值了(同樣動機不明~~這就是設定好嘛。。。)
    最後最後的輸出
    小黑點,用到了激活後的狀態值和Output Gate的結果。 一定按照圖的連接來捋一捋,公式還是非常清晰的。
    後向傳播
    又一波公式來襲。。。。。。
    這次就只貼公式了,因為要每個都講一下實在是太費功夫了,記住一個要點就是『看上面的圖!!』,看看每個要求偏導的東西都有誰會反向傳回東西給它,可以看到最複雜的就是4.13了,因為這是對那個狀態值求導,它不光連向了三個門(公式後三項,兩個本下一時刻,FG是本時刻的),還連向了最後的輸出b[c][t](公式第一項)以及下一時刻的自己(公式第二項),反向傳播公式推導用到的唯一數學工具就是鏈式法則,你要覺得求偏導看不懂,就把它拆成鏈看就好了。
    還有一點,記得最後的Loss Function是每一時間的一個求和,所以當你算當前層輸出層傳回來的殘差都時候就可以忽略其它東西了,舉個例子:4.11是對b[c][t]求偏導,而b[c][t]是正向傳播LSTM block的輸出,輸出到誰了?當前層的輸出層,下一層的Hidden Layer,這兩個東西的最後的Loss function是分開的,彼此之間沒有關系,所以公式裏是兩部分相加。4.11中的G和之前的H一樣,也是泛指,因為它不一定只輸出到下一時間的自己,可能還會到下一時間的其他隱層unit,G代表什麼純看你怎麼確定的網絡結構。
    ϵxt=∂∂btc=∑kK∂∂atk∂atk∂btc+∑gG∂∂at+1g∂at+1g∂btc=(4.11)
    egin{equation}
    epsilon_{t}^{x} = dfrac{partial mathcal L}{partial b_{c}^{t}} =sum_k^Kdfrac{partial mathcal L}{partial a_{k}^{t}}dfrac{partial a_{k}^{t}}{partial b_{c}^{t}}+ sum_g^Gdfrac{partial mathcal L}{partial a_{g}^{t+1}}dfrac{partial a_{g}^{t+1}}{partial b_{c}^{t}}=(4.11)
    end{equation}
    後記
    推導一遍之後你完全可以自己實現一次了,用到的東西也不複雜,可惜對於RNN和DL這些東西來說,確定網絡結構和調參才是對最後效果有著決定性的影響,RNN和LSTM裏可以調的東西太多了,每一個未知的激活函數選擇,具體網絡到底怎麼連接,還有學習速率這種老問題。也是個大工程的說 ps.這MD的編輯器還可以啊~~!!
    引用
    1A. Graves. Supervised Sequence Labelling with Recurrent Neural Networks. Textbook, Studies in Computational Intelligence, Springer, 2012.

    熱門文章