!12677 add ssd-resnet50-fpn

From: @zhao_ting_v
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-01 15:48:20 +08:00 committed by Gitee
commit 973b2fa911
10 changed files with 504 additions and 108 deletions

View File

@ -35,10 +35,11 @@ SSD discretizes the output space of bounding boxes into a set of default boxes o
The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections.
We present two different base architecture.
We present three different base architecture.
- **ssd300**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present.
- ***ssd-mobilenet-v1-fpn**, using mobilenet-v1 and FPN as feature extractor with weight-shared box predcitors.
- ***ssd-resnet50-fpn**, using resnet50 and FPN as feature extractor with weight-shared box predcitors.
## [Dataset](#contents)

View File

@ -21,7 +21,7 @@ import time
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
from src.dataset import create_ssd_dataset, create_mindrecord
from src.config import config
from src.eval_utils import metrics
@ -34,8 +34,12 @@ def ssd_eval(dataset_path, ckpt_path, anno_json):
is_training=False, use_multiprocessing=False)
if config.model == "ssd300":
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
else:
elif config.model == "ssd_mobilenet_v1_fpn":
net = ssd_mobilenet_v1_fpn(config=config)
elif config.model == "ssd_resnet50_fpn":
net = ssd_resnet50_fpn(config=config)
else:
raise ValueError(f'config.model: {config.model} is not supported')
net = SsdInferWithDecoder(net, Tensor(default_boxes), config)
print("Load Checkpoint!")
@ -88,7 +92,7 @@ if __name__ == '__main__':
elif args_opt.dataset == "voc":
json_path = os.path.join(config.voc_root, config.voc_json)
else:
raise ValueError('SSD eval only supprt dataset mode is coco and voc!')
raise ValueError('SSD eval only support dataset mode is coco and voc!')
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)

View File

@ -19,7 +19,7 @@ import numpy as np
import mindspore
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
from src.config import config
from src.box_utils import default_boxes
@ -40,8 +40,12 @@ if args.device_target == "Ascend":
if __name__ == '__main__':
if config.model == "ssd300":
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
else:
elif config.model == "ssd_mobilenet_v1_fpn":
net = ssd_mobilenet_v1_fpn(config=config)
elif config.model == "ssd_resnet50_fpn":
net = ssd_resnet50_fpn(config=config)
else:
raise ValueError(f'config.model: {config.model} is not supported')
net = SsdInferWithDecoder(net, Tensor(default_boxes), config)
param_dict = load_checkpoint(args.ckpt_file)

View File

@ -17,13 +17,14 @@
from .config_ssd300 import config as config_ssd300
from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn
from .config_ssd_resnet50_fpn import config as config_ssd_resnet50_fpn
using_model = "ssd300"
config_map = {
"ssd300": config_ssd300,
"ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn
"ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn,
"ssd_resnet50_fpn": config_ssd_resnet50_fpn
}
config = config_map[using_model]

View File

