K-最近鄰(KNN)是一種有監督的機器學習算法,可用於解決分類和回歸問題。它基於一個非常簡單的想法,數據點的值由它周圍的數據點決定。考慮的數據點數量由k值確定。因此,k值是算法的核心。
KNN分類器根據多數表決原則確定數據點的類別。如果k設置為5,則檢查5個最近點的類別。也可以根據多數類進行回歸預測,同樣,KNN回歸取5個最近點的平均值。
在本文中,我們將研究k值對於分類任務的重要性。
使用Scikit learn的make_classification函數創建一個示例分類數據集。
import numpy as npimport pandas as pdfrom sklearn.datasets import make_classificationX, y = make_classification( n_samples=1000, n_features=2, n_informative=2, n_redundant=0, n_classes=2, class_sep=0.8)數據集包含屬於2個類的1000個樣本。還可以創建數據點的散點圖(即樣本)。
import matplotlib.pyplot as pltplt.figure(figsize=(12,8))plt.scatter(X[:,0], X[:,1], c=y)
選擇最優k值是建立一個合理、精確的knn模型的必要條件。
如果k值太低,則模型會變得過於具體,不能很好地泛化。它對噪音也很敏感。該模型在訓練組上實現了很高的精度,但對於新的、以前看不到的數據點,該模型的預測能力較差。因此,我們很可能最終得到一個過擬合的模型。如果k選擇得太大,模型就會變得過於泛化,無法準確預測訓練和測試集中的數據點。這種情況被稱為欠擬合。我們現在創建兩個不同的knn模型,k值為1和50。然後創建預測的散點圖,以查看差異。
第一步是將數據集拆分為測試子集。
from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)第一個模型是k=1的knn模型。
from sklearn.neighbors import KNeighborsClassifierknn1 = KNeighborsClassifier(n_neighbors=1)knn1.fit(X_train, y_train)predict1 = knn1.predict(X_test)plt.figure(figsize=(12,8))plt.title("KNN with k=1", fontsize=16)plt.scatter(X_test[:,0], X_test[:,1], c=predict1)
只需將n_neighbors參數更改為50,就可以創建下一個模型。下面是這個模型在測試集上的預測。
你可以看到當k增加時,泛化是如何變化的。
我們需要的不是一個過於籠統或過於具體的模型。我們的目標是創建一個健壯和精確的模型。
過擬合模型(過於具體)對數據點值的細微變化很敏感。因此,在數據點發生微小變化後,預測可能會發生巨大變化。
欠擬合模型(過於通用)可能在訓練和測試子集上都表現不佳。
不幸的是,沒有一個找到最佳k值的解決方案。它取決於數據集的底層結構。然而,有一些工具可以幫助我們找到最佳k值。
GridSearchCV函數可用於創建、訓練和評估具有不同超參數值的模型。k是knn算法中最重要的超參數。
我們將創建一個GridSearchCV對象來評估k值從1到20的20個不同knn模型的性能。參數值作為字典傳遞給param_grid parameter。
from sklearn.model_selection import GridSearchCVknn = GridSearchCV( estimator = KNeighborsClassifier(), param_grid = {'n_neighbors': np.arange(1,21)}, scoring='neg_log_loss', cv = 5)你可以使用scikit learn的任何評分標準。我使用了log丟失,這是分類任務中常用的度量。
我們現在可以將數據集調整到GridSearchCV對象。不需要分割訓練和測試子集,因為應用了5倍交叉驗證。
knn.fit(X, y)我們得到每個k值的交叉驗證的平均測試分數。
scores = pd.Series(abs(knn.cv_results_['mean_test_score']))scores.index = np.arange(1,21)我更喜歡將結果保存在Pandas系列中,以便能夠輕鬆地繪製它們。畫出分數。
scores.plot(figsize=(12,8))plt.title("Log loss of knn with k vales from 1 to 20", fontsize=16)
在k值為10之後,測試集上的損失似乎沒有多少改善。
結論
knn算法是一種應用廣泛的算法。它簡單易懂。由於它不作任何假設,所以也可以用來解決非線性問題。
消極的一面是,由於模型需要存儲所有的數據點,因此隨著數據點數量的增加,knn算法變得非常緩慢。由於這個原因,它也不具有內存效率。
最後,它對離群值很敏感,因為離群值在決策中也有投票權。
感謝你的閱讀。