今天分享一篇2019年NIPS會議上一篇paper,方向為multi-label classification。論文題目為:AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification。論文下載連結為:https://arxiv.org/pdf/1811.01727.pdf,項目也開源出了代碼:https://github.com/yourh/AttentionXML。
概要來說,本篇paper是提出一種基於Attention機制的label 樹模型,來解決大規模多標籤文本分類(Extreme multi-label text classification (XMTC))問題。研究出發點為:(1)先前的方法沒有充分學習輸入的文本與每個label之間的關係信息;(2)對大規模的label並沒有進行一個可伸縮性的學習。針對兩個問題,文中對應提出對應創新點:(1)引入multi-label attention mechanism for XMTC (AttentionXML ),為每個label捕捉最相關的特徵信息;(2)引入probabilistic label tree (PLT)結構體系,處理百萬級別的label集合。
2 Model本篇論文的核心就是兩個概念AttentionXML與PLT,就重點講述這兩個概念,先說下probabilistic label tree (PLT)。
2.1 Probabilistic Label Tree(PLT)該概念是最先由K. Jasinska在2016年提出(Extreme f-measure maximization using sparse probability estimates.),解決extreme-scale datasets。該概念的想法是基於:在extreme labels集合中,是存在樹狀層級結構的,樹結構可能是真實存在,也可能是潛在的(本文解決的方向)。所以像之前XML-CNN之類的方法把所有labels看成一個平行的結構來看待,這樣導致所有的label都基於一個共同的表徵向量來學習預測,沒法差異性的學到與每個label最相關的信息。
Probabilistic Label Tree(PLT)的提出,就是利用概率的思想在label集合中構建一個Label Tree,基於樹的結構去訓練模型,預測label。構建的基本方法就是:通過每類的label文本,學到到該label表徵向量,然後用遞歸聚類(KMeans)構建Label Tree,即葉子節點是一個真的標籤,非葉子節點是一個虛擬標籤。Parabel在2018年提出一種遞歸分裂聚類方式構建一顆二叉平衡樹(主要意思是:每個節點下的子節點是有數量限制的,不能超過一個範圍),在XMTC任務取得當時最好的效果。
本文就是Parabel的基礎上提出了改進思路:認為Parabel方法構建的樹深度H(不包括根節點和葉子節點)太深,而樹的深度越深,label聚類錯誤可能性增大,訓練預測效率也降低;此外,許多尾部標籤與其他不同的標籤組合在一起,並分組到一個集群中,損害了尾部標籤的識別效果。所以,針對上述問題,本文提出一種方法,構建一個淺的(H很小)並且寬的PLT。
上圖為一個PLT示意圖,方形是樹的葉子,代表所有的標籤;圓形代表樹的節點,是構建的偽標籤;即L為label集合,H=2為去掉root與leaf樹的高度,K=M=4代表KMeans聚類的K值和每個節點下面的最大容量。在這樣的結構下,一個樣本在每個節點(z_n)的概率有如下計算模式:
其中Pa(n)是節點n的父節點,Path(n)是從根節點到node n上路徑的節點集合。
接著介紹本文是如何生成PLT:
具體的是,作者通過將每個標籤下文本的BOW特徵求和獲得該標籤的特徵向量,然後通過一個K-Means循環將這些標籤切分成兩個cluster,直到每個節點下面的標籤數小於M,這些cluster對應著樹的內部節點。這裡我理解的是,先完全按照二叉樹的方式解析成一個樹T_0,只是到最後一層的時候,按不超過M進行合併展開;接著按一種方式進行剪裁,將T_0樹遞歸方式壓縮到一個淺且寬的樹T_h,如下圖。這部分文中說的不是特別詳細,如果想弄清楚,需去了解Parabel的paper。
上圖顯示的是構建一顆PLT樹的過程,K=M=8,H=3,L=8000(M是max_leaf是最大葉子節點數,如果某個葉子裡面的標籤數超過8,就會切分該節點,H表示樹高-2,L表示一共的標籤數)。T0表示level=0的樹,裡面的數字表示樹每個高度的節點數。Th中的紅色數字對應的節點會被移掉,以為了獲得Th+1 樹。可以理解為:最後一層節點數10248>8000,已涵蓋所有label了,而1288=1024>512>256,所以512,256這兩層可以刪掉;同樣的方式,一直把樹的深度裁剪到預定的H=3的高度,且每個節點的容量都不超過M=8。其中注意,root節點是不受M值限制的。**
下兩圖為PLT生成的偽代碼和文本作者在三個數據集上生成的PLT情況:
2.2 AttentionXML在構建好了PLT後,文中作者採用的是層級方式來訓練模型的。具體包括:
(1)從上至下地給每個level單獨的訓練一個模型,每個模型都是一個多標籤學習;
(2)level-d的樹AttentionXML模型,是通過每個樣本的候選標籤g(x)訓練的。我們對第(d-1)層AttentionXML模型預測的標籤的從正到負,得分由高到底進行排序。我們選擇第(d-1)層的top C標籤作為下一層訓練的候選標籤g(x)。這就像是一種額外的負採樣,相比於只使用節點的正樣本,我們可以得到一個更精確的對數似然近似值。
(3)在預測階段,對於第i個樣本,第j個標籤的預測得分y_i,j通過概率鏈式法則很容易獲得。為了預測效率,我們使用beam search算法,對於第d層的,我們只預測d-1層top C的標籤。
每層的AttentionXML模型結構如下圖,主要包含5層:Word Representation Layer,Bidirectional LSTM Layer,Multi-label Attention Layer,Fully Connected Layer,Output Layer。
在Multi-label Attention Layer上,計算方式如下,就是常規的attention計算方式,讓不同的label學習到與文本向量h_i不同的權重信息。
在Fully Connected Layer上,文中是採用各個層級模型共享的方式,主要目的減少模型複雜度。
預測階段的偽代碼
3 總結這次主要分享的目的就是在處理大規模多標籤文本分類任務時,如何使用層級分類的思路解決該任務,提高識別效果。本文的實驗結果就不分析了。文中提出的構建淺且寬的PLT樹思路,可以借鑑,類似是一種折衷的方案,既不把label都視為平行結構,也不能把label構建成特別深的樹結構,影響學習效率。但其中使用KMeans去構建PLT樹,我內心是有點懷疑的,這應該會產生分類誤差,沒有一個衡量構建好壞或更合理的指標,不過目前我也沒想到好的方法,留給大家去思考了.