Google的神經網絡表格處理模型TabNet介紹

2020-12-20 騰訊網

Google Research的TabNet於2019年發布,在預印稿中被宣稱優於表格數據的現有方法。它是如何工作的,又如何可以嘗試呢?

表格數據可能構成當今大多數業務數據。考慮諸如零售交易,點擊流數據,工廠中的溫度和壓力傳感器,銀行使用的KYC (Know Your Customer) 信息或製藥公司使用的模型生物的基因表達數據之類的事情。

論文稱為TabNet: Attentive Interpretable Tabular Learning(https://arxiv.org/pdf/1908.07442.pdf),很好地總結了作者正在嘗試做的事情。「Net」部分告訴我們這是一種神經網絡,「Attentive 」部分表示它正在使用一種注意力機制,旨在實現可解釋性,並用於表格數據的機器學習。

它是如何工作的?

TabNet使用一種軟功能選擇將重點僅放在對當前示例很重要的功能上。這是通過順序的多步驟決策機制完成的。即,以多個步驟自上而下地處理輸入信息。正如論文所指出的那樣,「自上而下關注的思想是從處理視覺和語言數據或強化學習中得到的啟發,可以在高維輸入中搜索一小部分相關信息。」

儘管它們與BERT等流行的NLP模型中使用的transformer 有些不同,但執行這種順序關注的構件卻稱為transformer 塊。這些transformer 使用自注意力機制,試圖模擬句子中不同單詞之間的依賴關係。這裡使用的transformer類型試圖使用「軟」特性選擇,一步一步地消除與示例無關的那些特性,這是通過使用sparsemax函數完成的。

這篇論文的第一個圖,如下重現,描繪了信息是如何聚集起來形成預測的。

TabNet的一個好特性是它不需要特性預處理。另一個原因是,它具有內置的可解釋性,即為每個示例選擇最相關的特性。這意味著您不必應用外部解釋模塊,如shap或LIME。

在閱讀本文時,要理解這個架構中發生了什麼並不容易,但幸運的是,已經發表的代碼稍微澄清了一些問題,並表明它並不像您可能認為的那樣複雜。

我怎麼使用它?

現在TabNet有了更好的實現,如下所述:一個是PyTorch的接口,它有一個類似scikit學習的接口,還有一個是FastAI的接口。

根據作者readme描述要點如下:

為每個數據集創建新的train.csv,val.csv和test.csv文件,我不如讀取整個數據集並在內存中進行拆分(當然,只要可行),所以我寫了一個在我的代碼中為Pandas提供了新的輸入功能。

修改data_helper.py文件可能需要一些工作,至少在最初不確定您要做什麼以及應該如何定義功能列時(至少我是這樣)。還有許多參數需要更改,但它們位於主訓練循環文件中,而不是數據幫助器文件中。有鑑於此,我還嘗試在我的代碼中概括和簡化此過程。

我添加了一些快速的代碼來進行超參數優化,但到目前為止僅用於分類。

還值得一提的是,作者提供的示例代碼僅顯示了如何進行分類,而不是回歸,因此用戶也必須編寫額外的代碼。我添加了具有簡單均方誤差損失的回歸功能。

使用命令行運行測試

python train_tabnet.py \

--csv-path data/adult.csv \

--target-name "

--categorical-features workclass,education,marital.status,\

occupation,relationship,race,sex,native.country\

--feature_dim 16 \

--output_dim 16 \

--batch-size 4096 \

--virtual-batch-size 128 \

--batch-momentum 0.98 \

--gamma 1.5 \

--n_steps 5 \

--decay-every 2500 \

--lambda-sparsity 0.0001 \

--max-steps 7700

強制性參數包括--csv-path(指向CSV文件的位置),-target-name(具有預測目標的列的名稱)和-category-featues(逗號分隔列表) 應該視為分類的功能)。其餘輸入參數是需要針對每個特定問題進行優化的超參數。但是,上面顯示的值直接取自TabNet論文,因此作者已經針對成人普查數據集對其進行了優化。

