基於AttentionXML的Extreme Multi-Label Text Classification

2021-02-20 自然語言處理算法與實踐
1 前言

今天分享一篇2019年NIPS會議上一篇paper,方向為multi-label classification。論文題目為:AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification。論文下載連結為:https://arxiv.org/pdf/1811.01727.pdf,項目也開源出了代碼:https://github.com/yourh/AttentionXML。

概要來說,本篇paper是提出一種基於Attention機制的label 樹模型,來解決大規模多標籤文本分類(Extreme multi-label text classification (XMTC))問題。研究出發點為:(1)先前的方法沒有充分學習輸入的文本與每個label之間的關係信息;(2)對大規模的label並沒有進行一個可伸縮性的學習。針對兩個問題,文中對應提出對應創新點:(1)引入multi-label attention mechanism for XMTC (AttentionXML ),為每個label捕捉最相關的特徵信息;(2)引入probabilistic label tree (PLT)結構體系,處理百萬級別的label集合。

2 Model

本篇論文的核心就是兩個概念AttentionXML與PLT,就重點講述這兩個概念,先說下probabilistic label tree (PLT)。

2.1 Probabilistic Label Tree(PLT)

該概念是最先由K. Jasinska在2016年提出(Extreme f-measure maximization using sparse probability estimates.),解決extreme-scale datasets。該概念的想法是基於:在extreme labels集合中,是存在樹狀層級結構的,樹結構可能是真實存在,也可能是潛在的(本文解決的方向)。所以像之前XML-CNN之類的方法把所有labels看成一個平行的結構來看待,這樣導致所有的label都基於一個共同的表徵向量來學習預測,沒法差異性的學到與每個label最相關的信息。

Probabilistic Label Tree(PLT)的提出,就是利用概率的思想在label集合中構建一個Label Tree,基於樹的結構去訓練模型,預測label。構建的基本方法就是:通過每類的label文本,學到到該label表徵向量,然後用遞歸聚類(KMeans)構建Label Tree,即葉子節點是一個真的標籤,非葉子節點是一個虛擬標籤。Parabel在2018年提出一種遞歸分裂聚類方式構建一顆二叉平衡樹(主要意思是:每個節點下的子節點是有數量限制的,不能超過一個範圍),在XMTC任務取得當時最好的效果。

本文就是Parabel的基礎上提出了改進思路:認為Parabel方法構建的樹深度H(不包括根節點和葉子節點)太深,而樹的深度越深,label聚類錯誤可能性增大,訓練預測效率也降低;此外,許多尾部標籤與其他不同的標籤組合在一起,並分組到一個集群中,損害了尾部標籤的識別效果。所以,針對上述問題,本文提出一種方法,構建一個淺的(H很小)並且寬的PLT。

上圖為一個PLT示意圖,方形是樹的葉子,代表所有的標籤;圓形代表樹的節點,是構建的偽標籤;即L為label集合,H=2為去掉root與leaf樹的高度,K=M=4代表KMeans聚類的K值和每個節點下面的最大容量。在這樣的結構下,一個樣本在每個節點(z_n)的概率有如下計算模式:

其中Pa(n)是節點n的父節點,Path(n)是從根節點到node n上路徑的節點集合。

接著介紹本文是如何生成PLT:

具體的是,作者通過將每個標籤下文本的BOW特徵求和獲得該標籤的特徵向量,然後通過一個K-Means循環將這些標籤切分成兩個cluster,直到每個節點下面的標籤數小於M,這些cluster對應著樹的內部節點。這裡我理解的是,先完全按照二叉樹的方式解析成一個樹T_0,只是到最後一層的時候,按不超過M進行合併展開;接著按一種方式進行剪裁,將T_0樹遞歸方式壓縮到一個淺且寬的樹T_h,如下圖。這部分文中說的不是特別詳細,如果想弄清楚,需去了解Parabel的paper。

上圖顯示的是構建一顆PLT樹的過程,K=M=8,H=3,L=8000(M是max_leaf是最大葉子節點數,如果某個葉子裡面的標籤數超過8,就會切分該節點,H表示樹高-2,L表示一共的標籤數)。T0表示level=0的樹,裡面的數字表示樹每個高度的節點數。Th中的紅色數字對應的節點會被移掉,以為了獲得Th+1 樹。可以理解為:最後一層節點數10248>8000,已涵蓋所有label了,而1288=1024>512>256,所以512,256這兩層可以刪掉;同樣的方式,一直把樹的深度裁剪到預定的H=3的高度,且每個節點的容量都不超過M=8。其中注意,root節點是不受M值限制的。**

