決策樹的復興?結合神經網絡,提升ImageNet分類準確率且可解釋

2020-12-21 機器之心Pro

機器之心報導

機器之心編輯部

魚和熊掌我都要!BAIR公布神經支持決策樹新研究,兼顧準確率與可解釋性。

隨著深度學習在金融、醫療等領域的不斷落地,模型的可解釋性成了一個非常大的痛點,因為這些領域需要的是預測準確而且可以解釋其行為的模型。然而,深度神經網絡缺乏可解釋性也是出了名的,這就帶來了一種矛盾。可解釋性人工智慧(XAI)試圖平衡模型準確率與可解釋性之間的矛盾,但 XAI 在說明決策原因時並沒有直接解釋模型本身。

決策樹是一種用於分類的經典機器學習方法,它易於理解且可解釋性強,能夠在中等規模數據上以低難度獲得較好的模型。之前很火的微軟小冰讀心術極可能就是使用了決策樹。小冰會先讓我們想像一個知名人物(需要有點名氣才行),然後向我們詢問 15 個以內的問題,我們只需回答是、否或不知道,小冰就可以很快猜到我們想的那個人是誰。

周志華老師曾在「西瓜書」中展示過決策樹的示意圖:

決策樹示意圖。

儘管決策樹有諸多優點,但歷史經驗告訴我們,如果遇上 ImageNet 這一級別的數據,其性能還是遠遠比不上神經網絡。

「準確率」和「可解釋性」,「魚」與「熊掌」要如何兼得?把二者結合會怎樣?最近,來自加州大學伯克利分校和波士頓大學的研究者就實踐了這種想法。

他們提出了一種神經支持決策樹「Neural-backed decision trees」,在 ImageNet 上取得了 75.30% 的 top-1 分類準確率,在保留決策樹可解釋性的同時取得了當前神經網絡才能達到的準確率,比其他基於決策樹的圖像分類方法高出了大約 14%。

BAIR 博客地址:https://bair.berkeley.edu/blog/2020/04/23/decisions/

論文地址:https://arxiv.org/abs/2004.00221

開源項目地址:https://github.com/alvinwan/neural-backed-decision-trees

這種新提出的方法可解釋性有多強?我們來看兩張圖。

OpenAI Microscope 中深層神經網絡可視化後是這樣的:

而論文所提方法在 CIFAR100 上分類的可視化結果是這樣的:

哪種方法在圖像分類上的可解釋性強已經很明顯了吧。

決策樹的優勢與缺陷

在深度學習風靡之前,決策樹是準確性和可解釋性的標杆。下面,我們首先闡述決策樹的可解釋性。

如上圖所示,這個決策樹不只是給出輸入數據 x 的預測結果(是「超級漢堡」還是「華夫薯條」),還會輸出一系列導致最終預測的中間決策。我們可以對這些中間決策進行驗證或質疑。

然而,在圖像分類數據集上,決策樹的準確率要落後神經網絡 40%。神經網絡和決策樹的組合體也表現不佳,甚至在 CIFAR10 數據集上都無法和神經網絡相提並論。

這種準確率缺陷使其可解釋性的優點變得「一文不值」:我們首先需要一個準確率高的模型,但這個模型也要具備可解釋性。

走近神經支持決策樹

現在,這種兩難處境終於有了進展。加州大學伯克利分校和波士頓大學的研究者通過建立既可解釋又準確的模型來解決這個問題。

研究的關鍵點是將神經網絡和決策樹結合起來,保持高層次的可解釋性,同時用神經網絡進行低層次的決策。如下圖所示,研究者稱這種模型為「神經支持決策樹(NBDT)」,並表示這種模型在保留決策樹的可解釋性的同時,也能夠媲美神經網絡的準確性。

在這張圖中,每一個節點都包含一個神經網絡,上圖放大標記出了一個這樣的節點與其包含的神經網絡。在這個 NBDT 中,預測是通過決策樹進行的,保留高層次的可解釋性。但決策樹上的每個節點都有一個用來做低層次決策的神經網絡,比如上圖的神經網絡做出的低層決策是「有香腸」或者「沒有香腸」。