@ -0,0 +1,88 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""Config parameters for SSD models."""
from easydict import EasyDict as ed
config = ed({
"model": "ssd_resnet50_fpn",
"img_shape": [640, 640],
"num_ssd_boxes": -1,
"match_threshold": 0.5,
"nms_threshold": 0.6,
"min_score": 0.1,
"max_boxes": 100,
# learning rate settings
"global_step": 0,
"lr_init": 0.01333,
"lr_end_rate": 0.0,
"warmup_epochs": 2,
"weight_decay": 4e-4,
"momentum": 0.9,
# network
"num_default": [6, 6, 6, 6, 6],
"extras_in_channels": [256, 512, 1024, 256, 256],
"extras_out_channels": [256, 256, 256, 256, 256],
"extras_strides": [1, 1, 2, 2, 2, 2],
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
"feature_size": [80, 40, 20, 10, 5],
"min_scale": 0.2,
"max_scale": 0.95,
"aspect_ratios": [(2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
"steps": (8, 16, 32, 64, 128),
"prior_scaling": (0.1, 0.2),
"gamma": 2.0,
"alpha": 0.25,
"num_addition_layers": 4,
"use_anchor_generator": True,
"use_global_norm": True,
"use_float16": True,
# `mindrecord_dir` and `coco_root` are better to use absolute path.
"feature_extractor_base_param": "/ckpt/resnet50.ckpt",
"checkpoint_filter_list": ['network.multi_box.cls_layers.0.weight', 'network.multi_box.cls_layers.0.bias',
'network.multi_box.loc_layers.0.weight', 'network.multi_box.loc_layers.0.bias'],
"mindrecord_dir": "/data/MindRecord_COCO",
"coco_root": "/data/coco2017",
"train_data_type": "train2017",
"val_data_type": "val2017",
"instances_set": "annotations/instances_{}.json",
"classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81,
# The annotation.json position of voc validation dataset.
"voc_json": "annotations/voc_instances_val.json",
# voc original dataset.
"voc_root": "/data/voc_dataset",
# if coco or voc used, `image_dir` and `anno_path` are useless.
"image_dir": "",
"anno_path": ""
})

View File

@ -16,77 +16,8 @@
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'):
output = []
output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same",
group=1 if not depthwise else in_channel))
output.append(nn.BatchNorm2d(out_channel))
if activation:
output.append(nn.get_activation(activation))
return nn.SequentialCell(output)
class MobileNetV1(nn.Cell):
"""
MobileNet V1 backbone
"""
def __init__(self, class_num=1001, features_only=False):
super(MobileNetV1, self).__init__()
self.features_only = features_only
cnn = [
conv_bn_relu(3, 32, 3, 2, False), # Conv0
conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise
conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise
conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise
conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise
conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise
conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise
conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise
conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise
conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise
conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise
conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise
conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise
conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise
conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise
conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise
conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise
]
if self.features_only:
self.network = nn.CellList(cnn)
else:
self.network = nn.SequentialCell(cnn)
self.fc = nn.Dense(1024, class_num)
def construct(self, x):
output = x
if self.features_only:
features = ()
for block in self.network:
output = block(output)
features = features + (output,)
return features
output = self.network(x)
output = P.ReduceMean()(output, (2, 3))
output = self.fc(output)
return output
from .mobilenet_v1 import conv_bn_relu, MobileNetV1
from .resnet import resnet50
class FpnTopDown(nn.Cell):
@ -183,10 +114,26 @@ class MobileNetV1Fpn(nn.Cell):
features = self.bottom_up(features)
return features
class ResNetV1Fpn(nn.Cell):
"""
ResNet with FPN as SSD backbone.
"""
def __init__(self, resnet):
super(ResNetV1Fpn, self).__init__()
self.resnet = resnet
self.fpn = FpnTopDown([512, 1024, 2048], 256)
self.bottom_up = BottomUp(2, 256, 3, 2)
def construct(self, x):
_, _, c3, c4, c5 = self.resnet(x)
features = self.fpn((c3, c4, c5))
features = self.bottom_up(features)
return features
def mobilenet_v1_fpn(config):
return MobileNetV1Fpn(config)
def mobilenet_v1(class_num=1001):
return MobileNetV1(class_num)
def resnet50_fpn():
resnet = resnet50()
return ResNetV1Fpn(resnet)

View File

@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'):
output = []
output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same",
group=1 if not depthwise else in_channel))
output.append(nn.BatchNorm2d(out_channel))
if activation:
output.append(nn.get_activation(activation))
return nn.SequentialCell(output)
class MobileNetV1(nn.Cell):
"""
MobileNet V1 backbone
"""
def __init__(self, class_num=1001, features_only=False):
super(MobileNetV1, self).__init__()
self.features_only = features_only
cnn = [
conv_bn_relu(3, 32, 3, 2, False), # Conv0
conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise
conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise
conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise
conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise
conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise
conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise
conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise
conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise
conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise
conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise
conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise
conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise
conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise
conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise
conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise
conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise
]
if self.features_only:
self.network = nn.CellList(cnn)
else:
self.network = nn.SequentialCell(cnn)
self.fc = nn.Dense(1024, class_num)
def construct(self, x):
output = x
if self.features_only:
features = ()
for block in self.network:
output = block(output)
features = features + (output,)
return features
output = self.network(x)
output = P.ReduceMean()(output, (2, 3))
output = self.fc(output)
return output
def mobilenet_v1(class_num=1001):
return MobileNetV1(class_num)

View File

@ -0,0 +1,216 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import mindspore.nn as nn
from mindspore.ops import operations as P
def _conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same')
def _conv1x1(in_channel, out_channel, stride=1):
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, pad_mode='same')
def _conv7x7(in_channel, out_channel, stride=1):
return nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=stride, padding=0, pad_mode='same')
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.997,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.997,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock, self).__init__()
self.stride = stride
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1)
self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1)
self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
self.add = P.Add()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
return c1, c2, c3, c4, c5
def resnet50():
"""
Get ResNet50 neural network.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50()
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2])

View File