默認情況下,訓練過程會將信息寫入執行腳本的位置的tflog子文件夾。您可以將tensorboard指向此文件夾以查看訓練和驗證統計信息:

tensorboard --logdir tflog

如果您沒有GPU ...

…您可以嘗試這款Colaboratory筆記(https://colab.research.google.com/drive/1AWnaS6uQVDw0sdWjfh-E77QlLtD0cpDa)。請注意,如果您想查看Tensorboard日誌,最好的選擇是創建一個Google Storage存儲桶,並讓腳本在其中寫入日誌。這可以通過使用tb-log-location參數來完成。例如。如果您的存儲桶名稱是camembert-skyscrape,則可以在腳本的調用中添加--tb-log-location gs:// camembert-skyscraper。(不過請注意,您必須正確設置存儲桶的權限。這可能有點麻煩。)

然後可以將tensorboard從自己的本地計算機指向該存儲桶:

tensorboard --logdir gs://camembert-skyscraper

超參數優化

在存儲庫(opt_tabnet.py)中也有一個用於完成超參數優化的快捷腳本。同樣,在協作筆記本中顯示了一個示例。該腳本僅適用於到目前為止的分類,值得注意的是,某些訓練參數雖然實際上並不需要,但仍進行了硬編碼(例如,用於儘早停止的參數[您可以繼續執行多少步,而 驗證準確性沒有提高]。)

優化腳本中變化的參數為N_steps,feature_dim,batch-momentum,gamma,lambda-sparsity。(正如下面的優化技巧所建議的那樣,output_dim設置為等於feature_dim。)

論文中具有以下有關超參數優化的提示:

大多數數據集對N_steps∈[3,10]產生最佳結果。通常,更大的數據集和更複雜的任務需要更大的N_steps。N_steps的非常高的值可能會過度擬合併導致不良的泛化。

調整Nd [feature_dim]和Na [output_dim]的值是獲得性能與複雜性之間折衷的最有效方法。Nd = Na是大多數數據集的合理選擇。Nd和Na的非常高的值可能會過度擬合,導致泛化效果差。

γ的最佳選擇對整體性能具有重要作用。通常,較大的N_steps值有利於較大的γ。

批量較大對性能有利-如果內存限制允許,建議最大訓練數據集總大小的1-10%。虛擬批次大小通常比批次大小小得多。

最初,較高的學習率很重要,應逐漸降低直至收斂。

結果

我已經通過此命令行界面嘗試了TabNet的多個數據集,作者提供了他們在那裡找到的最佳參數設置。使用這些設置重複運行後,我注意到最佳驗證誤差(和測試誤差)往往在86%左右,類似於不進行超參數調整的CatBoost。作者報告論文中測試集的性能為85.7%。當我使用hyperopt進行超參數優化時,儘管使用了不同的參數設置,但我毫不奇怪地達到了約86%的相似性能。

對於其他數據集,例如Poker Hand 數據集,TabNet被認為遠遠擊敗了其他方法。我還沒有花很多時間,但是當然每個人都應邀請他們自己對各種數據集進行超參數優化的TabNet!

TabNet是一個有趣的體系結構,似乎有望用於表格數據分析。它直接對原始數據進行操作,並使用順序注意機制對每個示例執行顯式特徵選擇。此屬性還使其具有某種內置的可解釋性。

我試圖通過圍繞它編寫一些包裝器代碼來使TabNet稍微容易一些。下一步是將其與各種數據集中的其他方法進行比較。

tabnet的各種實現

google官方:https://github.com/google-research/google-research/tree/master/tabnet

pytorch:https://github.com/dreamquark-ai/tabnet

本文作者的一些改進:https://github.com/hussius/tabnet_fork

作者:Mikael Huss

deephub翻譯組

相關焦點

  • Google雲計算AI平臺內置TabNet表格模型,可替代傳統決策樹算法
    Google在雲計算平臺新加入內置深度學習模型TabNet,用戶可以簡單地使用深度學習來處理表格資料,TabNet具有兩項優點,除了可解釋性之外,還提供高效的執行性能,Google提到,這個算法很適合用於零售、金融和保險業,實例像是預測信用評分、欺詐偵測和數值預測等應用。
  • 清華大學圖神經網絡綜述:模型與應用
    這篇文章對圖神經網絡進行了廣泛的總結,並做出了以下貢獻:文章詳細介紹了圖神經網絡的經典模型。主要包括其原始模型,不同的變體和幾個通用框架。文章將圖神經網絡的應用系統地歸類為結構化場景、非結構化場景和其他場景中,並介紹了不同場景中的主要應用。本文為未來的研究提出四個未解決的問題。文章對每個問題進行了詳細分析,並提出未來的研究方向。
  • ICLR 2020 | 神經正切,5行代碼打造無限寬的神經網絡模型
    深度學習在自然語言處理,對話智能體和連接組學等多個領域都獲得了成功應用,這種學習方式已經改變了機器學習的研究格局,並給研究人員帶來了許多有趣而重要的開放性問題,例如:為什麼深度神經網絡(DNN)在被過度參數化的情況下仍能如此良好地泛化
  • 神經網絡提取PDF表格工具來了,支持圖片,還能白嫖谷歌GPU資源
    別著急,一種使用深度神經網絡識別提取表格的開源工具可以幫助你。兼容圖片、高準確率、還不佔用本地運算資源,如此實用的工具值得你擁有。測試實例如果在輸入的PDF文件中檢測的表格,模型會在邊界框(bounding box)標出表格邊框:然後,表格數據會被轉化為Panda數據框架,方便後續處理:怎麼樣,是不是很實用?那這個工具如何使用呢?
  • 谷歌重磅開源Neural Tangents:5行代碼打造無限寬神經網絡模型
    但是,問題來了:推導有限網絡的無限寬度限制需要大量的數學知識,並且必須針對不同研究的體系結構分別進行計算。對工程技術水平的要求也很高。谷歌最新開源的Neural Tangents,旨在解決這個問題,讓研究人員能夠輕鬆建立、訓練無限寬神經網絡。甚至只需要5行代碼,就能夠打造一個無限寬神經網絡模型。這一研究成果已經中了ICLR 2020。
  • 神經網絡中避免過擬合5種方法介紹
    打開APP 神經網絡中避免過擬合5種方法介紹 THU數據派 發表於 2020-02-04 11:30:00 本文介紹了5種在訓練神經網絡中避免過擬合的技術。
  • 深度學習之卷積神經網絡經典模型
    LeNet-5模型在CNN的應用中,文字識別系統所用的LeNet-5模型是非常經典的模型。LeNet-5模型是1998年,Yann LeCun教授提出的,它是第一個成功大規模應用在手寫數字識別問題的卷積神經網絡,在MNIST數據集中的正確率可以高達99.2%。下面詳細介紹一下LeNet-5模型工作的原理。
  • 神經網絡的工作原理介紹
    概要單純的講神經網絡的概念有些抽象,先通過一個實例展示一下機器學習中的神經網絡進行數據處理的完整過程。神經網絡的實例1.1 案例介紹實例:訓練一個神經網絡模型擬合 廣告投入(TV,radio,newspaper 3種方式)和銷售產出的關係,實現根據廣告投放來預測銷售情況。
  • 谷歌開源Neural Tangents:簡單快速訓練無限寬度神經網絡
    在自然語言處理、會話智能體和連接組學等許多領域,深度學習都已取得了廣泛的成功,機器學習領域的研究圖景也已經發生了變革。不過,仍還有一些有趣而又重要的問題有待解答,比如:為什麼即使在過度參數化時,深度神經網絡(DNN)也能取得非常好的泛化能力?深度網絡的架構、訓練和性能之間有何關係?如何提取出深度學習模型中的顯著特徵?
  • 詳解NLP中的預訓練模型、圖神經網絡、模型壓縮、知識圖譜
    為了真正全面系統的培養NLP人才,貪心學院推出了《自然語言處理終身升級版》課程覆蓋了從經典的機器學習、文本處理技術、序列模型、深度學習、預訓練模型、知識圖譜、圖神經網絡所有必要的技術。並落地實操工業級項目,由資深的NLP負責人全程直播講解,幫助你融會貫通,輕鬆拿offer。
  • VGG卷積神經網絡模型解析
    一:VGG介紹與模型結構VGG全稱是Visual Geometry Group屬於牛津大學科學工程系,其發布了一些列以
  • AutoML : 更有效地設計神經網絡模型
    字幕組雙語原文:AutoML : 更有效地設計神經網絡模型英語原文:AutoML: Creating Top-Performing Neural Networks WithoutDefining Architectures翻譯:雷鋒字幕組(chenx2ovo)自動化機器學習,通常被稱為AutoML,是自動化構建指神經網絡結構。
  • CNN系列-神經網絡模型結構設計的演變和理解
    思路一:加深加寬想提升網絡精度,最樸素的方法就是加深加寬網絡,提升模型複雜度來增強擬合能力。但會存在兩個問題:一是容易過擬合,測試集不準;二是梯度容易不穩定,模型難訓練。所以早期的alexnet、inception、resnet研究各種方法在模型加寬加深的防止模型過擬合和梯度不穩定等。思路二:網絡優化如果不加寬加深,則需要進一步優化現有網絡的性能。那該怎麼優化呢?
  • 利用AI Builder創建表格處理模型
    那麼今天就來介紹一下如何利用Power Platform中的AI Builder來創建表格處理模型。使用此模型,我們可以從表格或發票中輕鬆提取我們需要的文本。第一步:進入"Power Apps"選擇"生成"。
  • 基於TensorFlow Eager Execution的簡單神經網絡模型
    作者 | Yu Xuan Lee來源 | Medium編輯 | 代碼醫生團隊介紹
  • IBM的8位浮點精度深度神經網絡模型解析
    本文引用地址:http://www.eepw.com.cn/article/201901/396743.htm  IBM的研究人員聲稱,他們已開發出一個更加高效的模型用於處理神經網絡,該模型只需使用8位浮點精度進行訓練,推理(inferencing)時更是僅需4位浮點精度。
  • 基於深度神經網絡的脫硫系統預測模型及應用
    本文還結合某 2×350MW 燃煤電廠提供的實際工數據,以石灰石供漿密度對系統脫硫性能的影響為例,詳細介紹了利用所建立的深度神經網絡模型測試溼法脫硫系統各參數指標對脫硫效果的影響,並結合化學機理和工業實際進行的診斷過程。
  • 神經網絡的基礎是MP模型?南大周志華組提出新型神經元模型FT
    據論文介紹,這項研究為神經網絡提供了一種新的基本構造單元,展示了開發具有神經元可塑性的人工神經網絡的可行性。當前的神經網絡大多基於 MP 模型,即按照生物神經元的結構和工作原理構造出來的抽象和簡化模型。此類模型通常將神經元形式化為一個「激活函數複合上輸入信號加權和」的形式。
  • 用飛槳做自然語言處理:神經網絡語言模型應用實例 - 量子位
    但這種方法會有一個很大的問題,那就是前面提到的維度災難,而這裡要實現的神經網絡語言模型(Neural Network Language Model),便是用神經網絡構建語言模型,通過學習分布式詞表示(即詞向量)的方式解決了這個問題。
  • 人工神經網絡算法介紹及其參數講解
    算法介紹神經網絡是一種運算模型,由大量的節點(或稱神經元)之間相互聯接構成。