NBDT 具備和決策樹一樣的可解釋性。並且 NBDT 能夠輸出預測結果的中間決策,這一點優於當前的神經網絡。

如下圖所示,在一個預測「狗」的網絡中,神經網絡可能只輸出「狗」,但 NBDT 可以輸出「狗」和其他中間結果(動物、脊索動物、肉食動物等)。

此外,NBDT 的預測層次軌跡也是可視化的,可以說明哪些可能性被否定了。

與此同時,NBDT 也實現了可以媲美神經網絡的準確率。在 CIFAR10、CIFAR100 和 TinyImageNet200 等數據集上,NBDT 的準確率接近神經網絡(差距

神經支持決策樹是如何解釋的

對於個體預測的辯證理由

最有參考價值的辯證理由是面向該模型從未見過的對象。例如,考慮一個 NBDT(如下圖所示),同時在 Zebra 上進行推演。雖然此模型從未見過斑馬,但下圖所顯示的中間決策是正確的-斑馬既是動物又是蹄類動物。對於從未見過的物體而言,個體預測的合理性至關重要。

對於模型行為的辯證理由

此外,研究者發現使用 NBDT,可解釋性隨著準確性的提高而提高。這與文章開頭中介紹的準確性與可解釋性的對立背道而馳,即:NBDT 不僅具有準確性和可解釋性,還可以使準確性和可解釋性成為同一目標。

ResNet10 層次結構(左)不如 WideResNet 層次結構(右)。

例如,ResNet10 的準確度比 CIFAR10 上的 WideResNet28x10 低 4%。相應地,較低精度的 ResNet ^ 6 層次結構(左)將青蛙,貓和飛機分組在一起且意義較小,因為很難找到三個類共有的視覺特徵。而相比之下,準確性更高的 WideResNet 層次結構(右)更有意義,將動物與車完全分離開了。因此可以說,準確性越高,NBDT 就越容易解釋。

了解決策規則

使用低維表格數據時,決策樹中的決策規則很容易解釋,例如,如果盤子中有麵包,然後分配給合適的孩子(如下所示)。然而,決策規則對於像高維圖像的輸入而言則不是那麼直接。模型的決策規則不僅基於對象類型,而且還基於上下文,形狀和顏色等等。

此案例演示了如何使用低維表格數據輕鬆解釋決策的規則。

為了定量解釋決策規則,研究者使用了 WordNet3 的現有名詞層次;通過這種層次結構可以找到類別之間最具體的共享含義。例如,給定類別 Cat 和 Dog,WordNet 將反饋哺乳動物。在下圖中,研究者定量驗證了這些 WordNet 假設。

左側從屬樹(紅色箭頭)的 WordNet 假設是 Vehicle。右邊的 WordNet 假設(藍色箭頭)是 Animal。

值得注意的是,在具有 10 個類(如 CIFAR10)的小型數據集中,研究者可以找到所有節點的 WordNet 假設。但是,在具有 1000 個類別的大型數據集(即 ImageNet)中,則只能找到節點子集中的 WordNet 假設。

How it Works

Neural-Backed 決策樹的訓練與推斷過程可分解為如下四個步驟:

為決策樹構建稱為誘導層級「Induced Hierarchy」的層級;

該層級產生了一個稱為樹監督損失「Tree Supervision Loss」的獨特損失函數;

通過將樣本傳遞給神經網絡主幹開始推斷。在最後一層全連接層之前,主幹網絡均為神經網絡;

以序列決策法則方式運行最後一層全連接層結束推斷,研究者將其稱為嵌入決策法則「Embedded Decision Rules」。

Neural-Backed 決策樹訓練與推斷示意圖。

運行嵌入決策法則

這裡首先討論推斷問題。如前所述,NBDT 使用神經網絡主幹提取每個樣本的特徵。為便於理解接下來的操作,研究者首先構建一個與全連接層等價的退化決策樹,如下圖所示:

以上產生了一個矩陣-向量乘法,之後變為一個向量的內積,這裡將其表示為$\hat{y}$。以上輸出最大值的索引即為對類別的預測。

簡單決策樹(naive decision tree):研究者構建了一個每一類僅包含一個根節點與一個葉節點的基本決策樹,如上圖中「B—Naive」所示。每個葉節點均直接與根節點相連,並且具有一個表徵向量(來自 W 的行向量)。

使用從樣本提取的特徵 x 進行推斷意味著,計算 x 與每個子節點表徵向量的內積。類似於全連接層,最大內積的索引即為所預測的類別。

全連接層與簡單決策樹之間的直接等價關係,啟發研究者提出一種特別的推斷方法——使用內積的決策樹。

構建誘導層級

該層級決定了 NBDT 需要決策的類別集合。由於構建該層級時使用了預訓練神經網絡的權重,研究者將其稱為誘導層級。

具體地,研究者將全連接層中權重矩陣 W 的每個行向量,看做 d 維空間中的一點,如上圖「Step B」所示。接下來,在這些點上進行層級聚類。連續聚類之後便產生了這一層級。

使用樹監督損失進行訓練

考慮上圖中的「A-Hard」情形。假設綠色節點對應於 Horse 類。這只是一個類,同時它也是動物(橙色)。對結果而言,也可以知道到達根節點(藍色)的樣本應位於右側的動物處。到達節點動物「Animal」的樣本也應再次向右轉到「Horse」。所訓練的每個節點用於預測正確的子節點。研究者將強制實施這種損失的樹稱為樹監督損失(Tree Supervision Loss)。換句話說,這實際上是每個節點的交叉熵損失。

使用指南

我們可以直接使用 Python 包管理工具來安裝 nbdt:

pip install nbdt

安裝好 nbdt 後即可在任意一張圖片上進行推斷,nbdt 支持網頁連結或本地圖片。

nbdt https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32

# OR run on a local image

nbdt /imaginary/path/to/local/image.png

不想安裝也沒關係,研究者為我們提供了網頁版演示以及 Colab 示例,地址如下:

Demo:http://nbdt.alvinwan.com/demo/

Colab:http://nbdt.alvinwan.com/notebook/

下面的代碼展示了如何使用研究者提供的預訓練模型進行推斷:

from nbdt.model import SoftNBDT

from nbdt.models import ResNet18, wrn28_10_cifar10, wrn28_10_cifar100, wrn28_10 # use wrn28_10 for TinyImagenet200

model = wrn28_10_cifar10()

model = SoftNBDT(

pretrained=True,

dataset='CIFAR10',

arch='wrn28_10_cifar10',

model=model)

另外,研究者還提供了如何用少於 6 行代碼將 nbdt 與我們自己的神經網絡相結合,詳細內容請見其 GitHub 開源項目。

相關焦點

  • TPU加AutoML:50美元快速訓練高效的ImageNet圖像分類網絡
    昨日,Jeff Dean 在推特上表示他們在 ImageNet 圖像分類上發布了新的 DAWNBench 結果,新型 AmoebaNet-D 網絡在 TPU 上實現了最低的訓練時間和最少的訓練成本。在這一個基準測試上,基於進化策略的 DAWNBench 要比殘差網絡效果更好,且訓練成本降低了一倍。
  • 深度森林第三彈:周志華組提出可做表徵學習的多層梯度提升決策樹
    今日,南京大學的馮霽、俞揚和周志華提出了多層梯度提升決策樹模型,它通過堆疊多個回歸 GBDT 層作為構建塊,並探索了其學習層級表徵的能力。此外,與層級表徵的神經網絡不同,他們提出的方法並不要求每一層都是可微,也不需要使用反向傳播更新參數。因此,多層分布式表徵學習不僅有深度神經網絡,同時還有決策樹!近十年來,深層神經網絡的發展在機器學習領域取得了顯著進展。
  • 深度神經網絡的灰色區域:可解釋性問題
    【編者按】在解決視覺、聽覺問題方面表現出色的深度學習系統中,分類器和特徵模塊都是自動學習的,神經網絡可解釋性問題就成了一個灰色區域,思考這個問題對神經網絡效率的保證是有必要的。在這篇博客文章中,機器學習PhD、軟體架構師Adnan Masood針對這個問題進行了多方面的反思。
  • 告別AI模型黑盒子:可解釋的神經網絡研究(一)
    眾所周知,機器學習模型,如神經網絡,深度神經網絡等,有非常不錯的預測能力,但是讓人信任一個模型的結果除了有良好的精度之外,可解釋性也是一個重要的因素。本文將介紹機器學習模型可解釋性的定義、性質和方法,並在後續的文章中,著重介紹不同解釋模型的方法,力求在維持模型精度的同時,通過更好的解釋模型方法,提高模型的可解釋性,從而提高人們對模型和模型結果的信任和接受程度。
  • 適用於特殊類型自然語言分類的自適應特徵譜神經網絡
    為此,提出了一種新型的神經網絡結構——自適應特徵譜神經網絡。該算法有效減少了運算時間,可以自適應地選擇對分類最有用的特徵,形成最高效的特徵譜,得到的分類結果具有一定的可解釋性,而且由於其運行速度快、內存佔用小,因此非常適用於學習輔助軟體等方面。以此算法為基礎,開發了相應的個性化學習平臺。該算法使古詩文分類的準確率由93.84%提升到了99%。
  • 圖神經網絡讓預估到達準確率提升50%,谷歌地圖實現新突破
    所以,預估到達時間(ETA)準確率成為非常實際的研究課題。近日,DeepMind 與谷歌地圖展開合作,利用圖神經網絡等 ML 技術,極大了提升了柏林、東京、雪梨等大城市的實時 ETA 準確率。DeepMind 研究者與 Google Maps 團隊展開合作,嘗試通過圖神經網絡等高級機器學習技術,提升柏林、雅加達、聖保羅、雪梨、東京和華盛頓哥倫比亞特區等地的實時 ETA 準確率,最高提升了
  • 圖神經網絡讓預估到達準確率提升50%,谷歌地圖實現新突破
    所以,預估到達時間(ETA)準確率成為非常實際的研究課題。近日,DeepMind 與谷歌地圖展開合作,利用圖神經網絡等 ML 技術,極大了提升了柏林、東京、雪梨等大城市的實時 ETA 準確率。很多人使用谷歌地圖(Google Maps)獲取精確的交通預測和預估到達時間(Estimated Time of Arrival,ETA)。
  • ThunderGBM:快成一道閃電的梯度提升決策樹
    機器之心報導參與:淑婷、思源想在 GPU 上使用使用閃電般快速的提升方法?了解這個庫就好了。在很多任務上,它都比 LightGBM 和 XGBoost 快。儘管近年來神經網絡復興並大為流行,但提升算法在訓練樣本量有限、所需訓練時間較短、缺乏調參知識等場景依然有其不可或缺的優勢。
  • 知識蒸餾:如何用一個神經網絡訓練另一個神經網絡
    如果你曾經用神經網絡來解決一個複雜的問題,你就會知道它們的尺寸可能非常巨大,包含數百萬個參數。例如著名的BERT模型約有1億1千萬參數。為了說明這一點,參見下圖中的NLP中最常見架構的參數數量。目前,有三種方法可以壓縮神經網絡,同時保持預測性能:權值裁剪量化知識蒸餾在這篇文章中,我的目標是向你介紹「知識蒸餾」的基本原理,這是一個令人難以置信的令人興奮的想法,它的基礎是訓練一個較小的網絡來逼近大的網絡。
  • Facebook:易於解釋的神經元可能阻礙深度神經網絡的學習
    編輯:張倩、杜偉易於解釋的神經元對於提升神經網絡的性能來說是必要的嗎?Facebook 的研究者給出了出人意料的答案。AI 模型能「理解」什麼?為什麼能理解這些東西?回答這些問題對於復現和改進 AI 系統至關重要。
  • 神經網絡可以
    神經網絡的強大眾所周知,但是極易受到對抗樣本的攻擊——輸入樣本上的微小擾動就能讓其預測錯誤。儘管目前已經湧現出許多抵禦對抗攻擊的方法,但這些方法一般都會造成模型準確率的下降。因此,大部分前人工作認為在分類任務上,必須對模型的準確率(魚)和魯棒性(熊掌)做一個折中,兩者是無法兼得的。《A Closer Look at Accuracy vs.
  • 簡單且可擴展的圖神經網絡
    將 d 維節點特徵排列成 n × d 維矩陣X此處 n 表示節點數),在流行的 GCN (圖卷積網絡)模型【2】中實現的對圖最簡單的卷積類操作,將節點方向的變換和跨相鄰節點的特徵擴散相結合:Y= ReLU(AXW)這裡W是所有節點共享的可學習矩陣,A是線性擴散操作符,相當於鄰域中特徵的加權平均值【3】。這種形式的多層可以像傳統的卷積神經網絡中那樣按順序應用。
  • 簡單且可擴展的圖神經網絡
    所提出的可擴展架構,我們稱之為可擴展初始類圖網絡(Scalable Inception-like Graph Network,SIGN),對於節點分類任務,其形式如下:Y= softmax(ReLU(XW₀ |
  • 21秒看盡ImageNet屠榜模型,60+模型架構同臺獻藝
    如上展示了 13 到 19 年的分類任務 SOTA 效果演進,真正有大幅度提升的方法很多都在 13 到 15 年提出,例如 Inception 結構、殘差模塊等等。Leaderboard 地址:https://www.paperswithcode.com/sota/image-classification-on-imagenet機器之心根據視頻和網站內容進行了整理。以下為一些著名的模型、發布時間、Top-1 準確率、參數量,以及相關的論文連結。發布時取得 SOTA 的模型名以紅色字體標出。
  • 教程|從檢查過擬合到數據增強,一文簡述提升神經網絡性能方法
    本文簡要介紹了提升神經網絡性能的方法,如檢查過擬合、調參、算法集成、數據增強。神經網絡是一種在很多用例中能夠提供最優準確率的機器學習算法。但是,很多時候我們構建的神經網絡的準確率可能無法令人滿意,或者無法讓我們在數據科學競賽中拿到領先名次。所以,我們總是在尋求更好的方式來改善模型的性能。有很多技術可以幫助我們達到這個目標。
  • 決策樹算法——選擇困難症的「良藥」
    01新冠檢測和決策樹的基本原理決策樹算法是一種典型的、逼近離散函數值的分類方法。(註:以上例子僅為了解釋決策樹算法的模擬描述,不一定代表真實情況)02決策樹算法解決選擇困難症隨著新冠疫情逐步得到緩解85%,可以說預測準確率還不錯,應該能夠為小明解決出行的選擇問題了。
  • 什麼是人工神經網絡(ANN)?
    人工神經元的結構,人工神經網絡的基本組成部分(來源:維基百科)從本質上講,這聽起來像是一個非常瑣碎的數學運算。但是,當您將成千上萬的神經元多層放置並堆疊在一起時,您將獲得一個人工神經網絡,可以執行非常複雜的任務,例如對圖像進行分類或識別語音。
  • 人工智慧瓶頸之神經網絡的可解釋探討
    可理解性,即深度神經網絡內部工作原理透明,各模塊作用意義可見,能對模型輸出結果做出解釋,揭示其背後的決策邏輯,並能有效地分析模型內部的邏輯漏洞和數據死角,解決基於深度神經網絡的人工智慧系統所面臨的不可審查問題。因此,隨著基於深度神經網絡的人工智慧系統的廣泛應用,亟須對神經網絡的可解釋性進行研究並構造可解釋的神經網絡,從而提高人工智慧系統的安全性,保障人工智慧應用在各大領域能安全有效地運行。
  • 卷積神經網絡性能優化(提高準確率)
    神經網絡是一種在很多用例中能夠提供最優準確率的機器學習算法。但是,很多時候我們構建的神經網絡的準確率可能無法令人滿意,或者無法讓我們在數據科學競賽中拿到領先名次。所以,我們總是在尋求更好的方式來改善模型的性能。有很多技術可以幫助我們達到這個目標。本文將介紹這些技術,幫助大家構建更準確的神經網絡。
  • 人工智慧從寒冬到復興:從神經網絡到DNN
    而且IBM在這方面是了不起的,他們一個做語音的經理有次說,每次我們加一倍的數據,準確率就往上升;我們每炒掉一個語言學家,準確率也上去。  決策樹也是第一個被語音研究者所使用。然後就是貝葉斯網絡(Bayesian Network),幾年前紅得不得了,當然現在都是用深度學習網絡(deep neural network, DNN,在輸入和輸出之間有多個隱含層的人工神經網絡)了。