下兩圖為PLT生成的偽代碼和文本作者在三個數據集上生成的PLT情況:

2.2 AttentionXML

在構建好了PLT後,文中作者採用的是層級方式來訓練模型的。具體包括:
(1)從上至下地給每個level單獨的訓練一個模型,每個模型都是一個多標籤學習;
(2)level-d的樹AttentionXML模型,是通過每個樣本的候選標籤g(x)訓練的。我們對第(d-1)層AttentionXML模型預測的標籤的從正到負,得分由高到底進行排序。我們選擇第(d-1)層的top C標籤作為下一層訓練的候選標籤g(x)。這就像是一種額外的負採樣,相比於只使用節點的正樣本,我們可以得到一個更精確的對數似然近似值。
(3)在預測階段,對於第i個樣本,第j個標籤的預測得分y_i,j通過概率鏈式法則很容易獲得。為了預測效率,我們使用beam search算法,對於第d層的,我們只預測d-1層top C的標籤。

每層的AttentionXML模型結構如下圖,主要包含5層:Word Representation Layer,Bidirectional LSTM Layer,Multi-label Attention Layer,Fully Connected Layer,Output Layer。

在Multi-label Attention Layer上,計算方式如下,就是常規的attention計算方式,讓不同的label學習到與文本向量h_i不同的權重信息。

在Fully Connected Layer上,文中是採用各個層級模型共享的方式,主要目的減少模型複雜度。

                                        預測階段的偽代碼

3 總結

這次主要分享的目的就是在處理大規模多標籤文本分類任務時,如何使用層級分類的思路解決該任務,提高識別效果。本文的實驗結果就不分析了。文中提出的構建淺且寬的PLT樹思路,可以借鑑,類似是一種折衷的方案,既不把label都視為平行結構,也不能把label構建成特別深的樹結構,影響學習效率。但其中使用KMeans去構建PLT樹,我內心是有點懷疑的,這應該會產生分類誤差,沒有一個衡量構建好壞或更合理的指標,不過目前我也沒想到好的方法,留給大家去思考了.

相關焦點

  • Multi-Label Classification with Deep Learning
    In this tutorial, you will discover how to develop deep learning models for multi-label classification.
  • Keras-TextClassification 文本分類工具包
    /data'step3: goto # Train&Usage(調用) and Predict&Usage(調用)keras_textclassification(代碼主體,未完待續...)
  • 如何用 Python 和 BERT 做多標籤(multi-label)文本分類?
    一文裡,給你講過遷移學習的範例 ULMfit (Universal language model fine-tuning for text classification)。DATA_PATH = Path('demo-multi-label-classification-bert/sample/data/')LABEL_PATH = Path('demo-multi-label-classification-bert/sample/labels/')BERT_PRETRAINED_MODEL = "bert-base-uncased
  • 分類問題-----多標籤(multilabel)、多類別(multiclass)
    另外,以下幾個問題是需要關注和進一步研究的Dimensionality Reduction 降維Label Dependence 標註依賴Active learning 主動學習Multi-instance multi-label learning (MIML) 多實例多標籤Multi-view learning.
  • 使用Multi-Label訓練CNN能否達到Detection的效果?
    作者:我愛機器學習(52ml.net)連結:https://www.zhihu.com/question/52143412/answer/130037578問題在只需要時候輸出image含有object的label而不需要定位的情況下,使用multi-label訓練一個分類網絡(例如 resnet)能否達到object detection的效果。
  • 基於TorchText的PyTorch文本分類
    import torchimport torchtextfrom torchtext.datasets import text_classificationimport osimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoaderimport timefrom
  • NLP03:基於TF-IDF和LogisticRegression的文本分類
    "基於tfidf和樸素貝葉斯或者LogisticRegression的文本分類"""def textSegment(filePath):    """    讀取文本文件並進行分詞    :param filePath:文件路徑    :return:    """    textLines = open(filePath
  • Extreme US weather 美國遭遇極端惡劣天氣
    The USA has been suffering from extreme weather conditions美國連日來不斷遭遇惡劣天氣襲擊。The extreme conditions are being blamed for at least 25 deaths, mostly in traffic accidents on roads that are slick with ice.