VGG是在2014年由牛津大學著名研究組VGG(Visual Geometry Group)提出,斬獲該年ImageNet競賽中Localization Task(定位任務)第一名和Classification Task(分類任務)第二名。
具體的來說,論文中寫到通過堆疊兩個3×3的卷積核代替5×5的卷積核,堆疊三個3×3的卷積核替代7×7的卷積核。也就說小的多個卷積核堆疊和一個大的卷積核具有相同的感受野,提取的特徵是一樣的。
在卷積神經網絡中,決定某一層輸出結果中一個元素所對應的輸入層的區域大小,被稱為感受野(receptive field)。通俗的解釋是,輸出feature map上的一個單元對應輸入層上的區域大小。
下面使用深度學習框架pytorch搭建VGG網絡,IDE環境為Pycharm2020.2.3,Pytorch1.7.1,python3.6.8。數據集以花分類數據集為例,最後進訓練得到模型,並在網上隨機下載一張鬱金香圖片預測成功。
"""Author: LGDFileName: modelDateTime: 2021/1/24 21:30 SoftWare: PyCharm"""import torch.nn as nnimport torch
class VGG(nn.Module): def __init__(self, features, num_class=1000, init_weights=False): super(VGG, self).__init__() self.features = features self.classifier = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(512 * 7 * 7, 2048), nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(2048, 2048), nn.ReLU(True), nn.Linear(2048, num_class) ) if init_weights: self._initialize_weights()
def forward(self, x): x = self.features(x) x = torch.flatten(x, start_dim=1) x = self.classifier(x)
return x
def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0)
def make_features(cfg: list): layers = [] in_channels = 3 for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels=in_channels, out_channels=v, kernel_size=3, padding=1) layers += [conv2d, nn.ReLU(True)] in_channels = v
return nn.Sequential(*layers)
cfgs = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']}
def vgg(model_name='vgg16', **kwargs): try: cfg = cfgs[model_name] except: print("Waring: model name {} not in cfgs dict!".format(model_name)) exit(-1)
model = VGG(make_features(cfg), **kwargs)
return model