diff --git a/RELEASE.md b/RELEASE.md index f919bd7a2fa..686f4ccac10 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,7 @@ * DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset. * DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark. * Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset. + * SSD: a single stage object detection methods on COCO 2017 dataset. * GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset. * Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset. * Frontend and User Interface diff --git a/example/ssd_coco2017/README.md b/example/ssd_coco2017/README.md deleted file mode 100644 index bd43344b8b8..00000000000 --- a/example/ssd_coco2017/README.md +++ /dev/null @@ -1,88 +0,0 @@ -# SSD Example - -## Description - -SSD network based on MobileNetV2, with support for training and evaluation. - -## Requirements - -- Install [MindSpore](https://www.mindspore.cn/install/en). - -- Dataset - - We use coco2017 as training dataset in this example by default, and you can also use your own datasets. - - 1. If coco dataset is used. **Select dataset to coco when run script.** - Install Cython and pycocotool. - - ``` - pip install Cython - - pip install pycocotools - ``` - And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: - - - ``` - └─coco2017 - ├── annotations # annotation jsons - ├── train2017 # train dataset - └── val2017 # infer dataset - ``` - - 2. If your own dataset is used. **Select dataset to other when run script.** - Organize the dataset infomation into a TXT file, each row in the file is as follows: - - ``` - train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2 - ``` - - Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`. - - -## Running the example - -### Training - -To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.** - - -- Stand alone mode - - ``` - python train.py --dataset coco - - ``` - - You can run ```python train.py -h``` to get more information. - - -- Distribute mode - - ``` - sh run_distribute_train.sh 8 150 coco /data/hccl.json - ``` - - The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.** - -You will get the loss value of each step as following: - -``` -epoch: 1 step: 455, loss is 5.8653416 -epoch: 2 step: 455, loss is 5.4292373 -epoch: 3 step: 455, loss is 5.458992 -... -epoch: 148 step: 455, loss is 1.8340507 -epoch: 149 step: 455, loss is 2.0876894 -epoch: 150 step: 455, loss is 2.239692 -``` - -### Evaluation - -for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file. - -``` -python eval.py --ckpt_path ssd.ckpt --dataset coco -``` - -You can run ```python eval.py -h``` to get more information. diff --git a/example/ssd_coco2017/config.py b/example/ssd_coco2017/config.py deleted file mode 100644 index 452aaf97008..00000000000 --- a/example/ssd_coco2017/config.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Config parameters for SSD models.""" - - -class ConfigSSD: - """ - Config parameters for SSD. - - Examples: - ConfigSSD(). - """ - IMG_SHAPE = [300, 300] - NUM_SSD_BOXES = 1917 - NEG_PRE_POSITIVE = 3 - MATCH_THRESHOLD = 0.5 - - NUM_DEFAULT = [3, 6, 6, 6, 6, 6] - EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256] - EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128] - EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2] - EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] - FEATURE_SIZE = [19, 10, 5, 3, 2, 1] - SCALES = [21, 45, 99, 153, 207, 261, 315] - ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)] - STEPS = (16, 32, 64, 100, 150, 300) - PRIOR_SCALING = (0.1, 0.2) - - - # `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path. - MINDRECORD_DIR = "MindRecord_COCO" - COCO_ROOT = "coco2017" - TRAIN_DATA_TYPE = "train2017" - VAL_DATA_TYPE = "val2017" - INSTANCES_SET = "annotations/instances_{}.json" - COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', - 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', - 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', - 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', - 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', - 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', - 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', - 'teddy bear', 'hair drier', 'toothbrush') - NUM_CLASSES = len(COCO_CLASSES) diff --git a/example/ssd_coco2017/dataset.py b/example/ssd_coco2017/dataset.py deleted file mode 100644 index b88b22c8626..00000000000 --- a/example/ssd_coco2017/dataset.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SSD dataset""" -from __future__ import division - -import os -import math -import itertools as it -import numpy as np -import cv2 - -import mindspore.dataset as de -import mindspore.dataset.transforms.vision.c_transforms as C -from mindspore.mindrecord import FileWriter -from config import ConfigSSD - -config = ConfigSSD() - -class GeneratDefaultBoxes(): - """ - Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). - `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [x, y, w, h]. - `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [x1, y1, x2, y2]. - """ - def __init__(self): - fk = config.IMG_SHAPE[0] / np.array(config.STEPS) - self.default_boxes = [] - for idex, feature_size in enumerate(config.FEATURE_SIZE): - sk1 = config.SCALES[idex] / config.IMG_SHAPE[0] - sk2 = config.SCALES[idex + 1] / config.IMG_SHAPE[0] - sk3 = math.sqrt(sk1 * sk2) - - if config.NUM_DEFAULT[idex] == 3: - all_sizes = [(0.5, 1.0), (1.0, 1.0), (1.0, 0.5)] - else: - all_sizes = [(sk1, sk1), (sk3, sk3)] - for aspect_ratio in config.ASPECT_RATIOS[idex]: - w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) - all_sizes.append((w, h)) - all_sizes.append((h, w)) - - assert len(all_sizes) == config.NUM_DEFAULT[idex] - - for i, j in it.product(range(feature_size), repeat=2): - for w, h in all_sizes: - cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] - box = [np.clip(k, 0, 1) for k in (cx, cy, w, h)] - self.default_boxes.append(box) - - def to_ltrb(cx, cy, w, h): - return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2 - - # For IoU calculation - self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') - self.default_boxes = np.array(self.default_boxes, dtype='float32') - - -default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb -default_boxes = GeneratDefaultBoxes().default_boxes -x1, y1, x2, y2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) -vol_anchors = (x2 - x1) * (y2 - y1) -matching_threshold = config.MATCH_THRESHOLD - - -def ssd_bboxes_encode(boxes): - """ - Labels anchors with ground truth inputs. - - Args: - boxex: ground truth with shape [N, 5], for each row, it stores [x, y, w, h, cls]. - - Returns: - gt_loc: location ground truth with shape [num_anchors, 4]. - gt_label: class ground truth with shape [num_anchors, 1]. - num_matched_boxes: number of positives in an image. - """ - - def jaccard_with_anchors(bbox): - """Compute jaccard score a box and the anchors.""" - # Intersection bbox and volume. - xmin = np.maximum(x1, bbox[0]) - ymin = np.maximum(y1, bbox[1]) - xmax = np.minimum(x2, bbox[2]) - ymax = np.minimum(y2, bbox[3]) - w = np.maximum(xmax - xmin, 0.) - h = np.maximum(ymax - ymin, 0.) - - # Volumes. - inter_vol = h * w - union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol - jaccard = inter_vol / union_vol - return np.squeeze(jaccard) - - pre_scores = np.zeros((config.NUM_SSD_BOXES), dtype=np.float32) - t_boxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32) - t_label = np.zeros((config.NUM_SSD_BOXES), dtype=np.int64) - for bbox in boxes: - label = int(bbox[4]) - scores = jaccard_with_anchors(bbox) - mask = (scores > matching_threshold) - if not np.any(mask): - mask[np.argmax(scores)] = True - - mask = mask & (scores > pre_scores) - pre_scores = np.maximum(pre_scores, scores) - t_label = mask * label + (1 - mask) * t_label - for i in range(4): - t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] - - index = np.nonzero(t_label) - - # Transform to ltrb. - bboxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32) - bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 - bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] - - # Encode features. - bboxes_t = bboxes[index] - default_boxes_t = default_boxes[index] - bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.PRIOR_SCALING[0]) - bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1] - bboxes[index] = bboxes_t - - num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) - return bboxes, t_label.astype(np.int32), num_match_num - -def ssd_bboxes_decode(boxes, index): - """Decode predict boxes to [x, y, w, h]""" - boxes_t = boxes[index] - default_boxes_t = default_boxes[index] - boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2] - boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4] - - bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32) - - bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 - bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 - - return bboxes - -def preprocess_fn(image, box, is_training): - """Preprocess function for dataset.""" - - def _rand(a=0., b=1.): - """Generate random.""" - return np.random.rand() * (b - a) + a - - def _infer_data(image, input_shape, box): - img_h, img_w, _ = image.shape - input_h, input_w = input_shape - - scale = min(float(input_w) / float(img_w), float(input_h) / float(img_h)) - nw = int(img_w * scale) - nh = int(img_h * scale) - - image = cv2.resize(image, (nw, nh)) - - new_image = np.zeros((input_h, input_w, 3), np.float32) - dh = (input_h - nh) // 2 - dw = (input_w - nw) // 2 - new_image[dh: (nh + dh), dw: (nw + dw), :] = image - image = new_image - - #When the channels of image is 1 - if len(image.shape) == 2: - image = np.expand_dims(image, axis=-1) - image = np.concatenate([image, image, image], axis=-1) - - box = box.astype(np.float32) - - box[:, [0, 2]] = (box[:, [0, 2]] * scale + dw) / input_w - box[:, [1, 3]] = (box[:, [1, 3]] * scale + dh) / input_h - return image, np.array((img_h, img_w), np.float32), box - - def _data_aug(image, box, is_training, image_size=(300, 300)): - """Data augmentation function.""" - ih, iw, _ = image.shape - w, h = image_size - - if not is_training: - return _infer_data(image, image_size, box) - # Random settings - scale_w = _rand(0.75, 1.25) - scale_h = _rand(0.75, 1.25) - - flip = _rand() < .5 - nw = iw * scale_w - nh = ih * scale_h - scale = min(w / nw, h / nh) - nw = int(scale * nw) - nh = int(scale * nh) - - # Resize image - image = cv2.resize(image, (nw, nh)) - - # place image - new_image = np.zeros((h, w, 3), dtype=np.float32) - dw = (w - nw) // 2 - dh = (h - nh) // 2 - new_image[dh:dh + nh, dw:dw + nw, :] = image - image = new_image - - # Flip image or not - if flip: - image = cv2.flip(image, 1, dst=None) - - # Convert image to gray or not - gray = _rand() < .25 - if gray: - image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - - # When the channels of image is 1 - if len(image.shape) == 2: - image = np.expand_dims(image, axis=-1) - image = np.concatenate([image, image, image], axis=-1) - - box = box.astype(np.float32) - - # Transform box with shape[x1, y1, x2, y2]. - box[:, [0, 2]] = (box[:, [0, 2]] * scale * scale_w + dw) / w - box[:, [1, 3]] = (box[:, [1, 3]] * scale * scale_h + dh) / h - - if flip: - box[:, [0, 2]] = 1 - box[:, [2, 0]] - - box, label, num_match_num = ssd_bboxes_encode(box) - return image, box, label, num_match_num - return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE) - - -def create_coco_label(is_training): - """Get image path and annotation from COCO.""" - from pycocotools.coco import COCO - - coco_root = config.COCO_ROOT - data_type = config.VAL_DATA_TYPE - if is_training: - data_type = config.TRAIN_DATA_TYPE - - #Classes need to train or test. - train_cls = config.COCO_CLASSES - train_cls_dict = {} - for i, cls in enumerate(train_cls): - train_cls_dict[cls] = i - - anno_json = os.path.join(coco_root, config.INSTANCES_SET.format(data_type)) - - coco = COCO(anno_json) - classs_dict = {} - cat_ids = coco.loadCats(coco.getCatIds()) - for cat in cat_ids: - classs_dict[cat["id"]] = cat["name"] - - image_ids = coco.getImgIds() - image_files = [] - image_anno_dict = {} - - for img_id in image_ids: - image_info = coco.loadImgs(img_id) - file_name = image_info[0]["file_name"] - anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) - anno = coco.loadAnns(anno_ids) - image_path = os.path.join(coco_root, data_type, file_name) - annos = [] - for label in anno: - bbox = label["bbox"] - class_name = classs_dict[label["category_id"]] - if class_name in train_cls: - x_min, x_max = bbox[0], bbox[0] + bbox[2] - y_min, y_max = bbox[1], bbox[1] + bbox[3] - annos.append(list(map(round, [x_min, y_min, x_max, y_max])) + [train_cls_dict[class_name]]) - if len(annos) >= 1: - image_files.append(image_path) - image_anno_dict[image_path] = np.array(annos) - return image_files, image_anno_dict - - -def anno_parser(annos_str): - """Parse annotation from string to list.""" - annos = [] - for anno_str in annos_str: - anno = list(map(int, anno_str.strip().split(','))) - annos.append(anno) - return annos - - -def filter_valid_data(image_dir, anno_path): - """Filter valid image file, which both in image_dir and anno_path.""" - image_files = [] - image_anno_dict = {} - if not os.path.isdir(image_dir): - raise RuntimeError("Path given is not valid.") - if not os.path.isfile(anno_path): - raise RuntimeError("Annotation file is not valid.") - - with open(anno_path, "rb") as f: - lines = f.readlines() - for line in lines: - line_str = line.decode("utf-8").strip() - line_split = str(line_str).split(' ') - file_name = line_split[0] - image_path = os.path.join(image_dir, file_name) - if os.path.isfile(image_path): - image_anno_dict[image_path] = anno_parser(line_split[1:]) - image_files.append(image_path) - return image_files, image_anno_dict - - -def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): - """Create MindRecord file.""" - mindrecord_dir = config.MINDRECORD_DIR - mindrecord_path = os.path.join(mindrecord_dir, prefix) - writer = FileWriter(mindrecord_path, file_num) - if dataset == "coco": - image_files, image_anno_dict = create_coco_label(is_training) - else: - image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) - - ssd_json = { - "image": {"type": "bytes"}, - "annotation": {"type": "int32", "shape": [-1, 5]}, - } - writer.add_schema(ssd_json, "ssd_json") - - for image_name in image_files: - with open(image_name, 'rb') as f: - img = f.read() - annos = np.array(image_anno_dict[image_name], dtype=np.int32) - row = {"image": img, "annotation": annos} - writer.write_raw_data([row]) - writer.commit() - - -def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, - is_training=True, num_parallel_workers=4): - """Creatr SSD dataset with MindDataset.""" - ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, - num_parallel_workers=num_parallel_workers, shuffle=is_training) - decode = C.Decode() - ds = ds.map(input_columns=["image"], operations=decode) - compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) - - if is_training: - hwc_to_chw = C.HWC2CHW() - ds = ds.map(input_columns=["image", "annotation"], - output_columns=["image", "box", "label", "num_match_num"], - columns_order=["image", "box", "label", "num_match_num"], - operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers) - ds = ds.map(input_columns=["image"], operations=hwc_to_chw, python_multiprocessing=True, - num_parallel_workers=num_parallel_workers) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) - else: - hwc_to_chw = C.HWC2CHW() - ds = ds.map(input_columns=["image", "annotation"], - output_columns=["image", "image_shape", "annotation"], - columns_order=["image", "image_shape", "annotation"], - operations=compose_map_func) - ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) - return ds diff --git a/example/ssd_coco2017/util.py b/example/ssd_coco2017/util.py deleted file mode 100644 index 6e102853757..00000000000 --- a/example/ssd_coco2017/util.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""metrics utils""" - -import numpy as np -from config import ConfigSSD -from dataset import ssd_bboxes_decode - - -def calc_iou(bbox_pred, bbox_ground): - """Calculate iou of predicted bbox and ground truth.""" - bbox_pred = np.expand_dims(bbox_pred, axis=0) - - pred_w = bbox_pred[:, 2] - bbox_pred[:, 0] - pred_h = bbox_pred[:, 3] - bbox_pred[:, 1] - pred_area = pred_w * pred_h - - gt_w = bbox_ground[:, 2] - bbox_ground[:, 0] - gt_h = bbox_ground[:, 3] - bbox_ground[:, 1] - gt_area = gt_w * gt_h - - iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0]) - ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1]) - - iw = np.maximum(iw, 0) - ih = np.maximum(ih, 0) - intersection_area = iw * ih - - union_area = pred_area + gt_area - intersection_area - union_area = np.maximum(union_area, np.finfo(float).eps) - - iou = intersection_area * 1. / union_area - return iou - - -def apply_nms(all_boxes, all_scores, thres, max_boxes): - """Apply NMS to bboxes.""" - x1 = all_boxes[:, 0] - y1 = all_boxes[:, 1] - x2 = all_boxes[:, 2] - y2 = all_boxes[:, 3] - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - - order = all_scores.argsort()[::-1] - keep = [] - - while order.size > 0: - i = order[0] - keep.append(i) - - if len(keep) >= max_boxes: - break - - xx1 = np.maximum(x1[i], x1[order[1:]]) - yy1 = np.maximum(y1[i], y1[order[1:]]) - xx2 = np.minimum(x2[i], x2[order[1:]]) - yy2 = np.minimum(y2[i], y2[order[1:]]) - - w = np.maximum(0.0, xx2 - xx1 + 1) - h = np.maximum(0.0, yy2 - yy1 + 1) - inter = w * h - - ovr = inter / (areas[i] + areas[order[1:]] - inter) - - inds = np.where(ovr <= thres)[0] - - order = order[inds + 1] - return keep - - -def calc_ap(recall, precision): - """Calculate AP.""" - correct_recall = np.concatenate(([0.], recall, [1.])) - correct_precision = np.concatenate(([0.], precision, [0.])) - - for i in range(correct_recall.size - 1, 0, -1): - correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i]) - - i = np.where(correct_recall[1:] != correct_recall[:-1])[0] - - ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1]) - - return ap - -def metrics(pred_data): - """Calculate mAP of predicted bboxes.""" - config = ConfigSSD() - num_classes = config.NUM_CLASSES - - all_detections = [None for i in range(num_classes)] - all_pred_scores = [None for i in range(num_classes)] - all_annotations = [None for i in range(num_classes)] - average_precisions = {} - num = [0 for i in range(num_classes)] - accurate_num = [0 for i in range(num_classes)] - - for sample in pred_data: - pred_boxes = sample['boxes'] - boxes_scores = sample['box_scores'] - annotation = sample['annotation'] - - annotation = np.squeeze(annotation, axis=0) - - pred_labels = np.argmax(boxes_scores, axis=-1) - index = np.nonzero(pred_labels) - pred_boxes = ssd_bboxes_decode(pred_boxes, index) - - pred_boxes = pred_boxes.clip(0, 1) - boxes_scores = np.max(boxes_scores, axis=-1) - boxes_scores = boxes_scores[index] - pred_labels = pred_labels[index] - - top_k = 50 - - for c in range(1, num_classes): - if len(pred_labels) >= 1: - class_box_scores = boxes_scores[pred_labels == c] - class_boxes = pred_boxes[pred_labels == c] - - nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k) - - class_boxes = class_boxes[nms_index] - class_box_scores = class_box_scores[nms_index] - - cmask = class_box_scores > 0.5 - class_boxes = class_boxes[cmask] - class_box_scores = class_box_scores[cmask] - - all_detections[c] = class_boxes - all_pred_scores[c] = class_box_scores - - for c in range(1, num_classes): - if len(annotation) >= 1: - all_annotations[c] = annotation[annotation[:, 4] == c, :4] - - for c in range(1, num_classes): - false_positives = np.zeros((0,)) - true_positives = np.zeros((0,)) - scores = np.zeros((0,)) - num_annotations = 0.0 - - annotations = all_annotations[c] - num_annotations += annotations.shape[0] - detections = all_detections[c] - pred_scores = all_pred_scores[c] - - for index, detection in enumerate(detections): - scores = np.append(scores, pred_scores[index]) - if len(annotations) >= 1: - IoUs = calc_iou(detection, annotations) - assigned_anno = np.argmax(IoUs) - max_overlap = IoUs[assigned_anno] - - if max_overlap >= 0.5: - false_positives = np.append(false_positives, 0) - true_positives = np.append(true_positives, 1) - else: - false_positives = np.append(false_positives, 1) - true_positives = np.append(true_positives, 0) - else: - false_positives = np.append(false_positives, 1) - true_positives = np.append(true_positives, 0) - - if num_annotations == 0: - if c not in average_precisions.keys(): - average_precisions[c] = 0 - continue - accurate_num[c] = 1 - indices = np.argsort(-scores) - false_positives = false_positives[indices] - true_positives = true_positives[indices] - - false_positives = np.cumsum(false_positives) - true_positives = np.cumsum(true_positives) - - recall = true_positives * 1. / num_annotations - precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) - - average_precision = calc_ap(recall, precision) - - if c not in average_precisions.keys(): - average_precisions[c] = average_precision - else: - average_precisions[c] += average_precision - - num[c] += 1 - - count = 0 - for key in average_precisions: - if num[key] != 0: - count += (average_precisions[key] / num[key]) - - mAP = count * 1. / accurate_num.count(1) - return mAP diff --git a/model_zoo/ssd/README.md b/model_zoo/ssd/README.md new file mode 100644 index 00000000000..ded107e4992 --- /dev/null +++ b/model_zoo/ssd/README.md @@ -0,0 +1,119 @@ +# SSD Example + +## Description + +SSD network based on MobileNetV2, with support for training and evaluation. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Dataset + + We use coco2017 as training dataset in this example by default, and you can also use your own datasets. + + 1. If coco dataset is used. **Select dataset to coco when run script.** + Install Cython and pycocotool. + + ``` + pip install Cython + + pip install pycocotools + ``` + And change the coco_root and other settings you need in `config.py`. The directory structure is as follows: + + + ``` + . + └─cocodataset + ├─annotations + ├─instance_train2017.json + └─instance_val2017.json + ├─val2017 + └─train2017 + ``` + + 2. If your own dataset is used. **Select dataset to other when run script.** + Organize the dataset infomation into a TXT file, each row in the file is as follows: + + ``` + train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2 + ``` + + Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `image_dir`(dataset directory) and the relative path in `anno_path`(the TXT file path), `image_dir` and `anno_path` are setting in `config.py`. + + +## Running the example + +### Training + +To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.** + + +- Stand alone mode + + ``` + python train.py --dataset coco + + ``` + + You can run ```python train.py -h``` to get more information. + + +- Distribute mode + + ``` + sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json + ``` + + The input parameters are device numbers, epoch size, learning rate, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.** + +You will get the loss value of each step as following: + +``` +epoch: 1 step: 458, loss is 3.1681802 +epoch time: 228752.4654865265, per step time: 499.4595316299705 +epoch: 2 step: 458, loss is 2.8847265 +epoch time: 38912.93382644653, per step time: 84.96273761232868 +epoch: 3 step: 458, loss is 2.8398118 +epoch time: 38769.184827804565, per step time: 84.64887516987896 +... + +epoch: 498 step: 458, loss is 0.70908034 +epoch time: 38771.079778671265, per step time: 84.65301261718616 +epoch: 499 step: 458, loss is 0.7974688 +epoch time: 38787.413120269775, per step time: 84.68867493508685 +epoch: 500 step: 458, loss is 0.5548882 +epoch time: 39064.8467540741, per step time: 85.29442522723602 +``` + +### Evaluation + +for evaluation , run `eval.py` with `checkpoint_path`. `checkpoint_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file. + +``` +python eval.py --checkpoint_path ssd.ckpt --dataset coco +``` + +You can run ```python eval.py -h``` to get more information. + +You will get the result as following: + +``` +Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.189 +Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.341 +Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.183 +Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.040 +Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.181 +Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.326 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.213 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.348 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.380 +Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.124 +Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.412 +Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.588 + +======================================== + +mAP: 0.18937438355383837 +``` diff --git a/example/ssd_coco2017/eval.py b/model_zoo/ssd/eval.py similarity index 78% rename from example/ssd_coco2017/eval.py rename to model_zoo/ssd/eval.py index d5e0d86b67a..9054bf6f244 100644 --- a/example/ssd_coco2017/eval.py +++ b/model_zoo/ssd/eval.py @@ -14,49 +14,51 @@ # ============================================================================ """Evaluation for SSD""" + import os import argparse import time +import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2 -from dataset import create_ssd_dataset, data_to_mindrecord_byte_image -from config import ConfigSSD -from util import metrics +from src.ssd import SSD300, ssd_mobilenet_v2 +from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image +from src.config import config +from src.coco_eval import metrics def ssd_eval(dataset_path, ckpt_path): """SSD evaluation.""" - - ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False) - net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False) + batch_size = 1 + ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) + net = SSD300(ssd_mobilenet_v2(), config, is_training=False) print("Load Checkpoint!") param_dict = load_checkpoint(ckpt_path) net.init_parameters_data() load_param_into_net(net, param_dict) net.set_train(False) - i = 1. - total = ds.get_dataset_size() + i = batch_size + total = ds.get_dataset_size() * batch_size start = time.time() pred_data = [] print("\n========================================\n") print("total images num: ", total) print("Processing, please wait a moment.") for data in ds.create_dict_iterator(): + img_id = data['img_id'] img_np = data['image'] image_shape = data['image_shape'] - annotation = data['annotation'] output = net(Tensor(img_np)) for batch_idx in range(img_np.shape[0]): pred_data.append({"boxes": output[0].asnumpy()[batch_idx], "box_scores": output[1].asnumpy()[batch_idx], - "annotation": annotation, - "image_shape": image_shape}) - percent = round(i / total * 100, 2) + "img_id": int(np.squeeze(img_id[batch_idx])), + "image_shape": image_shape[batch_idx]}) + percent = round(i / total * 100., 2) print(f' {str(percent)} [{i}/{total}]', end='\r') - i += 1 + i += batch_size cost_time = int((time.time() - start) * 1000) print(f' 100% [{total}/{total}] cost {cost_time} ms') mAP = metrics(pred_data) @@ -73,22 +75,21 @@ if __name__ == '__main__': context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - config = ConfigSSD() prefix = "ssd_eval.mindrecord" - mindrecord_dir = config.MINDRECORD_DIR + mindrecord_dir = config.mindrecord_dir mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") if not os.path.exists(mindrecord_file): if not os.path.isdir(mindrecord_dir): os.makedirs(mindrecord_dir) if args_opt.dataset == "coco": - if os.path.isdir(config.COCO_ROOT): + if os.path.isdir(config.coco_root): print("Create Mindrecord.") data_to_mindrecord_byte_image("coco", False, prefix) print("Create Mindrecord Done, at {}".format(mindrecord_dir)) else: - print("COCO_ROOT not exits.") + print("coco_root not exits.") else: - if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): + if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): print("Create Mindrecord.") data_to_mindrecord_byte_image("other", False, prefix) print("Create Mindrecord Done, at {}".format(mindrecord_dir)) diff --git a/example/ssd_coco2017/run_distribute_train.sh b/model_zoo/ssd/scripts/run_distribute_train.sh similarity index 74% rename from example/ssd_coco2017/run_distribute_train.sh rename to model_zoo/ssd/scripts/run_distribute_train.sh index bd8519be415..accd35c50c7 100644 --- a/example/ssd_coco2017/run_distribute_train.sh +++ b/model_zoo/ssd/scripts/run_distribute_train.sh @@ -14,17 +14,16 @@ # limitations under the License. # ============================================================================ -echo "=================================================================================================================" +echo "==============================================================================================================" echo "Please run the scipt as: " -echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" -echo "for example: sh run_distribute_train.sh 8 350 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)" +echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" +echo "for example: sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)" echo "It is better to use absolute path." -echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script." echo "=================================================================================================================" -if [ $# != 4 ] && [ $# != 6 ] +if [ $# != 5 ] && [ $# != 7 ] then - echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [DATASET] \ + echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \ [MINDSPORE_HCCL_CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" exit 1 fi @@ -36,38 +35,39 @@ echo "After running the scipt, the network runs in the background. The log will export RANK_SIZE=$1 EPOCH_SIZE=$2 -DATASET=$3 -PRE_TRAINED=$5 -PRE_TRAINED_EPOCH_SIZE=$6 -export MINDSPORE_HCCL_CONFIG_PATH=$4 - +LR=$3 +DATASET=$4 +PRE_TRAINED=$6 +PRE_TRAINED_EPOCH_SIZE=$7 +export MINDSPORE_HCCL_CONFIG_PATH=$5 for((i=0;i env.log - if [ $# == 4 ] + if [ $# == 5 ] then - python ../train.py \ + python train.py \ --distribute=1 \ - --lr=0.4 \ + --lr=$LR \ --dataset=$DATASET \ --device_num=$RANK_SIZE \ --device_id=$DEVICE_ID \ --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & fi - if [ $# == 6 ] + if [ $# == 7 ] then - python ../train.py \ + python train.py \ --distribute=1 \ - --lr=0.4 \ + --lr=$LR \ --dataset=$DATASET \ --device_num=$RANK_SIZE \ --device_id=$DEVICE_ID \ diff --git a/model_zoo/ssd/src/__init__.py b/model_zoo/ssd/src/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/ssd/src/box_utils.py b/model_zoo/ssd/src/box_utils.py new file mode 100644 index 00000000000..5e75ab6a4eb --- /dev/null +++ b/model_zoo/ssd/src/box_utils.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Bbox utils""" + +import math +import itertools as it +import numpy as np +from .config import config + + +class GeneratDefaultBoxes(): + """ + Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). + `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. + `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. + """ + def __init__(self): + fk = config.img_shape[0] / np.array(config.steps) + scale_rate = (config.max_scale - config.min_scale) / (len(config.num_default) - 1) + scales = [config.min_scale + scale_rate * i for i in range(len(config.num_default))] + [1.0] + self.default_boxes = [] + for idex, feature_size in enumerate(config.feature_size): + sk1 = scales[idex] + sk2 = scales[idex + 1] + sk3 = math.sqrt(sk1 * sk2) + if idex == 0: + w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) + all_sizes = [(0.1, 0.1), (w, h), (h, w)] + else: + all_sizes = [(sk1, sk1)] + for aspect_ratio in config.aspect_ratios[idex]: + w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) + all_sizes.append((w, h)) + all_sizes.append((h, w)) + all_sizes.append((sk3, sk3)) + + assert len(all_sizes) == config.num_default[idex] + + for i, j in it.product(range(feature_size), repeat=2): + for w, h in all_sizes: + cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] + self.default_boxes.append([cy, cx, h, w]) + + def to_ltrb(cy, cx, h, w): + return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 + + # For IoU calculation + self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') + self.default_boxes = np.array(self.default_boxes, dtype='float32') + + +default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb +default_boxes = GeneratDefaultBoxes().default_boxes +y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) +vol_anchors = (x2 - x1) * (y2 - y1) +matching_threshold = config.match_thershold + + +def ssd_bboxes_encode(boxes): + """ + Labels anchors with ground truth inputs. + + Args: + boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls]. + + Returns: + gt_loc: location ground truth with shape [num_anchors, 4]. + gt_label: class ground truth with shape [num_anchors, 1]. + num_matched_boxes: number of positives in an image. + """ + + def jaccard_with_anchors(bbox): + """Compute jaccard score a box and the anchors.""" + # Intersection bbox and volume. + ymin = np.maximum(y1, bbox[0]) + xmin = np.maximum(x1, bbox[1]) + ymax = np.minimum(y2, bbox[2]) + xmax = np.minimum(x2, bbox[3]) + w = np.maximum(xmax - xmin, 0.) + h = np.maximum(ymax - ymin, 0.) + + # Volumes. + inter_vol = h * w + union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol + jaccard = inter_vol / union_vol + return np.squeeze(jaccard) + + pre_scores = np.zeros((config.num_ssd_boxes), dtype=np.float32) + t_boxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32) + t_label = np.zeros((config.num_ssd_boxes), dtype=np.int64) + for bbox in boxes: + label = int(bbox[4]) + scores = jaccard_with_anchors(bbox) + idx = np.argmax(scores) + scores[idx] = 2.0 + mask = (scores > matching_threshold) + mask = mask & (scores > pre_scores) + pre_scores = np.maximum(pre_scores, scores * mask) + t_label = mask * label + (1 - mask) * t_label + for i in range(4): + t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] + + index = np.nonzero(t_label) + + # Transform to ltrb. + bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32) + bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 + bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] + + # Encode features. + bboxes_t = bboxes[index] + default_boxes_t = default_boxes[index] + bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.prior_scaling[0]) + bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.prior_scaling[1] + bboxes[index] = bboxes_t + + num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) + return bboxes, t_label.astype(np.int32), num_match + + +def ssd_bboxes_decode(boxes): + """Decode predict boxes to [y, x, h, w]""" + boxes_t = boxes.copy() + default_boxes_t = default_boxes.copy() + boxes_t[:, :2] = boxes_t[:, :2] * config.prior_scaling[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2] + boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.prior_scaling[1]) * default_boxes_t[:, 2:4] + + bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32) + + bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 + bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 + + return np.clip(bboxes, 0, 1) + + +def intersect(box_a, box_b): + """Compute the intersect of two sets of boxes.""" + max_yx = np.minimum(box_a[:, 2:4], box_b[2:4]) + min_yx = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] + + +def jaccard_numpy(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes.""" + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * + (box_a[:, 3] - box_a[:, 1])) + area_b = ((box_b[2] - box_b[0]) * + (box_b[3] - box_b[1])) + union = area_a + area_b - inter + return inter / union diff --git a/model_zoo/ssd/src/coco_eval.py b/model_zoo/ssd/src/coco_eval.py new file mode 100644 index 00000000000..eb366180897 --- /dev/null +++ b/model_zoo/ssd/src/coco_eval.py @@ -0,0 +1,127 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Coco metrics utils""" + +import os +import json +import numpy as np +from .config import config +from .box_utils import ssd_bboxes_decode + + +def apply_nms(all_boxes, all_scores, thres, max_boxes): + """Apply NMS to bboxes.""" + y1 = all_boxes[:, 0] + x1 = all_boxes[:, 1] + y2 = all_boxes[:, 2] + x2 = all_boxes[:, 3] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + + order = all_scores.argsort()[::-1] + keep = [] + + while order.size > 0: + i = order[0] + keep.append(i) + + if len(keep) >= max_boxes: + break + + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thres)[0] + + order = order[inds + 1] + return keep + + +def metrics(pred_data): + """Calculate mAP of predicted bboxes.""" + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + num_classes = config.num_classes + + coco_root = config.coco_root + data_type = config.val_data_type + + #Classes need to train or test. + val_cls = config.coco_classes + val_cls_dict = {} + for i, cls in enumerate(val_cls): + val_cls_dict[i] = cls + + anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) + coco_gt = COCO(anno_json) + classs_dict = {} + cat_ids = coco_gt.loadCats(coco_gt.getCatIds()) + for cat in cat_ids: + classs_dict[cat["name"]] = cat["id"] + + predictions = [] + img_ids = [] + + for sample in pred_data: + pred_boxes = sample['boxes'] + box_scores = sample['box_scores'] + img_id = sample['img_id'] + h, w = sample['image_shape'] + + pred_boxes = ssd_bboxes_decode(pred_boxes) + final_boxes = [] + final_label = [] + final_score = [] + img_ids.append(img_id) + + for c in range(1, num_classes): + class_box_scores = box_scores[:, c] + score_mask = class_box_scores > config.min_score + class_box_scores = class_box_scores[score_mask] + class_boxes = pred_boxes[score_mask] * [h, w, h, w] + + if score_mask.any(): + nms_index = apply_nms(class_boxes, class_box_scores, config.nms_thershold, config.max_boxes) + class_boxes = class_boxes[nms_index] + class_box_scores = class_box_scores[nms_index] + + final_boxes += class_boxes.tolist() + final_score += class_box_scores.tolist() + final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores) + + for loc, label, score in zip(final_boxes, final_label, final_score): + res = {} + res['image_id'] = img_id + res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]] + res['score'] = score + res['category_id'] = label + predictions.append(res) + with open('predictions.json', 'w') as f: + json.dump(predictions, f) + + coco_dt = coco_gt.loadRes('predictions.json') + E = COCOeval(coco_gt, coco_dt, iouType='bbox') + E.params.imgIds = img_ids + E.evaluate() + E.accumulate() + E.summarize() + return E.stats[0] diff --git a/model_zoo/ssd/src/config.py b/model_zoo/ssd/src/config.py new file mode 100644 index 00000000000..683b8de31fd --- /dev/null +++ b/model_zoo/ssd/src/config.py @@ -0,0 +1,78 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#" ============================================================================ + +"""Config parameters for SSD models.""" + +from easydict import EasyDict as ed + +config = ed({ + "img_shape": [300, 300], + "num_ssd_boxes": 1917, + "neg_pre_positive": 3, + "match_thershold": 0.5, + "nms_thershold": 0.6, + "min_score": 0.1, + "max_boxes": 100, + + # learing rate settings + "global_step": 0, + "lr_init": 0.001, + "lr_end_rate": 0.001, + "warmup_epochs": 2, + "momentum": 0.9, + "weight_decay": 1.5e-4, + + # network + "num_default": [3, 6, 6, 6, 6, 6], + "extras_in_channels": [256, 576, 1280, 512, 256, 256], + "extras_out_channels": [576, 1280, 512, 256, 256, 128], + "extras_srides": [1, 1, 2, 2, 2, 2], + "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], + "feature_size": [19, 10, 5, 3, 2, 1], + "min_scale": 0.2, + "max_scale": 0.95, + "aspect_ratios": [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], + "steps": (16, 32, 64, 100, 150, 300), + "prior_scaling": (0.1, 0.2), + "gamma": 2.0, + "alpha": 0.75, + + # `mindrecord_dir` and `coco_root` are better to use absolute path. + "mindrecord_dir": "/data/MindRecord_COCO", + "coco_root": "/data/coco2017", + "train_data_type": "train2017", + "val_data_type": "val2017", + "instances_set": "annotations/instances_{}.json", + "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + "num_classes": 81, + + # if coco used, `image_dir` and `anno_path` are useless. + "image_dir": "", + "anno_path": "", +}) diff --git a/model_zoo/ssd/src/dataset.py b/model_zoo/ssd/src/dataset.py new file mode 100644 index 00000000000..19c66fc5985 --- /dev/null +++ b/model_zoo/ssd/src/dataset.py @@ -0,0 +1,289 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SSD dataset""" + +from __future__ import division + +import os +import cv2 +import numpy as np + +import mindspore.dataset as de +import mindspore.dataset.transforms.vision.c_transforms as C +from mindspore.mindrecord import FileWriter +from .config import config +from .box_utils import jaccard_numpy, ssd_bboxes_encode + + +def _rand(a=0., b=1.): + """Generate random.""" + return np.random.rand() * (b - a) + a + + +def random_sample_crop(image, boxes): + """Random Crop the image and boxes""" + height, width, _ = image.shape + min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) + + if min_iou is None: + return image, boxes + + # max trails (50) + for _ in range(50): + image_t = image + + w = _rand(0.3, 1.0) * width + h = _rand(0.3, 1.0) * height + + # aspect ratio constraint b/t .5 & 2 + if h / w < 0.5 or h / w > 2: + continue + + left = _rand() * (width - w) + top = _rand() * (height - h) + + rect = np.array([int(top), int(left), int(top+h), int(left+w)]) + overlap = jaccard_numpy(boxes, rect) + + # dropout some boxes + drop_mask = overlap > 0 + if not drop_mask.any(): + continue + + if overlap[drop_mask].min() < min_iou: + continue + + image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :] + + centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0 + + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 * drop_mask + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # take only matching gt boxes + boxes_t = boxes[mask, :].copy() + + boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2]) + boxes_t[:, :2] -= rect[:2] + boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4]) + boxes_t[:, 2:4] -= rect[:2] + + return image_t, boxes_t + return image, boxes + + +def preprocess_fn(img_id, image, box, is_training): + """Preprocess function for dataset.""" + def _infer_data(image, input_shape): + img_h, img_w, _ = image.shape + input_h, input_w = input_shape + + image = cv2.resize(image, (input_w, input_h)) + + #When the channels of image is 1 + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + return img_id, image, np.array((img_h, img_w), np.float32) + + def _data_aug(image, box, is_training, image_size=(300, 300)): + """Data augmentation function.""" + ih, iw, _ = image.shape + w, h = image_size + + if not is_training: + return _infer_data(image, image_size) + + # Random crop + box = box.astype(np.float32) + image, box = random_sample_crop(image, box) + ih, iw, _ = image.shape + + # Resize image + image = cv2.resize(image, (w, h)) + + # Flip image or not + flip = _rand() < .5 + if flip: + image = cv2.flip(image, 1, dst=None) + + # When the channels of image is 1 + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + box[:, [0, 2]] = box[:, [0, 2]] / ih + box[:, [1, 3]] = box[:, [1, 3]] / iw + + if flip: + box[:, [1, 3]] = 1 - box[:, [3, 1]] + + box, label, num_match = ssd_bboxes_encode(box) + return image, box, label, num_match + return _data_aug(image, box, is_training, image_size=config.img_shape) + + +def create_coco_label(is_training): + """Get image path and annotation from COCO.""" + from pycocotools.coco import COCO + + coco_root = config.coco_root + data_type = config.val_data_type + if is_training: + data_type = config.train_data_type + + #Classes need to train or test. + train_cls = config.coco_classes + train_cls_dict = {} + for i, cls in enumerate(train_cls): + train_cls_dict[cls] = i + + anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) + + coco = COCO(anno_json) + classs_dict = {} + cat_ids = coco.loadCats(coco.getCatIds()) + for cat in cat_ids: + classs_dict[cat["id"]] = cat["name"] + + image_ids = coco.getImgIds() + images = [] + image_path_dict = {} + image_anno_dict = {} + + for img_id in image_ids: + image_info = coco.loadImgs(img_id) + file_name = image_info[0]["file_name"] + anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = coco.loadAnns(anno_ids) + image_path = os.path.join(coco_root, data_type, file_name) + annos = [] + iscrowd = False + for label in anno: + bbox = label["bbox"] + class_name = classs_dict[label["category_id"]] + iscrowd = iscrowd or label["iscrowd"] + if class_name in train_cls: + x_min, x_max = bbox[0], bbox[0] + bbox[2] + y_min, y_max = bbox[1], bbox[1] + bbox[3] + annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]]) + + if not is_training and iscrowd: + continue + if len(annos) >= 1: + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(annos) + + return images, image_path_dict, image_anno_dict + + +def anno_parser(annos_str): + """Parse annotation from string to list.""" + annos = [] + for anno_str in annos_str: + anno = list(map(int, anno_str.strip().split(','))) + annos.append(anno) + return annos + + +def filter_valid_data(image_dir, anno_path): + """Filter valid image file, which both in image_dir and anno_path.""" + images = [] + image_path_dict = {} + image_anno_dict = {} + if not os.path.isdir(image_dir): + raise RuntimeError("Path given is not valid.") + if not os.path.isfile(anno_path): + raise RuntimeError("Annotation file is not valid.") + + with open(anno_path, "rb") as f: + lines = f.readlines() + for img_id, line in enumerate(lines): + line_str = line.decode("utf-8").strip() + line_split = str(line_str).split(' ') + file_name = line_split[0] + image_path = os.path.join(image_dir, file_name) + if os.path.isfile(image_path): + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = anno_parser(line_split[1:]) + + return images, image_path_dict, image_anno_dict + + +def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): + """Create MindRecord file.""" + mindrecord_dir = config.mindrecord_dir + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + if dataset == "coco": + images, image_path_dict, image_anno_dict = create_coco_label(is_training) + else: + images, image_path_dict, image_anno_dict = filter_valid_data(config.image_dir, config.anno_path) + + ssd_json = { + "img_id": {"type": "int32", "shape": [1]}, + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(ssd_json, "ssd_json") + + for img_id in images: + image_path = image_path_dict[img_id] + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + +def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, + is_training=True, num_parallel_workers=4): + """Creatr SSD dataset with MindDataset.""" + ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, + shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) + decode = C.Decode() + ds = ds.map(input_columns=["image"], operations=decode) + change_swap_op = C.HWC2CHW() + normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) + color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training)) + if is_training: + output_columns = ["image", "box", "label", "num_match"] + trans = [color_adjust_op, normalize_op, change_swap_op] + else: + output_columns = ["img_id", "image", "image_shape"] + trans = [normalize_op, change_swap_op] + ds = ds.map(input_columns=["img_id", "image", "annotation"], + output_columns=output_columns, columns_order=output_columns, + operations=compose_map_func, python_multiprocessing=is_training, + num_parallel_workers=num_parallel_workers) + ds = ds.map(input_columns=["image"], operations=trans, python_multiprocessing=is_training, + num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) + return ds diff --git a/model_zoo/ssd/src/init_params.py b/model_zoo/ssd/src/init_params.py new file mode 100644 index 00000000000..3ab164219c8 --- /dev/null +++ b/model_zoo/ssd/src/init_params.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parameters utils""" + +from mindspore import Tensor +from mindspore.common.initializer import initializer, TruncatedNormal + +def init_net_param(network, initialize_mode='TruncatedNormal'): + """Init the parameters in net.""" + params = network.trainable_params() + for p in params: + if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: + if initialize_mode == 'TruncatedNormal': + p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape(), p.data.dtype())) + else: + p.set_parameter_data(initialize_mode, p.data.shape(), p.data.dtype()) + + +def load_backbone_params(network, param_dict): + """Init the parameters from pre-train model, default is mobilenetv2.""" + for _, param in net.parameters_and_names(): + param_name = param.name.replace('network.backbone.', '') + name_split = param_name.split('.') + if 'features_1' in param_name: + param_name = param_name.replace('features_1', 'features') + if 'features_2' in param_name: + param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:]) + if param_name in param_dict: + param.set_parameter_data(param_dict[param_name].data) diff --git a/model_zoo/ssd/src/lr_schedule.py b/model_zoo/ssd/src/lr_schedule.py new file mode 100644 index 00000000000..4df26b39056 --- /dev/null +++ b/model_zoo/ssd/src/lr_schedule.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Learning rate schedule""" + +import math +import numpy as np + + +def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(float): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate diff --git a/mindspore/model_zoo/ssd.py b/model_zoo/ssd/src/ssd.py similarity index 74% rename from mindspore/model_zoo/ssd.py rename to model_zoo/ssd/src/ssd.py index b69942cd5c1..d2fb64531ec 100644 --- a/mindspore/model_zoo/ssd.py +++ b/model_zoo/ssd/src/ssd.py @@ -14,25 +14,17 @@ # ============================================================================ """SSD net based MobilenetV2.""" + import mindspore.common.dtype as mstype import mindspore as ms import mindspore.nn as nn -from mindspore import context +from mindspore import Parameter, context, Tensor from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.communication.management import get_group_size from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore.common.initializer import initializer -from mindspore.ops.operations import TensorAdd -from mindspore import Parameter - - -def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): - weight_shape = (out_channel, in_channel, kernel_size, kernel_size) - weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() - return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, - padding=0, pad_mode=pad_mod, weight_init=weight) def _make_divisible(v, divisor, min_value=None): @@ -46,6 +38,55 @@ def _make_divisible(v, divisor, min_value=None): return new_v +def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): + return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, + padding=0, pad_mode=pad_mod, has_bias=True) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.97, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0): + depthwise_conv = DepthwiseConv(in_channel, kernel_size, stride, pad_mode='same', pad=pad) + conv = _conv2d(in_channel, out_channel, kernel_size=1) + return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv]) + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = 0 + if groups == 1: + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', + padding=padding) + else: + conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='same', pad=padding) + layers = [conv, _bn(out_planes), nn.ReLU6()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + class DepthwiseConv(nn.Cell): """ Depthwise Convolution warpper definition. @@ -64,6 +105,7 @@ class DepthwiseConv(nn.Cell): Examples: >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) """ + def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): super(DepthwiseConv, self).__init__() self.has_bias = has_bias @@ -91,42 +133,9 @@ class DepthwiseConv(nn.Cell): return output -class ConvBNReLU(nn.Cell): - """ - Convolution/Depthwise fused with Batchnorm and ReLU block definition. - - Args: - in_planes (int): Input channel. - out_planes (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size for the first convolutional layer. Default: 1. - groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) - """ - def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', - padding=padding) - else: - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] - self.features = nn.SequentialCell(layers) - - def construct(self, x): - output = self.features(x) - return output - - class InvertedResidual(nn.Cell): """ - Mobilenetv2 residual block definition. + Residual block definition. Args: inp (int): Input channel. @@ -140,7 +149,7 @@ class InvertedResidual(nn.Cell): Examples: >>> ResidualBlock(3, 256, 1, 1) """ - def __init__(self, inp, oup, stride, expand_ratio): + def __init__(self, inp, oup, stride, expand_ratio, last_relu=False): super(InvertedResidual, self).__init__() assert stride in [1, 2] @@ -155,17 +164,21 @@ class InvertedResidual(nn.Cell): ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), # pw-linear nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), - nn.BatchNorm2d(oup), + _bn(oup), ]) self.conv = nn.SequentialCell(layers) - self.add = TensorAdd() + self.add = P.TensorAdd() self.cast = P.Cast() + self.last_relu = last_relu + self.relu = nn.ReLU6() def construct(self, x): identity = x x = self.conv(x) if self.use_res_connect: - return self.add(identity, x) + x = self.add(identity, x) + if self.last_relu: + x = self.relu(x) return x @@ -174,14 +187,14 @@ class FlattenConcat(nn.Cell): Concatenate predictions into a single tensor. Args: - config (Class): The default config of SSD. + config (dict): The default config of SSD. Returns: Tensor, flatten predictions. """ def __init__(self, config): super(FlattenConcat, self).__init__() - self.num_ssd_boxes = config.NUM_SSD_BOXES + self.num_ssd_boxes = config.num_ssd_boxes self.concat = P.Concat(axis=1) self.transpose = P.Transpose() def construct(self, inputs): @@ -199,7 +212,7 @@ class MultiBox(nn.Cell): Multibox conv layers. Each multibox layer contains class conf scores and localization predictions. Args: - config (Class): The default config of SSD. + config (dict): The default config of SSD. Returns: Tensor, localization predictions. @@ -207,17 +220,17 @@ class MultiBox(nn.Cell): """ def __init__(self, config): super(MultiBox, self).__init__() - num_classes = config.NUM_CLASSES - out_channels = config.EXTRAS_OUT_CHANNELS - num_default = config.NUM_DEFAULT + num_classes = config.num_classes + out_channels = config.extras_out_channels + num_default = config.num_default loc_layers = [] cls_layers = [] for k, out_channel in enumerate(out_channels): - loc_layers += [_conv2d(out_channel, 4 * num_default[k], - kernel_size=3, stride=1, pad_mod='same')] - cls_layers += [_conv2d(out_channel, num_classes * num_default[k], - kernel_size=3, stride=1, pad_mod='same')] + loc_layers += [_last_conv2d(out_channel, 4 * num_default[k], + kernel_size=3, stride=1, pad_mod='same', pad=0)] + cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k], + kernel_size=3, stride=1, pad_mod='same', pad=0)] self.multi_loc_layers = nn.layer.CellList(loc_layers) self.multi_cls_layers = nn.layer.CellList(cls_layers) @@ -238,7 +251,7 @@ class SSD300(nn.Cell): Args: backbone (Cell): Backbone Network. - config (Class): The default config of SSD. + config (dict): The default config of SSD. Returns: Tensor, localization predictions. @@ -246,25 +259,26 @@ class SSD300(nn.Cell): Examples:backbone SSD300(backbone=resnet34(num_classes=None), - config=ConfigSSDResNet34()). + config=config). """ def __init__(self, backbone, config, is_training=True): super(SSD300, self).__init__() self.backbone = backbone - in_channels = config.EXTRAS_IN_CHANNELS - out_channels = config.EXTRAS_OUT_CHANNELS - ratios = config.EXTRAS_RATIO - strides = config.EXTRAS_STRIDES + in_channels = config.extras_in_channels + out_channels = config.extras_out_channels + ratios = config.extras_ratio + strides = config.extras_srides residual_list = [] for i in range(2, len(in_channels)): - residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], expand_ratio=ratios[i]) + residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], + expand_ratio=ratios[i], last_relu=True) residual_list.append(residual) self.multi_residual = nn.layer.CellList(residual_list) self.multi_box = MultiBox(config) self.is_training = is_training if not is_training: - self.softmax = P.Softmax() + self.activation = P.Sigmoid() def construct(self, x): layer_out_13, output = self.backbone(x) @@ -275,77 +289,42 @@ class SSD300(nn.Cell): multi_feature += (feature,) pred_loc, pred_label = self.multi_box(multi_feature) if not self.is_training: - pred_label = self.softmax(pred_label) + pred_label = self.activation(pred_label) return pred_loc, pred_label -class LocalizationLoss(nn.Cell): +class SigmoidFocalClassificationLoss(nn.Cell): """" - Computes the localization loss with SmoothL1Loss. - - Returns: - Tensor, box regression loss. - """ - def __init__(self): - super(LocalizationLoss, self).__init__() - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.loss = nn.SmoothL1Loss() - self.expand_dims = P.ExpandDims() - self.less = P.Less() - - def construct(self, pred_loc, gt_loc, gt_label, num_matched_boxes): - mask = F.cast(self.less(0, gt_label), mstype.float32) - mask = self.expand_dims(mask, -1) - smooth_l1 = self.loss(gt_loc, pred_loc) * mask - box_loss = self.reduce_sum(smooth_l1, 1) - return self.reduce_mean(box_loss / F.cast(num_matched_boxes, mstype.float32), (0, 1)) - - -class ClassificationLoss(nn.Cell): - """" - Computes the classification loss with hard example mining. + Sigmoid focal-loss for classification. Args: - config (Class): The default config of SSD. + gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0 + alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25 Returns: - Tensor, classification loss. + Tensor, the focal loss. """ - def __init__(self, config): - super(ClassificationLoss, self).__init__() - self.num_classes = config.NUM_CLASSES - self.num_boxes = config.NUM_SSD_BOXES - self.neg_pre_positive = config.NEG_PRE_POSITIVE - self.minimum = P.Minimum() - self.less = P.Less() - self.sort = P.TopK() - self.tile = P.Tile() - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.expand_dims = P.ExpandDims() - self.sort_descend = P.TopK(True) - self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + def __init__(self, gamma=2.0, alpha=0.25): + super(SigmoidFocalClassificationLoss, self).__init__() + self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.sigmoid = P.Sigmoid() + self.pow = P.Pow() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.gamma = gamma + self.alpha = alpha - def construct(self, pred_label, gt_label, num_matched_boxes): - gt_label = F.cast(gt_label, mstype.int32) - mask = F.cast(self.less(0, gt_label), mstype.float32) - gt_label_shape = F.shape(gt_label) - pred_label = F.reshape(pred_label, (-1, self.num_classes)) - gt_label = F.reshape(gt_label, (-1,)) - cross_entropy = self.cross_entropy(pred_label, gt_label) - cross_entropy = F.reshape(cross_entropy, gt_label_shape) - - # Hard example mining - num_matched_boxes = F.reshape(num_matched_boxes, (-1,)) - neg_masked_cross_entropy = F.cast(cross_entropy * (1- mask), mstype.float16) - _, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes) - _, relative_position = self.sort(F.cast(loss_idx, mstype.float16), self.num_boxes) - num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes) - tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes)) - top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32) - class_loss = self.reduce_sum(cross_entropy * (mask + top_k_neg_mask), 1) - return self.reduce_mean(class_loss / F.cast(num_matched_boxes, mstype.float32), 0) + def construct(self, logits, label): + label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value) + sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label) + sigmoid = self.sigmoid(logits) + label = F.cast(label, mstype.float32) + p_t = label * sigmoid + (1 - label) * (1 - sigmoid) + modulating_factor = self.pow(1 - p_t, self.gamma) + alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha) + focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy + return focal_loss class SSDWithLossCell(nn.Cell): @@ -354,7 +333,7 @@ class SSDWithLossCell(nn.Cell): Args: network (Cell): The training network. - config (Class): SSD config. + config (dict): SSD config. Returns: Tensor, the loss of the network. @@ -362,14 +341,29 @@ class SSDWithLossCell(nn.Cell): def __init__(self, network, config): super(SSDWithLossCell, self).__init__() self.network = network - self.class_loss = ClassificationLoss(config) - self.box_loss = LocalizationLoss() + self.less = P.Less() + self.tile = P.Tile() + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.expand_dims = P.ExpandDims() + self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha) + self.loc_loss = nn.SmoothL1Loss() def construct(self, x, gt_loc, gt_label, num_matched_boxes): pred_loc, pred_label = self.network(x) - loss_cls = self.class_loss(pred_label, gt_label, num_matched_boxes) - loss_loc = self.box_loss(pred_loc, gt_loc, gt_label, num_matched_boxes) - return loss_cls + loss_loc + mask = F.cast(self.less(0, gt_label), mstype.float32) + num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32)) + + # Localization Loss + mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4)) + smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc + loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1) + + # Classification Loss + loss_cls = self.class_loss(pred_label, gt_label) + loss_cls = self.reduce_sum(loss_cls, (1, 2)) + + return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) class TrainingWrapper(nn.Cell): @@ -415,7 +409,6 @@ class TrainingWrapper(nn.Cell): return F.depend(loss, self.optimizer(grads)) - class SSDWithMobileNetV2(nn.Cell): """ MobileNetV2 architecture for SSD backbone. diff --git a/example/ssd_coco2017/train.py b/model_zoo/ssd/train.py similarity index 64% rename from example/ssd_coco2017/train.py rename to model_zoo/ssd/train.py index 9347bf61c8b..27f0e7ad0fe 100644 --- a/example/ssd_coco2017/train.py +++ b/model_zoo/ssd/train.py @@ -13,83 +13,38 @@ # limitations under the License. # ============================================================================ -"""train SSD and get checkpoint files.""" +"""Train SSD and get checkpoint files.""" import os -import math import argparse -import numpy as np import mindspore.nn as nn from mindspore import context, Tensor from mindspore.communication.management import init from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor from mindspore.train import Model, ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.common.initializer import initializer +from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 +from src.config import config +from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image +from src.lr_schedule import get_lr +from src.init_params import init_net_param -from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 -from config import ConfigSSD -from dataset import create_ssd_dataset, data_to_mindrecord_byte_image - - -def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): - """ - generate learning rate array - - Args: - global_step(int): total steps of the training - lr_init(float): init learning rate - lr_end(float): end learning rate - lr_max(float): max learning rate - warmup_epochs(int): number of warmup epochs - total_epochs(int): total epoch of training - steps_per_epoch(int): steps of one epoch - - Returns: - np.array, learning rate array - """ - lr_each_step = [] - total_steps = steps_per_epoch * total_epochs - warmup_steps = steps_per_epoch * warmup_epochs - for i in range(total_steps): - if i < warmup_steps: - lr = lr_init + (lr_max - lr_init) * i / warmup_steps - else: - lr = lr_end + (lr_max - lr_end) * \ - (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. - if lr < 0.0: - lr = 0.0 - lr_each_step.append(lr) - - current_step = global_step - lr_each_step = np.array(lr_each_step).astype(np.float32) - learning_rate = lr_each_step[current_step:] - - return learning_rate - - -def init_net_param(network, initialize_mode='XavierUniform'): - """Init the parameters in net.""" - params = network.trainable_params() - for p in params: - if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: - p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype())) def main(): parser = argparse.ArgumentParser(description="SSD training") parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " - "Mindrecord, default is false.") - parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") + "Mindrecord, default is False.") + parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is False.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") - parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.") + parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.") parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") - parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") + parser.add_argument("--epoch_size", type=int, default=250, help="Epoch size, default is 250.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") - parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") + parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") args_opt = parser.parse_args() @@ -111,27 +66,26 @@ def main(): # It will generate mindrecord file in args_opt.mindrecord_dir, # and the file name is ssd.mindrecord0, 1, ... file_num. - config = ConfigSSD() prefix = "ssd.mindrecord" - mindrecord_dir = config.MINDRECORD_DIR + mindrecord_dir = config.mindrecord_dir mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") if not os.path.exists(mindrecord_file): if not os.path.isdir(mindrecord_dir): os.makedirs(mindrecord_dir) if args_opt.dataset == "coco": - if os.path.isdir(config.COCO_ROOT): + if os.path.isdir(config.coco_root): print("Create Mindrecord.") data_to_mindrecord_byte_image("coco", True, prefix) print("Create Mindrecord Done, at {}".format(mindrecord_dir)) else: - print("COCO_ROOT not exits.") + print("coco_root not exits.") else: - if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): + if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): print("Create Mindrecord.") data_to_mindrecord_byte_image("other", True, prefix) print("Create Mindrecord Done, at {}".format(mindrecord_dir)) else: - print("IMAGE_DIR or ANNO_PATH not exits.") + print("image_dir or anno_path not exits.") if not args_opt.only_create_dataset: loss_scale = float(args_opt.loss_scale) @@ -143,7 +97,8 @@ def main(): dataset_size = dataset.get_dataset_size() print("Create dataset done!") - ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config) + backbone = ssd_mobilenet_v2() + ssd = SSD300(backbone=backbone, config=config) net = SSDWithLossCell(ssd, config) init_net_param(net) @@ -157,12 +112,13 @@ def main(): param_dict = load_checkpoint(args_opt.pre_trained) load_param_into_net(net, param_dict) - lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, - lr_init=0, lr_end=0, lr_max=args_opt.lr, - warmup_epochs=max(350 // 20, 1), - total_epochs=350, + lr = Tensor(get_lr(global_step=config.global_step, + lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, + warmup_epochs=config.warmup_epochs, + total_epochs=args_opt.epoch_size, steps_per_epoch=dataset_size)) - opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) + opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, + config.momentum, config.weight_decay, loss_scale) net = TrainingWrapper(net, opt, loss_scale) callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]