CNN主要處理圖像信息,主要應用于計算機視覺領域。
RNN(recurrent neural network)主要就是處理序列數(shù)據(jù)(自然語言處理、語音識別、視頻分類、文本情感分析、翻譯),核心就是它能保持過去的記憶。但RNN有著梯度消失問題,專家之后接著改進為LSTM和GRU結構。下面將用通俗的語言分別詳細介紹。
《無廢話的機器學習筆記》、《一文極速理解深度學習》、《一文總結經(jīng)典卷積神經(jīng)網(wǎng)絡CNN模型》
RNN(Recurrent Neural Network)
RNN中的處理單元,中間綠色就是過去處理的結果,左邊第一幅圖就是正常的DNN,不會保存過去的結果,右邊的圖都有一個特點,輸出的結果(藍色)不僅取決于當前的輸入,還取決于過去的輸入!不同的單元能賦予RNN不同的能力,如 多對一就能對一串文本進行分類,輸出離散值,比如根據(jù)你的言語判斷你今天高不高興。
RNN中保存著過去的信息,輸出取決于現(xiàn)在與過去。如果大伙學過數(shù)電,這就是狀態(tài)機!這玩意跟觸發(fā)器很像。
有個很重要的點:
這個權重fw沿時間維度是一致的,權值共享。就像CNN中一個卷積核在卷積過程中參數(shù)一致。所以CNN是沿著空間維度權值共享;RNN是沿著時間維度權值共享。
具體來說有三個權重,過去與現(xiàn)在各一個權重,加起來再來一個權重。 它們都沿著時間維度權值共享。不然每個時間都不一樣權重,參數(shù)量會很恐怖。
整體的計算圖(多對多):
每次的輸出y可以與標簽值構建損失函數(shù),這樣就跟之前DNN訓練模型思想一樣,訓練3套權重使損失函數(shù)不斷下降至滿意。
反向傳播要沿時間反向傳回去(backpropagation through time,BPTT)
Forward through entire sequence to compute loss, then backward through entire sequence to compute gradient.
這樣會有問題,就是一下子把全部序列弄進來求梯度,運算量非常大。實際我們會將大序列分成等長的小序列,分別處理:
不同隱含層中不同的值負責的是語料庫中不同的特征,所以隱含狀態(tài)的個數(shù)越多,模型就越能捕獲文本的底層特征。
下面來看一個例子:字符級語言模型(由上文預測下文):
我想輸入hell,然后模型預測我會輸出o;或者我輸入h,模型輸出e,我再輸入e,模型輸出l…
首先對h,e,l,o進行獨熱編碼,然后構建模型進行訓練。
輸入莎士比亞的劇本,讓模型自己生成劇本,訓練過程:
輸入latex文本,讓模型自己生成內(nèi)容,公式寫得有模有樣的,就不知道對不對:
當然輸入代碼,模型也會輸出代碼。所以現(xiàn)在火熱的Chatgpt的本質就是RNN。
對于圖像描述,專家會先用CNN對圖像進行特征抽取(編碼器),然后將特征再輸入RNN進行圖像描述(解碼器)。
還可以結合注意力機制(Image captioning with attention):
普通堆疊的RNN一旦隱含層變多變深,反向傳播時就很容易出現(xiàn)梯度消失/爆炸。
子豪兄總結得非常好,以最簡單的三層網(wǎng)絡來看,對于輸出的O3可以列出損失函數(shù)L3,對L3進行求偏導,分別對輸出權重w0,輸入權重wx,過去權重ws進行求導。我們發(fā)現(xiàn)對w0求偏導會很輕松。
但是,由于鏈式法則(chain rule),對輸入權重wx和過去權重ws求偏導就會很痛苦。在表達式里,對于越是前面層的鏈式求導,乘積項越多,所以很容易梯度消失/爆炸,梯度消失占大多數(shù)。
LSTM(Long Short-Term Memory)
長短時記憶神經(jīng)網(wǎng)絡(LSTM) 應運而生!
LSTM既有長期記憶也有短期記憶,包括遺忘門、輸入門、輸出門、長期記憶單元。右圖紅色函數(shù)是sigmoid,藍色函數(shù)是tanh。
C是長期記憶,h是短期記憶。
所以當前輸出ht是由短期記憶產(chǎn)生的。
我們看到長期記憶那條線是貫通的,且只有乘加操作。
LSTM算法詳解:
下面幾個圖完美解釋了:
所以總共有四個權重:Wf、Wi、Wc、Wo,當然還有它們對應的偏置項。
整體過程可以概括為:遺忘、更新、輸出。(更新包括先選擇保留信息,再更新最新記憶。)
原論文中的圖也非常形象:
現(xiàn)在反向傳播求偏導就舒服了
GRU(Gated Recurrent Unit)
GRU也能很好解決梯度消失問題,結構簡單一點,主要就是重置門和更新門。
GRU與LSTM對比:
- 參數(shù)數(shù)量:GRU的參數(shù)數(shù)量相對LSTM來說更少,因為它將LSTM中的輸入門、遺忘門和輸出門合并為了一個門控單元,從而減少了模型參數(shù)的數(shù)量。
LSTM中有三個門控單元:輸入門、遺忘門和輸出門。每個門控單元都有自己的權重矩陣和偏置向量。這些門控單元負責控制歷史信息的流入和流出。
GRU中只有兩個門控單元:更新門和重置門。它們共享一個權重矩陣和一個偏置向量。更新門控制當前輸入和上一時刻的輸出對當前時刻的輸出的影響,而重置門則控制上一時刻的輸出對當前時刻的影響。 - 計算速度:由于參數(shù)數(shù)量更少,GRU的計算速度相對LSTM更快。
- 長序列建模:在處理長序列數(shù)據(jù)時,LSTM更加優(yōu)秀。由于LSTM中引入了一個長期記憶單元(Cell State),使得它可以更好地處理長序列中的梯度消失和梯度爆炸問題。
GRU適用于:
處理簡單序列數(shù)據(jù),如語言模型和文本生成等任務。
處理序列數(shù)據(jù)時需要快速訓練和推斷的任務,如實時語音識別、語音合成等。
對計算資源有限的場景,如嵌入式設備、移動設備等。
LSTM適用于:
處理復雜序列數(shù)據(jù),如長文本分類、機器翻譯、語音識別等任務。
處理需要長時依賴關系的序列數(shù)據(jù),如長文本、長語音等。
對準確度要求較高的場景,如股票預測、醫(yī)學診斷等。
公式總結: