最近工作中需要對文本進行多標籤分類(Multi-label Text Classification),系統查閱了相關論文,藉此機會整理歸納一下,希望能幫助對此有需要的同學節約一些時間。
什麼是多標籤文本分類
學術上常用的說法有兩個,一個是ExtremeMulti-label Learning, 簡稱XML,另一個是Extreme Multi-label Text Classification,簡稱XMTC,本質上都是對文本進行多標籤分類,即對於一個給定的文本,可能有多個標籤,我們需要設計一個模型預測其標籤。和XMTC相似的一個問題叫Multi-class classification, 它講的是對於一個文本其標籤是標籤集合中的一個,標籤集合中有三個以上標籤,可見它和XMTC有本質區別,即XMTC要預測的一個文本可能有多個標籤,而Multi-classclassification要預測的一個文本只有一個標籤。
XMTC有很廣泛的應用,比如購物網站中對商品進行分類,所有商品種類有上百萬個,任意一種商品可能屬於其中的幾個或者幾十個,當用戶創建一個新商品時,網站需要根據商品的屬性預測其多個分類,以方便搜索使用。
主要方案
到目前為止主要解決方案有三類:基於樹的方法(Tree-based method),目標嵌入方法(Target-embedding method)和基於深度學習的方法(Deep learning method)。以下分別介紹。
基於樹的方法(Tree-based method)
和分類決策樹類似,此方法根據訓練數據創建一個樹形結構模型,當新數據需要預測的時候,從樹根開始尋找所屬的分支,最後找到的多個葉子節點就是對應分類。和傳統決策樹不同,此方法在創建樹形結構模型的時候,會根據所有特徵(features)學習一個權重參數,由此參數決定樹結構創建過程的每個節點的分支劃分,目的是要使各個分支有相似的標籤分布。此方法最有名的模型叫FastXML,其節點劃分時採用的打分方法叫nDCG,當模型建立後需要預測時,尋找葉子節點中分數最高的top n節點就是預測的標籤。FastXML是微軟印度研究院提出的,論文詳見[1],這個團隊在XMTC領域有重要的影響力,發表了了一系列的模型和改進,官方網站http://manikvarma.org,其包含了他們發表的主要文章、對應的源碼實現、測試的數據集、以及在測試數據集上各種算法對應的性能比較,其性能指標是這個領域進行性能比較的通用指標。
目標嵌入方法(Target-embedding method)
目標嵌入方法主要為解決XMTC訓練過程中稀疏數據的問題。試想如果目標標籤集合中有十萬條標籤,一般一條數據僅屬於其中的幾個或者十幾個標籤。如果用一個向量表示標籤集合,對於給定的一個文本,屬於某個標籤用1表示,否則用0表示,則其標籤向量中絕大部分數值都是0,這是典型的稀疏數據問題。目標嵌入方法的思路是把目標標籤集合從高維空間映射到低維空間,數據訓練和預測的目標都針對低維空間,最後對預測到低維空間的標籤再映射回原本的高維空間。從高維映射到低維空間過程稱為壓縮過程(compression),可採用線性或非線性變換,反之從低維映射回高維空間過程稱為解壓過程(decompression)。各種不同的模型主要區別在於壓縮和解壓算法不一樣。提出FastXML的團隊在這方面也有研究,論文見[2],另一篇這方面較新的論文見[3].
基於深度學習的方法(Deep learning method)
現在機器學習基本不能不提深度學習,自然有人嘗試用深度學習解決XMTC問題。
論文[4]使用CNN對多標籤文本分類,是一個常規的CNN網絡,文本通過詞嵌入(word embedding)轉換為向量作為CNN輸入,然後做max pooling和全連接層,最後通過sigmoid得到top n結果。實驗結果顯示其性能優於絕大部分之前的FastXML和目標嵌入模型。
當然,文本處理也不能忽略RNN和近年較火的attention based模型,論文[5]就是這麼幹的,其詞嵌入用word2vec,RNN用了Bidirectional LSTM,然後用了multi-label attention layer,最後使用全連接層。模型實驗結果表明其性能好於上面的CNN模型。
具體應用
我們項目的需求是對給定的軟體漏洞描述,預測其所屬的軟體library分類,一個漏洞可能屬於多個libraries,比如Apache Struts的一個漏洞可能屬於struct2-core, struts2-rest-plugin。我們沒有從性能可能是最好的深度學習模型開始,而是先嘗試FastXML,接下來再嘗試深度學習以比較性能。
實驗數據採用了8000條已標註的軟體漏洞,其中6000條用於訓練,2000條用於測試。
以上三種解決方案都是預測top n的標籤,即選取預測得分最高的前n個標籤,這和我們的需求不太一致,我們各個預測數據對應的標籤數量可能不同,所以對結果選取做了一些改變。我們設置一個閾值(threshold),對一個給定的文本,預測分數中的值如果大於此閾值就認為該文本屬於這個分數對應的標籤,否則不屬於,得到預測結果後可以算出該文本預測的precision和recall,對所有文本預測的結果取平均值即可得到該閾值對應的模型precision和recall。最後我們從0到1遍歷閾值,步長取0.01,這樣就可以得到各個閾值對應的模型precision和recall。一般情況下,precision越高recall越低,我們實驗中較好的結果是recall 0.42時precision為0.4。
接下來我們會進一步改進數據處理和模型,看看性能是否有改進。
[1] Y.Prabhu and M. Varma. FastXML: A fast, accurate and stable tree-classifierfor extreme multi-label learning. http://manikvarma.org/pubs/prabhu14.pdf
[2] K.Bhatia, H. Jain, P. Kar, M. Varma and P. Jain. Sparse local embeddings forextreme multi-label classification. http://manikvarma.org/pubs/bhatia15.pdf
[3] Wenjie Zhang,Junchi Yan, Deep Extreme Multi-label Learning.
https://arxiv.org/pdf/1704.03718.pdf
[4] Jingzhou Liu, Wei-Cheng Chang, Deep Learning for ExtremeMulti-label Text Classification. http://nyc.lti.cs.cmu.edu/yiming/Publications/jliu-sigir17.pdf
[5] Ronghui You, Suyang Dai, AttentionXML: ExtremeMulti-Label Text Classification with Multi-Label Attention Based RecurrentNeural Networks. https://arxiv.org/pdf/1811.01727.pdf