add ssd vgg backbone support

This commit is contained in:
CaoJian 2021-02-25 23:17:49 +08:00
parent e4b5336fec
commit f4b3bafb70
8 changed files with 284 additions and 5 deletions

View File

@ -35,11 +35,12 @@ SSD discretizes the output space of bounding boxes into a set of default boxes o
The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections.
We present three different base architecture.
We present four different base architecture.
- **ssd300**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present.
- ***ssd-mobilenet-v1-fpn**, using mobilenet-v1 and FPN as feature extractor with weight-shared box predcitors.
- ***ssd-resnet50-fpn**, using resnet50 and FPN as feature extractor with weight-shared box predcitors.
- **ssd-vgg16**, reference from the paper. Using vgg16 as backbone and the same bbox predictor as the paper present.
## [Dataset](#contents)

View File

@ -21,7 +21,7 @@ import time
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
from src.dataset import create_ssd_dataset, create_mindrecord
from src.config import config
from src.eval_utils import metrics
@ -34,6 +34,8 @@ def ssd_eval(dataset_path, ckpt_path, anno_json):
is_training=False, use_multiprocessing=False)
if config.model == "ssd300":
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
elif config.model == "ssd_vgg16":
net = ssd_vgg16(config=config)
elif config.model == "ssd_mobilenet_v1_fpn":
net = ssd_mobilenet_v1_fpn(config=config)
elif config.model == "ssd_resnet50_fpn":

View File

@ -19,7 +19,7 @@ import numpy as np
import mindspore
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
from src.config import config
from src.box_utils import default_boxes
@ -40,6 +40,8 @@ if args.device_target == "Ascend":
if __name__ == '__main__':
if config.model == "ssd300":
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
elif config.model == "ssd_vgg16":
net = ssd_vgg16(config=config)
elif config.model == "ssd_mobilenet_v1_fpn":
net = ssd_mobilenet_v1_fpn(config=config)
elif config.model == "ssd_resnet50_fpn":

View File

@ -11,18 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
# ============================================================================
"""Config parameters for SSD models."""
from .config_ssd300 import config as config_ssd300
from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn
from .config_ssd_resnet50_fpn import config as config_ssd_resnet50_fpn
from .config_ssd_vgg16 import config as config_ssd_vgg16
using_model = "ssd300"
config_map = {
"ssd300": config_ssd300,
"ssd_vgg16": config_ssd_vgg16,
"ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn,
"ssd_resnet50_fpn": config_ssd_resnet50_fpn
}

View File

