這個項目的代碼可以在我的Github上找到
https://github.com/HOD101s/RockPaperScissor-AI-簡介
這個項目的基礎是深度學習和圖像分類,目的是創建一個簡單而有趣的石頭剪刀布遊戲。首先,這個項目是我在5月份的COVID19隔離期中無聊的產物,希望當你讀到這個時,一切都恢復正常了。我的目的是通過這篇文章用簡單的術語向初學者解釋這個項目的基本原理。讓我們開始吧!
在構建任何類型的深度學習應用程式時,有三個主要步驟:
收集和處理數據建立一個合適的人工智慧模型部署使用整個項目都引用了我的Github repo,並與之攜手並進,所以請做好參考準備。
項目地址:https://github.com/HOD101s/RockPaperScissor-AI-
收集我們的數據
任何深度學習模型的基礎都是數據,任何一位機器學習工程師都會同意這一點,在ML中,數據遠比算法本身重要。我們需要收集石頭,布和剪刀的符號圖像,我沒有下載別人的數據並在上面進行訓練,而是製作了自己的數據集,鼓勵你也建立自己的數據集。之後嘗試更改數據並重新訓練模型,以查看數據對深度學習模型究竟有怎樣的影響。
PATH = os.getcwd()+'\\'cap = cv2.VideoCapture(0)label = sys.argv[1]SAVE_PATH = os.path.join(PATH, label)try: os.mkdir(SAVE_PATH)except FileExistsError:passct = int(sys.argv[2])maxCt = int(sys.argv[3])+1print("Hit Space to Capture Image")whileTrue: ret, frame = cap.read() cv2.imshow('Get Data : '+label,frame[50:350,100:450])if cv2.waitKey(1) & 0xFF == ord(' '): cv2.imwrite(SAVE_PATH+'\\'+label+'{}.jpg'.format(ct),frame[50:350,100:450]) print(SAVE_PATH+'\\'+label+'{}.jpg Captured'.format(ct)) ct+=1if ct >= maxCt:breakcap.release()cv2.destroyAllWindows()
我使用了Python的OpenCV庫進行所有與相機相關的操作,所以這裡的label指的是圖像屬於哪個類,根據標籤,圖像保存在適當的目錄中。ct和maxCt是用來保存圖像的起始索引和最終索引,剩下的是標準的OpenCV代碼,用於獲取網絡攝像頭源並將圖像保存到目錄中。需要注意的一點是,我所有的圖片維數都是300 x 300的。運行此目錄樹後,我的目錄樹如下所示。
C:.├───paper │ paper0.jpg │ paper1.jpg │ paper2.jpg│├───rock │ rock0.jpg │ rock1.jpg │ rock2.jpg│└───scissor scissor0.jpg scissor1.jpg scissor2.jpg
如果你引用的是Github存儲庫(https://github.com/HOD101s/RockPaperScissor-AI-) ,則getData.py會為你完成這項工作!
預處理我們的數據
我們需要使用圖像,而計算機可以識別數字,因此,我們將所有圖像轉換為它們各自的矢量表示,另外,我們的標籤尚待生成,由於已建立的標籤不能是文本,因此我使用shape_to_label字典為每個類手動構建了「獨熱編碼」表示。
DATA_PATH = sys.argv[1] # Path to folder containing datashape_to_label = {'rock':np.array([1.,0.,0.,0.]),'paper':np.array([0.,1.,0.,0.]),'scissor':np.array([0.,0.,1.,0.]),'ok':np.array([0.,0.,0.,1.])}arr_to_shape = {np.argmax(shape_to_label[x]):x for x in shape_to_label.keys()}imgData = list()labels = list()for dr in os.listdir(DATA_PATH):if dr notin ['rock','paper','scissor']:continue print(dr) lb = shape_to_label[dr] i = 0for pic in os.listdir(os.path.join(DATA_PATH,dr)): path = os.path.join(DATA_PATH,dr+'/'+pic) img = cv2.imread(path) imgData.append([img,lb]) imgData.append([cv2.flip(img, 1),lb]) #horizontally flipped image imgData.append([cv2.resize(img[50:250,50:250],(300,300)),lb]) # zoom : crop in and resize i+=3 print(i)np.random.shuffle(imgData)imgData,labels = zip(*imgData)imgData = np.array(imgData)labels = np.array(labels)
當我們根據類將圖像保存在目錄中時,目錄名用作標籤,該標籤使用shape_to_label字典轉換為獨熱表示。在我們遍歷系統中的文件以訪問圖像之後,cv2.imread()函數返回圖像的矢量表示。
我們通過翻轉圖像並放大圖像來進行一些手動的數據增強,這增加了我們的數據集大小,而無需拍攝新照片,數據增強是生成數據集的關鍵部分。最後,圖像和標籤存儲在單獨的numpy數組中。
cv2.imread()函數https://www.geeksforgeeks.org/python-opencv-cv2-imread-method/更多關於數據增強的信息。
https://towardsdatascience.com/data-augmentation-for-deep-learning-4fe21d1a4eb9通過遷移學習建立我們的模型:
在處理圖像數據時,有許多經過預訓練的模型可供使用,這些模型已經在具有數千個標籤的數據集上進行了訓練,由於這些模型通過其應用程式api的Tensorflow和Keras分布,我們可以使用這些模型,這使得在我們的應用程式中包含這些預先訓練的模型看起來很容易!
總之,遷移學習採用的是經過預訓練的模型,並且不包含進行最終預測的最終層,能夠區分這種情況下圖像中的特徵,並將這些信息傳遞給我們自己的Dense神經網絡。
為什麼不訓練你自己的模型呢?完全取決於你!然而,使用遷移學習可以在很多時候使你的進步更快,從某種意義上說,你避免了重複造輪子。
其他一些受歡迎的預訓練模型:
InceptionV3VGG16/19ResNetMobileNet這是一篇關於遷移學習的有趣文章!
https://ruder.io/transfer-learning/註:每當我們處理圖像數據時,幾乎都會使用卷積神經層,這裡使用的遷移學習模型就有這些層。有關CNNs的更多信息,請訪問:
https://medium.com/@RaghavPrabhu/understanding-of-convolutional-neural-network-cnn-deep-learning-99760835f148實現
我已經使用DenseNet121模型進行特徵提取,其輸出最終將輸入到我自己的Dense神經網絡中。
densenet = DenseNet121(include_top=False, weights='imagenet', classes=3,input_shape=(300,300,3))densenet.trainable=TruedefgenericModel(base): model = Sequential() model.add(base) model.add(MaxPool2D()) model.add(Flatten()) model.add(Dense(3,activation='softmax')) model.compile(optimizer=Adam(),loss='categorical_crossentropy',metrics=['acc'])return modeldnet = genericModel(densenet)history = dnet.fit( x=imgData, y=labels, batch_size = 16, epochs=8, callbacks=[checkpoint,es], validation_split=0.2)
關鍵點 :
由於我們的圖片尺寸為300x300,因此指定的輸入形狀也為3x300x300,3代表RGB的維度信息,因此該層具有足夠的神經元來處理整個圖像。我們將DenseNet層用作第一層,然後使用我們自己的Dense神經網絡。我已將可訓練參數設置為True,這也會重新訓練DenseNet的權重。儘管花了很多時間,但是這給了我更好的結果。我建議你在自己的實現中嘗試通過更改此類參數(也稱為超參數)來嘗試不同的迭代。由於我們有3類Rock-Paper-Scissor,最後一層是具有3個神經元和softmax激活的全連接層。最後一層返回圖像屬於3類中特定類的概率。如果你引用的是GitHub repo(https://github.com/HOD101s/RockPaperScissor-AI-) 的train.py,則要注意數據準備和模型訓練!至此,我們已經收集了數據,建立並訓練了模型,剩下的部分是使用OpenCV進行部署
OpenCV實現:
此實現的流程很簡單:
啟動網絡攝像頭並讀取每個幀將此框架傳遞給模型進行分類,即預測類用電腦隨意移動計算分數defprepImg(pth):return cv2.resize(pth,(300,300)).reshape(1,300,300,3)with open('model.json', 'r') as f: loaded_model_json = f.read()loaded_model = model_from_json(loaded_model_json)loaded_model.load_weights("modelweights.h5")print("Loaded model from disk")for rounds in range(NUM_ROUNDS): pred = ""for i in range(90): ret,frame = cap.read()# Countdown if i//20 < 3 : frame = cv2.putText(frame,str(i//20+1),(320,100),cv2.FONT_HERSHEY_SIMPLEX,3,(250,250,0),2,cv2.LINE_AA)# Predictionelif i/20 < 3.5: pred = arr_to_shape[np.argmax(loaded_model.predict(prepImg(frame[50:350,100:400])))]# Get Bots Moveelif i/20 == 3.5: bplay = random.choice(options) print(pred,bplay)# Update Scoreelif i//20 == 4: playerScore,botScore = updateScore(pred,bplay,playerScore,botScore)break cv2.rectangle(frame, (100, 150), (300, 350), (255, 255, 255), 2) frame = cv2.putText(frame,"Player : {} Bot : {}".format(playerScore,botScore),(120,400),cv2.FONT_HERSHEY_SIMPLEX,1,(250,250,0),2,cv2.LINE_AA) frame = cv2.putText(frame,pred,(150,140),cv2.FONT_HERSHEY_SIMPLEX,1,(250,250,0),2,cv2.LINE_AA) frame = cv2.putText(frame,"Bot Played : {}".format(bplay),(300,140),cv2.FONT_HERSHEY_SIMPLEX,1,(250,250,0),2,cv2.LINE_AA) cv2.imshow('Rock Paper Scissor',frame)if cv2.waitKey(1) & 0xff == ord('q'):break
上面的代碼片段包含相當重要的代碼塊,其餘部分只是使遊戲易於使用,RPS規則和得分。
所以我們開始加載我們訓練過的模型,它在開始程序的預測部分之前顯示倒計時,預測後,分數會根據球員的動作進行更新。
我們使用cv2.rectangle()顯式地繪製目標區域,使用prepImg()函數預處理後,只有幀的這一部分傳遞給模型進行預測。
整個play.py在我的repo上有代碼(https://github.com/HOD101s/RockPaperScissor-AI-/blob/master/play.py)。
結論:
我們已經成功地實現並學習了這個項目的工作原理,所以請繼續使用我的實現進行其它實驗學習。我做的一個主要的改進可能是增加了手部檢測,所以我們不需要顯式地繪製目標區域,模型將首先檢測手部位置,然後進行預測。我鼓勵你改進這個項目,並給我你的建議。精益求精!
原文連結:https://towardsdatascience.com/building-a-rock-paper-scissors-ai-using-tensorflow-and-opencv-d5fc44fc8222
☆ END ☆