如果大家一直在從事有關數據科學或機器學習的知識的研究,那麼大家肯定遇到過使用MNIST數據集的文章及項目。該數據集一共包括70,000張圖像,其中每個圖像是0到9十個手寫數字中的一個。我們使用相同的數據集來探索在微調機器學習模型參數時產生的前後差異。
本文我們結合代碼詳細的解釋了如何使用GridSearchCV來找到該數據集的最佳擬合參數,並使用它們來提高模型的預測準確性並改善混淆矩陣。
GridSearchCV
首先我們介紹一下GridSearchCV,GridSearchCV是一種調參手段;實質上是通過窮舉搜索:在所有候選的參數選擇中,通過循環遍歷,嘗試每一種可能性,表現最好的參數就是最終的結果。其原理就像是在數組裡找最大值。
為什麼叫網格搜索?以有兩個參數的模型為例,參數a有3種可能,參數b有4種可能,把所有可能性列出來,可以表示成一個3*4的表格,其中每個cell就是一個網格,循環過程就像是在每個網格裡遍歷、搜索,所以叫grid search。
導入庫和數據集
工程開始之前我們首先導入必要的庫並將訓練和測試數據作成.csv格式的文件。數據集中的每一行均由一個標籤和784個像素值組成,表示為28x28的圖像尺寸。
整個訓練集數據包括60,000張圖像,而測試數據集包括10,000張圖像。一旦我們有數據,我們便可以將它的特徵和標籤分別存儲在train_X,train_y,test_X和test_y中。
探索數據集
分析類的分布
正如我們在之前的文章中所討論的那樣,每個類的數據應該大致相同,以確保合適的模型訓練且基本沒有其他噪聲引起的偏差。
通過上圖我們發現每個數字的數量會有一些差異。然而,這些差異並不是太大,深度學習模型
查看訓練圖像
讓我們看看真實的圖像是什麼樣的。我們從訓練數據中隨機選擇10個圖像並使用plt.imshow()進行顯示。
從數據集中隨機選擇的圖像
我們在這10個隨機圖像中立即看到的是任何一種類型的數字之間的差異。看上面10張圖片中的所有數字為4的圖片。其中第一個是粗體和直線,第二個是粗體和對角線,而第三個是細和對角線。如果模型可以從數據中學習並實際檢測出所有不同的樣式的4,那將是非常了不起。
應用機器學習
我們決定使用隨機森林分類器訓練數據並預測測試數據。首先我們使用了所有參數的默認值。
接下來,使用預測,我們計算了準確度和混淆矩陣。
通過觀察我們發現該模型的預測準確率已經達到94.4%。混淆矩陣表明該模型能夠正確預測大量的圖像。接下來,我們決定調整模型參數以嘗試改進結果。
5. 參數調整
為了確定模型的最佳參數值組合,我們使用了GridSearchCV。這是一個由sklearn庫提供的方法,它允許我們定義一組我們希望為給定模型嘗試的可能值,通過它訓練數據並從參數值的組合中得到最佳估計器。
在這種特殊情況下,我們決定為一些參數選擇一系列估計值。估計值的數量可以是100或200,最大深度可以是10,50或100,最小樣本分為2或4,最大特徵可以基於sqrt或log2。
通過GridSearchCV,我們使用的例子是random_forest_classifier。我們將可能的參數值傳遞給param_grid,並將交叉驗證的值設置為5。設置verbose為5將日誌輸出到控制臺,並且njobs為-1使模型使用機器上的所有核心。然後,我們訓練這個網格,並用它來找到最好的估計。
最後,我們使用這個最佳模型來預測測試數據。
看一下上面的準確性,我們看到通過改變模型的參數,精度從94.42%提高到97.08%。混淆矩陣還表明更多圖像被正確分類。
機器學習不僅僅是讀取數據並應用多種算法,直到我們得到一個好的模型才能使用,但它還涉及對模型進行微調以使它們最適合當前的數據。
確定正確的參數是決定使用哪種算法並根據數據充分利用它的關鍵步驟之一。
結論
在本文中,我們討論了一個項目,我們通過選擇最佳的參數值組合來提高隨機森林分類器的準確性GridSearchCV。我們使用MNIST數據集並將準確度從94.42%提高到97.08%。