diff --git a/model_zoo/official/cv/ssd/README.md b/model_zoo/official/cv/ssd/README.md index 4a60e170e8f..aa0a7f2ed04 100644 --- a/model_zoo/official/cv/ssd/README.md +++ b/model_zoo/official/cv/ssd/README.md @@ -35,10 +35,11 @@ SSD discretizes the output space of bounding boxes into a set of default boxes o The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections. -We present two different base architecture. +We present three different base architecture. - **ssd300**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present. - ***ssd-mobilenet-v1-fpn**, using mobilenet-v1 and FPN as feature extractor with weight-shared box predcitors. +- ***ssd-resnet50-fpn**, using resnet50 and FPN as feature extractor with weight-shared box predcitors. ## [Dataset](#contents) diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py index 0e06d08c4f8..7a5c0e31676 100644 --- a/model_zoo/official/cv/ssd/eval.py +++ b/model_zoo/official/cv/ssd/eval.py @@ -21,7 +21,7 @@ import time import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn +from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn from src.dataset import create_ssd_dataset, create_mindrecord from src.config import config from src.eval_utils import metrics @@ -34,8 +34,12 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): is_training=False, use_multiprocessing=False) if config.model == "ssd300": net = SSD300(ssd_mobilenet_v2(), config, is_training=False) - else: + elif config.model == "ssd_mobilenet_v1_fpn": net = ssd_mobilenet_v1_fpn(config=config) + elif config.model == "ssd_resnet50_fpn": + net = ssd_resnet50_fpn(config=config) + else: + raise ValueError(f'config.model: {config.model} is not supported') net = SsdInferWithDecoder(net, Tensor(default_boxes), config) print("Load Checkpoint!") @@ -88,7 +92,7 @@ if __name__ == '__main__': elif args_opt.dataset == "voc": json_path = os.path.join(config.voc_root, config.voc_json) else: - raise ValueError('SSD eval only supprt dataset mode is coco and voc!') + raise ValueError('SSD eval only support dataset mode is coco and voc!') context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) diff --git a/model_zoo/official/cv/ssd/export.py b/model_zoo/official/cv/ssd/export.py index 9a76420a211..83a23f91576 100644 --- a/model_zoo/official/cv/ssd/export.py +++ b/model_zoo/official/cv/ssd/export.py @@ -19,7 +19,7 @@ import numpy as np import mindspore from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net, export -from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn +from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn from src.config import config from src.box_utils import default_boxes @@ -40,8 +40,12 @@ if args.device_target == "Ascend": if __name__ == '__main__': if config.model == "ssd300": net = SSD300(ssd_mobilenet_v2(), config, is_training=False) - else: + elif config.model == "ssd_mobilenet_v1_fpn": net = ssd_mobilenet_v1_fpn(config=config) + elif config.model == "ssd_resnet50_fpn": + net = ssd_resnet50_fpn(config=config) + else: + raise ValueError(f'config.model: {config.model} is not supported') net = SsdInferWithDecoder(net, Tensor(default_boxes), config) param_dict = load_checkpoint(args.ckpt_file) diff --git a/model_zoo/official/cv/ssd/src/config.py b/model_zoo/official/cv/ssd/src/config.py index 9f89df9db15..3c0ae01d663 100644 --- a/model_zoo/official/cv/ssd/src/config.py +++ b/model_zoo/official/cv/ssd/src/config.py @@ -17,13 +17,14 @@ from .config_ssd300 import config as config_ssd300 from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn - +from .config_ssd_resnet50_fpn import config as config_ssd_resnet50_fpn using_model = "ssd300" config_map = { "ssd300": config_ssd300, - "ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn + "ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn, + "ssd_resnet50_fpn": config_ssd_resnet50_fpn } config = config_map[using_model] diff --git a/model_zoo/official/cv/ssd/src/config_ssd_resnet50_fpn.py b/model_zoo/official/cv/ssd/src/config_ssd_resnet50_fpn.py new file mode 100644 index 00000000000..256a3bb4e1a --- /dev/null +++ b/model_zoo/official/cv/ssd/src/config_ssd_resnet50_fpn.py @@ -0,0 +1,88 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#" ============================================================================ + +"""Config parameters for SSD models.""" + +from easydict import EasyDict as ed + +config = ed({ + "model": "ssd_resnet50_fpn", + "img_shape": [640, 640], + "num_ssd_boxes": -1, + "match_threshold": 0.5, + "nms_threshold": 0.6, + "min_score": 0.1, + "max_boxes": 100, + + # learning rate settings + "global_step": 0, + "lr_init": 0.01333, + "lr_end_rate": 0.0, + "warmup_epochs": 2, + "weight_decay": 4e-4, + "momentum": 0.9, + + # network + "num_default": [6, 6, 6, 6, 6], + "extras_in_channels": [256, 512, 1024, 256, 256], + "extras_out_channels": [256, 256, 256, 256, 256], + "extras_strides": [1, 1, 2, 2, 2, 2], + "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], + "feature_size": [80, 40, 20, 10, 5], + "min_scale": 0.2, + "max_scale": 0.95, + "aspect_ratios": [(2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], + "steps": (8, 16, 32, 64, 128), + "prior_scaling": (0.1, 0.2), + "gamma": 2.0, + "alpha": 0.25, + "num_addition_layers": 4, + "use_anchor_generator": True, + "use_global_norm": True, + "use_float16": True, + + # `mindrecord_dir` and `coco_root` are better to use absolute path. + "feature_extractor_base_param": "/ckpt/resnet50.ckpt", + "checkpoint_filter_list": ['network.multi_box.cls_layers.0.weight', 'network.multi_box.cls_layers.0.bias', + 'network.multi_box.loc_layers.0.weight', 'network.multi_box.loc_layers.0.bias'], + "mindrecord_dir": "/data/MindRecord_COCO", + "coco_root": "/data/coco2017", + "train_data_type": "train2017", + "val_data_type": "val2017", + "instances_set": "annotations/instances_{}.json", + "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + "num_classes": 81, + # The annotation.json position of voc validation dataset. + "voc_json": "annotations/voc_instances_val.json", + # voc original dataset. + "voc_root": "/data/voc_dataset", + # if coco or voc used, `image_dir` and `anno_path` are useless. + "image_dir": "", + "anno_path": "" +}) diff --git a/model_zoo/official/cv/ssd/src/mobilenet_v1_fpn.py b/model_zoo/official/cv/ssd/src/fpn.py similarity index 56% rename from model_zoo/official/cv/ssd/src/mobilenet_v1_fpn.py rename to model_zoo/official/cv/ssd/src/fpn.py index da35c041541..c180cf1f13e 100644 --- a/model_zoo/official/cv/ssd/src/mobilenet_v1_fpn.py +++ b/model_zoo/official/cv/ssd/src/fpn.py @@ -16,77 +16,8 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.ops import functional as F - -def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'): - output = [] - output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same", - group=1 if not depthwise else in_channel)) - output.append(nn.BatchNorm2d(out_channel)) - if activation: - output.append(nn.get_activation(activation)) - return nn.SequentialCell(output) - - -class MobileNetV1(nn.Cell): - """ - MobileNet V1 backbone - """ - def __init__(self, class_num=1001, features_only=False): - super(MobileNetV1, self).__init__() - self.features_only = features_only - cnn = [ - conv_bn_relu(3, 32, 3, 2, False), # Conv0 - - conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise - conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise - conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise - conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise - - conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise - conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise - conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise - conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise - - conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise - conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise - conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise - conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise - - conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise - conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise - conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise - conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise - conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise - conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise - conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise - conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise - conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise - conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise - - conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise - conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise - conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise - conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise - ] - - if self.features_only: - self.network = nn.CellList(cnn) - else: - self.network = nn.SequentialCell(cnn) - self.fc = nn.Dense(1024, class_num) - - def construct(self, x): - output = x - if self.features_only: - features = () - for block in self.network: - output = block(output) - features = features + (output,) - return features - output = self.network(x) - output = P.ReduceMean()(output, (2, 3)) - output = self.fc(output) - return output +from .mobilenet_v1 import conv_bn_relu, MobileNetV1 +from .resnet import resnet50 class FpnTopDown(nn.Cell): @@ -183,10 +114,26 @@ class MobileNetV1Fpn(nn.Cell): features = self.bottom_up(features) return features +class ResNetV1Fpn(nn.Cell): + """ + ResNet with FPN as SSD backbone. + """ + def __init__(self, resnet): + super(ResNetV1Fpn, self).__init__() + self.resnet = resnet + self.fpn = FpnTopDown([512, 1024, 2048], 256) + self.bottom_up = BottomUp(2, 256, 3, 2) + + def construct(self, x): + _, _, c3, c4, c5 = self.resnet(x) + features = self.fpn((c3, c4, c5)) + features = self.bottom_up(features) + return features + def mobilenet_v1_fpn(config): return MobileNetV1Fpn(config) - -def mobilenet_v1(class_num=1001): - return MobileNetV1(class_num) +def resnet50_fpn(): + resnet = resnet50() + return ResNetV1Fpn(resnet) diff --git a/model_zoo/official/cv/ssd/src/mobilenet_v1.py b/model_zoo/official/cv/ssd/src/mobilenet_v1.py new file mode 100644 index 00000000000..5e15ded23ae --- /dev/null +++ b/model_zoo/official/cv/ssd/src/mobilenet_v1.py @@ -0,0 +1,91 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +from mindspore.ops import operations as P + +def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'): + output = [] + output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same", + group=1 if not depthwise else in_channel)) + output.append(nn.BatchNorm2d(out_channel)) + if activation: + output.append(nn.get_activation(activation)) + return nn.SequentialCell(output) + + +class MobileNetV1(nn.Cell): + """ + MobileNet V1 backbone + """ + def __init__(self, class_num=1001, features_only=False): + super(MobileNetV1, self).__init__() + self.features_only = features_only + cnn = [ + conv_bn_relu(3, 32, 3, 2, False), # Conv0 + + conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise + conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise + conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise + conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise + + conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise + conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise + conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise + conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise + + conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise + conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise + conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise + conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise + + conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise + conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise + conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise + conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise + conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise + + conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise + conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise + conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise + conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise + ] + + if self.features_only: + self.network = nn.CellList(cnn) + else: + self.network = nn.SequentialCell(cnn) + self.fc = nn.Dense(1024, class_num) + + def construct(self, x): + output = x + if self.features_only: + features = () + for block in self.network: + output = block(output) + features = features + (output,) + return features + output = self.network(x) + output = P.ReduceMean()(output, (2, 3)) + output = self.fc(output) + return output + +def mobilenet_v1(class_num=1001): + return MobileNetV1(class_num) diff --git a/model_zoo/official/cv/ssd/src/resnet.py b/model_zoo/official/cv/ssd/src/resnet.py new file mode 100644 index 00000000000..4729ca33ca3 --- /dev/null +++ b/model_zoo/official/cv/ssd/src/resnet.py @@ -0,0 +1,216 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet.""" +import mindspore.nn as nn +from mindspore.ops import operations as P + + +def _conv3x3(in_channel, out_channel, stride=1): + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same') + + +def _conv1x1(in_channel, out_channel, stride=1): + return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, pad_mode='same') + + +def _conv7x7(in_channel, out_channel, stride=1): + return nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=stride, padding=0, pad_mode='same') + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.997, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.997, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + self.stride = stride + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1) + self.bn1 = _bn(channel) + self.conv2 = _conv3x3(channel, channel, stride=stride) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1) + self.bn3 = _bn_last(out_channel) + self.relu = nn.ReLU() + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) + self.add = P.Add() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + self.conv1 = _conv7x7(3, 64, stride=2) + self.bn1 = _bn(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1) + layers.append(resnet_block) + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + return c1, c2, c3, c4, c5 + + +def resnet50(): + """ + Get ResNet50 neural network. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50() + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2]) diff --git a/model_zoo/official/cv/ssd/src/ssd.py b/model_zoo/official/cv/ssd/src/ssd.py index 29fdafa053f..4b114082e34 100644 --- a/model_zoo/official/cv/ssd/src/ssd.py +++ b/model_zoo/official/cv/ssd/src/ssd.py @@ -26,7 +26,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C -from .mobilenet_v1_fpn import mobilenet_v1_fpn +from .fpn import mobilenet_v1_fpn, resnet50_fpn def _make_divisible(v, divisor, min_value=None): @@ -371,6 +371,35 @@ class SsdMobilenetV1Fpn(nn.Cell): pred_label = F.cast(pred_label, mstype.float32) return pred_loc, pred_label +class SsdResNet50Fpn(nn.Cell): + """ + SSD Network using ResNet50 with fpn to extract features + + Args: + config (dict): The default config of SSD. + + Returns: + Tensor, localization predictions. + Tensor, class conf scores. + + Examples:backbone + SsdResNet50Fpn(config). + """ + def __init__(self, config): + super(SsdResNet50Fpn, self).__init__() + self.multi_box = WeightSharedMultiBox(config) + self.activation = P.Sigmoid() + self.feature_extractor = resnet50_fpn() + + def construct(self, x): + features = self.feature_extractor(x) + pred_loc, pred_label = self.multi_box(features) + if not self.training: + pred_label = self.activation(pred_label) + pred_loc = F.cast(pred_loc, mstype.float32) + pred_label = F.cast(pred_label, mstype.float32) + return pred_loc, pred_label + class SigmoidFocalClassificationLoss(nn.Cell): """" @@ -608,6 +637,8 @@ class SsdInferWithDecoder(nn.Cell): def ssd_mobilenet_v1_fpn(**kwargs): return SsdMobilenetV1Fpn(**kwargs) +def ssd_resnet50_fpn(**kwargs): + return SsdResNet50Fpn(**kwargs) def ssd_mobilenet_v2(**kwargs): return SSDWithMobileNetV2(**kwargs) diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py index 971e7c55685..f39d2aa1b80 100644 --- a/model_zoo/official/cv/ssd/train.py +++ b/model_zoo/official/cv/ssd/train.py @@ -25,7 +25,7 @@ from mindspore.train import Model from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed, dtype -from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn +from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn from src.config import config from src.dataset import create_ssd_dataset, create_mindrecord from src.lr_schedule import get_lr @@ -60,6 +60,36 @@ def get_args(): args_opt = parser.parse_args() return args_opt +def ssd_model_build(args_opt): + if config.model == "ssd300": + backbone = ssd_mobilenet_v2() + ssd = SSD300(backbone=backbone, config=config) + init_net_param(ssd) + if args_opt.freeze_layer == "backbone": + for param in backbone.feature_1.trainable_params(): + param.requires_grad = False + elif config.model == "ssd_mobilenet_v1_fpn": + ssd = ssd_mobilenet_v1_fpn(config=config) + init_net_param(ssd) + if config.feature_extractor_base_param != "": + param_dict = load_checkpoint(config.feature_extractor_base_param) + for x in list(param_dict.keys()): + param_dict["network.feature_extractor.mobilenet_v1." + x] = param_dict[x] + del param_dict[x] + load_param_into_net(ssd.feature_extractor.mobilenet_v1.network, param_dict) + elif config.model == "ssd_resnet50_fpn": + ssd = ssd_resnet50_fpn(config=config) + init_net_param(ssd) + if config.feature_extractor_base_param != "": + param_dict = load_checkpoint(config.feature_extractor_base_param) + for x in list(param_dict.keys()): + param_dict["network.feature_extractor.resnet." + x] = param_dict[x] + del param_dict[x] + load_param_into_net(ssd.feature_extractor.resnet, param_dict) + else: + raise ValueError(f'config.model: {config.model} is not supported') + return ssd + def main(): args_opt = get_args() rank = 0 @@ -74,7 +104,10 @@ def main(): context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) init() - context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89]) + if config.model == "ssd_resnet50_fpn": + context.set_auto_parallel_context(all_reduce_fusion_config=[90, 183, 279]) + else: + context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89]) rank = get_rank() mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True) @@ -92,28 +125,12 @@ def main(): device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) dataset_size = dataset.get_dataset_size() - print("Create dataset done!") - - backbone = ssd_mobilenet_v2() - if config.model == "ssd300": - ssd = SSD300(backbone=backbone, config=config) - elif config.model == "ssd_mobilenet_v1_fpn": - ssd = ssd_mobilenet_v1_fpn(config=config) - else: - raise ValueError(f'config.model: {config.model} is not supported') - if args_opt.run_platform == "GPU": + print(f"Create dataset done! dataset size is {dataset_size}") + ssd = ssd_model_build(args_opt) + if ("use_float16" in config and config.use_float16) or args_opt.run_platform == "GPU": ssd.to_float(dtype.float16) net = SSDWithLossCell(ssd, config) - init_net_param(net) - - if config.feature_extractor_base_param != "": - param_dict = load_checkpoint(config.feature_extractor_base_param) - for x in list(param_dict.keys()): - param_dict["network.feature_extractor.mobilenet_v1." + x] = param_dict[x] - del param_dict[x] - load_param_into_net(ssd.feature_extractor.mobilenet_v1.network, param_dict) - # checkpoint ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) save_ckpt_path = './ckpt_' + str(rank) + '/' @@ -125,10 +142,6 @@ def main(): filter_checkpoint_parameter_by_list(param_dict, config.checkpoint_filter_list) load_param_into_net(net, param_dict, True) - if args_opt.freeze_layer == "backbone": - for param in backbone.feature_1.trainable_params(): - param.requires_grad = False - lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, warmup_epochs=config.warmup_epochs,