@ -26,7 +26,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from .mobilenet_v1_fpn import mobilenet_v1_fpn
from .fpn import mobilenet_v1_fpn, resnet50_fpn
def _make_divisible(v, divisor, min_value=None):
@ -371,6 +371,35 @@ class SsdMobilenetV1Fpn(nn.Cell):
pred_label = F.cast(pred_label, mstype.float32)
return pred_loc, pred_label
class SsdResNet50Fpn(nn.Cell):
"""
SSD Network using ResNet50 with fpn to extract features
Args:
config (dict): The default config of SSD.
Returns:
Tensor, localization predictions.
Tensor, class conf scores.
Examples:backbone
SsdResNet50Fpn(config).
"""
def __init__(self, config):
super(SsdResNet50Fpn, self).__init__()
self.multi_box = WeightSharedMultiBox(config)
self.activation = P.Sigmoid()
self.feature_extractor = resnet50_fpn()
def construct(self, x):
features = self.feature_extractor(x)
pred_loc, pred_label = self.multi_box(features)
if not self.training:
pred_label = self.activation(pred_label)
pred_loc = F.cast(pred_loc, mstype.float32)
pred_label = F.cast(pred_label, mstype.float32)
return pred_loc, pred_label
class SigmoidFocalClassificationLoss(nn.Cell):
""""
@ -607,6 +636,8 @@ class SsdInferWithDecoder(nn.Cell):
def ssd_mobilenet_v1_fpn(**kwargs):
return SsdMobilenetV1Fpn(**kwargs)
def ssd_resnet50_fpn(**kwargs):
return SsdResNet50Fpn(**kwargs)
def ssd_mobilenet_v2(**kwargs):
return SSDWithMobileNetV2(**kwargs)

View File

@ -25,7 +25,7 @@ from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed, dtype
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn
from src.config import config
from src.dataset import create_ssd_dataset, create_mindrecord
from src.lr_schedule import get_lr
@ -60,6 +60,36 @@ def get_args():
args_opt = parser.parse_args()
return args_opt
def ssd_model_build(args_opt):
if config.model == "ssd300":
backbone = ssd_mobilenet_v2()
ssd = SSD300(backbone=backbone, config=config)
init_net_param(ssd)
if args_opt.freeze_layer == "backbone":
for param in backbone.feature_1.trainable_params():
param.requires_grad = False
elif config.model == "ssd_mobilenet_v1_fpn":
ssd = ssd_mobilenet_v1_fpn(config=config)
init_net_param(ssd)
if config.feature_extractor_base_param != "":
param_dict = load_checkpoint(config.feature_extractor_base_param)
for x in list(param_dict.keys()):
param_dict["network.feature_extractor.mobilenet_v1." + x] = param_dict[x]
del param_dict[x]
load_param_into_net(ssd.feature_extractor.mobilenet_v1.network, param_dict)
elif config.model == "ssd_resnet50_fpn":
ssd = ssd_resnet50_fpn(config=config)
init_net_param(ssd)
if config.feature_extractor_base_param != "":
param_dict = load_checkpoint(config.feature_extractor_base_param)
for x in list(param_dict.keys()):
param_dict["network.feature_extractor.resnet." + x] = param_dict[x]
del param_dict[x]
load_param_into_net(ssd.feature_extractor.resnet, param_dict)
else:
raise ValueError(f'config.model: {config.model} is not supported')
return ssd
def main():
args_opt = get_args()
rank = 0
@ -74,7 +104,10 @@ def main():
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
init()
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89])
if config.model == "ssd_resnet50_fpn":
context.set_auto_parallel_context(all_reduce_fusion_config=[90, 183, 279])
else:
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89])
rank = get_rank()
mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True)
@ -92,28 +125,12 @@ def main():
device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
backbone = ssd_mobilenet_v2()
if config.model == "ssd300":
ssd = SSD300(backbone=backbone, config=config)
elif config.model == "ssd_mobilenet_v1_fpn":
ssd = ssd_mobilenet_v1_fpn(config=config)
else:
raise ValueError(f'config.model: {config.model} is not supported')
if args_opt.run_platform == "GPU":
print(f"Create dataset done! dataset size is {dataset_size}")
ssd = ssd_model_build(args_opt)
if ("use_float16" in config and config.use_float16) or args_opt.run_platform == "GPU":
ssd.to_float(dtype.float16)
net = SSDWithLossCell(ssd, config)
init_net_param(net)
if config.feature_extractor_base_param != "":
param_dict = load_checkpoint(config.feature_extractor_base_param)
for x in list(param_dict.keys()):
param_dict["network.feature_extractor.mobilenet_v1." + x] = param_dict[x]
del param_dict[x]
load_param_into_net(ssd.feature_extractor.mobilenet_v1.network, param_dict)
# checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
save_ckpt_path = './ckpt_' + str(rank) + '/'
@ -125,10 +142,6 @@ def main():
filter_checkpoint_parameter_by_list(param_dict, config.checkpoint_filter_list)
load_param_into_net(net, param_dict, True)
if args_opt.freeze_layer == "backbone":
for param in backbone.feature_1.trainable_params():
param.requires_grad = False
lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size,
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
warmup_epochs=config.warmup_epochs,