NLP

【自然語言處理】Longformer 中文長文本 BERT 模型 – 新聞分類實作

the chronicles of narnia book

使用 pre-trained model BERT 進行各種 NLU 下游任務 fine-tuning 時,會有輸入限制 512 字符的限制,真實世界文章往往超過這個長度,因此如何應用 BERT 到長文本上是一個重要的議題。

傳統 BERT 長文本輸入處理方式

我在搜集資料時,找到此篇文章整理了 3 種克服方式:

  1. 截斷法:就是很直觀的取其中的 512 字符輸入,也是已知最常用的暴力解
    • head 截斷:從頭開始保留限制的 token 數
    • tail 截斷:從文本末端往前擷取
    • head + tail 截斷:開頭和結尾各保留一部分
  2. Pooling 法
    • 把長文本分成多個 512 字以內的 segments (段落),對每個段落都轉換為 BERT Embedding 再對 [CLS] 進行 max-pooling (比 mean-pooling 推薦,因為 max-pooling 會保留突出的特徵,而 mean-pooling 會將特徵打平,原本的特徵已經很稀疏,打平更會減弱特徵效果。)
    • 作者提到這種方法的缺點:拆分文本的方式會考慮不到長距離字符之間的上下文關聯,因為他們被拆開到不同 segment 沒辦法進行 attention
  3. 壓縮法
    • 按照句子分成多個 segments,接著用規則或單獨訓練一個模型的方式來排除一些權重較低的 segments
    • 和 Pooling 法一樣有無法考慮長距離關聯的問題

另一文章也提到,BERT 之所以難以處理長文本,是因為傳統 Transformer-based 模型採用「全連接」的 Attention 機制,每一token都要和另一token組合,attention的時間&空間複雜度高達 O(n^2)

後續針對長文本的處理所提出的新模型算法,以下這 2 個都算是使用到「限制Attention範圍」的概念來降低前述提及的複雜度:

  • Longformer
  • CogLTX

這篇我們以介紹 Longformer 為主。

Longformer 算法介紹

  • 發表單位:Allen AI
  • 發表時間:2020 年
  • 是否開源代碼:Yes
  • 論文亮點:
    • 一種可高效處理長文本的升級版 Transformer
    • Longformer 改進 Transformer 傳統 attention 機制:對每一 token 「只對固定窗口大小附近的 token」計算 local attention,再結合下游任務計算少量 global attention,將原始 Transformer 的複雜度降至 O(n)
    • Longformer 可通用於不同文檔級任務
    • 論文作者使用 Longformer 的 attention 方法繼續預訓練 RoBERTa,所得的模型在多個文檔級的任務 fine-tuned 後超越原 RoBERTa 成效
  • 模型算法介紹
    • 作者以「降低複雜度」為目的提出三種新的 attention 模式:
      • 滑窗機制 (Sliding window attention):對每一 token 只對附近 w 個 token 計算 attention,w 大小隨任務不同調整,複雜度 O(w*n)
      • 膨脹滑窗機制 (Dilated sliding window):考量更全面的上下文資訊。在滑動窗格中,被 attention 的 2 token 之間存在大小為 d 的空隙,可展開比普通滑窗機制更廣的 attention 範圍、又不增加計算負荷。論文中實驗證明這種方法表現得比第一種好。
        • 注:此方法來自於影像辨識中常用的圖像編碼方式 Dilated CNN 空洞卷積,用來拓寬模型編碼的視野範圍
      • 融合全局訊息的滑窗機制 (Global + sliding window):根據下游任務的不同添加少量 global attention (例如分類任務的 [CLS] 處會有一個 global attention)。也就是部分 token 會被添加 global attention
  • 模型實驗關鍵發現
    • 如果 Transformer 由底層至高層遞增滑動窗格大小,有助於提高成效;反之則降低
    • 增加空隙的滑窗機制有小幅的成效提升
    • 驗證成效提升是來源於新型 attention 機制,而非對 RoBERTa 的加強預訓練,進行多組將舊 attention 機制用於 Longformer 的實驗,這些實驗並沒有取得跟使用新型 attention 機制的模型表現一樣好。

[實作] 訓練 Longformer 文本分類器


0. 下載 Longformer_zh 模型 並 git clone repo:

$ git clone https://github.com/ValkyriaLenneth/Longformer_ZH.git
$ cd Longformer_ZH
$ touch zh_trainer.py
  1. 載入所需套件

2. 自定義 Dataset 物件

3. 定義驗證/預測函式

4. 定義訓練函式

5. 載入資料集、輸入模型訓練、啟動訓練

開始訓練囉
  • 程式邏輯和一般在 fine-tune BERT 模型時一致,模型物件則是基於官方定義的分類模型物件 BertForSequenceClassification,載入 Longformer 中文版預訓練模型(這邊就沒有另外展開客製)
  • Longformer 的輸入長度需為 512 的倍數,我這邊是以最靠近樣本最大長度的 512 倍數為 max_length
  • 為了有充足的記憶體進行訓練,透過 nn.DataParallel 將資料放到多顆 GPUs 平行處理;在訓練函式中, loss.sum().backward() 則是需要將多顆 GPU 預測結果的 loss 加總一起計算
分散資料的處理用量到兩顆GPU中
  • 初步先訓練 6 epochs 的效果, Accuracy 約為 0.826

%d 位部落客按了讚: