forked from mindspore-Ecosystem/mindspore
add ssd vgg backbone support
This commit is contained in:
parent
e4b5336fec
commit
f4b3bafb70
|
@ -35,11 +35,12 @@ SSD discretizes the output space of bounding boxes into a set of default boxes o
|
||||||
|
|
||||||
The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections.
|
The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections.
|
||||||
|
|
||||||
We present three different base architecture.
|
We present four different base architecture.
|
||||||
|
|
||||||
- **ssd300**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present.
|
- **ssd300**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present.
|
||||||
- ***ssd-mobilenet-v1-fpn**, using mobilenet-v1 and FPN as feature extractor with weight-shared box predcitors.
|
- ***ssd-mobilenet-v1-fpn**, using mobilenet-v1 and FPN as feature extractor with weight-shared box predcitors.
|
||||||
- ***ssd-resnet50-fpn**, using resnet50 and FPN as feature extractor with weight-shared box predcitors.
|
- ***ssd-resnet50-fpn**, using resnet50 and FPN as feature extractor with weight-shared box predcitors.
|
||||||
|
- **ssd-vgg16**, reference from the paper. Using vgg16 as backbone and the same bbox predictor as the paper present.
|
||||||
|
|
||||||
## [Dataset](#contents)
|
## [Dataset](#contents)
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
|
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
|
||||||
from src.dataset import create_ssd_dataset, create_mindrecord
|
from src.dataset import create_ssd_dataset, create_mindrecord
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.eval_utils import metrics
|
from src.eval_utils import metrics
|
||||||
|
@ -34,6 +34,8 @@ def ssd_eval(dataset_path, ckpt_path, anno_json):
|
||||||
is_training=False, use_multiprocessing=False)
|
is_training=False, use_multiprocessing=False)
|
||||||
if config.model == "ssd300":
|
if config.model == "ssd300":
|
||||||
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
|
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
|
||||||
|
elif config.model == "ssd_vgg16":
|
||||||
|
net = ssd_vgg16(config=config)
|
||||||
elif config.model == "ssd_mobilenet_v1_fpn":
|
elif config.model == "ssd_mobilenet_v1_fpn":
|
||||||
net = ssd_mobilenet_v1_fpn(config=config)
|
net = ssd_mobilenet_v1_fpn(config=config)
|
||||||
elif config.model == "ssd_resnet50_fpn":
|
elif config.model == "ssd_resnet50_fpn":
|
||||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
||||||
import mindspore
|
import mindspore
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||||
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
|
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.box_utils import default_boxes
|
from src.box_utils import default_boxes
|
||||||
|
|
||||||
|
@ -40,6 +40,8 @@ if args.device_target == "Ascend":
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if config.model == "ssd300":
|
if config.model == "ssd300":
|
||||||
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
|
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
|
||||||
|
elif config.model == "ssd_vgg16":
|
||||||
|
net = ssd_vgg16(config=config)
|
||||||
elif config.model == "ssd_mobilenet_v1_fpn":
|
elif config.model == "ssd_mobilenet_v1_fpn":
|
||||||
net = ssd_mobilenet_v1_fpn(config=config)
|
net = ssd_mobilenet_v1_fpn(config=config)
|
||||||
elif config.model == "ssd_resnet50_fpn":
|
elif config.model == "ssd_resnet50_fpn":
|
||||||
|
|
|
@ -11,18 +11,20 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#" ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Config parameters for SSD models."""
|
"""Config parameters for SSD models."""
|
||||||
|
|
||||||
from .config_ssd300 import config as config_ssd300
|
from .config_ssd300 import config as config_ssd300
|
||||||
from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn
|
from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn
|
||||||
from .config_ssd_resnet50_fpn import config as config_ssd_resnet50_fpn
|
from .config_ssd_resnet50_fpn import config as config_ssd_resnet50_fpn
|
||||||
|
from .config_ssd_vgg16 import config as config_ssd_vgg16
|
||||||
|
|
||||||
using_model = "ssd300"
|
using_model = "ssd300"
|
||||||
|
|
||||||
config_map = {
|
config_map = {
|
||||||
"ssd300": config_ssd300,
|
"ssd300": config_ssd300,
|
||||||
|
"ssd_vgg16": config_ssd_vgg16,
|
||||||
"ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn,
|
"ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn,
|
||||||
"ssd_resnet50_fpn": config_ssd_resnet50_fpn
|
"ssd_resnet50_fpn": config_ssd_resnet50_fpn
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Config parameters for SSD models."""
|
||||||
|
|
||||||
|
from easydict import EasyDict as ed
|
||||||
|
|
||||||
|
config = ed({
|
||||||
|
"model": "ssd_vgg16",
|
||||||
|
"img_shape": [300, 300],
|
||||||
|
"num_ssd_boxes": 7308,
|
||||||
|
"match_threshold": 0.5,
|
||||||
|
"nms_threshold": 0.6,
|
||||||
|
"min_score": 0.1,
|
||||||
|
"max_boxes": 100,
|
||||||
|
"ssd_vgg_bn": False,
|
||||||
|
|
||||||
|
# learing rate settings
|
||||||
|
"lr_init": 0.001,
|
||||||
|
"lr_end_rate": 0.001,
|
||||||
|
"warmup_epochs": 2,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"weight_decay": 1.5e-4,
|
||||||
|
|
||||||
|
# network
|
||||||
|
"num_default": [3, 6, 6, 6, 6, 6],
|
||||||
|
"extras_in_channels": [256, 512, 1024, 512, 256, 256],
|
||||||
|
"extras_out_channels": [512, 1024, 512, 256, 256, 256],
|
||||||
|
"extras_strides": [1, 1, 2, 2, 2, 2],
|
||||||
|
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
|
||||||
|
"feature_size": [38, 19, 10, 5, 3, 1],
|
||||||
|
"min_scale": 0.2,
|
||||||
|
"max_scale": 0.95,
|
||||||
|
"aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
|
||||||
|
"steps": (8, 16, 32, 64, 100, 300),
|
||||||
|
"prior_scaling": (0.1, 0.2),
|
||||||
|
"gamma": 2.0,
|
||||||
|
"alpha": 0.75,
|
||||||
|
|
||||||
|
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||||
|
"feature_extractor_base_param": "",
|
||||||
|
"pretrain_vgg_bn": False,
|
||||||
|
"checkpoint_filter_list": ['multi_loc_layers', 'multi_cls_layers'],
|
||||||
|
"mindrecord_dir": "/data/MindRecord_COCO",
|
||||||
|
"coco_root": "/data/coco2017",
|
||||||
|
"train_data_type": "train2017",
|
||||||
|
"val_data_type": "val2017",
|
||||||
|
"instances_set": "annotations/instances_{}.json",
|
||||||
|
"classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||||
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||||
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||||
|
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||||
|
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||||
|
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||||
|
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||||
|
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||||
|
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||||
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||||
|
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||||
|
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||||
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||||
|
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||||
|
'teddy bear', 'hair drier', 'toothbrush'),
|
||||||
|
"num_classes": 81,
|
||||||
|
# The annotation.json position of voc validation dataset.
|
||||||
|
"voc_json": "annotations/voc_instances_val.json",
|
||||||
|
# voc original dataset.
|
||||||
|
"voc_root": "/data/voc_dataset",
|
||||||
|
# if coco or voc used, `image_dir` and `anno_path` are useless.
|
||||||
|
"image_dir": "",
|
||||||
|
"anno_path": ""
|
||||||
|
})
|
|
@ -27,6 +27,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
|
|
||||||
from .fpn import mobilenet_v1_fpn, resnet50_fpn
|
from .fpn import mobilenet_v1_fpn, resnet50_fpn
|
||||||
|
from .vgg16 import vgg16
|
||||||
|
|
||||||
|
|
||||||
def _make_divisible(v, divisor, min_value=None):
|
def _make_divisible(v, divisor, min_value=None):
|
||||||
|
@ -641,3 +642,78 @@ def ssd_resnet50_fpn(**kwargs):
|
||||||
|
|
||||||
def ssd_mobilenet_v2(**kwargs):
|
def ssd_mobilenet_v2(**kwargs):
|
||||||
return SSDWithMobileNetV2(**kwargs)
|
return SSDWithMobileNetV2(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class SSD300VGG16(nn.Cell):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(SSD300VGG16, self).__init__()
|
||||||
|
|
||||||
|
# VGG16 backbone: block1~5
|
||||||
|
self.backbone = vgg16()
|
||||||
|
|
||||||
|
# SSD blocks: block6~7
|
||||||
|
self.b6_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6, pad_mode='pad')
|
||||||
|
self.b6_2 = nn.Dropout(0.5)
|
||||||
|
|
||||||
|
self.b7_1 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1)
|
||||||
|
self.b7_2 = nn.Dropout(0.5)
|
||||||
|
|
||||||
|
# Extra Feature Layers: block8~11
|
||||||
|
self.b8_1 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, padding=1, pad_mode='pad')
|
||||||
|
self.b8_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, pad_mode='valid')
|
||||||
|
|
||||||
|
self.b9_1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, padding=1, pad_mode='pad')
|
||||||
|
self.b9_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, pad_mode='valid')
|
||||||
|
|
||||||
|
self.b10_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1)
|
||||||
|
self.b10_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid')
|
||||||
|
|
||||||
|
self.b11_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1)
|
||||||
|
self.b11_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid')
|
||||||
|
|
||||||
|
# boxes
|
||||||
|
self.multi_box = MultiBox(config)
|
||||||
|
if not self.training:
|
||||||
|
self.activation = P.Sigmoid()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
# VGG16 backbone: block1~5
|
||||||
|
block4, x = self.backbone(x)
|
||||||
|
|
||||||
|
# SSD blocks: block6~7
|
||||||
|
x = self.b6_1(x) # 1024
|
||||||
|
x = self.b6_2(x)
|
||||||
|
|
||||||
|
x = self.b7_1(x) # 1024
|
||||||
|
x = self.b7_2(x)
|
||||||
|
block7 = x
|
||||||
|
|
||||||
|
# Extra Feature Layers: block8~11
|
||||||
|
x = self.b8_1(x) # 256
|
||||||
|
x = self.b8_2(x) # 512
|
||||||
|
block8 = x
|
||||||
|
|
||||||
|
x = self.b9_1(x) # 128
|
||||||
|
x = self.b9_2(x) # 256
|
||||||
|
block9 = x
|
||||||
|
|
||||||
|
x = self.b10_1(x) # 128
|
||||||
|
x = self.b10_2(x) # 256
|
||||||
|
block10 = x
|
||||||
|
|
||||||
|
x = self.b11_1(x) # 128
|
||||||
|
x = self.b11_2(x) # 256
|
||||||
|
block11 = x
|
||||||
|
|
||||||
|
# boxes
|
||||||
|
multi_feature = (block4, block7, block8, block9, block10, block11)
|
||||||
|
pred_loc, pred_label = self.multi_box(multi_feature)
|
||||||
|
if not self.training:
|
||||||
|
pred_label = self.activation(pred_label)
|
||||||
|
pred_loc = F.cast(pred_loc, mstype.float32)
|
||||||
|
pred_label = F.cast(pred_label, mstype.float32)
|
||||||
|
return pred_loc, pred_label
|
||||||
|
|
||||||
|
|
||||||
|
def ssd_vgg16(**kwargs):
|
||||||
|
return SSD300VGG16(**kwargs)
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""VGG16 backbone for SSD"""
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
|
from .config_ssd_vgg16 import config
|
||||||
|
|
||||||
|
pretrain_vgg_bn = config.pretrain_vgg_bn
|
||||||
|
ssd_vgg_bn = config.ssd_vgg_bn
|
||||||
|
|
||||||
|
|
||||||
|
def _get_key_mapper():
|
||||||
|
vgg_key_num = [1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]
|
||||||
|
size = len(vgg_key_num)
|
||||||
|
|
||||||
|
pretrain_vgg_bn_false = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
|
||||||
|
pretrain_vgg_bn_true = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
|
||||||
|
ssd_vgg_bn_false = [0, 2, 0, 2, 0, 2, 4, 0, 2, 4, 0, 2, 4]
|
||||||
|
ssd_vgg_bn_true = [0, 3, 0, 3, 0, 3, 6, 0, 3, 6, 0, 3, 6]
|
||||||
|
|
||||||
|
pretrain_vgg_keys = pretrain_vgg_bn_true if pretrain_vgg_bn else pretrain_vgg_bn_false
|
||||||
|
ssd_vgg_keys = ssd_vgg_bn_true if ssd_vgg_bn else ssd_vgg_bn_false
|
||||||
|
|
||||||
|
pretrain_vgg_keys = ['layers.' + str(pretrain_vgg_keys[i]) for i in range(size)]
|
||||||
|
ssd_vgg_keys = ['b' + str(vgg_key_num[i]) + '.' + str(ssd_vgg_keys[i]) for i in range(size)]
|
||||||
|
|
||||||
|
return {pretrain_vgg_keys[i]: ssd_vgg_keys[i] for i in range(size)}
|
||||||
|
|
||||||
|
|
||||||
|
ssd_vgg_key_mapper = _get_key_mapper()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_layer(channels):
|
||||||
|
in_channels = channels[0]
|
||||||
|
layers = []
|
||||||
|
for out_channels in channels[1:]:
|
||||||
|
layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3))
|
||||||
|
if ssd_vgg_bn:
|
||||||
|
layers.append(nn.BatchNorm2d(out_channels))
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
in_channels = out_channels
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
|
||||||
|
class VGG16(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(VGG16, self).__init__()
|
||||||
|
self.b1 = _make_layer([3, 64, 64])
|
||||||
|
self.b2 = _make_layer([64, 128, 128])
|
||||||
|
self.b3 = _make_layer([128, 256, 256, 256])
|
||||||
|
self.b4 = _make_layer([256, 512, 512, 512])
|
||||||
|
self.b5 = _make_layer([512, 512, 512, 512])
|
||||||
|
|
||||||
|
self.m1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
|
||||||
|
self.m2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
|
||||||
|
self.m3 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
|
||||||
|
self.m4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
|
||||||
|
self.m5 = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='SAME')
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
# block1
|
||||||
|
x = self.b1(x)
|
||||||
|
x = self.m1(x)
|
||||||
|
|
||||||
|
# block2
|
||||||
|
x = self.b2(x)
|
||||||
|
x = self.m2(x)
|
||||||
|
|
||||||
|
# block3
|
||||||
|
x = self.b3(x)
|
||||||
|
x = self.m3(x)
|
||||||
|
|
||||||
|
# block4
|
||||||
|
x = self.b4(x)
|
||||||
|
block4 = x
|
||||||
|
x = self.m4(x)
|
||||||
|
|
||||||
|
# block5
|
||||||
|
x = self.b5(x)
|
||||||
|
x = self.m5(x)
|
||||||
|
|
||||||
|
return block4, x
|
||||||
|
|
||||||
|
|
||||||
|
def vgg16():
|
||||||
|
return VGG16()
|
|
@ -25,7 +25,7 @@ from mindspore.train import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.common import set_seed, dtype
|
from mindspore.common import set_seed, dtype
|
||||||
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
|
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.dataset import create_ssd_dataset, create_mindrecord
|
from src.dataset import create_ssd_dataset, create_mindrecord
|
||||||
from src.lr_schedule import get_lr
|
from src.lr_schedule import get_lr
|
||||||
|
@ -86,6 +86,17 @@ def ssd_model_build(args_opt):
|
||||||
param_dict["network.feature_extractor.resnet." + x] = param_dict[x]
|
param_dict["network.feature_extractor.resnet." + x] = param_dict[x]
|
||||||
del param_dict[x]
|
del param_dict[x]
|
||||||
load_param_into_net(ssd.feature_extractor.resnet, param_dict)
|
load_param_into_net(ssd.feature_extractor.resnet, param_dict)
|
||||||
|
elif config.model == "ssd_vgg16":
|
||||||
|
ssd = ssd_vgg16(config=config)
|
||||||
|
init_net_param(ssd)
|
||||||
|
if config.feature_extractor_base_param != "":
|
||||||
|
param_dict = load_checkpoint(config.feature_extractor_base_param)
|
||||||
|
from src.vgg16 import ssd_vgg_key_mapper
|
||||||
|
for k in ssd_vgg_key_mapper:
|
||||||
|
v = ssd_vgg_key_mapper[k]
|
||||||
|
param_dict["network.backbone." + v + ".weight"] = param_dict[k + ".weight"]
|
||||||
|
del param_dict[k + ".weight"]
|
||||||
|
load_param_into_net(ssd.backbone, param_dict)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'config.model: {config.model} is not supported')
|
raise ValueError(f'config.model: {config.model} is not supported')
|
||||||
return ssd
|
return ssd
|
||||||
|
@ -106,6 +117,8 @@ def main():
|
||||||
init()
|
init()
|
||||||
if config.model == "ssd_resnet50_fpn":
|
if config.model == "ssd_resnet50_fpn":
|
||||||
context.set_auto_parallel_context(all_reduce_fusion_config=[90, 183, 279])
|
context.set_auto_parallel_context(all_reduce_fusion_config=[90, 183, 279])
|
||||||
|
if config.model == "ssd_vgg16":
|
||||||
|
context.set_auto_parallel_context(all_reduce_fusion_config=[20, 41, 62])
|
||||||
else:
|
else:
|
||||||
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89])
|
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89])
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
|
|
Loading…
Reference in New Issue