閱讀本文需要的背景知識點:拉格朗日乘數法、一丟丟編程知識
前面學習了一種用回歸的方式來做分類的算法——對數機率回歸算法,下面再來學習另一種分類算法——線性判別分析算法1(Linear Discriminant Analysis Algorithm/LDA),該算法由羅納德·艾爾默·費希爾在1936年提出,所以也被稱為費希爾的線性鑑別方法(Fisher's linear discriminant)先來看下圖,假設有二分類的數據集,「+」表示正例,「-」表示反例。線性判別分析算法就是要設法找到一條直線,使得同一個類別的點在該直線上的投影儘可能的接近,同時不同分類的點在直線上的投影儘可能的遠。該算法的主要思想總結來說就是要「類內小、類間大」,非常類似於在軟體設計時說的「低耦合、高內聚」。
來源:《機器學習》-周志華
假設有樣本數為N的數據集,X_i表示第i個樣本點的特徵向量,y_i表示第i個樣本點的標籤值,w表示直線的權重係數。
樣本點到直線的投影向量
均值向量與協方差矩陣
我們知道樣本點的協方差可以用于衡量兩個變量的總體誤差,那麼可以使用協方差的大小來表示類內。而樣本點的均值點可以用來表示相對位置,那麼可以使用均值點來表示類間。我們的目標是讓投影的「類內小、類間大」,那麼可以寫出對應的代價函數如下:
分子為均值向量大小之差的平方,該值越大代表類間越大。分母為兩類樣本點的協方差之和,該值越小代表類內越小,我們的目標就是求使得該代價函數最大時的w:(2)可以將公共的w的轉置與w提出來,觀察後可以寫成兩類樣本點的均值向量之差(3)中間兩項為實數可以提到前面,w為單位向量,與自己相乘為1(3)可以將公共的w的轉置與w提出來,中間改寫成樣本點向量與樣本點均值向量之差代價函數最優化
(1)代價函數的新形式,為S_b與S_w的"廣義瑞利商3(generalized Rayleigh quotient)"
(2)可以看到代價函數分子分母都是w的二次項,所以代價函數與w的長度無關,即縮放w不影響代價函數,不妨令分母為1。可以將問題轉化為當分母為1時,求分子前面加一個負號的最小值。
(3)可以運用拉格朗日乘數法4,引入一個新的變量λ,可以將(2)式改寫成新的形式
(4)對(3)式求偏導並令其等於零向量
(5)觀察後發現S_b*w的方向恆為兩類樣本點的均值向量之差的方向,不妨令其為λ倍的兩類樣本點的均值向量之差
(6)這樣就可以求出了w的方向
線性判別分析的核心思想在前面也介紹過——「類內小、類間大」,按照最後求得的公式直接計算即可。
(1)分別計算每一類的均值向量
(2)分別計算每一類的協方差矩陣
(3)計算每類協方差矩陣之和的逆矩陣,可以使用SVD矩陣分解來簡化求逆的複雜度
(4)帶入公式求出權重係數w
求新樣本的分類時,只需判斷新樣本點離哪一個分類的均值向量更近,則新樣本就是哪個分類,如下所示: 1def lda(X, y):
2 """
3 線性判別分析(LDA)
4 args:
5 X - 訓練數據集
6 y - 目標標籤值
7 return:
8 W - 權重係數
9 """
10 # 標籤值
11 y_classes = np.unique(y)
12 # 第一類
13 c1 = X[y==y_classes[0]][:]
14 # 第二類
15 c2 = X[y==y_classes[1]][:]
16 # 第一類均值向量
17 mu1 = np.mean(c1, axis=0)
18 # 第二類均值向量
19 mu2 = np.mean(c2, axis=0)
20 sigma1 = c1 - mu1
21 # 第一類協方差矩陣
22 sigma1 = sigma1.T.dot(sigma1) / c1.shape[0]
23 sigma2 = c2 - mu2
24 # 第二類協方差矩陣
25 sigma2 = sigma2.T.dot(sigma2) / c2.shape[0]
26 # 求權重係數
27 return np.linalg.pinv(sigma1 + sigma2).dot(mu1 - mu2), mu1, mu2
28
29def discriminant(X, w, mu1, mu2):
30 """
31 判別新樣本點
32 args:
33 X - 訓練數據集
34 w - 權重係數
35 mu1 - 第一類均值向量
36 mu2 - 第二類均值向量
37 return:
38 分類結果
39 """
40 a = np.abs(X.dot(w) - mu1.dot(w))
41 b = np.abs(X.dot(w) - mu2.dot(w))
42 return np.argmin(np.array([a, b]), axis=0)
1from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
2
3# 初始化線性判別分析器
4lda = LinearDiscriminantAnalysis()
5# 擬合線性模型
6lda.fit(X, y)
7# 權重係數
8W = lda.coef_
9# 截距
10b = lda.intercept_
下圖展示了存在二種分類時的演示數據,其中紅色表示標籤值為0的樣本、藍色表示標籤值為1的樣本:
下圖為擬合數據的結果,其中淺紅色表示擬合後根據權重係數計算出預測值為0的部分,淺藍色表示擬合後根據權重係數計算出預測值為1的部分:https://en.wikipedia.org/wiki/Linear_discriminant_analysishttps://en.wikipedia.org/wiki/Indicator_function
https://en.wikipedia.org/wiki/Rayleigh_quotient
https://en.wikipedia.org/wiki/Lagrange_multiplier
https://scikit-learn.org/stable/modules/generated/sklearn.discriminant_analysis.LinearDiscriminantAnalysis.html
註:本文力求準確並通俗易懂,但由於筆者也是初學者,水平有限,如文中存在錯誤或遺漏之處,懇請讀者通過留言的方式批評指正