本文介紹了高斯回歸的基本概念,並在 R 中利用高斯回歸實現對時間序列的預測。
為什麼要用高斯過程回歸現實實生活中,我們遇到的一個典型問題就是選擇合適的模型擬合訓練集中自變量 X 與因變量 y 之間的關係,並根據新的自變量 x 來預測對應的因變量 f
如果關係足夠簡單,那麼線性回歸就能實現很好的預測,但現實情況往往十分複雜,此時,高斯過程回歸就為我們提供了擬合複雜關係(quadratic, cubic, or even nonpolynomial)的絕佳方法
什麼是高斯過程回歸高斯過程可以看做是多維高斯分布向無限維的擴展,我們可以將 y=y1,y2,…,yn看作是從 n 維高斯分布中隨機抽取的一個點
對高斯過程的刻畫,如同高斯分布一樣,也是用均值和方差來刻畫。通常在應用高斯過程 f∼GP(m,K)的方法中,都是假設均值 m 為零,而協方差函數 K 則是根據具體應用而定
高斯回歸的本質其實就是通過一個映射把自變量從低維空間映射到高維空間(類似於支持向量機中的核函數將低維線性不可分映射為高維線性可分),只需找到合適的核函數,就可以知道 p(f|x,X,y)的分布,最常用的就是高斯核函數
高斯過程回歸的基本流程再利用高斯過程回歸時,不需要指明 f(x)的具體形式,如線性 f(x)=mx+c,或者二次等具體式,n 個訓練集的觀測值 y1,y2,…,yn會被看做多維(n 維)高斯分布中採樣出來的一個點
現在給定訓練集 x1,x2,…,xn與對應的觀測值y1,y2,…,yn,由於觀測通常是帶噪聲的,所以將每個觀測 y 建模為某個隱函數 f(x) 加上一個高斯噪聲,即
其中,f(x)被假定給予一個高斯過程先驗,即
其中協方差函數 k(x,x′)可以選擇不同的單一形式,也可以採用協方差函數的組合形式,由於假設均值為零,因此最後結果的好壞很大程度上取決於協方差函數的選擇。不同的協方差函數形式參見這篇文章對 Covariance Functions 的詳細介紹。常見的協方差函數如下,參見 Wikipedia-Gaussian Process
根據高斯分布的性質以及測試集和訓練集數據來自同一分布的特點,可以得到訓練數據與測試數據的聯合分布為高維的高斯分布,有了聯合分布就可以比較容易地求出預測數據 y∗ 的條件分布 p(y∗|y),對 y∗的估計的估計,我們就用分布的均值來作為其估計值
利用高斯過程進行時間序列預測R 中 kernlab 包的 gausspr 函數可以進行高斯回歸,並實現預測,以下面這個包含 46 個月的時間序列 ts7 為例
利用趨勢回歸併進行預測library(kernlab)
library(ggplot2)
library(gplots)
library(forecast)
library(data.table)
library(tidyr)
library(plotly)
temp <- data.table(ts7)
fit <- gausspr(demand~t, data=temp)
temp$fitted <- predict(fit, temp[,.(t)])
ggplot(temp, aes(x=year_month, group=1))
+ geom_line(aes(y=demand, col="demand"), size=1)
+ geom_line(aes(y=fitted, col="fitted"), size=1)
+ theme_bw() + theme(axis.text.x=element_text(angle=45,hjust=1,vjust=1))
+ scale_x_discrete(breaks=temp$year_month[seq(2,44,3)])
只利用趨勢項進行高斯回歸的擬合效果如下
然後用過去三年的時間序列作為訓練集對未來一個月的需求進行循環預測
temp1 <- data.table(ts7, fitted=0)
for (k in 0:9){
train <- temp1[(1+k):(36+k),3:4]
fit <- gausspr(demand~t, data=train)
temp1[(37+k), "fitted"] <- predict(fit, temp1[(37+k),.(t)])
}
利用趨勢+季節回歸併進行預測首先,去除趨勢之後,檢查去趨勢之後的時間序列是否具有明顯的季節性,並找出 CV 最小的前三個季節
temp$demand_detrend <- temp$demand - temp$fitted
ggplot(temp, aes(x=year_month, group=1)) + geom_line(aes(y=demand, col="demand"), size=1) + geom_line(aes(y=fitted, col="fitted"), size=1) + geom_line(aes(y=demand_detrend, col="demand_detrend")) + theme_bw() + theme(axis.text.x=element_text(angle=45,hjust=1,vjust=1)) + scale_x_discrete(breaks=temp$year_month[seq(2,44,3)])
季節性雷達圖
ggseasonplot(temp$demand_detrend, polar=TRUE) + ggtitle("Seasonal Plot") + geom_line(size=1) + theme_bw()
季節性箱形圖
ggplot(temp, aes(x=month, y=demand_detrend)) + geom_boxplot() + theme_bw()
獲取季節 cv 最小的前 3 個季節分別是12月、2月、10月
as.numeric(temp[, .(cv=sd(demand_detrend)/mean(abs(demand_detrend))), by=month][order(cv)][1:3,month])
加入全部 12 個月作為季節性之後,再對最後的 10 個月進行循環預測
temp2 <- data.table(ts7, fitted=0)
for (k in 0:9){
train <- temp2[(1+k):(36+k),3:15]
fit <- gausspr(demand~., data=train)
temp2[(37+k), "fitted"] <- predict(fit, temp2[(37+k),4:15])
}
預測結果比較temp <- data.frame(temp1[,c(1:3,16)], temp2[,16])
colnames(temp)[3:5] <- c("Actual_demand","Predict_trend","Predict_trend+seasonal")
temp[4:5] <- round(temp[4:5], 2)
temp[temp==0] <- NA
temp <- gather(temp[,-2], key="Series", value="value", -year_month)
p <- ggplot(temp, aes(x=year_month, y=value, group=Series, col=Series))
+ geom_line(size=1) + geom_point() + theme_bw()
+ theme(axis.text.x=element_text(angle=45,hjust=1,vjust=1))
+ scale_x_discrete(breaks=temp$year_month[seq(2,44,3)])
postlink <- plotly_POST(p, filename = "GPR Prediction Example")
postlink
總結當隨機變量呈現明顯的非線性趨勢時,高斯過程回歸能夠很好地預測線性預測的不足
季節性並不一定能夠提高預測效果,當某些月份的需求變動幅度很大時,加入季節虛擬變量反而會增大預測誤差
高斯過程不僅能用於回歸預測,還能用於解決分類問題,有興趣的讀者請自行探究
Tips:iOS 9系統顯示上面代碼可能不正常,iOS 8\10\11都沒問題,電腦瀏覽器和Android顯示都沒問題,如果顯示代碼不正常,可以用電腦瀏覽器打開試試。