使用 pre-trained model BERT 進行各種 NLU 下游任務 fine-tuning 時,會有輸入限制 512 字符的限制,真實世界文章往往超過這個長度,因此如何應用 BERT 到長文本上是一個重要的議題。
傳統 BERT 長文本輸入處理方式
我在搜集資料時,找到此篇文章整理了 3 種克服方式:
- 截斷法:就是很直觀的取其中的 512 字符輸入,也是已知最常用的暴力解
- head 截斷:從頭開始保留限制的 token 數
- tail 截斷:從文本末端往前擷取
- head + tail 截斷:開頭和結尾各保留一部分
- Pooling 法
- 把長文本分成多個 512 字以內的 segments (段落),對每個段落都轉換為 BERT Embedding 再對
[CLS]
進行 max-pooling (比 mean-pooling 推薦,因為 max-pooling 會保留突出的特徵,而 mean-pooling 會將特徵打平,原本的特徵已經很稀疏,打平更會減弱特徵效果。) - 作者提到這種方法的缺點:拆分文本的方式會考慮不到長距離字符之間的上下文關聯,因為他們被拆開到不同 segment 沒辦法進行 attention
- 把長文本分成多個 512 字以內的 segments (段落),對每個段落都轉換為 BERT Embedding 再對
- 壓縮法
- 按照句子分成多個 segments,接著用規則或單獨訓練一個模型的方式來排除一些權重較低的 segments
- 和 Pooling 法一樣有無法考慮長距離關聯的問題
另一文章也提到,BERT 之所以難以處理長文本,是因為傳統 Transformer-based 模型採用「全連接」的 Attention 機制,每一token都要和另一token組合,attention的時間&空間複雜度高達 O(n^2)。
後續針對長文本的處理所提出的新模型算法,以下這 2 個都算是使用到「限制Attention範圍」的概念來降低前述提及的複雜度:
- Longformer
- CogLTX
這篇我們以介紹 Longformer 為主。
Longformer 算法介紹
- 發表單位:Allen AI
- 發表時間:2020 年
- 是否開源代碼:Yes
- 論文亮點:
- 一種可高效處理長文本的升級版 Transformer
- Longformer 改進 Transformer 傳統 attention 機制:對每一 token 「只對固定窗口大小附近的 token」計算 local attention,再結合下游任務計算少量 global attention,將原始 Transformer 的複雜度降至 O(n)
- Longformer 可通用於不同文檔級任務
- 論文作者使用 Longformer 的 attention 方法繼續預訓練 RoBERTa,所得的模型在多個文檔級的任務 fine-tuned 後超越原 RoBERTa 成效
- 模型算法介紹
- 作者以「降低複雜度」為目的提出三種新的 attention 模式:
- 滑窗機制 (Sliding window attention):對每一 token 只對附近 w 個 token 計算 attention,w 大小隨任務不同調整,複雜度 O(w*n)
- 膨脹滑窗機制 (Dilated sliding window):考量更全面的上下文資訊。在滑動窗格中,被 attention 的 2 token 之間存在大小為 d 的空隙,可展開比普通滑窗機制更廣的 attention 範圍、又不增加計算負荷。論文中實驗證明這種方法表現得比第一種好。
- 注:此方法來自於影像辨識中常用的圖像編碼方式 Dilated CNN 空洞卷積,用來拓寬模型編碼的視野範圍
- 融合全局訊息的滑窗機制 (Global + sliding window):根據下游任務的不同添加少量 global attention (例如分類任務的 [CLS] 處會有一個 global attention)。也就是部分 token 會被添加 global attention
- 作者以「降低複雜度」為目的提出三種新的 attention 模式:
- 模型實驗關鍵發現
- 如果 Transformer 由底層至高層遞增滑動窗格大小,有助於提高成效;反之則降低
- 增加空隙的滑窗機制有小幅的成效提升
- 驗證成效提升是來源於新型 attention 機制,而非對 RoBERTa 的加強預訓練,進行多組將舊 attention 機制用於 Longformer 的實驗,這些實驗並沒有取得跟使用新型 attention 機制的模型表現一樣好。
[實作] 訓練 Longformer 文本分類器
- 模型:
- 上面介紹的論文實作的是英文版本 allenai/longformer
- 中文版本我找到了由 ValkyriaLenneth 開源的預訓練中文 Longformer 模型,這篇教學我用他的模型對文本分類任務進行 fine-tune。
- 資料集:
- 使用 中國科大訊飛 公開資料集,共 1.7 萬多條關於 app 應用描述的長文本標注數據,共 119 個類別
- 更多中文語意理解評測基準資料集請見:https://github.com/CLUEbenchmark/CLUE
- 模型訓練程式:部分參考 《進擊的 BERT:NLP 界的巨人之力與遷移學習》、Transfer Learning for NLP: Fine-Tuning BERT for Classification
0. 下載 Longformer_zh 模型 並 git clone repo:
(2023.03 更新:原作者也將模型上傳到 huggingface 了,因此也可以從這邊下載 https://huggingface.co/ValkyriaLenneth/longformer_zh)
$ git clone https://github.com/ValkyriaLenneth/Longformer_ZH.git
$ cd Longformer_ZH
$ touch zh_trainer.py
- 載入所需套件
2. 自定義 Dataset 物件
3. 定義驗證/預測函式
4. 定義訓練函式
5. 載入資料集、輸入模型訓練、啟動訓練
- 程式邏輯和一般在 fine-tune BERT 模型時一致,模型物件則是基於官方定義的分類模型物件 BertForSequenceClassification,載入 Longformer 中文版預訓練模型(這邊就沒有另外展開客製)
- Longformer 的輸入長度需為 512 的倍數,我這邊是以最靠近樣本最大長度的 512 倍數為 max_length
- 為了有充足的記憶體進行訓練,透過
nn.DataParallel
將資料放到多顆 GPUs 平行處理;在訓練函式中,loss.sum().backward()
則是需要將多顆 GPU 預測結果的 loss 加總一起計算
- 初步先訓練 6 epochs 的效果, Accuracy 約為 0.826