GitHub地址:
mmdetection是商湯和香港中文大學基於pytorch開源的一個深度學習目標檢測工具,包括了RPN、Fast R-CNN、Faster R-CNN、Mask R-CNN、SSD、RetinaNet以及Cascade R-CNN等,還包括了各種提取特徵的主幹網絡ResNet、ResNext、SENet、VGG、HRNet,還有包括了其它的特徵如DCN、Group Normalization、Soft-NMS、Generalized Attention等,mmdetection已經成為目標檢測競賽的必備工具。
網絡
模型下載地址:
mmdetection提供了很多的預訓練模型,模型是基於COCO_2017_train訓練的,在COCO_2017_val上測試的,通過8 NVIDIA Tesla V100 GPU訓練的,訓練時每個batch size為16(每塊顯卡2張圖片)。默認下載地址使用的是AWS的鏡像,速度可能比較慢,大家可以改為阿里雲鏡像,將下載連結中的https://s3.ap-northeast-2.amazonaws.com/open-mmlab改為https://open-mmlab.oss-cn-beijing.aliyuncs.com,經過測試部分模型下載不支持阿里雲鏡像。
Cascade R-CNN模型
mmdetection安裝需要先安裝anaconda,具體安裝步驟我這裡就不重複的
conda create -n open-mmlab python=3.7 -yconda activate open-mmlab
conda install pytorch torchvision -c pytorch
上面安裝命令默認安裝的是最新的pytorch,安裝的時候需要先看自己cuda的版本,通過nvcc -V可以查看也可以直接通過cat /usr/local/cuda/version.txt查看,如果是cuda9.0請用下面的命令安裝
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch
git clone https://github.com/open-mmlab/mmdetection.gitcd mmdetection
pip install mmcvpython setup.py develop 34;pip install -v -e .&step1 找到你環境的安裝目錄 pip --version
這裡需要先根據之前提供的模型下載地址先下載預訓練模型
from mmdet.apis import init_detector, inference_detector, show_resultimport mmcvconfig_file = &39;checkpoint_file = &39;39;cuda:0& test a single image and show the resultsimg = &39; visualize the results in a new windowshow_result(img, result, model.CLASSES)39;result.jpg& test a video and show the resultsvideo = mmcv.VideoReader(&39;)for frame in video: result = inference_detector(model, frame) show_result(frame, result, model.CLASSES, wait_time=1)
如果我們想要在自己的數據集上訓練一個目標檢測模型,我們需要先標記數據可以使用labelme或labelImg工具進行標記
利用mmdetection/mmdet/datasets/custom.py類來加載數據,數據格式如下,我們需要將所有的圖片和對應的標籤文件最終合成一個下面這樣數據格式的文件,可以自己寫一個腳本來進行轉換
Annotation format: [ { &39;: &39;, &39;: 1280, &39;: 720, &39;: { &39;: <np.ndarray> (n, 4), &39;: <np.ndarray> (n, ), &39;: <np.ndarray> (k, 4), (optional field) &39;: <np.ndarray> (k, 4) (optional field) } }, ... ]
根據自己選擇的模型在mmdetection/configs/目錄下找到對應的配置文件,可以根據自己的需要去修改一些參數和網絡的結構,這裡有幾個參數是必須要注意一下dataset_type,根據自己選擇的數據格式進行修改,如果你使用的datasets文件是custom.py,就需要改為dataset_type = &39;,就是對應文件裡面dataset的類名
data_root:為數據存放的目錄
ann_file:就是上面合成數據文件的路徑
img_prefix:圖片存放的路徑
checkpoint_config = dict(interval=1):保存模型間隔的epoch,為1表示每次epoch之後都保存模型
total_epochs:迭代總的epoch次數
work_dir:模型的保存目錄
load_from:預訓練模型的目錄,epoch從0開始訓練
resume_from:重新訓練模型的目錄,根據保存模型時的epoch開始訓練
單個GPU訓練
python tools/train.py ${CONFIG_FILE}
多GPU訓練
./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
可選參數: