From aeb6d2a59f0c670a233be8dab77933f077fcf30b Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Tue, 13 Oct 2020 15:59:31 +0800 Subject: [PATCH] add ssd-mobilenetv1-fpn --- model_zoo/official/cv/ssd/eval.py | 7 +- .../official/cv/ssd/src/anchor_generator.py | 92 +++++++++ model_zoo/official/cv/ssd/src/box_utils.py | 12 +- model_zoo/official/cv/ssd/src/config.py | 69 +------ .../official/cv/ssd/src/config_ssd300.py | 84 ++++++++ .../cv/ssd/src/config_ssd_mobilenet_v1_fpn.py | 88 ++++++++ model_zoo/official/cv/ssd/src/init_params.py | 4 +- .../official/cv/ssd/src/mobilenet_v1_fpn.py | 192 ++++++++++++++++++ model_zoo/official/cv/ssd/src/ssd.py | 153 +++++++++++++- model_zoo/official/cv/ssd/train.py | 106 ++++++---- 10 files changed, 688 insertions(+), 119 deletions(-) create mode 100644 model_zoo/official/cv/ssd/src/anchor_generator.py create mode 100644 model_zoo/official/cv/ssd/src/config_ssd300.py create mode 100644 model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py create mode 100644 model_zoo/official/cv/ssd/src/mobilenet_v1_fpn.py diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py index f98b98926f0..37e3f1efa29 100644 --- a/model_zoo/official/cv/ssd/eval.py +++ b/model_zoo/official/cv/ssd/eval.py @@ -21,7 +21,7 @@ import time import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.ssd import SSD300, ssd_mobilenet_v2 +from src.ssd import SSD300, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn from src.dataset import create_ssd_dataset, create_mindrecord from src.config import config from src.eval_utils import metrics @@ -31,7 +31,10 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): batch_size = 1 ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False, use_multiprocessing=False) - net = SSD300(ssd_mobilenet_v2(), config, is_training=False) + if config.model == "ssd300": + net = SSD300(ssd_mobilenet_v2(), config, is_training=False) + else: + net = ssd_mobilenet_v1_fpn(config=config) print("Load Checkpoint!") param_dict = load_checkpoint(ckpt_path) net.init_parameters_data() diff --git a/model_zoo/official/cv/ssd/src/anchor_generator.py b/model_zoo/official/cv/ssd/src/anchor_generator.py new file mode 100644 index 00000000000..62e2676d167 --- /dev/null +++ b/model_zoo/official/cv/ssd/src/anchor_generator.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Anchor Generator""" + +import numpy as np + + +class GridAnchorGenerator: + """ + Anchor Generator + """ + def __init__(self, image_shape, scale, scales_per_octave, aspect_ratios): + super(GridAnchorGenerator, self).__init__() + self.scale = scale + self.scales_per_octave = scales_per_octave + self.aspect_ratios = aspect_ratios + self.image_shape = image_shape + + + def generate(self, step): + scales = np.array([2**(float(scale) / self.scales_per_octave) + for scale in range(self.scales_per_octave)]).astype(np.float32) + aspects = np.array(list(self.aspect_ratios)).astype(np.float32) + + scales_grid, aspect_ratios_grid = np.meshgrid(scales, aspects) + scales_grid = scales_grid.reshape([-1]) + aspect_ratios_grid = aspect_ratios_grid.reshape([-1]) + + feature_size = [self.image_shape[0] / step, self.image_shape[0] / step] + grid_height, grid_width = feature_size + + base_size = np.array([self.scale * step, self.scale * step]).astype(np.float32) + anchor_offset = step / 2.0 + + ratio_sqrt = np.sqrt(aspect_ratios_grid) + heights = scales_grid / ratio_sqrt * base_size[0] + widths = scales_grid * ratio_sqrt * base_size[1] + + y_centers = np.arange(grid_height).astype(np.float32) + y_centers = y_centers * step + anchor_offset + x_centers = np.arange(grid_width).astype(np.float32) + x_centers = x_centers * step + anchor_offset + x_centers, y_centers = np.meshgrid(x_centers, y_centers) + + x_centers_shape = x_centers.shape + y_centers_shape = y_centers.shape + + widths_grid, x_centers_grid = np.meshgrid(widths, x_centers.reshape([-1])) + heights_grid, y_centers_grid = np.meshgrid(heights, y_centers.reshape([-1])) + + x_centers_grid = x_centers_grid.reshape(*x_centers_shape, -1) + y_centers_grid = y_centers_grid.reshape(*y_centers_shape, -1) + widths_grid = widths_grid.reshape(-1, *x_centers_shape) + heights_grid = heights_grid.reshape(-1, *y_centers_shape) + + + bbox_centers = np.stack([y_centers_grid, x_centers_grid], axis=3) + bbox_sizes = np.stack([heights_grid, widths_grid], axis=3) + bbox_centers = bbox_centers.reshape([-1, 2]) + bbox_sizes = bbox_sizes.reshape([-1, 2]) + bbox_corners = np.concatenate([bbox_centers - 0.5 * bbox_sizes, bbox_centers + 0.5 * bbox_sizes], axis=1) + self.bbox_corners = bbox_corners / np.array([*self.image_shape, *self.image_shape]).astype(np.float32) + self.bbox_centers = np.concatenate([bbox_centers, bbox_sizes], axis=1) + self.bbox_centers = self.bbox_centers / np.array([*self.image_shape, *self.image_shape]).astype(np.float32) + + print(self.bbox_centers.shape) + return self.bbox_centers, self.bbox_corners + + def generate_multi_levels(self, steps): + bbox_centers_list = [] + bbox_corners_list = [] + for step in steps: + bbox_centers, bbox_corners = self.generate(step) + bbox_centers_list.append(bbox_centers) + bbox_corners_list.append(bbox_corners) + + self.bbox_centers = np.concatenate(bbox_centers_list, axis=0) + self.bbox_corners = np.concatenate(bbox_corners_list, axis=0) + return self.bbox_centers, self.bbox_corners diff --git a/model_zoo/official/cv/ssd/src/box_utils.py b/model_zoo/official/cv/ssd/src/box_utils.py index dfb2e7a03e7..25fd9ae4289 100644 --- a/model_zoo/official/cv/ssd/src/box_utils.py +++ b/model_zoo/official/cv/ssd/src/box_utils.py @@ -19,6 +19,7 @@ import math import itertools as it import numpy as np from .config import config +from .anchor_generator import GridAnchorGenerator class GeneratDefaultBoxes(): @@ -36,7 +37,7 @@ class GeneratDefaultBoxes(): sk1 = scales[idex] sk2 = scales[idex + 1] sk3 = math.sqrt(sk1 * sk2) - if idex == 0: + if idex == 0 and not config.aspect_ratios[idex]: w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) all_sizes = [(0.1, 0.1), (w, h), (h, w)] else: @@ -61,9 +62,12 @@ class GeneratDefaultBoxes(): self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32') self.default_boxes = np.array(self.default_boxes, dtype='float32') - -default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr -default_boxes = GeneratDefaultBoxes().default_boxes +if 'use_anchor_generator' in config and config.use_anchor_generator: + generator = GridAnchorGenerator(config.img_shape, 4, 2, [1.0, 2.0, 0.5]) + default_boxes, default_boxes_tlbr = generator.generate_multi_levels(config.steps) +else: + default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr + default_boxes = GeneratDefaultBoxes().default_boxes y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) vol_anchors = (x2 - x1) * (y2 - y1) matching_threshold = config.match_threshold diff --git a/model_zoo/official/cv/ssd/src/config.py b/model_zoo/official/cv/ssd/src/config.py index a41831c0be4..a9f0b59f931 100644 --- a/model_zoo/official/cv/ssd/src/config.py +++ b/model_zoo/official/cv/ssd/src/config.py @@ -15,68 +15,15 @@ """Config parameters for SSD models.""" -from easydict import EasyDict as ed +from .config_ssd300 import config as config_ssd300 +from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn -config = ed({ - "img_shape": [300, 300], - "num_ssd_boxes": 1917, - "neg_pre_positive": 3, - "match_threshold": 0.5, - "nms_threshold": 0.6, - "min_score": 0.1, - "max_boxes": 100, - # learing rate settings - "lr_init": 0.001, - "lr_end_rate": 0.001, - "warmup_epochs": 2, - "momentum": 0.9, - "weight_decay": 1.5e-4, +using_model = "ssd300" - # network - "num_default": [3, 6, 6, 6, 6, 6], - "extras_in_channels": [256, 576, 1280, 512, 256, 256], - "extras_out_channels": [576, 1280, 512, 256, 256, 128], - "extras_strides": [1, 1, 2, 2, 2, 2], - "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], - "feature_size": [19, 10, 5, 3, 2, 1], - "min_scale": 0.2, - "max_scale": 0.95, - "aspect_ratios": [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], - "steps": (16, 32, 64, 100, 150, 300), - "prior_scaling": (0.1, 0.2), - "gamma": 2.0, - "alpha": 0.75, +config_map = { + "ssd300": config_ssd300, + "ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn +} - # `mindrecord_dir` and `coco_root` are better to use absolute path. - "mindrecord_dir": "/data/MindRecord_COCO", - "coco_root": "/data/coco2017", - "train_data_type": "train2017", - "val_data_type": "val2017", - "instances_set": "annotations/instances_{}.json", - "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', - 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', - 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', - 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', - 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', - 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', - 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', - 'teddy bear', 'hair drier', 'toothbrush'), - "num_classes": 81, - # The annotation.json position of voc validation dataset. - "voc_json": "annotations/voc_instances_val.json", - # voc original dataset. - "voc_root": "/data/voc_dataset", - # if coco or voc used, `image_dir` and `anno_path` are useless. - "image_dir": "", - "anno_path": "", - "export_format": "MINDIR", - "export_file": "ssd.mindir" -}) +config = config_map[using_model] diff --git a/model_zoo/official/cv/ssd/src/config_ssd300.py b/model_zoo/official/cv/ssd/src/config_ssd300.py new file mode 100644 index 00000000000..8c46e86c5cf --- /dev/null +++ b/model_zoo/official/cv/ssd/src/config_ssd300.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#" ============================================================================ + +"""Config parameters for SSD models.""" + +from easydict import EasyDict as ed + +config = ed({ + "model": "ssd300", + "img_shape": [300, 300], + "num_ssd_boxes": 1917, + "neg_pre_positive": 3, + "match_threshold": 0.5, + "nms_threshold": 0.6, + "min_score": 0.1, + "max_boxes": 100, + + # learing rate settings + "lr_init": 0.001, + "lr_end_rate": 0.001, + "warmup_epochs": 2, + "momentum": 0.9, + "weight_decay": 1.5e-4, + + # network + "num_default": [3, 6, 6, 6, 6, 6], + "extras_in_channels": [256, 576, 1280, 512, 256, 256], + "extras_out_channels": [576, 1280, 512, 256, 256, 128], + "extras_strides": [1, 1, 2, 2, 2, 2], + "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], + "feature_size": [19, 10, 5, 3, 2, 1], + "min_scale": 0.2, + "max_scale": 0.95, + "aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], + "steps": (16, 32, 64, 100, 150, 300), + "prior_scaling": (0.1, 0.2), + "gamma": 2.0, + "alpha": 0.75, + + # `mindrecord_dir` and `coco_root` are better to use absolute path. + "feature_extractor_base_param": "", + "mindrecord_dir": "/data/MindRecord_COCO", + "coco_root": "/data/coco2017", + "train_data_type": "train2017", + "val_data_type": "val2017", + "instances_set": "annotations/instances_{}.json", + "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + "num_classes": 81, + # The annotation.json position of voc validation dataset. + "voc_json": "annotations/voc_instances_val.json", + # voc original dataset. + "voc_root": "/data/voc_dataset", + # if coco or voc used, `image_dir` and `anno_path` are useless. + "image_dir": "", + "anno_path": "", + "export_format": "MINDIR", + "export_file": "ssd.mindir" +}) diff --git a/model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py b/model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py new file mode 100644 index 00000000000..8904737ad61 --- /dev/null +++ b/model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#" ============================================================================ + +"""Config parameters for SSD models.""" + +from easydict import EasyDict as ed + +config = ed({ + "model": "ssd_mobilenet_v1_fpn", + "img_shape": [640, 640], + "num_ssd_boxes": 51150, + "neg_pre_positive": 3, + "match_threshold": 0.5, + "nms_threshold": 0.6, + "min_score": 0.1, + "max_boxes": 100, + + # learning rate settings + "global_step": 0, + "lr_init": 0.01333, + "lr_end_rate": 0.0, + "warmup_epochs": 2, + "momentum": 0.9, + "weight_decay": 1.5e-4, + + # network + "num_default": [6, 6, 6, 6, 6], + "extras_in_channels": [256, 512, 1024, 256, 256], + "extras_out_channels": [256, 256, 256, 256, 256], + "extras_strides": [1, 1, 2, 2, 2, 2], + "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], + "feature_size": [80, 40, 20, 10, 5], + "min_scale": 0.2, + "max_scale": 0.95, + "aspect_ratios": [(2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], + "steps": (8, 16, 32, 64, 128), + "prior_scaling": (0.1, 0.2), + "gamma": 2.0, + "alpha": 0.75, + "num_addition_layers": 4, + "use_anchor_generator": True, + "use_global_norm": True, + + # `mindrecord_dir` and `coco_root` are better to use absolute path. + "feature_extractor_base_param": "/ckpt/mobilenet_v1.ckpt", + "mindrecord_dir": "/data/MindRecord_COCO", + "coco_root": "/data/coco2017", + "train_data_type": "train2017", + "val_data_type": "val2017", + "instances_set": "annotations/instances_{}.json", + "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + "num_classes": 81, + # The annotation.json position of voc validation dataset. + "voc_json": "annotations/voc_instances_val.json", + # voc original dataset. + "voc_root": "/data/voc_dataset", + # if coco or voc used, `image_dir` and `anno_path` are useless. + "image_dir": "", + "anno_path": "", + "export_format": "MINDIR", + "export_file": "ssd.mindir" +}) diff --git a/model_zoo/official/cv/ssd/src/init_params.py b/model_zoo/official/cv/ssd/src/init_params.py index 6527706fbbe..6ffb2ed58f0 100644 --- a/model_zoo/official/cv/ssd/src/init_params.py +++ b/model_zoo/official/cv/ssd/src/init_params.py @@ -22,14 +22,14 @@ def init_net_param(network, initialize_mode='TruncatedNormal'): for p in params: if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: if initialize_mode == 'TruncatedNormal': - p.set_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype)) + p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype)) else: p.set_data(initialize_mode, p.data.shape, p.data.dtype) def load_backbone_params(network, param_dict): """Init the parameters from pre-train model, default is mobilenetv2.""" - for _, param in net.parameters_and_names(): + for _, param in network.parameters_and_names(): param_name = param.name.replace('network.backbone.', '') name_split = param_name.split('.') if 'features_1' in param_name: diff --git a/model_zoo/official/cv/ssd/src/mobilenet_v1_fpn.py b/model_zoo/official/cv/ssd/src/mobilenet_v1_fpn.py new file mode 100644 index 00000000000..da35c041541 --- /dev/null +++ b/model_zoo/official/cv/ssd/src/mobilenet_v1_fpn.py @@ -0,0 +1,192 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'): + output = [] + output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same", + group=1 if not depthwise else in_channel)) + output.append(nn.BatchNorm2d(out_channel)) + if activation: + output.append(nn.get_activation(activation)) + return nn.SequentialCell(output) + + +class MobileNetV1(nn.Cell): + """ + MobileNet V1 backbone + """ + def __init__(self, class_num=1001, features_only=False): + super(MobileNetV1, self).__init__() + self.features_only = features_only + cnn = [ + conv_bn_relu(3, 32, 3, 2, False), # Conv0 + + conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise + conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise + conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise + conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise + + conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise + conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise + conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise + conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise + + conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise + conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise + conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise + conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise + + conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise + conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise + conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise + conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise + conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise + conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise + + conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise + conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise + conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise + conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise + ] + + if self.features_only: + self.network = nn.CellList(cnn) + else: + self.network = nn.SequentialCell(cnn) + self.fc = nn.Dense(1024, class_num) + + def construct(self, x): + output = x + if self.features_only: + features = () + for block in self.network: + output = block(output) + features = features + (output,) + return features + output = self.network(x) + output = P.ReduceMean()(output, (2, 3)) + output = self.fc(output) + return output + + +class FpnTopDown(nn.Cell): + """ + Fpn to extract features + """ + def __init__(self, in_channel_list, out_channels): + super(FpnTopDown, self).__init__() + self.lateral_convs_list_ = [] + self.fpn_convs_ = [] + for channel in in_channel_list: + l_conv = nn.Conv2d(channel, out_channels, kernel_size=1, stride=1, + has_bias=True, padding=0, pad_mode='same') + fpn_conv = conv_bn_relu(out_channels, out_channels, kernel_size=3, stride=1, depthwise=False) + self.lateral_convs_list_.append(l_conv) + self.fpn_convs_.append(fpn_conv) + self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) + self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) + self.num_layers = len(in_channel_list) + + def construct(self, inputs): + image_features = () + for i, feature in enumerate(inputs): + image_features = image_features + (self.lateral_convs_list[i](feature),) + + features = (image_features[-1],) + for i in range(len(inputs) - 1): + top = len(inputs) - i - 1 + down = top - 1 + size = F.shape(inputs[down]) + top_down = P.ResizeBilinear((size[2], size[3]))(features[-1]) + top_down = top_down + image_features[down] + features = features + (top_down,) + + extract_features = () + num_features = len(features) + for i in range(num_features): + extract_features = extract_features + (self.fpn_convs_list[i](features[num_features - i - 1]),) + + return extract_features + + +class BottomUp(nn.Cell): + """ + Bottom Up feature extractor + """ + def __init__(self, levels, channels, kernel_size, stride): + super(BottomUp, self).__init__() + self.levels = levels + bottom_up_cells = [ + conv_bn_relu(channels, channels, kernel_size, stride, False) for x in range(self.levels) + ] + self.blocks = nn.CellList(bottom_up_cells) + + def construct(self, features): + for block in self.blocks: + features = features + (block(features[-1]),) + return features + + +class FeatureSelector(nn.Cell): + """ + Select specific layers from an entire feature list + """ + def __init__(self, feature_idxes): + super(FeatureSelector, self).__init__() + self.feature_idxes = feature_idxes + + def construct(self, feature_list): + selected = () + for i in self.feature_idxes: + selected = selected + (feature_list[i],) + return selected + + +class MobileNetV1Fpn(nn.Cell): + """ + MobileNetV1 with FPN as SSD backbone. + """ + def __init__(self, config): + super(MobileNetV1Fpn, self).__init__() + self.mobilenet_v1 = MobileNetV1(features_only=True) + + self.selector = FeatureSelector([10, 22, 26]) + + self.layer_indexs = [10, 22, 26] + self.fpn = FpnTopDown([256, 512, 1024], 256) + self.bottom_up = BottomUp(2, 256, 3, 2) + + def construct(self, x): + features = self.mobilenet_v1(x) + features = self.selector(features) + features = self.fpn(features) + features = self.bottom_up(features) + return features + + +def mobilenet_v1_fpn(config): + return MobileNetV1Fpn(config) + + +def mobilenet_v1(class_num=1001): + return MobileNetV1(class_num) diff --git a/model_zoo/official/cv/ssd/src/ssd.py b/model_zoo/official/cv/ssd/src/ssd.py index a91e9d819af..6e9d1df45df 100644 --- a/model_zoo/official/cv/ssd/src/ssd.py +++ b/model_zoo/official/cv/ssd/src/ssd.py @@ -26,6 +26,8 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C +from .mobilenet_v1_fpn import mobilenet_v1_fpn + def _make_divisible(v, divisor, min_value=None): """nsures that all layers have a channel number that is divisible by 8.""" @@ -67,6 +69,7 @@ class ConvBNReLU(nn.Cell): kernel_size (int): Input kernel size. stride (int): Stride size for the first convolutional layer. Default: 1. groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + shared_conv(Cell): Use the weight shared conv, default: None. Returns: Tensor, output tensor. @@ -74,18 +77,21 @@ class ConvBNReLU(nn.Cell): Examples: >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) """ - def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, shared_conv=None): super(ConvBNReLU, self).__init__() padding = 0 in_channels = in_planes out_channels = out_planes - if groups == 1: - conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding) + if shared_conv is None: + if groups == 1: + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding) + else: + out_channels = in_planes + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', + padding=padding, group=in_channels) + layers = [conv, _bn(out_planes), nn.ReLU6()] else: - out_channels = in_planes - conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', - padding=padding, group=in_channels) - layers = [conv, _bn(out_planes), nn.ReLU6()] + layers = [shared_conv, _bn(out_planes), nn.ReLU6()] self.features = nn.SequentialCell(layers) def construct(self, x): @@ -205,6 +211,86 @@ class MultiBox(nn.Cell): return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) +class WeightSharedMultiBox(nn.Cell): + """ + Weight shared Multi-box conv layers. Each multi-box layer contains class conf scores and localization predictions. + All box predictors shares the same conv weight in different features. + + Args: + config (dict): The default config of SSD. + loc_cls_shared_addition(bool): Whether the location predictor and classifier prediction share the + same addition layer. + Returns: + Tensor, localization predictions. + Tensor, class conf scores. + """ + def __init__(self, config, loc_cls_shared_addition=False): + super(WeightSharedMultiBox, self).__init__() + num_classes = config.num_classes + out_channels = config.extras_out_channels[0] + num_default = config.num_default[0] + num_features = len(config.feature_size) + num_addition_layers = config.num_addition_layers + self.loc_cls_shared_addition = loc_cls_shared_addition + + if not loc_cls_shared_addition: + loc_convs = [ + _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) + ] + cls_convs = [ + _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) + ] + addition_loc_layer_list = [] + addition_cls_layer_list = [] + for _ in range(num_features): + addition_loc_layer = [ + ConvBNReLU(out_channels, out_channels, 3, 1, 1, loc_convs[x]) for x in range(num_addition_layers) + ] + addition_cls_layer = [ + ConvBNReLU(out_channels, out_channels, 3, 1, 1, cls_convs[x]) for x in range(num_addition_layers) + ] + addition_loc_layer_list.append(nn.SequentialCell(addition_loc_layer)) + addition_cls_layer_list.append(nn.SequentialCell(addition_cls_layer)) + self.addition_layer_loc = nn.CellList(addition_loc_layer_list) + self.addition_layer_cls = nn.CellList(addition_cls_layer_list) + else: + convs = [ + _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) + ] + addition_layer_list = [] + for _ in range(num_features): + addition_layers = [ + ConvBNReLU(out_channels, out_channels, 3, 1, 1, convs[x]) for x in range(num_addition_layers) + ] + addition_layer_list.append(nn.SequentialCell(addition_layers)) + self.addition_layer = nn.SequentialCell(addition_layer_list) + + loc_layers = [_conv2d(out_channels, 4 * num_default, + kernel_size=3, stride=1, pad_mod='same')] + cls_layers = [_conv2d(out_channels, num_classes * num_default, + kernel_size=3, stride=1, pad_mod='same')] + + self.loc_layers = nn.SequentialCell(loc_layers) + self.cls_layers = nn.SequentialCell(cls_layers) + self.flatten_concat = FlattenConcat(config) + + def construct(self, inputs): + loc_outputs = () + cls_outputs = () + num_heads = len(inputs) + for i in range(num_heads): + if self.loc_cls_shared_addition: + features = self.addition_layer[i](inputs[i]) + loc_outputs += (self.loc_layers(features),) + cls_outputs += (self.cls_layers(features),) + else: + features = self.addition_layer_loc[i](inputs[i]) + loc_outputs += (self.loc_layers(features),) + features = self.addition_layer_cls[i](inputs[i]) + cls_outputs += (self.cls_layers(features),) + return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) + + class SSD300(nn.Cell): """ SSD300 Network. Default backbone is resnet34. @@ -255,6 +341,40 @@ class SSD300(nn.Cell): return pred_loc, pred_label +class SsdMobilenetV1Fpn(nn.Cell): + """ + SSD Network using mobilenetV1 with fpn to extract features + + Args: + config (dict): The default config of SSD. + is_training (bool): Used for training, default is True. + + Returns: + Tensor, localization predictions. + Tensor, class conf scores. + + Examples:backbone + SsdMobilenetV1Fpn(config, True). + """ + def __init__(self, config, is_training=True): + super(SsdMobilenetV1Fpn, self).__init__() + self.multi_box = WeightSharedMultiBox(config) + self.is_training = is_training + if not is_training: + self.activation = P.Sigmoid() + + self.feature_extractor = mobilenet_v1_fpn(config) + + def construct(self, x): + features = self.feature_extractor(x) + pred_loc, pred_label = self.multi_box(features) + if not self.is_training: + pred_label = self.activation(pred_label) + pred_loc = F.cast(pred_loc, mstype.float32) + pred_label = F.cast(pred_label, mstype.float32) + return pred_loc, pred_label + + class SigmoidFocalClassificationLoss(nn.Cell): """" Sigmoid focal-loss for classification. @@ -328,6 +448,12 @@ class SSDWithLossCell(nn.Cell): return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) +grad_scale = C.MultitypeFuncGraph("grad_scale") +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * P.Reciprocal()(scale) + + class TrainingWrapper(nn.Cell): """ Encapsulation class of SSD network training. @@ -339,8 +465,9 @@ class TrainingWrapper(nn.Cell): network (Cell): The training network. Note that loss function should have been added. optimizer (Optimizer): Optimizer for updating the weights. sens (Number): The adjust parameter. Default: 1.0. + use_global_nrom(bool): Whether apply global norm before optimizer. Default: False """ - def __init__(self, network, optimizer, sens=1.0): + def __init__(self, network, optimizer, sens=1.0, use_global_norm=False): super(TrainingWrapper, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() @@ -350,6 +477,7 @@ class TrainingWrapper(nn.Cell): self.sens = sens self.reducer_flag = False self.grad_reducer = None + self.use_global_norm = use_global_norm self.parallel_mode = context.get_auto_parallel_context("parallel_mode") if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: self.reducer_flag = True @@ -360,6 +488,7 @@ class TrainingWrapper(nn.Cell): else: degree = get_group_size() self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + self.hyper_map = C.HyperMap() def construct(self, *args): weights = self.weights @@ -369,6 +498,9 @@ class TrainingWrapper(nn.Cell): if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) + if self.use_global_norm: + grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_array(self.sens)), grads) + grads = C.clip_by_global_norm(grads) return F.depend(loss, self.optimizer(grads)) @@ -439,5 +571,10 @@ class SSDWithMobileNetV2(nn.Cell): def get_out_channels(self): return self.last_channel + +def ssd_mobilenet_v1_fpn(**kwargs): + return SsdMobilenetV1Fpn(**kwargs) + + def ssd_mobilenet_v2(**kwargs): return SSDWithMobileNetV2(**kwargs) diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py index 2094077e258..c34c76b4c0a 100644 --- a/model_zoo/official/cv/ssd/train.py +++ b/model_zoo/official/cv/ssd/train.py @@ -25,7 +25,7 @@ from mindspore.train import Model from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed, dtype -from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 +from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn from src.config import config from src.dataset import create_ssd_dataset, create_mindrecord from src.lr_schedule import get_lr @@ -74,63 +74,85 @@ def main(): context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) init() + context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89]) rank = get_rank() mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True) - if not args_opt.only_create_dataset: - loss_scale = float(args_opt.loss_scale) - if args_opt.run_platform == "CPU": - loss_scale = 1.0 - # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. - use_multiprocessing = (args_opt.run_platform != "CPU") - dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, - device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) + if args_opt.only_create_dataset: + return - dataset_size = dataset.get_dataset_size() - print("Create dataset done!") + loss_scale = float(args_opt.loss_scale) + if args_opt.run_platform == "CPU": + loss_scale = 1.0 - backbone = ssd_mobilenet_v2() + # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. + use_multiprocessing = (args_opt.run_platform != "CPU") + dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, + device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) + + dataset_size = dataset.get_dataset_size() + print("Create dataset done!") + + backbone = ssd_mobilenet_v2() + if config.model == "ssd300": ssd = SSD300(backbone=backbone, config=config) - if args_opt.run_platform == "GPU": - ssd.to_float(dtype.float16) - net = SSDWithLossCell(ssd, config) - init_net_param(net) + elif config.model == "ssd_mobilenet_v1_fpn": + ssd = ssd_mobilenet_v1_fpn(config=config) + else: + raise ValueError(f'config.model: {config.model} is not supported') + if args_opt.run_platform == "GPU": + ssd.to_float(dtype.float16) + net = SSDWithLossCell(ssd, config) - # checkpoint - ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) - save_ckpt_path = './ckpt_' + str(rank) + '/' - ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) + init_net_param(net) - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - if args_opt.filter_weight: - filter_checkpoint_parameter(param_dict) - load_param_into_net(net, param_dict) + if config.feature_extractor_base_param != "": + param_dict = load_checkpoint(config.feature_extractor_base_param) + for x in list(param_dict.keys()): + param_dict["network.feature_extractor.mobilenet_v1." + x] = param_dict[x] + del param_dict[x] + load_param_into_net(ssd.feature_extractor.mobilenet_v1.network, param_dict) - if args_opt.freeze_layer == "backbone": - for param in backbone.feature_1.trainable_params(): - param.requires_grad = False + # checkpoint + ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) + save_ckpt_path = './ckpt_' + str(rank) + '/' + ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) - lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, - lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, - warmup_epochs=config.warmup_epochs, - total_epochs=args_opt.epoch_size, - steps_per_epoch=dataset_size)) + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + if args_opt.filter_weight: + filter_checkpoint_parameter(param_dict) + load_param_into_net(net, param_dict) + if args_opt.freeze_layer == "backbone": + for param in backbone.feature_1.trainable_params(): + param.requires_grad = False + + lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, + lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, + warmup_epochs=config.warmup_epochs, + total_epochs=args_opt.epoch_size, + steps_per_epoch=dataset_size)) + + if "use_global_norm" in config and config.use_global_norm: + opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, + config.momentum, config.weight_decay, 1.0) + net = TrainingWrapper(net, opt, loss_scale, True) + else: opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, loss_scale) - net = TrainingWrapper(net, opt, loss_scale) - callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] - model = Model(net) - dataset_sink_mode = False - if args_opt.mode == "sink" and args_opt.run_platform != "CPU": - print("In sink mode, one epoch return a loss.") - dataset_sink_mode = True - print("Start train SSD, the first epoch will be slower because of the graph compilation.") - model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) + + callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] + model = Model(net) + dataset_sink_mode = False + if args_opt.mode == "sink" and args_opt.run_platform != "CPU": + print("In sink mode, one epoch return a loss.") + dataset_sink_mode = True + print("Start train SSD, the first epoch will be slower because of the graph compilation.") + model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) if __name__ == '__main__': main()