隨機森林是一個包含多個決策樹的分類器,因其運算速度快、分類精度高、算法穩定等特點,被廣泛應用到遙感圖像的分類研究中。Scikit-Learn作為Python 程式語言的免費軟體機器學習庫,提供了對隨機森林算法的支持,但沒有提供針對遙感影像分類的相關函數。因此,本篇文章將為讀者介紹利用Python及其擴展包Scikit-Learn對遙感影像進行隨機森林分類的完整過程,包括:ShapeFile格式樣本數據的讀取、柵格數據讀取和裁剪、利用Scikit-Learn的RandomForestClassifier模塊進行樣本訓練和遙感影像分類。
一、Scikit-Learn的安裝
直接執行命令pip install scikit-learn,所有依賴庫都會自動安裝。安裝完成後,添加代碼from sklearn.ensemble import RandomForestClassifier即可使用。
二、樣本繪製
在ArcGIS中繪製訓練樣本,格式為shpfile,可以是點類型或線類型,建立Value欄位,用於存儲分類編號,例如1-溼地,2-湖泊,3-水稻。
三、柵格數據裁剪
通過shp樣本獲取對應的柵格值,需要使用多邊形裁剪柵格數據,我們使用射線算法。全部代碼如下(TrainByRandomForest.py):
from sklearn.ensemble import RandomForestClassifier
import osgeo
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
def GetSubRaster(inraster,polygonPoints:list):
polygonPoints.append(polygonPoints[0])#面多邊形坐標封閉
print("當前多邊形節點數量:"+str(len(polygonPoints)))
#計算最小邊界矩形
minX=10000000000000
maxX=-minX
minY=100000000000000000
maxY=-minY
for point in polygonPoints:
if point.X<minX:minX=point.X
if point.X>maxX:maxX=point.X
if point.Y<minY:minY=point.Y
if point.Y>maxY:maxY=point.Y
leftX=minX
upY=maxY
rightX=maxX
bottomY=minY
rds = gdal.Open(inraster)
transform = (rds.GetGeoTransform())
lX = transform[0]#左上角點
lY = transform[3]
rX = transform[1]#解析度
rY = transform[5]
wpos=int((leftX-lX)/rX)
hpos=int((upY-lY)/rY)
width=int((rightX-leftX)/rX)
height=int((bottomY-upY)/rY)
BandsCount = rds.RasterCount
arr = rds.ReadAsArray(wpos,hpos,width,height)
fixX=list()
nodatavalue=rds.GetRasterBand(1).GetNoDataValue()
for i in range(height):
if height>200:
print(f"多邊形裁剪進度:{round(((i+1)/height)*100,4)}%")
Y=upY+i*rY+.00001
#射線算法只需要比對多邊形的一條水平線上的邊
pointsindex=list()
for k in range(len(polygonPoints)-1):
point1=polygonPoints[k]
point2=polygonPoints[k+1]
if (point1.Y>=Y and point2.Y<=Y) or (point1.Y<=Y and point2.Y>=Y):
pointsindex.append(k)
for j in range(width):
count=0
for m in (pointsindex):
point1=polygonPoints[m]
point2=polygonPoints[m+1]
X=leftX+j*rX+.00001
if point1.X==point2.X:
intersectX=point1.X
if intersectX>X:count+=1
else:
k=(point2.Y-point1.Y)/(point2.X-point1.X)
if k==0:
if X<point1.X or X<point2.X:
count+=1
else:
intersectX=(Y-point1.Y)/k+point1.X
ifintersectX>X:count+=1
if count%2==0:
if BandsCount>1:
for bc in range(BandsCount):
arr[bc][i][j]=(nodatavalue)
else:
arr[i][j]=-1
#為了測試結果的正確性,可以先將其寫到硬碟
#WriteRaster("test.tif",arr,inraster,width,height,BandsCount,leftX,upY)
return arr,width,height,BandsCount,leftX,upY
四、創建分類器並訓練樣本
def createClassifier(inraster,inshp,field:str="Id",treenum:int=100):
rasterspatial = gdal.Open(inraster)
spatial2=osr.SpatialReference()
spatial2.ImportFromWkt(rasterspatial.GetProjectionRef())
shpspatial=ogr.Open(inshp)
layer=shpspatial.GetLayer(0)
spatial1=layer.GetSpatialRef()
ct=osr.CreateCoordinateTransformation(spatial1,spatial2)
oFeature = layer.GetNextFeature()
# 下面開始遍歷圖層中的要素
geom=oFeature.GetGeometryRef()
if geom.GetGeometryType()==ogr.wkbPoint:
return createClassifierByPoint(inraster,inshp)
k=geom.GetGeometryType()
if geom.GetGeometryType()!=ogr.wkbPolygon:
print("樣本必須為單部件多邊形")
return False
trainX = list()
trainY = list()
print("讀取樣本")
while oFeature is not None:
geom=oFeature.GetGeometryRef()
wkt=geom.ExportToWkt()
points=WKTToPoints(wkt)
polygonPoints=[]
value=oFeature.GetField(field)
for point in points:
pC=ct.TransformPoint(point.X,point.Y,0)
polygonPoints.append(Point(pC[0],pC[1]))
arr,width,height,BandsCount,leftX,upY=GetSubRaster(inraster,polygonPoints)
for i in range(height):
for k in range(width):
nodata=True
tem = list()
for bc in range(BandsCount):
v=int(arr[bc][i][k])
tem.append(v)
if v>0:nodata=False
if nodata:
continue
trainX.append(tem)
trainY.append(int(value))
oFeature = layer.GetNextFeature()
ct=None
spatial1=None
spatial2=None
print("訓練樣本")
clf = RandomForestClassifier(n_estimators=treenum)
clf.fit(trainX, trainY)#訓練樣本
print("訓練完成")
return clf
五、隨機森林分類
from sklearn.ensemble import RandomForestClassifier
import osgeo
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
import numpy
import os
import sys
import TrainByRandomForest as tbrf
def RandomForestClassification(ClassifyRaster,TrainRaster,TrainShp,outfile,blockSize=0,treenum=100,max_depth=10):
rds = gdal.Open(ClassifyRaster)
#print((rds.GetRasterBand(1).DataType))
transform = (rds.GetGeoTransform())
lX = transform[0]#左上角點
lY = transform[3]
rX = transform[1]#解析度
rY = transform[5]
width = rds.RasterXSize
height = rds.RasterYSize
bX = lX + rX * width#右下角點
bY = lY + rY * height
BandsCount = rds.RasterCount
clf = tbrf.createClassifier(TrainRaster,TrainShp)
Z = list()
fixX = list()
if blockSize == 0:
p,a = memory_usage()
pv = 0.6 / 10000
checkMemory(2000)#內存小於2GB,不在計算
bl = (a - 2000) / pv / height / BandsCount
blockSize = math.ceil(height / bl)
if blockSize < 1:blockSize = 1
if blockSize > 1:blockSize+=5
if blockSize != 1:
blockHeight = 0
modHeight = 0
modHeight = height % blockSize
if modHeight == 0:
blockHeight = int(height / blockSize)
else:
blockHeight = int(height / blockSize)
print(f"分塊大小{width}*{blockHeight}")
for bs in range(blockSize):
print(f"計算塊{bs+1}/{blockSize}")
checkMemory(500)
arr = rds.ReadAsArray(0,bs * blockHeight,width,blockHeight)
for i in range(blockHeight):
print(f"分塊:{bs+1}/{blockSize}添加分類數據{round((i+1)*100/blockHeight,4)}%")
for k in range(width):
tem = list()
for bc inrange(BandsCount):
tem.append(int(arr[bc][i][k]))
fixX.append(tem)
print(f"分塊:{bs+1}/{blockSize}計算分類結果……")
checkMemory(800)
z = clf.predict(fixX)
Z.extend(z)
fixX = list()
arr = None
print(f"計算餘數:{width}*{modHeight}")
checkMemory(500)
arr = rds.ReadAsArray(0,blockSize *blockHeight,width,modHeight)
if modHeight > 0:
for i in range(modHeight):
print(f"餘塊:添加分類數據{round((i+1)*100/modHeight,4)}%")
for k in range(width):
tem = list()
for bc in range(BandsCount):
tem.append(int(arr[bc][i][k]))
fixX.append(tem)
print("餘塊:計算分類結果……")
checkMemory(500)
z = clf.predict(fixX)
Z.extend(z)
Z = numpy.array(Z)
#Z=Z.reshape(1,width*height)
Z = Z.reshape(height,width)
fixX = None
arr = None
else:
checkMemory(1000)
arr = rds.ReadAsArray(0,0,width,height)
for i in range(height):
print(f"添加訓練樣本{round((i+1)*100/height,4)}%")
for k in range(width):
tem = list()
for bc in range(BandsCount):
tem.append(int(arr[bc][i][k]))
fixX.append(tem)
arr = None
print("計算分類結果……")
Z = clf.predict(fixX)
Z = numpy.array(Z)
Z = Z.reshape(height,width)
driver = gdal.GetDriverByName("GTiff")
filepath,filename = os.path.split(outfile)
short,ext = os.path.splitext(filename)
print("創建輸出文件")
out = driver.Create(outfile,width,height,1,rds.GetRasterBand(1).DataType)
out.SetGeoTransform(transform)
out.SetProjection(rds.GetProjectionRef())
print("寫入數據……")
out.GetRasterBand(1).WriteArray(Z)
out.FlushCache()
out = None
print("計算完成")
六、分類圖像和結果
點擊閱讀原文查看數據獲取方法