@ -0,0 +1,84 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for SSD models."""
from easydict import EasyDict as ed
config = ed({
"model": "ssd_vgg16",
"img_shape": [300, 300],
"num_ssd_boxes": 7308,
"match_threshold": 0.5,
"nms_threshold": 0.6,
"min_score": 0.1,
"max_boxes": 100,
"ssd_vgg_bn": False,
# learing rate settings
"lr_init": 0.001,
"lr_end_rate": 0.001,
"warmup_epochs": 2,
"momentum": 0.9,
"weight_decay": 1.5e-4,
# network
"num_default": [3, 6, 6, 6, 6, 6],
"extras_in_channels": [256, 512, 1024, 512, 256, 256],
"extras_out_channels": [512, 1024, 512, 256, 256, 256],
"extras_strides": [1, 1, 2, 2, 2, 2],
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
"feature_size": [38, 19, 10, 5, 3, 1],
"min_scale": 0.2,
"max_scale": 0.95,
"aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
"steps": (8, 16, 32, 64, 100, 300),
"prior_scaling": (0.1, 0.2),
"gamma": 2.0,
"alpha": 0.75,
# `mindrecord_dir` and `coco_root` are better to use absolute path.
"feature_extractor_base_param": "",
"pretrain_vgg_bn": False,
"checkpoint_filter_list": ['multi_loc_layers', 'multi_cls_layers'],
"mindrecord_dir": "/data/MindRecord_COCO",
"coco_root": "/data/coco2017",
"train_data_type": "train2017",
"val_data_type": "val2017",
"instances_set": "annotations/instances_{}.json",
"classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81,
# The annotation.json position of voc validation dataset.
"voc_json": "annotations/voc_instances_val.json",
# voc original dataset.
"voc_root": "/data/voc_dataset",
# if coco or voc used, `image_dir` and `anno_path` are useless.
"image_dir": "",
"anno_path": ""
})

View File

@ -27,6 +27,7 @@ from mindspore.ops import functional as F
from mindspore.ops import composite as C
from .fpn import mobilenet_v1_fpn, resnet50_fpn
from .vgg16 import vgg16
def _make_divisible(v, divisor, min_value=None):
@ -641,3 +642,78 @@ def ssd_resnet50_fpn(**kwargs):
def ssd_mobilenet_v2(**kwargs):
return SSDWithMobileNetV2(**kwargs)
class SSD300VGG16(nn.Cell):
def __init__(self, config):
super(SSD300VGG16, self).__init__()
# VGG16 backbone: block1~5
self.backbone = vgg16()
# SSD blocks: block6~7
self.b6_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6, pad_mode='pad')
self.b6_2 = nn.Dropout(0.5)
self.b7_1 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1)
self.b7_2 = nn.Dropout(0.5)
# Extra Feature Layers: block8~11
self.b8_1 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, padding=1, pad_mode='pad')
self.b8_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, pad_mode='valid')
self.b9_1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, padding=1, pad_mode='pad')
self.b9_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, pad_mode='valid')
self.b10_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1)
self.b10_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid')
self.b11_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1)
self.b11_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid')
# boxes
self.multi_box = MultiBox(config)
if not self.training:
self.activation = P.Sigmoid()
def construct(self, x):
# VGG16 backbone: block1~5
block4, x = self.backbone(x)
# SSD blocks: block6~7
x = self.b6_1(x) # 1024
x = self.b6_2(x)
x = self.b7_1(x) # 1024
x = self.b7_2(x)
block7 = x
# Extra Feature Layers: block8~11
x = self.b8_1(x) # 256
x = self.b8_2(x) # 512
block8 = x
x = self.b9_1(x) # 128
x = self.b9_2(x) # 256
block9 = x
x = self.b10_1(x) # 128
x = self.b10_2(x) # 256
block10 = x
x = self.b11_1(x) # 128
x = self.b11_2(x) # 256
block11 = x
# boxes
multi_feature = (block4, block7, block8, block9, block10, block11)
pred_loc, pred_label = self.multi_box(multi_feature)
if not self.training:
pred_label = self.activation(pred_label)
pred_loc = F.cast(pred_loc, mstype.float32)
pred_label = F.cast(pred_label, mstype.float32)
return pred_loc, pred_label
def ssd_vgg16(**kwargs):
return SSD300VGG16(**kwargs)

View File

@ -0,0 +1,99 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""VGG16 backbone for SSD"""
from mindspore import nn
from .config_ssd_vgg16 import config
pretrain_vgg_bn = config.pretrain_vgg_bn
ssd_vgg_bn = config.ssd_vgg_bn
def _get_key_mapper():
vgg_key_num = [1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]
size = len(vgg_key_num)
pretrain_vgg_bn_false = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
pretrain_vgg_bn_true = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
ssd_vgg_bn_false = [0, 2, 0, 2, 0, 2, 4, 0, 2, 4, 0, 2, 4]
ssd_vgg_bn_true = [0, 3, 0, 3, 0, 3, 6, 0, 3, 6, 0, 3, 6]
pretrain_vgg_keys = pretrain_vgg_bn_true if pretrain_vgg_bn else pretrain_vgg_bn_false
ssd_vgg_keys = ssd_vgg_bn_true if ssd_vgg_bn else ssd_vgg_bn_false
pretrain_vgg_keys = ['layers.' + str(pretrain_vgg_keys[i]) for i in range(size)]
ssd_vgg_keys = ['b' + str(vgg_key_num[i]) + '.' + str(ssd_vgg_keys[i]) for i in range(size)]
return {pretrain_vgg_keys[i]: ssd_vgg_keys[i] for i in range(size)}
ssd_vgg_key_mapper = _get_key_mapper()
def _make_layer(channels):
in_channels = channels[0]
layers = []
for out_channels in channels[1:]:
layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3))
if ssd_vgg_bn:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU())
in_channels = out_channels
return nn.SequentialCell(layers)
class VGG16(nn.Cell):
def __init__(self):
super(VGG16, self).__init__()
self.b1 = _make_layer([3, 64, 64])
self.b2 = _make_layer([64, 128, 128])
self.b3 = _make_layer([128, 256, 256, 256])
self.b4 = _make_layer([256, 512, 512, 512])
self.b5 = _make_layer([512, 512, 512, 512])
self.m1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
self.m2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
self.m3 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
self.m4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
self.m5 = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='SAME')
def construct(self, x):
# block1
x = self.b1(x)
x = self.m1(x)
# block2
x = self.b2(x)
x = self.m2(x)
# block3
x = self.b3(x)
x = self.m3(x)
# block4
x = self.b4(x)
block4 = x
x = self.m4(x)
# block5
x = self.b5(x)
x = self.m5(x)
return block4, x
def vgg16():
return VGG16()

View File

@ -25,7 +25,7 @@ from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed, dtype
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
from src.config import config
from src.dataset import create_ssd_dataset, create_mindrecord
from src.lr_schedule import get_lr
@ -86,6 +86,17 @@ def ssd_model_build(args_opt):
param_dict["network.feature_extractor.resnet." + x] = param_dict[x]
del param_dict[x]
load_param_into_net(ssd.feature_extractor.resnet, param_dict)
elif config.model == "ssd_vgg16":
ssd = ssd_vgg16(config=config)
init_net_param(ssd)
if config.feature_extractor_base_param != "":
param_dict = load_checkpoint(config.feature_extractor_base_param)
from src.vgg16 import ssd_vgg_key_mapper
for k in ssd_vgg_key_mapper:
v = ssd_vgg_key_mapper[k]
param_dict["network.backbone." + v + ".weight"] = param_dict[k + ".weight"]
del param_dict[k + ".weight"]
load_param_into_net(ssd.backbone, param_dict)
else:
raise ValueError(f'config.model: {config.model} is not supported')
return ssd
@ -106,6 +117,8 @@ def main():
init()
if config.model == "ssd_resnet50_fpn":
context.set_auto_parallel_context(all_reduce_fusion_config=[90, 183, 279])
if config.model == "ssd_vgg16":
context.set_auto_parallel_context(all_reduce_fusion_config=[20, 41, 62])
else:
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89])
rank = get_rank()