Google Research的TabNet於2019年發布,在預印稿中被宣稱優於表格數據的現有方法。它是如何工作的,又如何可以嘗試呢?
表格數據可能構成當今大多數業務數據。考慮諸如零售交易,點擊流數據,工廠中的溫度和壓力傳感器,銀行使用的KYC (Know Your Customer) 信息或製藥公司使用的模型生物的基因表達數據之類的事情。
論文稱為TabNet: Attentive Interpretable Tabular Learning(https://arxiv.org/pdf/1908.07442.pdf),很好地總結了作者正在嘗試做的事情。「Net」部分告訴我們這是一種神經網絡,「Attentive 」部分表示它正在使用一種注意力機制,旨在實現可解釋性,並用於表格數據的機器學習。
它是如何工作的?
TabNet使用一種軟功能選擇將重點僅放在對當前示例很重要的功能上。這是通過順序的多步驟決策機制完成的。即,以多個步驟自上而下地處理輸入信息。正如論文所指出的那樣,「自上而下關注的思想是從處理視覺和語言數據或強化學習中得到的啟發,可以在高維輸入中搜索一小部分相關信息。」
儘管它們與BERT等流行的NLP模型中使用的transformer 有些不同,但執行這種順序關注的構件卻稱為transformer 塊。這些transformer 使用自注意力機制,試圖模擬句子中不同單詞之間的依賴關係。這裡使用的transformer類型試圖使用「軟」特性選擇,一步一步地消除與示例無關的那些特性,這是通過使用sparsemax函數完成的。
這篇論文的第一個圖,如下重現,描繪了信息是如何聚集起來形成預測的。
TabNet的一個好特性是它不需要特性預處理。另一個原因是,它具有內置的可解釋性,即為每個示例選擇最相關的特性。這意味著您不必應用外部解釋模塊,如shap或LIME。
在閱讀本文時,要理解這個架構中發生了什麼並不容易,但幸運的是,已經發表的代碼稍微澄清了一些問題,並表明它並不像您可能認為的那樣複雜。
我怎麼使用它?
現在TabNet有了更好的實現,如下所述:一個是PyTorch的接口,它有一個類似scikit學習的接口,還有一個是FastAI的接口。
根據作者readme描述要點如下:
為每個數據集創建新的train.csv,val.csv和test.csv文件,我不如讀取整個數據集並在內存中進行拆分(當然,只要可行),所以我寫了一個在我的代碼中為Pandas提供了新的輸入功能。
修改data_helper.py文件可能需要一些工作,至少在最初不確定您要做什麼以及應該如何定義功能列時(至少我是這樣)。還有許多參數需要更改,但它們位於主訓練循環文件中,而不是數據幫助器文件中。有鑑於此,我還嘗試在我的代碼中概括和簡化此過程。
我添加了一些快速的代碼來進行超參數優化,但到目前為止僅用於分類。
還值得一提的是,作者提供的示例代碼僅顯示了如何進行分類,而不是回歸,因此用戶也必須編寫額外的代碼。我添加了具有簡單均方誤差損失的回歸功能。
使用命令行運行測試
python train_tabnet.py \
--csv-path data/adult.csv \
--target-name "
--categorical-features workclass,education,marital.status,\
occupation,relationship,race,sex,native.country\
--feature_dim 16 \
--output_dim 16 \
--batch-size 4096 \
--virtual-batch-size 128 \
--batch-momentum 0.98 \
--gamma 1.5 \
--n_steps 5 \
--decay-every 2500 \
--lambda-sparsity 0.0001 \
--max-steps 7700
強制性參數包括--csv-path(指向CSV文件的位置),-target-name(具有預測目標的列的名稱)和-category-featues(逗號分隔列表) 應該視為分類的功能)。其餘輸入參數是需要針對每個特定問題進行優化的超參數。但是,上面顯示的值直接取自TabNet論文,因此作者已經針對成人普查數據集對其進行了優化。
默認情況下,訓練過程會將信息寫入執行腳本的位置的tflog子文件夾。您可以將tensorboard指向此文件夾以查看訓練和驗證統計信息:
tensorboard --logdir tflog
如果您沒有GPU ...
…您可以嘗試這款Colaboratory筆記(https://colab.research.google.com/drive/1AWnaS6uQVDw0sdWjfh-E77QlLtD0cpDa)。請注意,如果您想查看Tensorboard日誌,最好的選擇是創建一個Google Storage存儲桶,並讓腳本在其中寫入日誌。這可以通過使用tb-log-location參數來完成。例如。如果您的存儲桶名稱是camembert-skyscrape,則可以在腳本的調用中添加--tb-log-location gs:// camembert-skyscraper。(不過請注意,您必須正確設置存儲桶的權限。這可能有點麻煩。)
然後可以將tensorboard從自己的本地計算機指向該存儲桶:
tensorboard --logdir gs://camembert-skyscraper
超參數優化
在存儲庫(opt_tabnet.py)中也有一個用於完成超參數優化的快捷腳本。同樣,在協作筆記本中顯示了一個示例。該腳本僅適用於到目前為止的分類,值得注意的是,某些訓練參數雖然實際上並不需要,但仍進行了硬編碼(例如,用於儘早停止的參數[您可以繼續執行多少步,而 驗證準確性沒有提高]。)
優化腳本中變化的參數為N_steps,feature_dim,batch-momentum,gamma,lambda-sparsity。(正如下面的優化技巧所建議的那樣,output_dim設置為等於feature_dim。)
論文中具有以下有關超參數優化的提示:
大多數數據集對N_steps∈[3,10]產生最佳結果。通常,更大的數據集和更複雜的任務需要更大的N_steps。N_steps的非常高的值可能會過度擬合併導致不良的泛化。
調整Nd [feature_dim]和Na [output_dim]的值是獲得性能與複雜性之間折衷的最有效方法。Nd = Na是大多數數據集的合理選擇。Nd和Na的非常高的值可能會過度擬合,導致泛化效果差。
γ的最佳選擇對整體性能具有重要作用。通常,較大的N_steps值有利於較大的γ。
批量較大對性能有利-如果內存限制允許,建議最大訓練數據集總大小的1-10%。虛擬批次大小通常比批次大小小得多。
最初,較高的學習率很重要,應逐漸降低直至收斂。
結果
我已經通過此命令行界面嘗試了TabNet的多個數據集,作者提供了他們在那裡找到的最佳參數設置。使用這些設置重複運行後,我注意到最佳驗證誤差(和測試誤差)往往在86%左右,類似於不進行超參數調整的CatBoost。作者報告論文中測試集的性能為85.7%。當我使用hyperopt進行超參數優化時,儘管使用了不同的參數設置,但我毫不奇怪地達到了約86%的相似性能。
對於其他數據集,例如Poker Hand 數據集,TabNet被認為遠遠擊敗了其他方法。我還沒有花很多時間,但是當然每個人都應邀請他們自己對各種數據集進行超參數優化的TabNet!
TabNet是一個有趣的體系結構,似乎有望用於表格數據分析。它直接對原始數據進行操作,並使用順序注意機制對每個示例執行顯式特徵選擇。此屬性還使其具有某種內置的可解釋性。
我試圖通過圍繞它編寫一些包裝器代碼來使TabNet稍微容易一些。下一步是將其與各種數據集中的其他方法進行比較。
tabnet的各種實現
google官方:https://github.com/google-research/google-research/tree/master/tabnet
pytorch:https://github.com/dreamquark-ai/tabnet
本文作者的一些改進:https://github.com/hussius/tabnet_fork
作者:Mikael Huss
deephub翻譯組