diff --git a/model_zoo/official/cv/yolov5/README.md b/model_zoo/official/cv/yolov5/README.md new file mode 100644 index 00000000000..b74a6ba01c1 --- /dev/null +++ b/model_zoo/official/cv/yolov5/README.md @@ -0,0 +1,365 @@ +# Contents + +- [YOLOv5 Description](#YOLOv5-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Training](#training) + - [Testing Process](#testing-process) + - [Evaluation](#testing) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) + - [Convert Process](#convert-process) + - [Convert](#convert) +- [Model Description](#model-description) + - [Performance](#performance) + - [Evaluation Performance](#evaluation-performance) + - [Inference Performance](#inference-performance) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [YOLOv5 Description](#contents) + +YOLOv5 is a state-of-the-art detector which is faster (FPS) and more accurate (MS COCO AP50...95 and AP50) than all available alternative detectors. +YOLOv5 has verified a large number of features, and selected for use such of them for improving the accuracy of both the classifier and the detector. +These features can be used as best-practice for future studies and developments. + +[Code](https://github.com/ultralytics/yolov5) + +# [Model Architecture](#contents) + +YOLOv5 choose CSP with Focus backbone, SPP additional module, PANet path-aggregation neck, and YOLOv5 (anchor based) head as the architecture of YOLOv5. + +# [Dataset](#contents) + +Dataset support: [MS COCO] or datasetd with the same format as MS COCO +Annotation support: [MS COCO] or annotation as the same format as MS COCO + +- The directory structure is as follows, the name of directory and file is user define: + + ```shell + ├── dataset + ├── YOLOv5 + ├── annotations + │ ├─ train.json + │ └─ val.json + ├─ images + ├─ train + │ └─images + │ ├─picture1.jpg + │ ├─ ... + │ └─picturen.jpg + └─ val + └─images + ├─picture1.jpg + ├─ ... + └─picturen.jpg + ``` + +we suggest user to use MS COCO dataset to experience our model, +other datasets need to use the same format as MS COCO. + +# [Environment Requirements](#contents) + +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. +- Framework + - [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + +# [Quick Start](#contents) + +After installing MindSpore via the official website, you can start training and evaluation as follows: + +``` shell +# The parameter of training_shape define image shape for network, default is [640, 640], +``` + +```shell +#run training example(1p) by python command +python train.py \ + --data_dir=./dataset/xxx \ + --is_distributed=0 \ + --lr=0.01 \ + --T_max=320 \ + --max_epoch=320 \ + --warmup_epochs=4 \ + --training_shape=640 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & +``` + +```shell +# standalone training example(1p) by shell script +sh run_standalone_train.sh dataset/xxx +``` + +```shell +# For Ascend device, distributed training example(8p) by shell script +sh run_distribute_train.sh dataset/xxx rank_table_8p.json +``` + +```python +# run evaluation by python command +python eval.py \ + --data_dir=./dataset/xxx \ + --pretrained=yolov5.ckpt \ + --testing_shape=640 > log.txt 2>&1 & +``` + +```python +# run evaluation by shell script +sh run_eval.sh dataset/xxx checkpoint/xxx.ckpt +``` + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) + +```python +└─yolov5 + ├─README.md + ├─mindspore_hub_conf.md # config for mindspore hub + ├─scripts + ├─run_standalone_train.sh # launch standalone training(1p) in ascend + ├─run_distribute_train.sh # launch distributed training(8p) in ascend + └─run_eval.sh # launch evaluating in ascend + ├─src + ├─__init__.py # python init file + ├─config.py # parameter configuration + ├─yolov5_backbone.py # backbone of network + ├─distributed_sampler.py # iterator of dataset + ├─initializer.py # initializer of parameters + ├─logger.py # log function + ├─loss.py # loss function + ├─lr_scheduler.py # generate learning rate + ├─transforms.py # Preprocess data + ├─util.py # util function + ├─yolo.py # yolov5 network + ├─yolo_dataset.py # create dataset for YOLOV5 + + ├─eval.py # evaluate val results + ├─export.py # convert mindspore model to air model + └─train.py # train net +``` + +## [Script Parameters](#contents) + +Major parameters train.py as follows: + +```shell +optional arguments: + -h, --help show this help message and exit + --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" + --data_dir DATA_DIR Train dataset directory. + --per_batch_size PER_BATCH_SIZE + Batch size for Training. Default: 8. + --pretrained_backbone PRETRAINED_BACKBONE + The backbone file of yolov5. Default: "". + --resume_yolov5 RESUME_YOLOV5 + The ckpt file of YOLOv5, which used to fine tune. + Default: "" + --lr_scheduler LR_SCHEDULER + Learning rate scheduler, options: exponential, + cosine_annealing. Default: exponential + --lr LR Learning rate. Default: 0.01 + --lr_epochs LR_EPOCHS + Epoch of changing of lr changing, split with ",". + Default: 220,250 + --lr_gamma LR_GAMMA Decrease lr by a factor of exponential lr_scheduler. + Default: 0.1 + --eta_min ETA_MIN Eta_min in cosine_annealing scheduler. Default: 0 + --T_max T_MAX T-max in cosine_annealing scheduler. Default: 320 + --max_epoch MAX_EPOCH + Max epoch num to train the model. Default: 320 + --warmup_epochs WARMUP_EPOCHS + Warmup epochs. Default: 0 + --weight_decay WEIGHT_DECAY + Weight decay factor. Default: 0.0005 + --momentum MOMENTUM Momentum. Default: 0.9 + --loss_scale LOSS_SCALE + Static loss scale. Default: 1024 + --label_smooth LABEL_SMOOTH + Whether to use label smooth in CE. Default:0 + --label_smooth_factor LABEL_SMOOTH_FACTOR + Smooth strength of original one-hot. Default: 0.1 + --log_interval LOG_INTERVAL + Logging interval steps. Default: 100 + --ckpt_path CKPT_PATH + Checkpoint save location. Default: outputs/ + --ckpt_interval CKPT_INTERVAL + Save checkpoint interval. Default: None + --is_save_on_master IS_SAVE_ON_MASTER + Save ckpt on master or all rank, 1 for master, 0 for + all ranks. Default: 1 + --is_distributed IS_DISTRIBUTED + Distribute train or not, 1 for yes, 0 for no. Default: + 1 + --rank RANK Local rank of distributed. Default: 0 + --group_size GROUP_SIZE + World size of device. Default: 1 + --need_profiler NEED_PROFILER + Whether use profiler. 0 for no, 1 for yes. Default: 0 + --training_shape TRAINING_SHAPE + Fix training shape. Default: "" + --resize_rate RESIZE_RATE + Resize rate for multi-scale training. Default: None +``` + +## [Training Process](#contents) + +### Training + +```python +python train.py \ + --data_dir=/dataset/xxx \ + --is_distributed=0 \ + --lr=0.01 \ + --T_max=320 \ + --max_epoch=320 \ + --warmup_epochs=4 \ + --training_shape=640 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & +``` + +The python command above will run in the background, you can view the results through the file log.txt. + +After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows: + +```shell +# grep "loss:" train/log.txt +2021-05-13 20:50:25,617:INFO:epoch[0], iter[100], loss:loss:2648.764910, fps:61.59 imgs/sec, lr:1.7226087948074564e-05 +2021-05-13 20:50:39,821:INFO:epoch[0], iter[200], loss:loss:764.535622, fps:56.33 imgs/sec, lr:3.4281620173715055e-05 +2021-05-13 20:50:53,287:INFO:epoch[0], iter[300], loss:loss:494.950782, fps:59.47 imgs/sec, lr:5.1337152399355546e-05 +2021-05-13 20:51:06,138:INFO:epoch[0], iter[400], loss:loss:393.339678, fps:62.25 imgs/sec, lr:6.839268462499604e-05 +2021-05-13 20:51:17,985:INFO:epoch[0], iter[500], loss:loss:329.976604, fps:67.57 imgs/sec, lr:8.544822048861533e-05 +2021-05-13 20:51:29,359:INFO:epoch[0], iter[600], loss:loss:294.734397, fps:70.37 imgs/sec, lr:0.00010250374907627702 +2021-05-13 20:51:40,634:INFO:epoch[0], iter[700], loss:loss:281.497078, fps:70.98 imgs/sec, lr:0.00011955928493989632 +2021-05-13 20:51:52,307:INFO:epoch[0], iter[800], loss:loss:264.300707, fps:68.54 imgs/sec, lr:0.0001366148208035156 +2021-05-13 20:52:05,479:INFO:epoch[0], iter[900], loss:loss:261.971103, fps:60.76 imgs/sec, lr:0.0001536703493911773 +2021-05-13 20:52:17,362:INFO:epoch[0], iter[1000], loss:loss:264.591175, fps:67.33 imgs/sec, lr:0.00017072587797883898 +... +``` + +### Distributed Training + +For Ascend device, distributed training example(8p) by shell script + +```shell +sh run_distribute_train.sh dataset/coco2017 rank_table_8p.json +``` + +The above shell script will run distribute training in the background. You can view the results through the file train_parallel[X]/log.txt. The loss value will be achieved as follows: + +```shell +# distribute training result(8p) +... +2021-05-13 21:08:41,992:INFO:epoch[0], iter[600], loss:247.577421, fps:469.29 imgs/sec, lr:0.0001640283880988136 +2021-05-13 21:08:56,291:INFO:epoch[0], iter[700], loss:235.298894, fps:447.67 imgs/sec, lr:0.0001913209562189877 +2021-05-13 21:09:10,431:INFO:epoch[0], iter[800], loss:239.481037, fps:452.78 imgs/sec, lr:0.00021861353889107704 +2021-05-13 21:09:23,517:INFO:epoch[0], iter[900], loss:232.826709, fps:489.15 imgs/sec, lr:0.0002459061215631664 +2021-05-13 21:09:36,407:INFO:epoch[0], iter[1000], loss:224.734599, fps:496.65 imgs/sec, lr:0.0002731987042352557 +2021-05-13 21:09:49,072:INFO:epoch[0], iter[1100], loss:232.334771, fps:505.34 imgs/sec, lr:0.0003004912578035146 +2021-05-13 21:10:03,597:INFO:epoch[0], iter[1200], loss:242.001476, fps:440.69 imgs/sec, lr:0.00032778384047560394 +2021-05-13 21:10:18,237:INFO:epoch[0], iter[1300], loss:225.391021, fps:437.20 imgs/sec, lr:0.0003550764231476933 +2021-05-13 21:10:33,027:INFO:epoch[0], iter[1400], loss:228.738176, fps:432.76 imgs/sec, lr:0.0003823690058197826 +2021-05-13 21:10:47,424:INFO:epoch[0], iter[1500], loss:225.712950, fps:444.54 imgs/sec, lr:0.0004096615593880415 +2021-05-13 21:11:02,077:INFO:epoch[0], iter[1600], loss:221.249353, fps:436.77 imgs/sec, lr:0.00043695414206013083 +2021-05-13 21:11:16,631:INFO:epoch[0], iter[1700], loss:222.449119, fps:439.89 imgs/sec, lr:0.00046424672473222017 +... +``` + +## [Evaluation Process](#contents) + +### Valid + +```python +python eval.py \ + --data_dir=./dataset/coco2017 \ + --pretrained=yolov5.ckpt \ + --testing_shape=640 > log.txt 2>&1 & +OR +sh run_eval.sh dataset/coco2017 checkpoint/yolov5.ckpt +``` + +The above python command will run in the background. You can view the results through the file "log.txt". The mAP of the test dataset will be as follows: + +```shell +# log.txt +=============coco eval reulst========= + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.372 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.574 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.403 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.219 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.426 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.480 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.302 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.504 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.560 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.399 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.619 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.674 +``` + +## [Convert Process](#contents) + +### Convert + +If you want to infer the network on Ascend 310, you should convert the model to AIR: + +```python +python export.py [BATCH_SIZE] [PRETRAINED_BACKBONE] +``` + +# [Model Description](#contents) + +## [Performance](#contents) + +### Evaluation Performance + +YOLOv5 on 118K images(The annotation and data format must be the same as coco2017) + +| Parameters | YOLOv5s | +| -------------------------- | ----------------------------------------------------------- | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G | +| uploaded Date | 5/14/2021 (month/day/year) | +| MindSpore Version | 1.0.0-alpha | +| Dataset | 11.8K images | +| Training Parameters | epoch=320, batch_size=8, lr=0.01, momentum=0.9 | +| Optimizer | Momentum | +| Loss Function | Sigmoid Cross Entropy with logits, Giou Loss | +| outputs | heatmaps | +| Loss | 53 | +| Speed | 1p 55 img/s 8p 440 img/s(shape=640) | +| Total time | 80h | +| Checkpoint for Fine tuning | 58M (.ckpt file) | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/| + +### Inference Performance + +YOLOv5 on 5K images(The annotation and data format must be the same as coco val2017 ) + +| Parameters | YOLOv5s | +| -------------------------- | ----------------------------------------------------------- | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G | +| uploaded Date | 5/14/2021 (month/day/year) | +| MindSpore Version | 1.2.0 | +| Dataset | 5K images | +| batch_size | 1 | +| outputs | box position and sorces, and probability | +| Accuracy | map=36.8~37.2%(shape=640) | +| Model for inference | 58M (.ckpt file) | + +# [Description of Random Situation](#contents) + +In dataset.py, we set the seed inside ```create_dataset``` function. +In var_init.py, we set seed for weight initialization + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/yolov5/eval.py b/model_zoo/official/cv/yolov5/eval.py new file mode 100644 index 00000000000..1d694e7152c --- /dev/null +++ b/model_zoo/official/cv/yolov5/eval.py @@ -0,0 +1,396 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YoloV5 eval.""" +import os +import argparse +import datetime +import time +import sys +from collections import defaultdict + +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from mindspore import Tensor +from mindspore.context import ParallelMode +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore as ms + +from src.yolo import YOLOV5s +from src.logger import get_logger +from src.yolo_dataset import create_yolo_dataset +from src.config import ConfigYOLOV5 + +parser = argparse.ArgumentParser('mindspore coco testing') + +# device related +parser.add_argument('--device_target', type=str, default='Ascend', + help='device where the code will be implemented. (Default: Ascend)') + +# dataset related +parser.add_argument('--data_dir', type=str, default='', help='train data dir') +parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu') + +# network related +parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load') + +# logging related +parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location') + +# detect_related +parser.add_argument('--nms_thresh', type=float, default=0.6, help='threshold for NMS') +parser.add_argument('--ann_file', type=str, default='', help='path to annotation') +parser.add_argument('--testing_shape', type=str, default='', help='shape for test ') +parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes') +parser.add_argument('--multi_label', type=ast.iteral_eval, default=True, help='whether to use multi label') +parser.add_argument('--multi_label_thresh', type=float, default=0.1, help='threshhold to throw low quality boxes') + +args, _ = parser.parse_known_args() + +args.data_root = os.path.join(args.data_dir, 'val2017') +args.ann_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json') + + +class Redirct: + def __init__(self): + self.content = "" + + def write(self, content): + self.content += content + + def flush(self): + self.content = "" + + +class DetectionEngine: + """Detection engine.""" + + def __init__(self, args_detection): + self.ignore_threshold = args_detection.ignore_threshold + self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', + 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', + 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', + 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + self.num_classes = len(self.labels) + self.results = {} + self.file_path = '' + self.save_prefix = args_detection.outputs_dir + self.ann_file = args_detection.ann_file + self._coco = COCO(self.ann_file) + self._img_ids = list(sorted(self._coco.imgs.keys())) + self.det_boxes = [] + self.nms_thresh = args_detection.nms_thresh + self.multi_label = args_detection.multi_label + self.multi_label_thresh = args_detection.multi_label_thresh + # self.coco_catids = self._coco.getCatIds() + self.coco_catIds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, + 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 84, 85, 86, 87, 88, 89, 90] + + def do_nms_for_results(self): + """Get result boxes.""" + # np.save('/opt/disk1/hjc/yolov5_positive_policy/result.npy', self.results) + for img_id in self.results: + for clsi in self.results[img_id]: + dets = self.results[img_id][clsi] + dets = np.array(dets) + keep_index = self._diou_nms(dets, thresh=nms_thresh) + + keep_box = [{'image_id': int(img_id), + 'category_id': int(clsi), + 'bbox': list(dets[i][:4].astype(float)), + 'score': dets[i][4].astype(float)} + for i in keep_index] + self.det_boxes.extend(keep_box) + + def _nms(self, predicts, threshold): + """Calculate NMS.""" + # convert xywh -> xmin ymin xmax ymax + x1 = predicts[:, 0] + y1 = predicts[:, 1] + x2 = x1 + predicts[:, 2] + y2 = y1 + predicts[:, 3] + scores = predicts[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + reserved_boxes = [] + while order.size > 0: + i = order[0] + reserved_boxes.append(i) + max_x1 = np.maximum(x1[i], x1[order[1:]]) + max_y1 = np.maximum(y1[i], y1[order[1:]]) + min_x2 = np.minimum(x2[i], x2[order[1:]]) + min_y2 = np.minimum(y2[i], y2[order[1:]]) + + intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1) + intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1) + intersect_area = intersect_w * intersect_h + ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area) + + indexes = np.where(ovr <= threshold)[0] + order = order[indexes + 1] + return reserved_boxes + + def _diou_nms(self, dets, thresh=0.5): + """ + convert xywh -> xmin ymin xmax ymax + """ + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = x1 + dets[:, 2] + y2 = y1 + dets[:, 3] + scores = dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + center_x1 = (x1[i] + x2[i]) / 2 + center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2 + center_y1 = (y1[i] + y2[i]) / 2 + center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2 + inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2 + out_max_x = np.maximum(x2[i], x2[order[1:]]) + out_max_y = np.maximum(y2[i], y2[order[1:]]) + out_min_x = np.minimum(x1[i], x1[order[1:]]) + out_min_y = np.minimum(y1[i], y1[order[1:]]) + outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2 + diou = ovr - inter_diag / outer_diag + diou = np.clip(diou, -1, 1) + inds = np.where(diou <= thresh)[0] + order = order[inds + 1] + return keep + + def write_result(self): + """Save result to file.""" + import json + t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S') + try: + self.file_path = self.save_prefix + '/predict' + t + '.json' + f = open(self.file_path, 'w') + json.dump(self.det_boxes, f) + except IOError as e: + raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e))) + else: + f.close() + return self.file_path + + def get_eval_result(self): + """Get eval result.""" + coco_gt = COCO(self.ann_file) + coco_dt = coco_gt.loadRes(self.file_path) + coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') + coco_eval.evaluate() + coco_eval.accumulate() + rdct = Redirct() + stdout = sys.stdout + sys.stdout = rdct + coco_eval.summarize() + sys.stdout = stdout + return rdct.content + + def detect(self, outputs, batch, image_shape, image_id): + """Detect boxes.""" + outputs_num = len(outputs) + # output [|32, 52, 52, 3, 85| ] + for batch_id in range(batch): + for out_id in range(outputs_num): + # 32, 52, 52, 3, 85 + out_item = outputs[out_id] + # 52, 52, 3, 85 + out_item_single = out_item[batch_id, :] + # get number of items in one head, [B, gx, gy, anchors, 5+80] + dimensions = out_item_single.shape[:-1] + out_num = 1 + for d in dimensions: + out_num *= d + ori_w, ori_h = image_shape[batch_id] + img_id = int(image_id[batch_id]) + x = out_item_single[..., 0] * ori_w + y = out_item_single[..., 1] * ori_h + w = out_item_single[..., 2] * ori_w + h = out_item_single[..., 3] * ori_h + + conf = out_item_single[..., 4:5] + cls_emb = out_item_single[..., 5:] + cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1) + x = x.reshape(-1) + y = y.reshape(-1) + w = w.reshape(-1) + h = h.reshape(-1) + x_top_left = x - w / 2. + y_top_left = y - h / 2. + cls_emb = cls_emb.reshape(-1, self.num_classes) + if self.multi_label: + conf = conf.reshape(-1, 1) + # create all False + confidence = cls_emb * conf + flag = cls_emb > self.multi_label_thresh + flag = flag.nonzero() + for index in range(len(flag[0])): + i = flag[0][index] + j = flag[1][index] + confi = confidence[i][j] + if confi < self.ignore_threshold: + continue + if img_id not in self.results: + self.results[img_id] = defaultdict(list) + x_lefti = max(0, x_top_left[i]) + y_lefti = max(0, y_top_left[i]) + wi = min(w[i], ori_w) + hi = min(h[i], ori_h) + clsi = j + # transform catId to match coco + coco_clsi = self.coco_catIds[clsi] + self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) + else: + cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1) + conf = conf.reshape(-1) + cls_argmax = cls_argmax.reshape(-1) + + # create all False + flag = np.random.random(cls_emb.shape) > sys.maxsize + for i in range(flag.shape[0]): + c = cls_argmax[i] + flag[i, c] = True + confidence = cls_emb[flag] * conf + + for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, + cls_argmax): + if confi < self.ignore_threshold: + continue + if img_id not in self.results: + self.results[img_id] = defaultdict(list) + x_lefti = max(0, x_lefti) + y_lefti = max(0, y_lefti) + wi = min(wi, ori_w) + hi = min(hi, ori_h) + # transform catId to match coco + coco_clsi = self.coco_catids[clsi] + self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) + + +def convert_testing_shape(args_testing_shape): + """Convert testing shape to list.""" + testing_shape = [int(args_testing_shape), int(args_testing_shape)] + return testing_shape + + +if __name__ == "__main__": + start_time = time.time() + device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 + # device_id = 1 + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=device_id) + + # logger + args.outputs_dir = os.path.join(args.log_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0 + args.logger = get_logger(args.outputs_dir, rank_id) + + context.reset_auto_parallel_context() + parallel_mode = ParallelMode.STAND_ALONE + context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1) + + args.logger.info('Creating Network....') + network = YOLOV5s(is_training=False) + + args.logger.info(args.pretrained) + if os.path.isfile(args.pretrained): + param_dict = load_checkpoint(args.pretrained) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.'): + continue + elif key.startswith('yolo_network.'): + param_dict_new[key[13:]] = values + else: + param_dict_new[key] = values + load_param_into_net(network, param_dict_new) + args.logger.info('load_model {} success'.format(args.pretrained)) + else: + args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained)) + assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained)) + exit(1) + + data_root = args.data_root + ann_file = args.ann_file + + config = ConfigYOLOV5() + if args.testing_shape: + config.test_img_shape = convert_testing_shape(args.testing_shape) + + ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size, + max_epoch=1, device_num=1, rank=rank_id, shuffle=False, + config=config) + + args.logger.info('testing shape : {}'.format(config.test_img_shape)) + args.logger.info('total {} images to eval'.format(data_size)) + + network.set_train(False) + + # init detection engine + detection = DetectionEngine(args) + + input_shape = Tensor(tuple(config.test_img_shape), ms.float32) + args.logger.info('Start inference....') + for image_index, data in enumerate(ds.create_dict_iterator(num_epochs=1)): + image = data["image"].asnumpy() + image = np.concatenate((image[..., ::2, ::2], image[..., 1::2, ::2], + image[..., ::2, 1::2], image[..., 1::2, 1::2]), axis=1) + image = Tensor(image) + image_shape_ = data["image_shape"] + image_id_ = data["img_id"] + prediction = network(image, input_shape) + output_big, output_me, output_small = prediction + output_big = output_big.asnumpy() + output_me = output_me.asnumpy() + output_small = output_small.asnumpy() + image_id_ = image_id_.asnumpy() + image_shape_ = image_shape_.asnumpy() + detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape_, image_id_) + if image_index % 1000 == 0: + args.logger.info('Processing... {:.2f}% '.format(image_index * args.per_batch_size / data_size * 100)) + + args.logger.info('Calculating mAP...') + detection.do_nms_for_results() + result_file_path = detection.write_result() + args.logger.info('result file path: {}'.format(result_file_path)) + eval_result = detection.get_eval_result() + + cost_time = time.time() - start_time + args.logger.info('\n=============coco eval reulst=========\n' + eval_result) + args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) diff --git a/model_zoo/official/cv/yolov5/export.py b/model_zoo/official/cv/yolov5/export.py new file mode 100644 index 00000000000..e8282402f48 --- /dev/null +++ b/model_zoo/official/cv/yolov5/export.py @@ -0,0 +1,50 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import argparse +import numpy as np + +import mindspore +from mindspore import context, Tensor +from mindspore.train.serialization import export, load_checkpoint, load_param_into_net + +from src.yolo import YOLOV5s + +parser = argparse.ArgumentParser(description='yolov5 export') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--testing_shape", type=int, default=640, help="test shape") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="yolov5", help="output file name.") +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') +parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", + help="device target") +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) + +if __name__ == "__main__": + ts_shape = args.testing_shape + + network = YOLOV5s(is_training=False) + network.set_train(False) + + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(network, param_dict) + + input_data = Tensor(np.zeros([args.batch_size, 3, ts_shape, ts_shape]), mindspore.float32) + + export(network, input_data, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/cv/yolov5/mindspore_hub_conf.py b/model_zoo/official/cv/yolov5/mindspore_hub_conf.py new file mode 100644 index 00000000000..0f9a18543a4 --- /dev/null +++ b/model_zoo/official/cv/yolov5/mindspore_hub_conf.py @@ -0,0 +1,22 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.yolo import YOLOV5s + +def create_network(name, *args, **kwargs): + if name == "yolov5s": + yolov5s_net = YOLOV5s(is_training=True) + return yolov5s_net + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/yolov5/scripts/run_distribute_train.sh b/model_zoo/official/cv/yolov5/scripts/run_distribute_train.sh new file mode 100644 index 00000000000..82214ae19b9 --- /dev/null +++ b/model_zoo/official/cv/yolov5/scripts/run_distribute_train.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET_PATH=$(get_real_path $1) +RANK_TABLE_FILE=$(get_real_path $2) +echo $DATASET_PATH +echo $RANK_TABLE_FILE + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" +exit 1 +fi + +if [ ! -f $RANK_TABLE_FILE ] +then + echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file" +exit 1 +fi + +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$RANK_TABLE_FILE + +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + python train.py \ + --data_dir=$DATASET_PATH \ + --is_distributed=1 \ + --lr=0.02 \ + --T_max=300 \ + --max_epoch=300 \ + --warmup_epochs=20 \ + --per_batch_size=16 \ + --training_shape=640 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & + cd .. +done diff --git a/model_zoo/official/cv/yolov5/scripts/run_eval.sh b/model_zoo/official/cv/yolov5/scripts/run_eval.sh new file mode 100644 index 00000000000..a092a24945d --- /dev/null +++ b/model_zoo/official/cv/yolov5/scripts/run_eval.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +DATASET_PATH=$(get_real_path $1) +CHECKPOINT_PATH=$(get_real_path $2) +echo $DATASET_PATH +echo $CHECKPOINT_PATH + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -f $CHECKPOINT_PATH ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp -r ../src ./eval +cd ./eval || exit +env > env.log +echo "start inferring for device $DEVICE_ID" +python eval.py \ + --data_dir=$DATASET_PATH \ + --pretrained=$CHECKPOINT_PATH \ + --testing_shape=640 > log.txt 2>&1 & +cd .. diff --git a/model_zoo/official/cv/yolov5/scripts/run_standalone_train.sh b/model_zoo/official/cv/yolov5/scripts/run_standalone_train.sh new file mode 100644 index 00000000000..4db3ec380e4 --- /dev/null +++ b/model_zoo/official/cv/yolov5/scripts/run_standalone_train.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 1 ] +then + echo "Usage: sh run_standalone_train.sh [DATASET_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET_PATH=$(get_real_path $1) +echo $DATASET_PATH + + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" +exit 1 +fi + + +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log + +python train.py \ + --data_dir=$DATASET_PATH \ + --is_distributed=0 \ + --lr=0.01 \ + --T_max=320 \ + --max_epoch=320 \ + --warmup_epochs=4 \ + --training_shape=640 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & +cd .. \ No newline at end of file diff --git a/model_zoo/official/cv/yolov5/src/__init__.py b/model_zoo/official/cv/yolov5/src/__init__.py new file mode 100644 index 00000000000..6228b713269 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/model_zoo/official/cv/yolov5/src/config.py b/model_zoo/official/cv/yolov5/src/config.py new file mode 100644 index 00000000000..9f9777bcc99 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/config.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Config parameters for yolov5 models.""" + + +class ConfigYOLOV5: + """ + Config parameters for the yolov5. + + Examples: + ConfigYOLOV5() + """ + # train_param + # data augmentation related + hue = 0.015 + saturation = 1.5 + value = 0.4 + jitter = 0.3 + + resize_rate = 10 + multi_scale = [[320, 320], [352, 352], [384, 384], [416, 416], [448, 448], + [480, 480], [512, 512], [544, 544], [576, 576], [608, 608], + [640, 640], [672, 672], [704, 704], [736, 736], [768, 768]] + num_classes = 80 + max_box = 150 + + # confidence under ignore_threshold means no object when training + ignore_threshold = 0.7 + + # h->w + anchor_scales = [(12, 16), + (19, 36), + (40, 28), + (36, 75), + (76, 55), + (72, 146), + (142, 110), + (192, 243), + (459, 401)] + out_channel = 3 * (num_classes + 5) + + # test_param + test_img_shape = [640, 640] diff --git a/model_zoo/official/cv/yolov5/src/distributed_sampler.py b/model_zoo/official/cv/yolov5/src/distributed_sampler.py new file mode 100644 index 00000000000..5c4a507b9a3 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/distributed_sampler.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Yolo dataset distributed sampler.""" +from __future__ import division +import math +import numpy as np + + +class DistributedSampler: + """Distributed sampler.""" + def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + print("***********Setting world_size to 1 since it is not passed in ******************") + num_replicas = 1 + if rank is None: + print("***********Setting rank to 0 since it is not passed in ******************") + rank = 0 + self.dataset_size = dataset_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size) + # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset + indices = indices.tolist() + self.epoch += 1 + # change to list type + else: + indices = list(range(self.dataset_size)) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples diff --git a/model_zoo/official/cv/yolov5/src/initializer.py b/model_zoo/official/cv/yolov5/src/initializer.py new file mode 100644 index 00000000000..237dcd7c882 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/initializer.py @@ -0,0 +1,202 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parameter init.""" +import math +from functools import reduce +import numpy as np +from mindspore.common import initializer as init +from mindspore.common.initializer import Initializer as MeInitializer +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.nn as nn +from .util import load_backbone + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def _assignment(arr, num): + """Assign the value of 'num' and 'arr'.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + + +def _calculate_correct_fan(array, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(array) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Fills the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + + Examples: + >>> w = np.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') + """ + fan = _calculate_correct_fan(arr, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return np.random.uniform(-bound, bound, arr.shape) + + +def _calculate_fan_in_and_fan_out(arr): + """Calculate fan in and fan out.""" + dimensions = len(arr.shape) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") + + num_input_fmaps = arr.shape[1] + num_output_fmaps = arr.shape[0] + receptive_field_size = 1 + if dimensions > 2: + receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:]) + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +class KaimingUniform(MeInitializer): + """Kaiming uniform initializer.""" + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingUniform, self).__init__() + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + + def _initialize(self, arr): + tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity) + _assignment(arr, tmp) + + +def default_recurisive_init(custom_cell): + """Initialize parameter.""" + for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype)) + if cell.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + cell.bias.set_data(init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype)) + elif isinstance(cell, nn.Dense): + cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype)) + if cell.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + cell.bias.set_data(init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype)) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + pass + +def load_yolov5_params(args, network): + """Load yolov5 backbone parameter from checkpoint.""" + if args.pretrained_backbone: + network = load_backbone(network, args.pretrained_backbone, args) + args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone)) + + if args.resume_yolov5: + param_dict = load_checkpoint(args.resume_yolov5) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.'): + continue + elif key.startswith('yolo_network.'): + param_dict_new[key[13:]] = values + args.logger.info('in resume {}'.format(key)) + else: + param_dict_new[key] = values + args.logger.info('in resume {}'.format(key)) + + args.logger.info('resume finished') + load_param_into_net(network, param_dict_new) + args.logger.info('load_model {} success'.format(args.resume_yolov5)) diff --git a/model_zoo/official/cv/yolov5/src/logger.py b/model_zoo/official/cv/yolov5/src/logger.py new file mode 100644 index 00000000000..d9f924ad96c --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/logger.py @@ -0,0 +1,80 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Custom Logger.""" +import os +import sys +import logging +from datetime import datetime + + +class LOGGER(logging.Logger): + """ + Logger. + + Args: + logger_name: String. Logger name. + rank: Integer. Rank id. + """ + def __init__(self, logger_name, rank=0): + super(LOGGER, self).__init__(logger_name) + self.rank = rank + if rank % 8 == 0: + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + console.setFormatter(formatter) + self.addHandler(console) + + def setup_logging_file(self, log_dir, rank=0): + """Setup logging file.""" + self.rank = rank + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) + self.log_fn = os.path.join(log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + fh.setFormatter(formatter) + self.addHandler(fh) + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO): + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info('Args:') + args_dict = vars(args) + for key in args_dict.keys(): + self.info('--> %s: %s', key, args_dict[key]) + self.info('') + + def important_info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.rank == 0: + line_width = 2 + important_msg = '\n' + important_msg += ('*'*70 + '\n')*line_width + important_msg += ('*'*line_width + '\n')*2 + important_msg += '*'*line_width + ' '*8 + msg + '\n' + important_msg += ('*'*line_width + '\n')*2 + important_msg += ('*'*70 + '\n')*line_width + self.info(important_msg, *args, **kwargs) + + +def get_logger(path, rank): + """Get Logger.""" + logger = LOGGER('YOLOV5', rank) + logger.setup_logging_file(path, rank) + return logger diff --git a/model_zoo/official/cv/yolov5/src/loss.py b/model_zoo/official/cv/yolov5/src/loss.py new file mode 100644 index 00000000000..75f08f7f810 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/loss.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YOLOV5 loss.""" +from mindspore.ops import operations as P +import mindspore.nn as nn + +class ConfidenceLoss(nn.Cell): + """Loss for confidence.""" + def __init__(self): + super(ConfidenceLoss, self).__init__() + self.cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.reduce_sum = P.ReduceSum() + + def construct(self, object_mask, predict_confidence, ignore_mask): + confidence_loss = self.cross_entropy(predict_confidence, object_mask) + confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask + confidence_loss = self.reduce_sum(confidence_loss, ()) + return confidence_loss + + +class ClassLoss(nn.Cell): + """Loss for classification.""" + def __init__(self): + super(ClassLoss, self).__init__() + self.cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.reduce_sum = P.ReduceSum() + + def construct(self, object_mask, predict_class, class_probs): + class_loss = object_mask * self.cross_entropy(predict_class, class_probs) + class_loss = self.reduce_sum(class_loss, ()) + return class_loss diff --git a/model_zoo/official/cv/yolov5/src/lr_scheduler.py b/model_zoo/official/cv/yolov5/src/lr_scheduler.py new file mode 100644 index 00000000000..ceb4e22b384 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/lr_scheduler.py @@ -0,0 +1,180 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Learning rate scheduler.""" +import math +from collections import Counter + +import numpy as np + + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + """Linear learning rate.""" + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """Warmup step learning rate.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma**milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Cosine annealing learning rate.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Cosine annealing learning rate V2.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + last_lr = 0 + last_epoch_V1 = 0 + + T_max_V2 = int(max_epoch*1/3) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + if i < total_steps*2/3: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 + last_lr = lr + last_epoch_V1 = last_epoch + else: + base_lr = last_lr + last_epoch = last_epoch-last_epoch_V1 + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2 + + lr_each_step.append(lr) + return np.array(lr_each_step).astype(np.float32) + + +def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Warmup cosine annealing learning rate.""" + start_sample_epoch = 60 + step_sample = 2 + tobe_sampled_epoch = 60 + end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch + max_sampled_epoch = max_epoch+tobe_sampled_epoch + T_max = max_sampled_epoch + + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + total_sampled_steps = int(max_sampled_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + + for i in range(total_sampled_steps): + last_epoch = i // steps_per_epoch + if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample): + continue + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + assert total_steps == len(lr_each_step) + return np.array(lr_each_step).astype(np.float32) + + +def get_lr(args): + """generate learning rate.""" + if args.lr_scheduler == 'exponential': + lr = warmup_step_lr(args.lr, + args.lr_epochs, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma, + ) + elif args.lr_scheduler == 'cosine_annealing': + lr = warmup_cosine_annealing_lr(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + elif args.lr_scheduler == 'cosine_annealing_V2': + lr = warmup_cosine_annealing_lr_V2(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + elif args.lr_scheduler == 'cosine_annealing_sample': + lr = warmup_cosine_annealing_lr_sample(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + else: + raise NotImplementedError(args.lr_scheduler) + return lr diff --git a/model_zoo/official/cv/yolov5/src/transforms.py b/model_zoo/official/cv/yolov5/src/transforms.py new file mode 100644 index 00000000000..61609f44c35 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/transforms.py @@ -0,0 +1,621 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Preprocess dataset.""" +import random +import threading +import copy + +import numpy as np +from PIL import Image +import cv2 +import mindspore.dataset.vision.py_transforms as PV + +def _rand(a=0., b=1.): + return np.random.rand() * (b - a) + a + + +def bbox_iou(bbox_a, bbox_b, offset=0): + """Calculate Intersection-Over-Union(IOU) of two bounding boxes. + + Parameters + ---------- + bbox_a : numpy.ndarray + An ndarray with shape :math:`(N, 4)`. + bbox_b : numpy.ndarray + An ndarray with shape :math:`(M, 4)`. + offset : float or int, default is 0 + The ``offset`` is used to control the whether the width(or height) is computed as + (right - left + ``offset``). + Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``. + + Returns + ------- + numpy.ndarray + An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of + bounding boxes in `bbox_a` and `bbox_b`. + + """ + if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4: + raise IndexError("Bounding boxes axis 1 must have at least length 4") + + tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) + br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4]) + + area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2) + area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1) + area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1) + return area_i / (area_a[:, None] + area_b - area_i) + + +def statistic_normalize_img(img, statistic_norm): + """Statistic normalize images.""" + # img: RGB + if isinstance(img, Image.Image): + img = np.array(img) + img = img/255. + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + if statistic_norm: + img = (img - mean) / std + return img + + +def get_interp_method(interp, sizes=()): + """ + Get the interpolation method for resize functions. + The major purpose of this function is to wrap a random interp method selection + and a auto-estimation method. + + Note: + When shrinking an image, it will generally look best with AREA-based + interpolation, whereas, when enlarging an image, it will generally look best + with Bicubic or Bilinear. + + Args: + interp (int): Interpolation method for all resizing operations. + + - 0: Nearest Neighbors Interpolation. + - 1: Bilinear interpolation. + - 2: Bicubic interpolation over 4x4 pixel neighborhood. + - 3: Nearest Neighbors. Originally it should be Area-based, as we cannot find Area-based, + so we use NN instead. Area-based (resampling using pixel area relation). + It may be a preferred method for image decimation, as it gives moire-free results. + But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). + - 4: Lanczos interpolation over 8x8 pixel neighborhood. + - 9: Cubic for enlarge, area for shrink, bilinear for others. + - 10: Random select from interpolation method mentioned above. + + sizes (tuple): Format should like (old_height, old_width, new_height, new_width), + if None provided, auto(9) will return Area(2) anyway. Default: () + + Returns: + int, interp method from 0 to 4. + """ + if interp == 9: + if sizes: + assert len(sizes) == 4 + oh, ow, nh, nw = sizes + if nh > oh and nw > ow: + return 2 + if nh < oh and nw < ow: + return 0 + return 1 + return 2 + if interp == 10: + return random.randint(0, 4) + if interp not in (0, 1, 2, 3, 4): + raise ValueError('Unknown interp method %d' % interp) + return interp + + +def pil_image_reshape(interp): + """Reshape pil image.""" + reshape_type = { + 0: Image.NEAREST, + 1: Image.BILINEAR, + 2: Image.BICUBIC, + 3: Image.NEAREST, + 4: Image.LANCZOS, + } + return reshape_type[interp] + + +def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, max_boxes, label_smooth, + label_smooth_factor=0.1, iou_threshold=0.213): + """ + Introduction + ------------ + 对训练数据的ground truth box进行预处理 + Parameters + ---------- + true_boxes: ground truth box 形状为[boxes, 5], x_min, y_min, x_max, y_max, class_id + """ + anchors = np.array(anchors) + num_layers = anchors.shape[0] // 3 + anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + true_boxes = np.array(true_boxes, dtype='float32') + # input_shape = np.array([in_shape, in_shape], dtype='int32') + input_shape = np.array(in_shape, dtype='int32') + boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. + # trans to box center point + boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] + # input_shape is [h, w] + true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] + true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] + # true_boxes = [xywh] + grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] + # grid_shape [h, w] + y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), + 5 + num_classes), dtype='float32') for l in range(num_layers)] + # y_true [gridy, gridx] + # 这里扩充维度是为了后面应用广播计算每个图中所有box的anchor互相之间的iou + anchors = np.expand_dims(anchors, 0) + anchors_max = anchors / 2. + anchors_min = -anchors_max + # 因为之前对box做了padding, 因此需要去除全0行 + valid_mask = boxes_wh[..., 0] > 0 + wh = boxes_wh[valid_mask] + if wh.size != 0: + # 为了应用广播扩充维度 + wh = np.expand_dims(wh, -2) + # wh 的shape为[box_num, 1, 2] + boxes_max = wh / 2. + boxes_min = -boxes_max + intersect_min = np.maximum(boxes_min, anchors_min) + intersect_max = np.minimum(boxes_max, anchors_max) + intersect_wh = np.maximum(intersect_max - intersect_min, 0.) + intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] + box_area = wh[..., 0] * wh[..., 1] + anchor_area = anchors[..., 0] * anchors[..., 1] + iou = intersect_area / (box_area + anchor_area - intersect_area) + #topk iou + # 找出和ground truth box的iou最大的anchor box, + # 然后将对应不同比例的负责该ground turth box 的位置置为ground truth box坐标 + + topk = 4 + topk_flag = iou.argsort() + topk_flag = topk_flag >= topk_flag.shape[1] - topk + flag = topk_flag.nonzero() + for index in range(len(flag[0])): + t = flag[0][index] + n = flag[1][index] + if iou[t][n] < iou_threshold: + continue + for l in range(num_layers): + if n in anchor_mask[l]: + i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y + j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x + + k = anchor_mask[l].index(n) + c = true_boxes[t, 4].astype('int32') + y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] + y_true[l][j, i, k, 4] = 1. + + # lable-smooth + if label_smooth: + sigma = label_smooth_factor / (num_classes - 1) + y_true[l][j, i, k, 5:] = sigma + y_true[l][j, i, k, 5 + c] = 1 - label_smooth_factor + else: + y_true[l][j, i, k, 5 + c] = 1. + #best anchor for gt + best_anchor = np.argmax(iou, axis=-1) + for t, n in enumerate(best_anchor): + for l in range(num_layers): + if n in anchor_mask[l]: + i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y + j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x + + k = anchor_mask[l].index(n) + c = true_boxes[t, 4].astype('int32') + y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] + y_true[l][j, i, k, 4] = 1. + + # lable-smooth + if label_smooth: + sigma = label_smooth_factor / (num_classes - 1) + y_true[l][j, i, k, 5:] = sigma + y_true[l][j, i, k, 5 + c] = 1 - label_smooth_factor + else: + y_true[l][j, i, k, 5 + c] = 1. + + # pad_gt_boxes for avoiding dynamic shape + pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + + mask0 = np.reshape(y_true[0][..., 4:5], [-1]) + gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) + # gt_box [boxes, [x,y,w,h]] + gt_box0 = gt_box0[mask0 == 1] + # gt_box0: get all boxes which have object + if gt_box0.shape[0] < max_boxes: + pad_gt_box0[:gt_box0.shape[0]] = gt_box0 + else: + pad_gt_box0 = gt_box0[:max_boxes] + # gt_box0.shape[0]: total number of boxes in gt_box0 + # top N of pad_gt_box0 is real box, and after are pad by zero + + mask1 = np.reshape(y_true[1][..., 4:5], [-1]) + gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) + gt_box1 = gt_box1[mask1 == 1] + if gt_box1.shape[0] < max_boxes: + pad_gt_box1[:gt_box1.shape[0]] = gt_box1 + else: + pad_gt_box1 = gt_box1[:max_boxes] + + mask2 = np.reshape(y_true[2][..., 4:5], [-1]) + gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) + + gt_box2 = gt_box2[mask2 == 1] + if gt_box2.shape[0] < max_boxes: + pad_gt_box2[:gt_box2.shape[0]] = gt_box2 + else: + pad_gt_box2 = gt_box2[:max_boxes] + return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 + + +class PreprocessTrueBox: + def __init__(self, config): + self.anchor_scales = config.anchor_scales + self.num_classes = config.num_classes + self.max_box = config.max_box + self.label_smooth = config.label_smooth + self.label_smooth_factor = config.label_smooth_factor + + def __call__(self, anno, input_shape): + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=input_shape, + num_classes=self.num_classes, max_boxes=self.max_box, + label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor) + return anno, np.array(bbox_true_1), np.array(bbox_true_2), np.array(bbox_true_3), \ + np.array(gt_box1), np.array(gt_box2), np.array(gt_box3) + + +def _reshape_data(image, image_size): + """Reshape image.""" + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + ori_w, ori_h = image.size + ori_image_shape = np.array([ori_w, ori_h], np.int32) + # original image shape fir:H sec:W + h, w = image_size + interp = get_interp_method(interp=9, sizes=(ori_h, ori_w, h, w)) + image = image.resize((w, h), pil_image_reshape(interp)) + image_data = statistic_normalize_img(image, statistic_norm=True) + if len(image_data.shape) == 2: + image_data = np.expand_dims(image_data, axis=-1) + image_data = np.concatenate([image_data, image_data, image_data], axis=-1) + image_data = image_data.astype(np.float32) + return image_data, ori_image_shape + + +def color_distortion(img, hue, sat, val, device_num): + """Color distortion.""" + hue = _rand(-hue, hue) + sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) + val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) + if device_num != 1: + cv2.setNumThreads(1) + x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL) + x = x / 255. + x[..., 0] += hue + x[..., 0][x[..., 0] > 1] -= 1 + x[..., 0][x[..., 0] < 0] += 1 + x[..., 1] *= sat + x[..., 2] *= val + x[x > 1] = 1 + x[x < 0] = 0 + x = x * 255. + x = x.astype(np.uint8) + image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL) + return image_data + + +def filp_pil_image(img): + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def convert_gray_to_color(img): + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + img = np.concatenate([img, img, img], axis=-1) + return img + + +def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box): + iou = bbox_iou(box, crop_box) + return min_iou <= iou.min() and max_iou >= iou.max() + + +def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints): + """Choose candidate by constraints.""" + if use_constraints: + constraints = ( + (0.1, None), + (0.3, None), + (0.5, None), + (0.7, None), + (0.9, None), + (None, 1), + ) + else: + constraints = ( + (None, None), + ) + # add default candidate + candidates = [(0, 0, input_w, input_h)] + for constraint in constraints: + min_iou, max_iou = constraint + min_iou = -np.inf if min_iou is None else min_iou + max_iou = np.inf if max_iou is None else max_iou + + for _ in range(max_trial): + # box_data should have at least one box + new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter) + scale = _rand(0.5, 2) + + if new_ar < 1: + nh = int(scale * input_h) + nw = int(nh * new_ar) + else: + nw = int(scale * input_w) + nh = int(nw / new_ar) + + dx = int(_rand(0, input_w - nw)) + dy = int(_rand(0, input_h - nh)) + + if box.size > 0: + t_box = copy.deepcopy(box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy + + crop_box = np.array((0, 0, input_w, input_h)) + if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]): + continue + else: + candidates.append((dx, dy, nw, nh)) + else: + raise Exception("!!! annotation box is less than 1") + return candidates + + +def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w, + image_h, flip, box, box_data, allow_outside_center, max_boxes): + """Calculate correct boxes.""" + while candidates: + if len(candidates) > 1: + # ignore default candidate which do not crop + candidate = candidates.pop(np.random.randint(1, len(candidates))) + else: + candidate = candidates.pop(np.random.randint(0, len(candidates))) + dx, dy, nw, nh = candidate + t_box = copy.deepcopy(box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy + if flip: + t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]] + + if allow_outside_center: + pass + else: + t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2])/2. >= 0., (t_box[:, 1] + t_box[:, 3])/2. >= 0.)] + t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w, + (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)] + + # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero + t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 + # recorrect w,h not higher than input size + t_box[:, 2][t_box[:, 2] > input_w] = input_w + t_box[:, 3][t_box[:, 3] > input_h] = input_h + box_w = t_box[:, 2] - t_box[:, 0] + box_h = t_box[:, 3] - t_box[:, 1] + # discard invalid box: w or h smaller than 1 pixel + t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] + + if t_box.shape[0] > 0: + # break if number of find t_box + box_data[: len(t_box)] = t_box + return box_data, candidate + return np.zeros(shape=[max_boxes, 5], dtype=np.float64), (0, 0, nw, nh) + + +def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, + anchors, num_classes, max_trial=10, device_num=1): + """Crop an image randomly with bounding box constraints. + + This data augmentation is used in training of + Single Shot Multibox Detector [#]_. More details can be found in + data augmentation section of the original paper. + .. [#] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, + Scott Reed, Cheng-Yang Fu, Alexander C. Berg. + SSD: Single Shot MultiBox Detector. ECCV 2016.""" + + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + + image_w, image_h = image.size + input_h, input_w = image_input_size + + np.random.shuffle(box) + if len(box) > max_boxes: + box = box[:max_boxes] + flip = _rand() < .5 + box_data = np.zeros((max_boxes, 5)) + + candidates = _choose_candidate_by_constraints(use_constraints=False, + max_trial=max_trial, + input_w=input_w, + input_h=input_h, + image_w=image_w, + image_h=image_h, + jitter=jitter, + box=box) + box_data, candidate = _correct_bbox_by_candidates(candidates=candidates, + input_w=input_w, + input_h=input_h, + image_w=image_w, + image_h=image_h, + flip=flip, + box=box, + box_data=box_data, + allow_outside_center=True, + max_boxes=max_boxes) + dx, dy, nw, nh = candidate + interp = get_interp_method(interp=10) + image = image.resize((nw, nh), pil_image_reshape(interp)) + # place image, gray color as back graoud + new_image = Image.new('RGB', (input_w, input_h), (128, 128, 128)) + new_image.paste(image, (dx, dy)) + image = new_image + + if flip: + image = filp_pil_image(image) + + image = np.array(image) + image = convert_gray_to_color(image) + image_data = color_distortion(image, hue, sat, val, device_num) + return image_data, box_data + + +def preprocess_fn(image, box, config, input_size, device_num): + """Preprocess data function.""" + config_anchors = config.anchor_scales + anchors = np.array([list(x) for x in config_anchors]) + max_boxes = config.max_box + num_classes = config.num_classes + jitter = config.jitter + hue = config.hue + sat = config.saturation + val = config.value + image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val, + image_input_size=input_size, max_boxes=max_boxes, + num_classes=num_classes, anchors=anchors, device_num=device_num) + return image, anno + + +def reshape_fn(image, img_id, config): + input_size = config.test_img_shape + image, ori_image_shape = _reshape_data(image, image_size=input_size) + return image, ori_image_shape, img_id + + +class MultiScaleTrans: + """Multi scale transform.""" + def __init__(self, config, device_num): + self.config = config + self.seed = 0 + self.size_list = [] + self.resize_rate = config.resize_rate + self.dataset_size = config.dataset_size + self.size_dict = {} + self.seed_num = int(1e6) + self.seed_list = self.generate_seed_list(seed_num=self.seed_num) + self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) + self.device_num = device_num + self.anchor_scales = config.anchor_scales + self.num_classes = config.num_classes + self.max_box = config.max_box + self.label_smooth = config.label_smooth + self.label_smooth_factor = config.label_smooth_factor + + def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): + seed_list = [] + random.seed(init_seed) + for _ in range(seed_num): + seed = random.randint(seed_range[0], seed_range[1]) + seed_list.append(seed) + return seed_list + + def __call__(self, img, anno, input_size, mosaic_flag): + if mosaic_flag[0] == 0: + img = PV.Decode()(img) + img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) + return img, anno, np.array(img.shape[0:2]) + + +def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, + batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3): + """Preprocess true box for multi-thread.""" + i = 0 + for anno in annos: + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, + num_classes=config.num_classes, max_boxes=config.max_box, + label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) + batch_bbox_true_1[result_index + i] = bbox_true_1 + batch_bbox_true_2[result_index + i] = bbox_true_2 + batch_bbox_true_3[result_index + i] = bbox_true_3 + batch_gt_box1[result_index + i] = gt_box1 + batch_gt_box2[result_index + i] = gt_box2 + batch_gt_box3[result_index + i] = gt_box3 + i = i + 1 + + +def batch_preprocess_true_box(annos, config, input_shape): + """Preprocess true box with multi-thread.""" + batch_bbox_true_1 = [] + batch_bbox_true_2 = [] + batch_bbox_true_3 = [] + batch_gt_box1 = [] + batch_gt_box2 = [] + batch_gt_box3 = [] + threads = [] + + step = 4 + for index in range(0, len(annos), step): + for _ in range(step): + batch_bbox_true_1.append(None) + batch_bbox_true_2.append(None) + batch_bbox_true_3.append(None) + batch_gt_box1.append(None) + batch_gt_box2.append(None) + batch_gt_box3.append(None) + step_anno = annos[index: index + step] + t = threading.Thread(target=thread_batch_preprocess_true_box, + args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2, + batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3)) + t.start() + threads.append(t) + + for t in threads: + t.join() + + return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ + np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) + + +def batch_preprocess_true_box_single(annos, config, input_shape): + """Preprocess true boxes.""" + batch_bbox_true_1 = [] + batch_bbox_true_2 = [] + batch_bbox_true_3 = [] + batch_gt_box1 = [] + batch_gt_box2 = [] + batch_gt_box3 = [] + for anno in annos: + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, + num_classes=config.num_classes, max_boxes=config.max_box, + label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) + batch_bbox_true_1.append(bbox_true_1) + batch_bbox_true_2.append(bbox_true_2) + batch_bbox_true_3.append(bbox_true_3) + batch_gt_box1.append(gt_box1) + batch_gt_box2.append(gt_box2) + batch_gt_box3.append(gt_box3) + + return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ + np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) diff --git a/model_zoo/official/cv/yolov5/src/util.py b/model_zoo/official/cv/yolov5/src/util.py new file mode 100644 index 00000000000..3e439232133 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/util.py @@ -0,0 +1,188 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Util class or function.""" +from mindspore.train.serialization import load_checkpoint +import mindspore.nn as nn +import mindspore.common.dtype as mstype + +from .yolo import YoloLossBlock + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f', tb_writer=None): + self.name = name + self.fmt = fmt + self.reset() + self.tb_writer = tb_writer + self.cur_step = 1 + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + if self.tb_writer is not None: + self.tb_writer.add_scalar(self.name, self.val, self.cur_step) + self.cur_step += 1 + + def __str__(self): + fmtstr = '{name}:{avg' + self.fmt + '}' + return fmtstr.format(**self.__dict__) + + +def load_backbone(net, ckpt_path, args): + """Load cspdarknet53 backbone checkpoint.""" + param_dict = load_checkpoint(ckpt_path) + yolo_backbone_prefix = 'feature_map.backbone' + darknet_backbone_prefix = 'backbone' + find_param = [] + not_found_param = [] + net.init_parameters_data() + for name, cell in net.cells_and_names(): + if name.startswith(yolo_backbone_prefix): + name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) + if isinstance(cell, (nn.Conv2d, nn.Dense)): + darknet_weight = '{}.weight'.format(name) + darknet_bias = '{}.bias'.format(name) + if darknet_weight in param_dict: + cell.weight.set_data(param_dict[darknet_weight].data) + find_param.append(darknet_weight) + else: + not_found_param.append(darknet_weight) + if darknet_bias in param_dict: + cell.bias.set_data(param_dict[darknet_bias].data) + find_param.append(darknet_bias) + else: + not_found_param.append(darknet_bias) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + darknet_moving_mean = '{}.moving_mean'.format(name) + darknet_moving_variance = '{}.moving_variance'.format(name) + darknet_gamma = '{}.gamma'.format(name) + darknet_beta = '{}.beta'.format(name) + if darknet_moving_mean in param_dict: + cell.moving_mean.set_data(param_dict[darknet_moving_mean].data) + find_param.append(darknet_moving_mean) + else: + not_found_param.append(darknet_moving_mean) + if darknet_moving_variance in param_dict: + cell.moving_variance.set_data(param_dict[darknet_moving_variance].data) + find_param.append(darknet_moving_variance) + else: + not_found_param.append(darknet_moving_variance) + if darknet_gamma in param_dict: + cell.gamma.set_data(param_dict[darknet_gamma].data) + find_param.append(darknet_gamma) + else: + not_found_param.append(darknet_gamma) + if darknet_beta in param_dict: + cell.beta.set_data(param_dict[darknet_beta].data) + find_param.append(darknet_beta) + else: + not_found_param.append(darknet_beta) + + args.logger.info('================found_param {}========='.format(len(find_param))) + args.logger.info(find_param) + args.logger.info('================not_found_param {}========='.format(len(not_found_param))) + args.logger.info(not_found_param) + args.logger.info('=====load {} successfully ====='.format(ckpt_path)) + + return net + + +def default_wd_filter(x): + """default weight decay filter.""" + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + return False + if parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + return False + if parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + return False + + return True + + +def get_param_groups(network): + """Param groups for optimizer.""" + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] + + +class ShapeRecord: + """Log image shape.""" + def __init__(self): + self.shape_record = { + 416: 0, + 448: 0, + 480: 0, + 512: 0, + 544: 0, + 576: 0, + 608: 0, + 640: 0, + 672: 0, + 704: 0, + 736: 0, + 'total': 0 + } + + def set(self, shape): + if len(shape) > 1: + shape = shape[0] + shape = int(shape) + self.shape_record[shape] += 1 + self.shape_record['total'] += 1 + + def show(self, logger): + for key in self.shape_record: + rate = self.shape_record[key] / float(self.shape_record['total']) + logger.info('shape {}: {:.2f}%'.format(key, rate*100)) + + +def keep_loss_fp32(network): + """Keep loss of network with float32""" + for _, cell in network.cells_and_names(): + if isinstance(cell, (YoloLossBlock,)): + cell.to_float(mstype.float32) diff --git a/model_zoo/official/cv/yolov5/src/yolo.py b/model_zoo/official/cv/yolov5/src/yolo.py new file mode 100644 index 00000000000..b26c3128515 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/yolo.py @@ -0,0 +1,459 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YOLOv5 based on DarkNet.""" +import mindspore as ms +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import get_group_size +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C + +from src.yolov5_backbone import YOLOv5Backbone, Conv, C3 +from src.config import ConfigYOLOV5 +from src.loss import ConfidenceLoss, ClassLoss + + +class YOLOv5(nn.Cell): + def __init__(self, backbone, out_channel): + super(YOLOv5, self).__init__() + self.out_channel = out_channel + self.backbone = backbone + + self.conv1 = Conv(512, 256, k=1, s=1)#10 + self.C31 = C3(512, 256, n=1, shortcut=False)#11 + self.conv2 = Conv(256, 128, k=1, s=1) + self.C32 = C3(256, 128, n=1, shortcut=False)#13 + self.conv3 = Conv(128, 128, k=3, s=2) + self.C33 = C3(256, 256, n=1, shortcut=False)#15 + self.conv4 = Conv(256, 256, k=3, s=2) + self.C34 = C3(512, 512, n=1, shortcut=False)#17 + + self.backblock1 = YoloBlock(128, 255) + self.backblock2 = YoloBlock(256, 255) + self.backblock3 = YoloBlock(512, 255) + + self.concat = P.Concat(axis=1) + + def construct(self, x): + """ + input_shape of x is (batch_size, 3, h, w) + feature_map1 is (batch_size, backbone_shape[2], h/8, w/8) + feature_map2 is (batch_size, backbone_shape[3], h/16, w/16) + feature_map3 is (batch_size, backbone_shape[4], h/32, w/32) + """ + img_hight = P.Shape()(x)[2] * 2 + img_width = P.Shape()(x)[3] * 2 + + backbone4, backbone6, backbone9 = self.backbone(x) + + cv1 = self.conv1(backbone9)#10 + ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(cv1) + concat1 = self.concat((ups1, backbone6)) + bcsp1 = self.C31(concat1)#13 + cv2 = self.conv2(bcsp1) + ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(cv2)#15 + concat2 = self.concat((ups2, backbone4)) + bcsp2 = self.C32(concat2)#17 + cv3 = self.conv3(bcsp2) + + concat3 = self.concat((cv3, cv2)) + bcsp3 = self.C33(concat3)#20 + cv4 = self.conv4(bcsp3) + concat4 = self.concat((cv4, cv1)) + bcsp4 = self.C34(concat4)#23 + small_object_output = self.backblock1(bcsp2) # h/8, w/8 + medium_object_output = self.backblock2(bcsp3) # h/16, w/16 + big_object_output = self.backblock3(bcsp4) # h/32, w/32 + return small_object_output, medium_object_output, big_object_output + + +class YoloBlock(nn.Cell): + """ + YoloBlock for YOLOv5. + + Args: + in_channels: Integer. Input channel. + out_channels: Integer. Output channel. + + Returns: + Tuple, tuple of output tensor,(f1,f2,f3). + + Examples: + YoloBlock(12, 255) + + """ + def __init__(self, in_channels, out_channels): + super(YoloBlock, self).__init__() + + self.cv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, has_bias=True) + + def construct(self, x): + """construct method""" + + out = self.cv(x) + return out + + +class DetectionBlock(nn.Cell): + """ + YOLOv5 detection Network. It will finally output the detection result. + + Args: + scale: Character. + config: ConfigYOLOV5, Configuration instance. + is_training: Bool, Whether train or not, default True. + + Returns: + Tuple, tuple of output tensor,(f1,f2,f3). + + Examples: + DetectionBlock(scale='l',stride=32) + """ + + def __init__(self, scale, config=ConfigYOLOV5(), is_training=True): + super(DetectionBlock, self).__init__() + self.config = config + if scale == 's': + idx = (0, 1, 2) + self.scale_x_y = 1.2 + self.offset_x_y = 0.1 + elif scale == 'm': + idx = (3, 4, 5) + self.scale_x_y = 1.1 + self.offset_x_y = 0.05 + elif scale == 'l': + idx = (6, 7, 8) + self.scale_x_y = 1.05 + self.offset_x_y = 0.025 + else: + raise KeyError("Invalid scale value for DetectionBlock") + self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) + self.num_anchors_per_scale = 3 + self.num_attrib = 4+1+self.config.num_classes + self.lambda_coord = 1 + + self.sigmoid = nn.Sigmoid() + self.reshape = P.Reshape() + self.tile = P.Tile() + self.concat = P.Concat(axis=-1) + self.conf_training = is_training + + def construct(self, x, input_shape): + """construct method""" + num_batch = P.Shape()(x)[0] + grid_size = P.Shape()(x)[2:4] + + # Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib] + prediction = P.Reshape()(x, (num_batch, + self.num_anchors_per_scale, + self.num_attrib, + grid_size[0], + grid_size[1])) + prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2)) + + range_x = range(grid_size[1]) + range_y = range(grid_size[0]) + grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32) + grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32) + # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid + # [batch, gridx, gridy, 1, 1] + grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1)) + grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1)) + # Shape is [grid_size[0], grid_size[1], 1, 2] + grid = self.concat((grid_x, grid_y)) + + box_xy = prediction[:, :, :, :, :2] + box_wh = prediction[:, :, :, :, 2:4] + box_confidence = prediction[:, :, :, :, 4:5] + box_probs = prediction[:, :, :, :, 5:] + + # gridsize1 is x + # gridsize0 is y + box_xy = (self.scale_x_y * self.sigmoid(box_xy) - self.offset_x_y + grid) / \ + P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32) + # box_wh is w->h + box_wh = P.Exp()(box_wh) * self.anchors / input_shape + box_confidence = self.sigmoid(box_confidence) + box_probs = self.sigmoid(box_probs) + + if self.conf_training: + return prediction, box_xy, box_wh + return self.concat((box_xy, box_wh, box_confidence, box_probs)) + + +class Iou(nn.Cell): + """Calculate the iou of boxes""" + def __init__(self): + super(Iou, self).__init__() + self.min = P.Minimum() + self.max = P.Maximum() + + def construct(self, box1, box2): + """ + box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h] + box2: gt_box [batch, 1, 1, 1, maxbox, 4] + convert to topLeft and rightDown + """ + box1_xy = box1[:, :, :, :, :, :2] + box1_wh = box1[:, :, :, :, :, 2:4] + box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft + box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown + + box2_xy = box2[:, :, :, :, :, :2] + box2_wh = box2[:, :, :, :, :, 2:4] + box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0) + box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0) + + intersect_mins = self.max(box1_mins, box2_mins) + intersect_maxs = self.min(box1_maxs, box2_maxs) + intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0)) + # P.squeeze: for effiecient slice + intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \ + P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2]) + box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2]) + box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2]) + iou = intersect_area / (box1_area + box2_area - intersect_area) + # iou : [batch, gx, gy, anchors, maxboxes] + return iou + + +class YoloLossBlock(nn.Cell): + """ + Loss block cell of YOLOV5 network. + """ + def __init__(self, scale, config=ConfigYOLOV5()): + super(YoloLossBlock, self).__init__() + self.config = config + if scale == 's': + # anchor mask + idx = (0, 1, 2) + elif scale == 'm': + idx = (3, 4, 5) + elif scale == 'l': + idx = (6, 7, 8) + else: + raise KeyError("Invalid scale value for DetectionBlock") + self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) + self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32) + self.concat = P.Concat(axis=-1) + self.iou = Iou() + self.reduce_max = P.ReduceMax(keep_dims=False) + self.confidence_loss = ConfidenceLoss() + self.class_loss = ClassLoss() + + self.reduce_sum = P.ReduceSum() + self.giou = Giou() + + def construct(self, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape): + """ + prediction : origin output from yolo + pred_xy: (sigmoid(xy)+grid)/grid_size + pred_wh: (exp(wh)*anchors)/input_shape + y_true : after normalize + gt_box: [batch, maxboxes, xyhw] after normalize + """ + object_mask = y_true[:, :, :, :, 4:5] + class_probs = y_true[:, :, :, :, 5:] + true_boxes = y_true[:, :, :, :, :4] + + grid_shape = P.Shape()(prediction)[1:3] + grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32) + + pred_boxes = self.concat((pred_xy, pred_wh)) + true_wh = y_true[:, :, :, :, 2:4] + true_wh = P.Select()(P.Equal()(true_wh, 0.0), + P.Fill()(P.DType()(true_wh), + P.Shape()(true_wh), 1.0), + true_wh) + true_wh = P.Log()(true_wh / self.anchors * input_shape) + # 2-w*h for large picture, use small scale, since small obj need more precise + box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4] + + gt_shape = P.Shape()(gt_box) + gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2])) + + # add one more dimension for broadcast + iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box) + # gt_box is x,y,h,w after normalize + # [batch, grid[0], grid[1], num_anchor, num_gt] + best_iou = self.reduce_max(iou, -1) + # [batch, grid[0], grid[1], num_anchor] + + # ignore_mask IOU too small + ignore_mask = best_iou < self.ignore_threshold + ignore_mask = P.Cast()(ignore_mask, ms.float32) + ignore_mask = P.ExpandDims()(ignore_mask, -1) + # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume. + # so we turn off its gradient + ignore_mask = F.stop_gradient(ignore_mask) + + confidence_loss = self.confidence_loss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask) + class_loss = self.class_loss(object_mask, prediction[:, :, :, :, 5:], class_probs) + + object_mask_me = P.Reshape()(object_mask, (-1, 1)) # [8, 72, 72, 3, 1] + box_loss_scale_me = P.Reshape()(box_loss_scale, (-1, 1)) + pred_boxes_me = xywh2x1y1x2y2(pred_boxes) + pred_boxes_me = P.Reshape()(pred_boxes_me, (-1, 4)) + true_boxes_me = xywh2x1y1x2y2(true_boxes) + true_boxes_me = P.Reshape()(true_boxes_me, (-1, 4)) + ciou = self.giou(pred_boxes_me, true_boxes_me) + ciou_loss = object_mask_me * box_loss_scale_me * (1 - ciou) + ciou_loss_me = self.reduce_sum(ciou_loss, ()) + loss = ciou_loss_me * 4 + confidence_loss + class_loss + batch_size = P.Shape()(prediction)[0] + return loss / batch_size + + +class YOLOV5s(nn.Cell): + """ + YOLOV5 network. + + Args: + is_training: Bool. Whether train or not. + + Returns: + Cell, cell instance of YOLOV5 neural network. + + Examples: + YOLOV5s(True) + """ + + def __init__(self, is_training): + super(YOLOV5s, self).__init__() + self.config = ConfigYOLOV5() + + # YOLOv5 network + self.feature_map = YOLOv5(backbone=YOLOv5Backbone(), + out_channel=self.config.out_channel) + + # prediction on the default anchor boxes + self.detect_1 = DetectionBlock('l', is_training=is_training) + self.detect_2 = DetectionBlock('m', is_training=is_training) + self.detect_3 = DetectionBlock('s', is_training=is_training) + + def construct(self, x, input_shape): + small_object_output, medium_object_output, big_object_output = self.feature_map(x) + output_big = self.detect_1(big_object_output, input_shape) + output_me = self.detect_2(medium_object_output, input_shape) + output_small = self.detect_3(small_object_output, input_shape) + # big is the final output which has smallest feature map + return output_big, output_me, output_small + + +class YoloWithLossCell(nn.Cell): + """YOLOV5 loss.""" + def __init__(self, network): + super(YoloWithLossCell, self).__init__() + self.yolo_network = network + self.config = ConfigYOLOV5() + self.loss_big = YoloLossBlock('l', self.config) + self.loss_me = YoloLossBlock('m', self.config) + self.loss_small = YoloLossBlock('s', self.config) + self.tenser_to_array = P.TupleToArray() + + + def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape): + input_shape = F.shape(x)[2:4] + input_shape = F.cast(self.tenser_to_array(input_shape) * 2, ms.float32) + + yolo_out = self.yolo_network(x, input_shape) + loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) + loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) + loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) + return loss_l + loss_m + loss_s * 0.2 + + +class TrainingWrapper(nn.Cell): + """Training wrapper.""" + def __init__(self, network, optimizer, sens=1.0): + super(TrainingWrapper, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + if self.reducer_flag: + mean = context.get_auto_parallel_context("gradients_mean") + if auto_parallel_context().get_device_num_is_set(): + degree = context.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, *args): + weights = self.weights + loss = self.network(*args) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*args, sens) + if self.reducer_flag: + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) + + +class Giou(nn.Cell): + """Calculating giou""" + def __init__(self): + super(Giou, self).__init__() + self.cast = P.Cast() + self.reshape = P.Reshape() + self.min = P.Minimum() + self.max = P.Maximum() + self.concat = P.Concat(axis=1) + self.mean = P.ReduceMean() + self.div = P.RealDiv() + self.eps = 0.000001 + + def construct(self, box_p, box_gt): + """construct method""" + box_p_area = (box_p[..., 2:3] - box_p[..., 0:1]) * (box_p[..., 3:4] - box_p[..., 1:2]) + box_gt_area = (box_gt[..., 2:3] - box_gt[..., 0:1]) * (box_gt[..., 3:4] - box_gt[..., 1:2]) + x_1 = self.max(box_p[..., 0:1], box_gt[..., 0:1]) + x_2 = self.min(box_p[..., 2:3], box_gt[..., 2:3]) + y_1 = self.max(box_p[..., 1:2], box_gt[..., 1:2]) + y_2 = self.min(box_p[..., 3:4], box_gt[..., 3:4]) + intersection = (y_2 - y_1) * (x_2 - x_1) + xc_1 = self.min(box_p[..., 0:1], box_gt[..., 0:1]) + xc_2 = self.max(box_p[..., 2:3], box_gt[..., 2:3]) + yc_1 = self.min(box_p[..., 1:2], box_gt[..., 1:2]) + yc_2 = self.max(box_p[..., 3:4], box_gt[..., 3:4]) + c_area = (xc_2 - xc_1) * (yc_2 - yc_1) + union = box_p_area + box_gt_area - intersection + union = union + self.eps + c_area = c_area + self.eps + iou = self.div(self.cast(intersection, ms.float32), self.cast(union, ms.float32)) + res_mid0 = c_area - union + res_mid1 = self.div(self.cast(res_mid0, ms.float32), self.cast(c_area, ms.float32)) + giou = iou - res_mid1 + giou = C.clip_by_value(giou, -1.0, 1.0) + return giou + +def xywh2x1y1x2y2(box_xywh): + boxes_x1 = box_xywh[..., 0:1] - box_xywh[..., 2:3] / 2 + boxes_y1 = box_xywh[..., 1:2] - box_xywh[..., 3:4] / 2 + boxes_x2 = box_xywh[..., 0:1] + box_xywh[..., 2:3] / 2 + boxes_y2 = box_xywh[..., 1:2] + box_xywh[..., 3:4] / 2 + boxes_x1y1x2y2 = P.Concat(-1)((boxes_x1, boxes_y1, boxes_x2, boxes_y2)) + + return boxes_x1y1x2y2 diff --git a/model_zoo/official/cv/yolov5/src/yolo_dataset.py b/model_zoo/official/cv/yolov5/src/yolo_dataset.py new file mode 100644 index 00000000000..75234400b0e --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/yolo_dataset.py @@ -0,0 +1,291 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YOLOV5 dataset.""" +import os +import multiprocessing +import random +import numpy as np +import cv2 +from PIL import Image +from pycocotools.coco import COCO +import mindspore.dataset as de +import mindspore.dataset.vision.c_transforms as CV +from src.distributed_sampler import DistributedSampler +from src.transforms import reshape_fn, MultiScaleTrans, PreprocessTrueBox + + +min_keypoints_per_image = 10 +GENERATOR_PARALLEL_WORKER = 8 + +def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + +def _count_visible_keypoints(anno): + return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) + + +def has_valid_annotation(anno): + """Check annotation file.""" + # if it's empty, there is no annotation + if not anno: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + # keypoints task have a slight different criteria for considering + # if an annotation is valid + if "keypoints" not in anno[0]: + return True + # for keypoint detection tasks, only consider valid images those + # containing at least min_keypoints_per_image + if _count_visible_keypoints(anno) >= min_keypoints_per_image: + return True + return False + + +class COCOYoloDataset: + """YOLOV5 Dataset for COCO.""" + def __init__(self, root, ann_file, remove_images_without_annotations=True, + filter_crowd_anno=True, is_training=True): + self.coco = COCO(ann_file) + self.root = root + self.img_ids = list(sorted(self.coco.imgs.keys())) + self.filter_crowd_anno = filter_crowd_anno + self.is_training = is_training + self.mosaic = True + # filter images without any annotations + if remove_images_without_annotations: + img_ids = [] + for img_id in self.img_ids: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + if has_valid_annotation(anno): + img_ids.append(img_id) + self.img_ids = img_ids + + self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()} + + self.cat_ids_to_continuous_ids = { + v: i for i, v in enumerate(self.coco.getCatIds()) + } + self.continuous_ids_cat_ids = { + v: k for k, v in self.cat_ids_to_continuous_ids.items() + } + self.count = 0 + + def _mosaic_preprocess(self, index, input_size): + labels4 = [] + s = 384 + self.mosaic_border = [-s // 2, -s // 2] + yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] + indices = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)] + for i, img_ids_index in enumerate(indices): + coco = self.coco + img_id = self.img_ids[img_ids_index] + img_path = coco.loadImgs(img_id)[0]["file_name"] + img = Image.open(os.path.join(self.root, img_path)).convert("RGB") + img = np.array(img) + h, w = img.shape[:2] + + if i == 0: # top left + img4 = np.full((s * 2, s * 2, img.shape[2]), 128, dtype=np.uint8) # base image with 4 tiles + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + + padw = x1a - x1b + padh = y1a - y1b + + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + # filter crowd annotations + if self.filter_crowd_anno: + annos = [anno for anno in target if anno["iscrowd"] == 0] + else: + annos = [anno for anno in target] + + target = {} + boxes = [anno["bbox"] for anno in annos] + target["bboxes"] = boxes + + classes = [anno["category_id"] for anno in annos] + classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] + target["labels"] = classes + + bboxes = target['bboxes'] + labels = target['labels'] + out_target = [] + + for bbox, label in zip(bboxes, labels): + tmp = [] + # convert to [x_min y_min x_max y_max] + bbox = self._convetTopDown(bbox) + tmp.extend(bbox) + tmp.append(int(label)) + # tmp [x_min y_min x_max y_max, label] + out_target.append(tmp) # 这里out_target是label的实际宽高,对应于图片中的实际度量 + + labels = out_target.copy() + labels = np.array(labels) + out_target = np.array(out_target) + + labels[:, 0] = out_target[:, 0] + padw + labels[:, 1] = out_target[:, 1] + padh + labels[:, 2] = out_target[:, 2] + padw + labels[:, 3] = out_target[:, 3] + padh + labels4.append(labels) + + if labels4: + labels4 = np.concatenate(labels4, 0) + np.clip(labels4[:, :4], 0, 2 * s, out=labels4[:, :4]) # use with random_perspective + flag = np.array([1]) + return img4, labels4, input_size, flag + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints", + generated by the image's annotation. img is a PIL image. + """ + coco = self.coco + img_id = self.img_ids[index] + img_path = coco.loadImgs(img_id)[0]["file_name"] + if not self.is_training: + img = Image.open(os.path.join(self.root, img_path)).convert("RGB") + return img, img_id + + input_size = [640, 640] + if self.mosaic and random.random() < 0.5: + return self._mosaic_preprocess(index, input_size) + img = np.fromfile(os.path.join(self.root, img_path), dtype='int8') + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + # filter crowd annotations + if self.filter_crowd_anno: + annos = [anno for anno in target if anno["iscrowd"] == 0] + else: + annos = [anno for anno in target] + + target = {} + boxes = [anno["bbox"] for anno in annos] + target["bboxes"] = boxes + + classes = [anno["category_id"] for anno in annos] + classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] + target["labels"] = classes + + bboxes = target['bboxes'] + labels = target['labels'] + out_target = [] + for bbox, label in zip(bboxes, labels): + tmp = [] + # convert to [x_min y_min x_max y_max] + bbox = self._convetTopDown(bbox) + tmp.extend(bbox) + tmp.append(int(label)) + # tmp [x_min y_min x_max y_max, label] + out_target.append(tmp) + flag = np.array([0]) + return img, out_target, input_size, flag + + def __len__(self): + return len(self.img_ids) + + def _convetTopDown(self, bbox): + x_min = bbox[0] + y_min = bbox[1] + w = bbox[2] + h = bbox[3] + return [x_min, y_min, x_min+w, y_min+h] + + +def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank, + config=None, is_training=True, shuffle=True): + """Create dataset for YOLOV5.""" + cv2.setNumThreads(0) + de.config.set_enable_shared_mem(True) + if is_training: + filter_crowd = True + remove_empty_anno = True + else: + filter_crowd = False + remove_empty_anno = False + + yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd, + remove_images_without_annotations=remove_empty_anno, is_training=is_training) + distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle) + yolo_dataset.size = len(distributed_sampler) + hwc_to_chw = CV.HWC2CHW() + + config.dataset_size = len(yolo_dataset) + cores = multiprocessing.cpu_count() + num_parallel_workers = int(cores / device_num) + if is_training: + multi_scale_trans = MultiScaleTrans(config, device_num) + yolo_dataset.transforms = multi_scale_trans + + dataset_column_names = ["image", "annotation", "input_size", "mosaic_flag"] + output_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3", + "gt_box1", "gt_box2", "gt_box3"] + map1_out_column_names = ["image", "annotation", "size"] + map2_in_column_names = ["annotation", "size"] + map2_out_column_names = ["annotation", "bbox1", "bbox2", "bbox3", + "gt_box1", "gt_box2", "gt_box3"] + + ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler, + python_multiprocessing=True, num_parallel_workers=min(4, num_parallel_workers)) + ds = ds.map(operations=multi_scale_trans, input_columns=dataset_column_names, + output_columns=map1_out_column_names, column_order=map1_out_column_names, + num_parallel_workers=min(12, num_parallel_workers), python_multiprocessing=True) + ds = ds.map(operations=PreprocessTrueBox(config), input_columns=map2_in_column_names, + output_columns=map2_out_column_names, column_order=output_column_names, + num_parallel_workers=min(4, num_parallel_workers), python_multiprocessing=False) + mean = [m * 255 for m in [0.485, 0.456, 0.406]] + std = [s * 255 for s in [0.229, 0.224, 0.225]] + ds = ds.map([CV.Normalize(mean, std), + hwc_to_chw], num_parallel_workers=min(4, num_parallel_workers)) + + def concatenate(images): + images = np.concatenate((images[..., ::2, ::2], images[..., 1::2, ::2], + images[..., ::2, 1::2], images[..., 1::2, 1::2]), axis=0) + return images + ds = ds.map(operations=concatenate, input_columns="image", num_parallel_workers=min(4, num_parallel_workers)) + ds = ds.batch(batch_size, num_parallel_workers=min(4, num_parallel_workers), drop_remainder=True) + else: + ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], + sampler=distributed_sampler) + compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config)) + ds = ds.map(operations=compose_map_func, input_columns=["image", "img_id"], + output_columns=["image", "image_shape", "img_id"], + column_order=["image", "image_shape", "img_id"], + num_parallel_workers=8) + ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(max_epoch) + return ds, len(yolo_dataset) diff --git a/model_zoo/official/cv/yolov5/src/yolov5_backbone.py b/model_zoo/official/cv/yolov5/src/yolov5_backbone.py new file mode 100644 index 00000000000..9f3f6fa5618 --- /dev/null +++ b/model_zoo/official/cv/yolov5/src/yolov5_backbone.py @@ -0,0 +1,201 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DarkNet model.""" +import mindspore.nn as nn +from mindspore.ops import operations as P + + +class Concat(nn.Cell): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super(Concat, self).__init__() + self.d = dimension + self.concat = P.Concat(self.d) + + def forward(self, x): + return self.concat + + +class Bottleneck(nn.Cell): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + super(Bottleneck, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1) + self.add = shortcut and c1 == c2 + + def construct(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Cell): + # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks + def __init__(self, c1, c2, n=1, shortcut=True, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(BottleneckCSP, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, has_bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, has_bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_, momentum=0.9, eps=1e-5) # applied to cat(cv2, cv3) + self.act = nn.LeakyReLU(0.1) + self.m = nn.SequentialCell([Bottleneck(c_, c_, shortcut, e=1.0) for _ in range(n)]) + self.concat = P.Concat(1) + + def construct(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + concat2 = self.concat((y1, y2)) + return self.cv4(self.act(self.bn(concat2))) + + +class C3(nn.Cell): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(C3, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.SequentialCell([Bottleneck(c_, c_, shortcut, e=1.0) for _ in range(n)]) + self.concat = P.Concat(1) + + def construct(self, x): + y1 = self.m(self.cv1(x)) + y2 = self.cv2(x) + concat2 = self.concat((y1, y2)) + return self.cv3(concat2) + + +class SPP(nn.Cell): + # Spatial pyramid pooling layer used in YOLOv3-SPP + def __init__(self, c1, c2, k=(5, 9, 13)): + super(SPP, self).__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + + self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, pad_mode='same') + self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, pad_mode='same') + self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, pad_mode='same') + self.concat = P.Concat(1) + + def construct(self, x): + x = self.cv1(x) + m1 = self.maxpool1(x) + m2 = self.maxpool2(x) + m3 = self.maxpool3(x) + concatm = self.concat((x, m1, m2, m3)) + return self.cv2(concatm) + + +class Focus(nn.Cell): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, act=True): + super(Focus, self).__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, act) + self.concat = P.Concat(1) + + def construct(self, x): + w = P.Shape()(x)[2] + h = P.Shape()(x)[3] + concat4 = self.concat((x[..., 0:w:2, 0:h:2], x[..., 1:w:2, 0:h:2], x[..., 0:w:2, 1:h:2], x[..., 1:w:2, 1:h:2])) + return self.conv(concat4) + + +class Focusv2(nn.Cell): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, act=True): + super(Focusv2, self).__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, act) + + def construct(self, x): + return self.conv(x) + + +class SiLU(nn.Cell): + def __init__(self): + super(SiLU, self).__init__() + self.sigmoid = P.Sigmoid() + + def construct(self, x): + return x * self.sigmoid(x) + + +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +class Conv(nn.Cell): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, + dilation=1, + alpha=0.1, + momentum=0.97, + eps=1e-3, + pad_mode="same", + act=True): # ch_in, ch_out, kernel, stride, padding + super(Conv, self).__init__() + self.padding = autopad(k, p) + self.pad_mode = None + if self.padding == 0: + self.pad_mode = 'same' + elif self.padding == 1: + self.pad_mode = 'pad' + self.conv = nn.Conv2d(c1, c2, k, s, padding=self.padding, pad_mode=self.pad_mode, has_bias=False) + self.bn = nn.BatchNorm2d(c2, momentum=momentum, eps=eps) + self.act = SiLU() if act is True else (act if isinstance(act, nn.Cell) else P.Identity()) + + def construct(self, x): + return self.act(self.bn(self.conv(x))) + + +class YOLOv5Backbone(nn.Cell): + + def __init__(self): + super(YOLOv5Backbone, self).__init__() + + # self.outchannel = 1024 + # self.concat = P.Concat(axis=1) + # self.add = P.TensorAdd() + + self.focusv2 = Focusv2(3, 32, k=3, s=1) + self.conv1 = Conv(32, 64, k=3, s=2) + self.C31 = C3(64, 64, n=1) + self.conv2 = Conv(64, 128, k=3, s=2) + self.C32 = C3(128, 128, n=3) + self.conv3 = Conv(128, 256, k=3, s=2) + self.C33 = C3(256, 256, n=3) + self.conv4 = Conv(256, 512, k=3, s=2) + self.spp = SPP(512, 512, k=[5, 9, 13]) + self.C34 = C3(512, 512, n=1, shortcut=False) + + def construct(self, x): + """construct method""" + fcs = self.focusv2(x) + cv1 = self.conv1(fcs) + bcsp1 = self.C31(cv1) + cv2 = self.conv2(bcsp1) + bcsp2 = self.C32(cv2) + cv3 = self.conv3(bcsp2) + bcsp3 = self.C33(cv3) + cv4 = self.conv4(bcsp3) + spp1 = self.spp(cv4) + bcsp4 = self.C34(spp1) + return bcsp2, bcsp3, bcsp4 diff --git a/model_zoo/official/cv/yolov5/train.py b/model_zoo/official/cv/yolov5/train.py new file mode 100644 index 00000000000..2d076563d59 --- /dev/null +++ b/model_zoo/official/cv/yolov5/train.py @@ -0,0 +1,265 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YoloV5 train.""" +import os +import time +import argparse +import datetime +import mindspore as ms +from mindspore.context import ParallelMode +from mindspore.nn.optim.momentum import Momentum +from mindspore import Tensor +from mindspore import context +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.callback import ModelCheckpoint, RunContext +from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig + +from src.yolo import YOLOV5s, YoloWithLossCell, TrainingWrapper +from src.logger import get_logger +from src.util import AverageMeter, get_param_groups +from src.lr_scheduler import get_lr +from src.yolo_dataset import create_yolo_dataset +from src.initializer import default_recurisive_init, load_yolov5_params +from src.config import ConfigYOLOV5 +ms.set_seed(1) + + +def parse_args(cloud_args=None): + """Parse train arguments.""" + parser = argparse.ArgumentParser('mindspore coco training') + + # device related + parser.add_argument('--device_target', type=str, default='Ascend', + help='device where the code will be implemented.') + + # dataset related + parser.add_argument('--data_dir', type=str, help='Train dataset directory.') + parser.add_argument('--per_batch_size', default=8, type=int, help='Batch size for Training. Default: 8') + + # network related + parser.add_argument('--pretrained_backbone', default='', type=str, + help='The backbone file of YOLOv5. Default: "".') + parser.add_argument('--resume_yolov5', default='', type=str, + help='The ckpt file of YOLOv5, which used to fine tune. Default: ""') + + # optimizer and lr related + parser.add_argument('--lr_scheduler', default='cosine_annealing', type=str, + help='Learning rate scheduler, options: exponential, cosine_annealing. Default: exponential') + parser.add_argument('--lr', default=0.013, type=float, help='Learning rate. Default: 0.01') + parser.add_argument('--lr_epochs', type=str, default='220,250', + help='Epoch of changing of lr changing, split with ",". Default: 220,250') + parser.add_argument('--lr_gamma', type=float, default=0.1, + help='Decrease lr by a factor of exponential lr_scheduler. Default: 0.1') + parser.add_argument('--eta_min', type=float, default=0., help='Eta_min in cosine_annealing scheduler. Default: 0') + parser.add_argument('--T_max', type=int, default=300, help='T-max in cosine_annealing scheduler. Default: 320') + parser.add_argument('--max_epoch', type=int, default=300, help='Max epoch num to train the model. Default: 320') + parser.add_argument('--warmup_epochs', default=20, type=float, help='Warmup epochs. Default: 0') + parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay factor. Default: 0.0005') + parser.add_argument('--momentum', type=float, default=0.9, help='Momentum. Default: 0.9') + + # loss related + parser.add_argument('--loss_scale', type=int, default=1024, help='Static loss scale. Default: 1024') + parser.add_argument('--label_smooth', type=int, default=0, help='Whether to use label smooth in CE. Default:0') + parser.add_argument('--label_smooth_factor', type=float, default=0.1, + help='Smooth strength of original one-hot. Default: 0.1') + + # logging related + parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100') + parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/') + parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None') + + parser.add_argument('--is_save_on_master', type=int, default=1, + help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 1') + + # distributed related + parser.add_argument('--is_distributed', type=int, default=1, + help='Distribute train or not, 1 for yes, 0 for no. Default: 1') + parser.add_argument('--rank', type=int, default=0, help='Local rank of distributed. Default: 0') + parser.add_argument('--group_size', type=int, default=1, help='World size of device. Default: 1') + + # roma obs + parser.add_argument('--train_url', type=str, default="", help='train url') + # profiler init + parser.add_argument('--need_profiler', type=int, default=0, + help='Whether use profiler. 0 for no, 1 for yes. Default: 0') + + # reset default config + parser.add_argument('--training_shape', type=str, default="", help='Fix training shape. Default: ""') + parser.add_argument('--resize_rate', type=int, default=10, + help='Resize rate for multi-scale training. Default: None') + + args, _ = parser.parse_known_args() + args = merge_args(args, cloud_args) + if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.T_max: + args.T_max = args.max_epoch + + args.lr_epochs = list(map(int, args.lr_epochs.split(','))) + args.data_root = os.path.join(args.data_dir, 'train2017') + args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2017.json') + + devid = int(os.getenv('DEVICE_ID', '0')) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.device_target, save_graphs=False, device_id=devid) + # init distributed + if args.is_distributed: + if args.device_target == "Ascend": + init() + else: + init("nccl") + args.rank = get_rank() + args.group_size = get_group_size() + + # select for master rank save ckpt or all rank save, compatible for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = get_logger(args.outputs_dir, args.rank) + args.logger.save_args(args) + + return args + +def merge_args(args, cloud_args): + args_dict = vars(args) + if isinstance(cloud_args, dict): + for key in cloud_args.keys(): + val = cloud_args[key] + if key in args_dict and val: + arg_type = type(args_dict[key]) + if arg_type is not type(None): + val = arg_type(val) + args_dict[key] = val + return args + + +def convert_training_shape(args_training_shape): + training_shape = [int(args_training_shape), int(args_training_shape)] + return training_shape + + +def train(cloud_args=None): + args = parse_args(cloud_args) + loss_meter = AverageMeter('loss') + + context.reset_auto_parallel_context() + parallel_mode = ParallelMode.STAND_ALONE + degree = 1 + if args.is_distributed: + parallel_mode = ParallelMode.DATA_PARALLEL + degree = get_group_size() + context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree) + + network = YOLOV5s(is_training=True) + # default is kaiming-normal + default_recurisive_init(network) + load_yolov5_params(args, network) + + network = YoloWithLossCell(network) + config = ConfigYOLOV5() + + config.label_smooth = args.label_smooth + config.label_smooth_factor = args.label_smooth_factor + + if args.training_shape: + config.multi_scale = [convert_training_shape(args.training_shape)] + if args.resize_rate: + config.resize_rate = args.resize_rate + + ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True, + batch_size=args.per_batch_size, max_epoch=args.max_epoch, + device_num=args.group_size, rank=args.rank, config=config) + args.logger.info('Finish loading dataset') + + args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size) + + if not args.ckpt_interval: + args.ckpt_interval = args.steps_per_epoch + + lr = get_lr(args) + + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + + network = TrainingWrapper(network, opt, args.loss_scale // 2) + network.set_train() + + if args.rank_save_ckpt_flag: + # checkpoint save + ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, + keep_checkpoint_max=ckpt_max_num) + save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/') + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=save_ckpt_path, + prefix='{}'.format(args.rank)) + cb_params = _InternalCallbackParam() + cb_params.train_network = network + cb_params.epoch_num = ckpt_max_num + cb_params.cur_epoch_num = 1 + run_context = RunContext(cb_params) + ckpt_cb.begin(run_context) + + old_progress = -1 + t_end = time.time() + data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1) + + for i, data in enumerate(data_loader): + images = data["image"] + input_shape = images.shape[2:4] + images = Tensor.from_numpy(images) + batch_y_true_0 = Tensor.from_numpy(data['bbox1']) + batch_y_true_1 = Tensor.from_numpy(data['bbox2']) + batch_y_true_2 = Tensor.from_numpy(data['bbox3']) + batch_gt_box0 = Tensor.from_numpy(data['gt_box1']) + batch_gt_box1 = Tensor.from_numpy(data['gt_box2']) + batch_gt_box2 = Tensor.from_numpy(data['gt_box3']) + input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) + loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, + batch_gt_box2, input_shape) + loss_meter.update(loss.asnumpy()) + + if args.rank_save_ckpt_flag: + # ckpt progress + cb_params.cur_step_num = i + 1 # current step number + cb_params.batch_num = i + 2 + ckpt_cb.step_end(run_context) + + if i % args.log_interval == 0: + time_used = time.time() - t_end + epoch = int(i / args.steps_per_epoch) + fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used + if args.rank == 0: + args.logger.info( + 'epoch[{}], iter[{}], {}, fps:{:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i])) + t_end = time.time() + loss_meter.reset() + old_progress = i + + if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag: + cb_params.cur_epoch_num += 1 + + args.logger.info('==========end training===============') + +if __name__ == "__main__": + train()