作者 | 吳斯銘 穆永譽
單位| 東北大學自然語言處理實驗室
01 引言今天給大家介紹的外文精品博客是來自Huggingface研究員關于Reformer模型的講解。該博客從四個方面解讀了Reformer如何使用8GB的顯存完成模型在50萬token的長序列數據上的訓練,包括LSH自注意力、分塊FFN、可逆殘差、Axial位置編碼。該作者還通過實驗測試了各項技術對于内存開銷節省的實際效果。原博客内容較多,本文對其中核心的技術講解進行了概括。
02 作者介紹
Huggingface研究員:Patrick von Platen
03 譯者說Transformer模型由于其極強的序列分析能力而呈現出一統NLP和CV的趨勢,但是Transformer在處理長序列時會有很高的内存占用。這些内存占用通常來自以下幾個部分:
1.位置編碼。Transformer可以處理的序列的最長長度 被存儲位置編碼的張量的維度大小所限制。常用的解決辦法包括提高張量的維度大小(hidden_size),增強其存儲位置信息的能力從而提高模型處理序列的上限。但與此同時輸出還是超長序列(seq_len),這就導緻序列的表示(seq_len*hidden_size)變得非常大,最終導緻很高的内存開銷。
2.Self-Attention。在解碼時,為了避免重複計算,通常會緩存前面步驟的計算的key和value,當序列長度很長時,緩存結果會占用很多的内存。
3.FFN。FFN會将注意力模塊輸出的序列的表示線性變換為一個比詞嵌入維度更大的表示(seq_len*hidden_size ----> seq_len*ffn_size),随着序列長度的增加,這個表示的大小會急劇增加。
4.為了進行反向傳播而保存的中間結果。訓練模型時,為了通過反向傳播更新模型的參數,需要保留每一層計算的中間結果。當模型變深時,這些中間結果會造成很大的内存開銷。
常用的降低Transformer内存占用的技術包括:量化、模型壓縮、知識蒸餾。而Reformer則是一項針對長序列輸入應用場景的降低Transformer内存占用的工作。Reformer分别從改進模型結構和計算方式的角度出發,探究如何在模型性能、速度、内存占用這三者之間達到一個較好的trade-off。該工作發表于ICLR2020,截止22年6月8日谷歌學術查詢此工作已獲得779次引用。
另外,一個比較有意思的現象是,如果從速度的角度考慮,在目前的算力下優化Transformer,也是一個memory-bound的問題。本人參與了東北大學自然語言處理實驗室的神經機器翻譯開源項目NiuTrans.NMT,在對NiuTrans.NMT進行性能分析時,我發現數據移動是時間占比最大的操作。另外,MLSys 2021的best paper《Data Movement Is All You Need: A Case Study on Optimizing Transformers》也指出了在Transformer的訓練中,數據移動是瓶頸。我覺得這些現象可以啟示我們,在優化機器翻譯系統時,更多地從内存的角度考慮,或許可以達到事半功倍的效果。
04 原博客精華内容概括4.1 LSH(局部敏感哈希)自注意力
Reformer使用LSH自注意力作為全局自注意力的一個近似。LSH自注意力的想法是,當序列長度非常長時,一個query隻在某幾個key上的注意力權重會明顯大于0。因此隻對那些和query相似的key進行自注意力操作,在節省顯存的同時,也能得到一個對全局自注意力比較好的近似。
那麼,該如何尋找與query相似的key呢?
Reformer的作者發現共享query和key的投影矩陣不會影響模型性能。在Reformer中,query和key使用同一個投影矩陣構造。因此,尋找與query相似的key簡化成了對query進行聚類的問題。Reformer用餘弦相似度來衡量query間的相似性,并通過LSH算法将query分成若幹個類别。随後,Reformer根據類别重排序列,使在一個類别的query在序列中是相鄰的。最後,Reformer将序列分成若幹子序列,在每段子序列上執行局部自注意力,得到每個query的自注意力表示,并還原序列的順序。這樣,Reformer就通過LSH自注意力得到了全局自注意力的一個近似。
實驗證明随着序列長度的增長,局部敏感哈希自注意力很好地降低了顯存占用。
4.2 分塊FFN
如圖,從注意力模塊得到的序列的中間表示在FFN中做了一次線性變換之後,會使得這個中間表示的詞嵌入維度大幅升高,這一操作非常消耗顯存容量。比如Transformer Base中,線性變換前中間表示的詞嵌入維度是512,經過一次線性變換後,中間表示的詞嵌入維度就變成2048。而當序列很長時,這個中間表示的大小也會随着序列的增長而急劇增加,這樣的中間表示對顯存容量是個很大的挑戰。然而,我們真的需要保留這麼大的中間表示嗎?讓我們用下圖來分析FFN的計算過程:
如上圖所示,序列中每個token在FFN中的計算是相互獨立的。這就是說,對于每個token,實際上并不需要保留整個序列的中間表示來完成FFN中的計算。Reformer将原序列分成若幹子序列,在每段子序列上執行FFN,因此減少了中間表示的大小,從而緩解處理長序列時顯存不足的問題。如下圖所示。
實驗證明随着序列長度的增長,分塊FFN很好地降低了顯存占用。 4.3 可逆殘差 在訓練神經網絡時,由于反向傳播需要根據神經網絡每層的輸入輸出來計算該層的梯度,因此在訓練過程中通常需要保留神經網絡每一層的輸入和輸出,這導緻訓練模型比使用模型進行推理消耗更多的顯存。能不能既避免保留大量的輸入輸出,又正常計算梯度呢?Reformer通過隐式地保留張量之間的運算關系來做到這一點。
Reformer使用可逆殘差代替了Transformer的殘差連接,以此來節省中間計算結果的顯存占用。如圖所示為可逆殘差的結構,在前向計算時,Transformer每一層在使用了可逆殘差後有兩個輸出,分别是Y1和Y2(同時作為下一層的輸入X1、X2)。由于每一層的中間結果以及輸入都可以由該層的輸出Y1和Y2計算出來,因此該層除了輸出Y1、Y2以外,其它計算結果都是可以在完成該層的前向計算後抛棄的(如圖中的X1、X2、Z、Y),這節省了顯存的使用。由于一個層的輸出Y1、Y2也是下一層的輸入X1、X2,理論上,隻需要保存最後一層的激活,就足夠進行反向傳播。
博客中的實驗顯示,使用可逆殘差後,訓練層數更多的模型所需要的顯存有了明顯下降。
4.4 Axial位置編碼
存儲位置編碼的張量的維度大小決定了模型能處理的序列長度的上限。如果為了處理更長的序列,增加該張量的維度大小,那麼在面對長序列時會使内存占用變得非常大。例如,假設hidden_size是1024,需要處理的序列長度是50M,用來存儲位置編碼的張量的參數量達到了512M,也就是2GB的内存占用。能不能維持一個較小的維度,同時提高模型能處理的序列長度呢?
如圖所示,Reformer通過使用Axial位置編碼,對兩個短序列的位置編碼進行組合來表示長序列位置信息。将兩個短序列視為坐标軸,e'i的Axial位置編碼的值為對應坐标的位置編碼的拼接。Reformer使用的Axial位置編碼在不增加内存占用的情況下,提高了位置編碼可以表示的序列長度。
博客通過實驗證明,Axial位置編碼可以有效降低參數量,減少顯存占用。
詳細精彩内容請參見原文 The Reformer - Pushing the limits of language modeling (huggingface.co)
,更多精彩资讯请关注tft每日頭條,我们将持续为您更新最新资讯!