diff --git a/model_zoo/research/cv/retinanet_resnet101/README_CN.md b/model_zoo/research/cv/retinanet_resnet101/README_CN.md new file mode 100644 index 00000000000..a433ea2b43c --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/README_CN.md @@ -0,0 +1,315 @@ +# 1. 内容 + + + +- [Retinanet 描述](#-Retinanet-描述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [脚本说明](#脚本说明) + - [脚本和示例代码](#脚本和示例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [用法](#用法) + - [运行](#运行) + - [结果](#结果) + - [评估过程](#评估过程) + - [用法](#usage) + - [运行](#running) + - [结果](#outcome) + - [模型说明](#模型说明) + - [性能](#性能) + - [训练性能](#训练性能) + - [推理性能](#推理性能) +- [随机情况的描述](#随机情况的描述) +- [ModelZoo 主页](#modelzoo-主页) + + + +## [Retinanet 描述](#content) + +RetinaNet算法源自2018年Facebook AI Research的论文 Focal Loss for Dense Object Detection。该论文最大的贡献在于提出了Focal Loss用于解决类别不均衡问题,从而创造了RetinaNet(One Stage目标检测算法)这个精度超越经典Two Stage的Faster-RCNN的目标检测网络。 + +[论文](https://arxiv.org/pdf/1708.02002.pdf) +Lin T Y , Goyal P , Girshick R , et al. Focal Loss for Dense Object Detection[C]// 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, 2017:2999-3007. + +## [模型架构](#content) + +Retinanet的整体网络架构如下所示: + +[链接](https://arxiv.org/pdf/1708.02002.pdf) + +## [数据集](#content) + +数据集可参考文献. + +MSCOCO2017 + +- 数据集大小: 19.3G, 123287张80类彩色图像 + + - 训练:19.3G, 118287张图片 + + - 测试:1814.3M, 5000张图片 + +- 数据格式:RGB图像. + + - 注意:数据将在src/dataset.py 中被处理 + +## [环境要求](#content) + +- 硬件(Ascend) + - 使用Ascend处理器准备硬件环境。 +- 架构 + - [MindSpore](https://www.mindspore.cn/install/en) +- 想要获取更多信息,请检查以下资源: + - [MindSpore 教程](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +## [脚本说明](#content) + +### [脚本和示例代码](#content) + +```shell +. +└─Retinanet_resnet101 + ├─README.md + ├─scripts + ├─run_single_train.sh # 使用Ascend环境单卡训练 + ├─run_distribute_train.sh # 使用Ascend环境八卡并行训练 + ├─run_eval.sh # 使用Ascend环境运行推理脚本 + ├─src + ├─backbone.py # 网络模型定义 + ├─bottleneck.py # 网络颈部定义 + ├─config.py # 参数配置 + ├─dataset.py # 数据预处理 + ├─retinahead.py # 网络预测头部定义 + ├─init_params.py # 参数初始化 + ├─lr_generator.py # 学习率生成函数 + ├─coco_eval # coco数据集评估 + ├─box_utils.py # 先验框设置 + ├─_init_.py # 初始化 + ├─train.py # 网络训练脚本 + └─eval.py # 网络推理脚本 + +``` + +### [脚本参数](#content) + +```python +在train.py和config.py脚本中使用到的主要参数是: +"img_shape": [600, 600], # 图像尺寸 +"num_retinanet_boxes": 67995, # 设置的先验框总数 +"match_thershold": 0.5, # 匹配阈值 +"softnms_sigma": 0.5, # softnms算法σ值 +"nms_thershold": 0.6, # 非极大抑制阈值 +"min_score": 0.1, # 最低得分 +"max_boxes": 100, # 检测框最大数量 +"global_step": 0, # 全局步数 +"lr_init": 1e-6, # 初始学习率 +"lr_end_rate": 5e-3, # 最终学习率与最大学习率的比值 +"warmup_epochs1": 2, # 第一阶段warmup的周期数 +"warmup_epochs2": 5, # 第二阶段warmup的周期数 +"warmup_epochs3": 23, # 第三阶段warmup的周期数 +"warmup_epochs4": 60, # 第四阶段warmup的周期数 +"warmup_epochs5": 160, # 第五阶段warmup的周期数 +"momentum": 0.9, # momentum +"weight_decay": 1.5e-4, # 权重衰减率 +"num_default": [9, 9, 9, 9, 9], # 单个网格中先验框的个数 +"extras_out_channels": [256, 256, 256, 256, 256], # 特征层输出通道数 +"feature_size": [75, 38, 19, 10, 5], # 特征层尺寸 +"aspect_ratios": [(0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0)], # 先验框大小变化比值 +"steps": ( 8, 16, 32, 64, 128), # 先验框设置步长 +"anchor_size":(32, 64, 128, 256, 512), # 先验框尺寸 +"prior_scaling": (0.1, 0.2), # 用于调节回归与回归在loss中占的比值 +"gamma": 2.0, # focal loss中的参数 +"alpha": 0.75, # focal loss中的参数 +"mindrecord_dir": "/opr/root/data/MindRecord_COCO", # mindrecord文件路径 +"coco_root": "/opr/root/data/", # coco数据集路径 +"train_data_type": "train2017", # train图像的文件夹名 +"val_data_type": "val2017", # val图像的文件夹名 +"instances_set": "annotations_trainval2017/annotations/instances_{}.json", # 标签文件路径 +"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', # coco数据集的种类 + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), +"num_classes": 81, # 数据集类别数 +"voc_root": "", # voc数据集路径 +"voc_dir": "", +"image_dir": "", # 图像路径 +"anno_path": "", # 标签文件路径 +"save_checkpoint": True, # 保存checkpoint +"save_checkpoint_epochs": 1, # 保存checkpoint epoch数 +"keep_checkpoint_max":1, # 保存checkpoint的最大数量 +"save_checkpoint_path": "./model", # 保存checkpoint的路径 +"finish_epoch":0, # 已经运行完成的 epoch 数 +"checkpoint_path":"/opr/root/reretina/retinanet2/LOG0/model/retinanet-400_458.ckpt" # 用于验证的checkpoint路径 +``` + +### [训练过程](#content) + +#### 用法 + +您可以使用python或shell脚本进行训练。shell脚本的用法如下: + +- Ascend: + +```训练 +# 八卡并行训练示例: + +创建 RANK_TABLE_FILE +sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional) + +# 单卡训练示例: + +sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional) + +``` + +> 注意: + + RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). + +#### 运行 + +``` 运行 +# 训练示例 + + python: + data和存储mindrecord文件的路径在config里设置 + + # 单卡训练示例: + + python train.py + shell: + Ascend: + + # 八卡并行训练示例(在retinanet目录下运行): + + sh scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE(预训练EPOCH大小) + 例如:sh scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322 + + # 单卡训练示例(在retinanet目录下运行): + + sh scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322 + +``` + +#### 结果 + +训练结果将存储在示例路径中。checkpoint将存储在 `./model` 路径下,训练日志将被记录到 `./log.txt` 中,训练日志部分示例如下: + +``` 训练日志 +epoch: 397 step: 458, loss is 0.6153226 +lr:[0.000598] +epoch time: 313364.642 ms, per step time: 684.202 ms +epoch: 398 step: 458, loss is 0.5491791 +lr:[0.000544] +epoch time: 313486.094 ms, per step time: 684.467 ms +epoch: 399 step: 458, loss is 0.51681435 +lr:[0.000511] +epoch time: 313514.348 ms, per step time: 684.529 ms +epoch: 400 step: 458, loss is 0.4305706 +lr:[0.000500] +epoch time: 314138.455 ms, per step time: 685.892 ms +``` + +### [评估过程](#content) + +#### 用法 + +您可以使用python或shell脚本进行训练。shell脚本的用法如下: + +```eval +sh scripts/run_eval.sh [DATASET] [DEVICE_ID] +``` + +#### 运行 + +```eval运行 +# 验证示例 + + python: + Ascend: python eval.py + checkpoint 的路径在config里设置 + shell: + Ascend: sh scripts/run_eval.sh coco 0 +``` + +> checkpoint 可以在训练过程中产生. + +#### 结果 + +计算结果将存储在示例路径中,您可以在 `eval.log` 查看. + +``` mAP + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.371 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.517 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.408 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.143 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.394 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.547 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.318 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.455 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.464 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.172 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.489 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.680 + +======================================== + +mAP: 0.3710347196613514 +``` + +## [模型说明](#content) + +### [性能](#content) + +#### 训练性能 + +| 参数 | Ascend | +| -------------------------- | ------------------------------------- | +| 模型名称 | Retinanet | +| 运行环境 | 华为云 Modelarts | +| 上传时间 | 10/03/2021 | +| MindSpore 版本 | 1.0.1 | +| 数据集 | 123287 张图片 | +| Batch_size | 32 | +| 训练参数 | src/config.py | +| 优化器 | Momentum | +| 损失函数 | Focal loss | +| 最终损失 | 0.43 | +| 精确度 (8p) | mAP[0.3710] | +| 训练总时间 (8p) | 34h50m20s | +| 脚本 | [Retianet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/Retinanet_resnet101) | + +#### 推理性能 + +| 参数 | Ascend | +| ------------------- | :-------------------------- | +| 模型名称 | Retinanet | +| 运行环境 | 华为云 Modelarts | +| 上传时间 | 10/03/2021 | +| MindSpore 版本 | 1.0.1 | +| 数据集 | 5k 张图片 | +| Batch_size | 1 | +| 精确度 | mAP[0.3710] | +| 总时间 | 10 mins and 50 seconds | + +# [随机情况的描述](#内容) + +在 `dataset.py` 脚本中, 我们在 `create_dataset` 函数中设置了随机种子. 我们在 `train.py` 脚本中也设置了随机种子. + +# [ModelZoo 主页](#内容) + +请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file diff --git a/model_zoo/research/cv/retinanet_resnet101/eval.py b/model_zoo/research/cv/retinanet_resnet101/eval.py new file mode 100644 index 00000000000..3fa31a1b34f --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/eval.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Evaluation for retinanet""" + +import os +import argparse +import time +import numpy as np +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.retinahead import retinahead, retinanetInferWithDecoder +from src.backbone import resnet101 +from src.dataset import create_retinanet_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord +from src.config import config +from src.coco_eval import metrics +from src.box_utils import default_boxes + + +def retinanet_eval(dataset_path, ckpt_path): + """retinanet evaluation.""" + batch_size = 1 + ds = create_retinanet_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) + backbone = resnet101(config.num_classes) + net = retinahead(backbone, config) + net = retinanetInferWithDecoder(net, Tensor(default_boxes), config) + print("Load Checkpoint!") + param_dict = load_checkpoint(ckpt_path) + net.init_parameters_data() + load_param_into_net(net, param_dict) + + net.set_train(False) + i = batch_size + total = ds.get_dataset_size() * batch_size + start = time.time() + pred_data = [] + print("\n========================================\n") + print("total images num: ", total) + print("Processing, please wait a moment.") + for data in ds.create_dict_iterator(output_numpy=True): + img_id = data['img_id'] + img_np = data['image'] + image_shape = data['image_shape'] + + output = net(Tensor(img_np)) + for batch_idx in range(img_np.shape[0]): + pred_data.append({"boxes": output[0].asnumpy()[batch_idx], + "box_scores": output[1].asnumpy()[batch_idx], + "img_id": int(np.squeeze(img_id[batch_idx])), + "image_shape": image_shape[batch_idx]}) + percent = round(i / total * 100., 2) + + print(f' {str(percent)} [{i}/{total}]', end='\r') + i += batch_size + cost_time = int((time.time() - start) * 1000) + print(f' 100% [{total}/{total}] cost {cost_time} ms') + mAP = metrics(pred_data) + print("\n========================================\n") + print(f"mAP: {mAP}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='retinanet evaluation') + parser.add_argument("--device_id", type=int, default=3, help="Device id, default is 0.") + parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") + parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"), + help="run platform, only support Ascend and GPU.") + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) + + prefix = "retinanet_eval.mindrecord" + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + if args_opt.dataset == "voc": + config.coco_root = config.voc_root + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if args_opt.dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("coco", False, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + elif args_opt.dataset == "voc": + if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root): + print("Create Mindrecord.") + voc_data_to_mindrecord(mindrecord_dir, False, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("voc_root or voc_dir not exits.") + else: + if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("other", False, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("IMAGE_DIR or ANNO_PATH not exits.") + + print("Start Eval!") + retinanet_eval(mindrecord_file, config.checkpoint_path) diff --git a/model_zoo/research/cv/retinanet_resnet101/export.py b/model_zoo/research/cv/retinanet_resnet101/export.py new file mode 100644 index 00000000000..16a512cb780 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/export.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Export file""" + +import argparse +import numpy as np + +from mindspore import dtype as mstype +from mindspore import context, Tensor +from mindspore.train.serialization import export, load_checkpoint, load_param_into_net + +from src.retinahead import retinahead +from src.backbone import resnet101 +from src.config import config + +parser = argparse.ArgumentParser(description="retinanet_resnet101 export") +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="retinanet_resnet101", help="output file name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format") +parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", + help="device target") +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) + +if __name__ == "__main__": + network = retinahead(backbone=resnet101(80), config=config, is_training=False) + + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(network, param_dict) + + network.set_train(False) + + shape = [args.batch_size, 3] + [600, 600] + input_data = Tensor(np.zeros(shape), mstype.float32) + + export(network, input_data, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/research/cv/retinanet_resnet101/scripts/run_distribute_train.sh b/model_zoo/research/cv/retinanet_resnet101/scripts/run_distribute_train.sh new file mode 100644 index 00000000000..ebb3a80dee9 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/scripts/run_distribute_train.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" +echo "for example: sh run_distribute_train.sh 8 500 0.1 coco /data/hccl.json /opt/retinanet-500_458.ckpt(optional) 200(optional)" +echo "It is better to use absolute path." +echo "=================================================================================================================" + +if [ $# != 5 ] && [ $# != 7 ] +then + echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \ +[RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" + exit 1 +fi + +# Before start distribute train, first create mindrecord files. +BASE_PATH=$(cd "`dirname $0`" || exit; pwd) +cd $BASE_PATH/../ || exit +python train.py --only_create_dataset=True + +echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt" + +export RANK_SIZE=$1 +EPOCH_SIZE=$2 +LR=$3 +DATASET=$4 +PRE_TRAINED=$6 +PRE_TRAINED_EPOCH_SIZE=$7 +export RANK_TABLE_FILE=$5 + +for((i=0;i env.log + if [ $# == 5 ] + then + python train.py \ + --distribute=True \ + --lr=$LR \ + --dataset=$DATASET \ + --device_num=$RANK_SIZE \ + --device_id=$DEVICE_ID \ + --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & + fi + + if [ $# == 7 ] + then + python train.py \ + --distribute=True \ + --lr=$LR \ + --dataset=$DATASET \ + --device_num=$RANK_SIZE \ + --device_id=$DEVICE_ID \ + --pre_trained=$PRE_TRAINED \ + --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ + --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & + fi + + cd ../ +done diff --git a/model_zoo/research/cv/retinanet_resnet101/scripts/run_eval.sh b/model_zoo/research/cv/retinanet_resnet101/scripts/run_eval.sh new file mode 100644 index 00000000000..111de06e7ad --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/scripts/run_eval.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_eval.sh [DATASET] [DEVICE_ID]" +exit 1 +fi + +DATASET=$1 +echo $DATASET + + +export DEVICE_NUM=1 +export DEVICE_ID=$2 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +BASE_PATH=$(cd "`dirname $0`" || exit; pwd) +cd $BASE_PATH/../ || exit + +if [ -d "eval$2" ]; +then + rm -rf ./eval$2 +fi + +mkdir ./eval$2 +cp ./*.py ./eval$2 +cp -r ./src ./eval$2 +cd ./eval$2 || exit +env > env.log +echo "start inferring for device $DEVICE_ID" +python eval.py \ + --dataset=$DATASET \ + --device_id=$2 > log.txt 2>&1 & +cd .. diff --git a/model_zoo/research/cv/retinanet_resnet101/scripts/run_single_train.sh b/model_zoo/research/cv/retinanet_resnet101/scripts/run_single_train.sh new file mode 100644 index 00000000000..67f87fd91e8 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/scripts/run_single_train.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" +echo "for example: sh run_single_train.sh 0 500 0.1 coco /opt/retinanet-500_458.ckpt(optional) 200(optional)" +echo "It is better to use absolute path." +echo "=================================================================================================================" + +if [ $# != 4 ] && [ $# != 6 ] +then + echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] [DATASET] \ +[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" + exit 1 +fi + +# Before start single train, first create mindrecord files. +BASE_PATH=$(cd "`dirname $0`" || exit; pwd) +cd $BASE_PATH/../ || exit +python train.py --only_create_dataset=True + +echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt" + +export DEVICE_ID=$1 +EPOCH_SIZE=$2 +LR=$3 +DATASET=$4 + +rm -rf LOG$1 +mkdir ./LOG$1 +cp ./*.py ./LOG$1 +cp -r ./src ./LOG$1 +cd ./LOG$1 || exit +echo "start training for device $1" +env > env.log +python train.py \ +--distribute=False \ +--lr=$LR \ +--dataset=$DATASET \ +--device_num=1 \ +--device_id=$DEVICE_ID \ +--epoch_size=$EPOCH_SIZE > log.txt 2>&1 & +cd ../ + diff --git a/model_zoo/research/cv/retinanet_resnet101/src/__init__.py b/model_zoo/research/cv/retinanet_resnet101/src/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/cv/retinanet_resnet101/src/backbone.py b/model_zoo/research/cv/retinanet_resnet101/src/backbone.py new file mode 100644 index 00000000000..9f2b822d771 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/backbone.py @@ -0,0 +1,226 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""BackBone file""" + +import mindspore.nn as nn +from mindspore.ops import operations as P + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.97, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = 0 + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', + padding=padding) + layers = [conv, _bn(out_planes), nn.ReLU()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + + channel = out_channel // self.expansion + self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) + self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) + self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', padding=0, + has_bn=True, activation='relu') + + self.down_sample = False + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel, + kernel_size=1, stride=stride, + pad_mode='same', padding=0, has_bn=True, activation='relu') + self.add = P.TensorAdd() + self.relu = P.ReLU() + + def construct(self, x): + """construct""" + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class resnet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes): + super(resnet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = ResidualBlock(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + + for _ in range(1, layer_num): + resnet_block = ResidualBlock(out_channel, out_channel, stride=1) + layers.append(resnet_block) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + C1 = self.maxpool(x) + + C2 = self.layer1(C1) + C3 = self.layer2(C2) + C4 = self.layer3(C3) + C5 = self.layer4(C4) + return C3, C4, C5 + + +def resnet101(num_classes): + return resnet(ResidualBlock, + [3, 4, 23, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + num_classes) + + +def resnet152(num_classes): + return resnet(ResidualBlock, + [3, 8, 36, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + num_classes) diff --git a/model_zoo/research/cv/retinanet_resnet101/src/bottleneck.py b/model_zoo/research/cv/retinanet_resnet101/src/bottleneck.py new file mode 100644 index 00000000000..f66908f6573 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/bottleneck.py @@ -0,0 +1,71 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bottleneck""" + +import mindspore.nn as nn +from mindspore.ops import operations as P + + +class FPN(nn.Cell): + """FPN""" + def __init__(self, config, backbone, is_training=True): + super(FPN, self).__init__() + + self.backbone = backbone + feature_size = config.feature_size + self.P5_1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, pad_mode='same') + self.P_upsample1 = P.ResizeNearestNeighbor((feature_size[1], feature_size[1])) + self.P5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same') + + self.P4_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, pad_mode='same') + self.P_upsample2 = P.ResizeNearestNeighbor((feature_size[0], feature_size[0])) + self.P4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same') + + self.P3_1 = nn.Conv2d(512, 256, kernel_size=1, stride=1, pad_mode='same') + self.P3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same') + + self.P6_0 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, pad_mode='same') + + self.P7_1 = nn.ReLU() + self.P7_2 = nn.Conv2d(256, 256, kernel_size=3, stride=2, pad_mode='same') + + self.is_training = is_training + if not is_training: + self.activation = P.Sigmoid() + + def construct(self, x): + """construct""" + C3, C4, C5 = self.backbone(x) + + P5 = self.P5_1(C5) + P5_upsampled = self.P_upsample1(P5) + P5 = self.P5_2(P5) + + P4 = self.P4_1(C4) + P4 = P5_upsampled + P4 + P4_upsampled = self.P_upsample2(P4) + P4 = self.P4_2(P4) + + P3 = self.P3_1(C3) + P3 = P4_upsampled + P3 + P3 = self.P3_2(P3) + + P6 = self.P6_0(C5) + + P7 = self.P7_1(P6) + P7 = self.P7_2(P7) + multi_feature = (P3, P4, P5, P6, P7) + + return multi_feature diff --git a/model_zoo/research/cv/retinanet_resnet101/src/box_utils.py b/model_zoo/research/cv/retinanet_resnet101/src/box_utils.py new file mode 100644 index 00000000000..9795da62931 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/box_utils.py @@ -0,0 +1,166 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Bbox utils""" + +import math +import itertools as it +import numpy as np +from .config import config + + +class GeneratDefaultBoxes(): + """ + Generate Default boxes for retinanet, follows the order of (W, H, archor_sizes). + `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. + `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. + """ + + def __init__(self): + fk = config.img_shape[0] / np.array(config.steps) + scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]) + anchor_size = np.array(config.anchor_size) + self.default_boxes = [] + for idex, feature_size in enumerate(config.feature_size): + base_size = anchor_size[idex] / config.img_shape[0] + size1 = base_size * scales[0] + size2 = base_size * scales[1] + size3 = base_size * scales[2] + all_sizes = [] + for aspect_ratio in config.aspect_ratios[idex]: + w1, h1 = size1 * math.sqrt(aspect_ratio), size1 / math.sqrt(aspect_ratio) + all_sizes.append((h1, w1)) + w2, h2 = size2 * math.sqrt(aspect_ratio), size2 / math.sqrt(aspect_ratio) + all_sizes.append((h2, w2)) + w3, h3 = size3 * math.sqrt(aspect_ratio), size3 / math.sqrt(aspect_ratio) + all_sizes.append((h3, w3)) + + assert len(all_sizes) == config.num_default[idex] + + for i, j in it.product(range(feature_size), repeat=2): + for h, w in all_sizes: + cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] + self.default_boxes.append([cy, cx, h, w]) + + def to_ltrb(cy, cx, h, w): + return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 + + # For IoU calculation + self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') + self.default_boxes = np.array(self.default_boxes, dtype='float32') + + +default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb +default_boxes = GeneratDefaultBoxes().default_boxes +y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) +vol_anchors = (x2 - x1) * (y2 - y1) +matching_threshold = config.match_thershold + + +def retinanet_bboxes_encode(boxes): + """ + Labels anchors with ground truth inputs. + + Args: + boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls]. + + Returns: + gt_loc: location ground truth with shape [num_anchors, 4]. + gt_label: class ground truth with shape [num_anchors, 1]. + num_matched_boxes: number of positives in an image. + """ + + def jaccard_with_anchors(bbox): + """Compute jaccard score a box and the anchors.""" + # Intersection bbox and volume. + ymin = np.maximum(y1, bbox[0]) + xmin = np.maximum(x1, bbox[1]) + ymax = np.minimum(y2, bbox[2]) + xmax = np.minimum(x2, bbox[3]) + w = np.maximum(xmax - xmin, 0.) + h = np.maximum(ymax - ymin, 0.) + + # Volumes. + inter_vol = h * w + union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol + jaccard = inter_vol / union_vol + return np.squeeze(jaccard) + + pre_scores = np.zeros((config.num_retinanet_boxes), dtype=np.float32) + t_boxes = np.zeros((config.num_retinanet_boxes, 4), dtype=np.float32) + t_label = np.zeros((config.num_retinanet_boxes), dtype=np.int64) + for bbox in boxes: + label = int(bbox[4]) + scores = jaccard_with_anchors(bbox) + idx = np.argmax(scores) + scores[idx] = 2.0 + mask = (scores > matching_threshold) + mask = mask & (scores > pre_scores) + pre_scores = np.maximum(pre_scores, scores * mask) + t_label = mask * label + (1 - mask) * t_label + for i in range(4): + t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] + + index = np.nonzero(t_label) + + # Transform to ltrb. + bboxes = np.zeros((config.num_retinanet_boxes, 4), dtype=np.float32) + bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 + bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] + + # Encode features. + bboxes_t = bboxes[index] + default_boxes_t = default_boxes[index] + bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.prior_scaling[0]) + tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001) + bboxes_t[:, 2:4] = np.log(tmp) / config.prior_scaling[1] + bboxes[index] = bboxes_t + + num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) + return bboxes, t_label.astype(np.int32), num_match + + +def retinanet_bboxes_decode(boxes): + """Decode predict boxes to [y, x, h, w]""" + boxes_t = boxes.copy() + default_boxes_t = default_boxes.copy() + boxes_t[:, :2] = boxes_t[:, :2] * config.prior_scaling[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2] + boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.prior_scaling[1]) * default_boxes_t[:, 2:4] + + bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32) + + bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 + bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 + + return np.clip(bboxes, 0, 1) + + +def intersect(box_a, box_b): + """Compute the intersect of two sets of boxes.""" + max_yx = np.minimum(box_a[:, 2:4], box_b[2:4]) + min_yx = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] + + +def jaccard_numpy(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes.""" + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * + (box_a[:, 3] - box_a[:, 1])) + area_b = ((box_b[2] - box_b[0]) * + (box_b[3] - box_b[1])) + union = area_a + area_b - inter + return inter / union diff --git a/model_zoo/research/cv/retinanet_resnet101/src/coco_eval.py b/model_zoo/research/cv/retinanet_resnet101/src/coco_eval.py new file mode 100644 index 00000000000..d4e2666f7da --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/coco_eval.py @@ -0,0 +1,196 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Coco metrics utils""" + +import os +import json +import numpy as np +from .config import config + + +def apply_softnms(dets, scores, sigma=0.5, method=2, thresh=0.001, Nt=0.1): + ''' + the soft nms implement using python + :param dets: the pred_bboxes + :param method: the policy of decay pred_bbox score in soft nms + :param thresh: the threshold + :param Nt: Nt + :return: the index of pred_bbox after soft nms + ''' + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + + areas = (y2 - y1 + 1.) * (x2 - x1 + 1.) + orders = scores.argsort()[::-1] + keep = [] + + while orders.size > 0: + + i = orders[0] + keep.append(i) + + for j in orders[1:]: + + xx1 = np.maximum(x1[i], x1[j]) + yy1 = np.maximum(y1[i], y1[j]) + xx2 = np.minimum(x2[i], x2[j]) + yy2 = np.minimum(y2[i], y2[j]) + w = np.maximum(xx2 - xx1 + 1., 0.) + h = np.maximum(yy2 - yy1 + 1., 0.) + + inter = w * h + overlap = inter / (areas[i] + areas[j] - inter) + + if method == 1: # linear + + if overlap > Nt: + + weight = 1 - overlap + + else: + + weight = 1 + + elif method == 2: # gaussian + + weight = np.exp(-(overlap * overlap) / sigma) + + else: # original NMS + + if overlap > Nt: + + weight = 0 + + else: + + weight = 1 + + scores[j] = weight * scores[j] + + if scores[j] < thresh: + orders = np.delete(orders, np.where(orders == j)) + + orders = np.delete(orders, 0) + + return keep + + +def apply_nms(all_boxes, all_scores, thres, max_boxes): + """Apply NMS to bboxes.""" + y1 = all_boxes[:, 0] + x1 = all_boxes[:, 1] + y2 = all_boxes[:, 2] + x2 = all_boxes[:, 3] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + + order = all_scores.argsort()[::-1] + keep = [] + + while order.size > 0: + i = order[0] + keep.append(i) + + if len(keep) >= max_boxes: + break + + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thres)[0] + + order = order[inds + 1] + return keep + + +def metrics(pred_data): + """Calculate mAP of predicted bboxes.""" + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + num_classes = config.num_classes + + coco_root = config.coco_root + data_type = config.val_data_type + + # Classes need to train or test. + val_cls = config.coco_classes + val_cls_dict = {} + for i, cls in enumerate(val_cls): + val_cls_dict[i] = cls + + anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) + coco_gt = COCO(anno_json) + classs_dict = {} + cat_ids = coco_gt.loadCats(coco_gt.getCatIds()) + for cat in cat_ids: + classs_dict[cat["name"]] = cat["id"] + + predictions = [] + img_ids = [] + + for sample in pred_data: + pred_boxes = sample['boxes'] + box_scores = sample['box_scores'] + img_id = sample['img_id'] + h, w = sample['image_shape'] + + final_boxes = [] + final_label = [] + final_score = [] + img_ids.append(img_id) + + for c in range(1, num_classes): + class_box_scores = box_scores[:, c] + score_mask = class_box_scores > config.min_score + class_box_scores = class_box_scores[score_mask] + class_boxes = pred_boxes[score_mask] * [h, w, h, w] + + if score_mask.any(): + # nms_index = apply_nms(class_boxes, class_box_scores, config.nms_thershold, config.max_boxes) + # apply_softnms( dets, scores,method=2, thresh=0.001, Nt=0.1, sigma=0.5 ) + nms_index = apply_softnms(class_boxes, class_box_scores, config.softnms_sigma) + class_boxes = class_boxes[nms_index] + class_box_scores = class_box_scores[nms_index] + + final_boxes += class_boxes.tolist() + final_score += class_box_scores.tolist() + final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores) + + for loc, label, score in zip(final_boxes, final_label, final_score): + res = {} + res['image_id'] = img_id + res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]] + res['score'] = score + res['category_id'] = label + predictions.append(res) + with open('predictions.json', 'w') as f: + json.dump(predictions, f) + + coco_dt = coco_gt.loadRes('predictions.json') + E = COCOeval(coco_gt, coco_dt, iouType='bbox') + E.params.imgIds = img_ids + E.evaluate() + E.accumulate() + E.summarize() + return E.stats[0] diff --git a/model_zoo/research/cv/retinanet_resnet101/src/config.py b/model_zoo/research/cv/retinanet_resnet101/src/config.py new file mode 100644 index 00000000000..b16b259434d --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/config.py @@ -0,0 +1,87 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# " ============================================================================ + +"""Config parameters for retinanet models.""" + +from easydict import EasyDict as ed + +config = ed({ + "img_shape": [600, 600], + "num_retinanet_boxes": 67995, + "match_thershold": 0.5, + "softnms_sigma": 0.5, + "nms_thershold": 0.6, + "min_score": 0.1, + "max_boxes": 100, + + # learing rate settings + "global_step": 0, + "lr_init": 1e-6, + "lr_end_rate": 5e-3, + "warmup_epochs1": 2, + "warmup_epochs2": 5, + "warmup_epochs3": 23, + "warmup_epochs4": 60, + "warmup_epochs5": 160, + "momentum": 0.9, + "weight_decay": 1.5e-4, + + # network + "num_default": [9, 9, 9, 9, 9], + "extras_out_channels": [256, 256, 256, 256, 256], + "feature_size": [75, 38, 19, 10, 5], + "aspect_ratios": [(0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0)], + "steps": (8, 16, 32, 64, 128), + "anchor_size": (32, 64, 128, 256, 512), + "prior_scaling": (0.1, 0.2), + "gamma": 2.0, + "alpha": 0.75, + + # `mindrecord_dir` and `coco_root` are better to use absolute path. + "mindrecord_dir": "/opr/root/data/MindRecord_COCO", + "coco_root": "/opr/root/data/", + "train_data_type": "train2017", + "val_data_type": "val2017", + "instances_set": "anno/instances_{}.json", + "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + "num_classes": 81, + # The annotation.json position of voc validation dataset. + "voc_root": "", + # voc original dataset. + "voc_dir": "", + # if coco or voc used, `image_dir` and `anno_path` are useless. + "image_dir": "", + "anno_path": "", + "save_checkpoint": True, + "save_checkpoint_epochs": 5, + "keep_checkpoint_max": 30, + "save_checkpoint_path": "./model", + "finish_epoch": 0, + "checkpoint_path": "/opr/root/reretina/retinanet2/LOG0/model/retinanet-400_458.ckpt" +}) diff --git a/model_zoo/research/cv/retinanet_resnet101/src/dataset.py b/model_zoo/research/cv/retinanet_resnet101/src/dataset.py new file mode 100644 index 00000000000..5596f2a4243 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/dataset.py @@ -0,0 +1,454 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""retinanet dataset""" + +from __future__ import division + +import os +import json +import xml.etree.ElementTree as et +import numpy as np +import cv2 + +import mindspore.dataset as de +import mindspore.dataset.vision.c_transforms as C +from mindspore.mindrecord import FileWriter +from .config import config +from .box_utils import jaccard_numpy, retinanet_bboxes_encode + + +def _rand(a=0., b=1.): + """Generate random.""" + return np.random.rand() * (b - a) + a + + +def get_imageId_from_fileName(filename): + """Get imageID from fileName""" + filename = os.path.splitext(filename)[0] + if filename.isdigit(): + return int(filename) + return id_iter + + +def random_sample_crop(image, boxes): + """Random Crop the image and boxes""" + height, width, _ = image.shape + min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) + + if min_iou is None: + return image, boxes + + # max trails (50) + for _ in range(50): + image_t = image + + w = _rand(0.3, 1.0) * width + h = _rand(0.3, 1.0) * height + + # aspect ratio constraint b/t .5 & 2 + if h / w < 0.5 or h / w > 2: + continue + + left = _rand() * (width - w) + top = _rand() * (height - h) + + rect = np.array([int(top), int(left), int(top + h), int(left + w)]) + overlap = jaccard_numpy(boxes, rect) + + # dropout some boxes + drop_mask = overlap > 0 + if not drop_mask.any(): + continue + + if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2): + continue + + image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :] + + centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0 + + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 * drop_mask + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # take only matching gt boxes + boxes_t = boxes[mask, :].copy() + + boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2]) + boxes_t[:, :2] -= rect[:2] + boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4]) + boxes_t[:, 2:4] -= rect[:2] + + return image_t, boxes_t + return image, boxes + + +def preprocess_fn(img_id, image, box, is_training): + """Preprocess function for dataset.""" + cv2.setNumThreads(2) + + def _infer_data(image, input_shape): + img_h, img_w, _ = image.shape + input_h, input_w = input_shape + + image = cv2.resize(image, (input_w, input_h)) + + # When the channels of image is 1 + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + return img_id, image, np.array((img_h, img_w), np.float32) + + def _data_aug(image, box, is_training, image_size=(600, 600)): + """Data augmentation function.""" + ih, iw, _ = image.shape + w, h = image_size + + if not is_training: + return _infer_data(image, image_size) + + # Random crop + box = box.astype(np.float32) + image, box = random_sample_crop(image, box) + ih, iw, _ = image.shape + + # Resize image + image = cv2.resize(image, (w, h)) + + # Flip image or not + flip = _rand() < .5 + if flip: + image = cv2.flip(image, 1, dst=None) + + # When the channels of image is 1 + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + box[:, [0, 2]] = box[:, [0, 2]] / ih + box[:, [1, 3]] = box[:, [1, 3]] / iw + + if flip: + box[:, [1, 3]] = 1 - box[:, [3, 1]] + + box, label, num_match = retinanet_bboxes_encode(box) + return image, box, label, num_match + + return _data_aug(image, box, is_training, image_size=config.img_shape) + + +def create_voc_label(is_training): + """Get image path and annotation from VOC.""" + voc_dir = config.voc_dir + cls_map = {name: i for i, name in enumerate(config.coco_classes)} + sub_dir = 'train' if is_training else 'eval' + voc_dir = os.path.join(voc_dir, sub_dir) + if not os.path.isdir(voc_dir): + raise ValueError(f'Cannot find {sub_dir} dataset path.') + + image_dir = anno_dir = voc_dir + if os.path.isdir(os.path.join(voc_dir, 'Images')): + image_dir = os.path.join(voc_dir, 'Images') + if os.path.isdir(os.path.join(voc_dir, 'Annotations')): + anno_dir = os.path.join(voc_dir, 'Annotations') + + if not is_training: + data_dir = config.voc_root + json_file = os.path.join(data_dir, config.instances_set.format(sub_dir)) + file_dir = os.path.split(json_file)[0] + if not os.path.isdir(file_dir): + os.makedirs(file_dir) + json_dict = {"images": [], "type": "instances", "annotations": [], + "categories": []} + bnd_id = 1 + + image_files_dict = {} + image_anno_dict = {} + images = [] + for anno_file in os.listdir(anno_dir): + print(anno_file) + if not anno_file.endswith('xml'): + continue + tree = et.parse(os.path.join(anno_dir, anno_file)) + root_node = tree.getroot() + file_name = root_node.find('filename').text + img_id = get_imageId_from_fileName(file_name) + image_path = os.path.join(image_dir, file_name) + print(image_path) + if not os.path.isfile(image_path): + print(f'Cannot find image {file_name} according to annotations.') + continue + + labels = [] + for obj in root_node.iter('object'): + cls_name = obj.find('name').text + if cls_name not in cls_map: + print(f'Label "{cls_name}" not in "{config.coco_classes}"') + continue + bnd_box = obj.find('bndbox') + x_min = int(bnd_box.find('xmin').text) - 1 + y_min = int(bnd_box.find('ymin').text) - 1 + x_max = int(bnd_box.find('xmax').text) - 1 + y_max = int(bnd_box.find('ymax').text) - 1 + labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]]) + + if not is_training: + o_width = abs(x_max - x_min) + o_height = abs(y_max - y_min) + ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id': \ + img_id, 'bbox': [x_min, y_min, o_width, o_height], \ + 'category_id': cls_map[cls_name], 'id': bnd_id, \ + 'ignore': 0, \ + 'segmentation': []} + json_dict['annotations'].append(ann) + bnd_id = bnd_id + 1 + + if labels: + images.append(img_id) + image_files_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(labels) + + if not is_training: + size = root_node.find("size") + width = int(size.find('width').text) + height = int(size.find('height').text) + image = {'file_name': file_name, 'height': height, 'width': width, + 'id': img_id} + json_dict['images'].append(image) + + if not is_training: + for cls_name, cid in cls_map.items(): + cat = {'supercategory': 'none', 'id': cid, 'name': cls_name} + json_dict['categories'].append(cat) + json_fp = open(json_file, 'w') + json_str = json.dumps(json_dict) + json_fp.write(json_str) + json_fp.close() + + return images, image_files_dict, image_anno_dict + + +def create_coco_label(is_training): + """Get image path and annotation from COCO.""" + from pycocotools.coco import COCO + + coco_root = config.coco_root + data_type = config.val_data_type + if is_training: + data_type = config.train_data_type + + # Classes need to train or test. + train_cls = config.coco_classes + train_cls_dict = {} + for i, cls in enumerate(train_cls): + train_cls_dict[cls] = i + + anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) + + coco = COCO(anno_json) + classs_dict = {} + cat_ids = coco.loadCats(coco.getCatIds()) + for cat in cat_ids: + classs_dict[cat["id"]] = cat["name"] + + image_ids = coco.getImgIds() + images = [] + image_path_dict = {} + image_anno_dict = {} + + for img_id in image_ids: + image_info = coco.loadImgs(img_id) + file_name = image_info[0]["file_name"] + anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = coco.loadAnns(anno_ids) + image_path = os.path.join(coco_root, data_type, file_name) + annos = [] + iscrowd = False + for label in anno: + bbox = label["bbox"] + class_name = classs_dict[label["category_id"]] + iscrowd = iscrowd or label["iscrowd"] + if class_name in train_cls: + x_min, x_max = bbox[0], bbox[0] + bbox[2] + y_min, y_max = bbox[1], bbox[1] + bbox[3] + annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]]) + + if not is_training and iscrowd: + continue + if len(annos) >= 1: + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(annos) + + return images, image_path_dict, image_anno_dict + + +def anno_parser(annos_str): + """Parse annotation from string to list.""" + annos = [] + for anno_str in annos_str: + anno = list(map(int, anno_str.strip().split(','))) + annos.append(anno) + return annos + + +def filter_valid_data(image_dir, anno_path): + """Filter valid image file, which both in image_dir and anno_path.""" + images = [] + image_path_dict = {} + image_anno_dict = {} + if not os.path.isdir(image_dir): + raise RuntimeError("Path given is not valid.") + if not os.path.isfile(anno_path): + raise RuntimeError("Annotation file is not valid.") + + with open(anno_path, "rb") as f: + lines = f.readlines() + for img_id, line in enumerate(lines): + line_str = line.decode("utf-8").strip() + line_split = str(line_str).split(' ') + file_name = line_split[0] + image_path = os.path.join(image_dir, file_name) + if os.path.isfile(image_path): + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = anno_parser(line_split[1:]) + + return images, image_path_dict, image_anno_dict + + +def voc_data_to_mindrecord(mindrecord_dir, is_training, prefix="retinanet.mindrecord", file_num=8): + """Create MindRecord file by image_dir and anno_path.""" + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + images, image_path_dict, image_anno_dict = create_voc_label(is_training) + + retinanet_json = { + "img_id": {"type": "int32", "shape": [1]}, + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(retinanet_json, "retinanet_json") + + for img_id in images: + image_path = image_path_dict[img_id] + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + +def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="retina2.mindrecord", file_num=4): + """Create MindRecord file.""" + mindrecord_dir = config.mindrecord_dir + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + if dataset == "coco": + images, image_path_dict, image_anno_dict = create_coco_label(is_training) + else: + images, image_path_dict, image_anno_dict = filter_valid_data(config.image_dir, config.anno_path) + + retinanet_json = { + "img_id": {"type": "int32", "shape": [1]}, + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(retinanet_json, "retinanet_json") + + for img_id in images: + image_path = image_path_dict[img_id] + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + +def create_retinanet_dataset(mindrecord_file, batch_size, repeat_num, device_num=1, rank=0, + is_training=True, num_parallel_workers=64): + """Creatr retinanet dataset with MindDataset.""" + ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, + shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) + decode = C.Decode() + ds = ds.map(operations=decode, input_columns=["image"]) + change_swap_op = C.HWC2CHW() + normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training)) + if is_training: + output_columns = ["image", "box", "label", "num_match"] + trans = [color_adjust_op, normalize_op, change_swap_op] + else: + output_columns = ["img_id", "image", "image_shape"] + trans = [normalize_op, change_swap_op] + ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"], + output_columns=output_columns, column_order=output_columns, + python_multiprocessing=is_training, + num_parallel_workers=num_parallel_workers) + ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=is_training, + num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) + return ds + + +def create_mindrecord(dataset="coco", prefix="retinanet.mindrecord", is_training=True): + """create_mindrecord""" + print("Start create dataset!") + + # It will generate mindrecord file in config.mindrecord_dir, + # and the file name is retinanet.mindrecord0, 1, ... file_num. + + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("coco", is_training, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + elif dataset == "voc": + if os.path.isdir(config.voc_dir): + print("Create Mindrecord.") + voc_data_to_mindrecord(mindrecord_dir, is_training, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("voc_dir not exits.") + else: + if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("other", is_training, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("image_dir or anno_path not exits.") + return mindrecord_file diff --git a/model_zoo/research/cv/retinanet_resnet101/src/init_params.py b/model_zoo/research/cv/retinanet_resnet101/src/init_params.py new file mode 100644 index 00000000000..51185243816 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/init_params.py @@ -0,0 +1,35 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parameters utils""" + +from mindspore.common.initializer import initializer, TruncatedNormal + + +def init_net_param(network, initialize_mode='TruncatedNormal'): + """Init the parameters in net.""" + params = network.trainable_params() + for p in params: + if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: + if initialize_mode == 'TruncatedNormal': + p.set_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype)) + else: + p.set_data(initialize_mode, p.data.shape, p.data.dtype) + + +def filter_checkpoint_parameter(param_dict): + """remove useless parameters""" + for key in list(param_dict.keys()): + if 'multi_loc_layers' in key or 'multi_cls_layers' in key: + del param_dict[key] diff --git a/model_zoo/research/cv/retinanet_resnet101/src/lr_schedule.py b/model_zoo/research/cv/retinanet_resnet101/src/lr_schedule.py new file mode 100644 index 00000000000..0a2e6ce9e37 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/lr_schedule.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Learning rate schedule""" + +import math +import numpy as np + + +def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs1, warmup_epochs2, warmup_epochs3, warmup_epochs4, + warmup_epochs5, total_epochs, steps_per_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(float): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps1 = steps_per_epoch * warmup_epochs1 + warmup_steps2 = warmup_steps1 + steps_per_epoch * warmup_epochs2 + warmup_steps3 = warmup_steps2 + steps_per_epoch * warmup_epochs3 + warmup_steps4 = warmup_steps3 + steps_per_epoch * warmup_epochs4 + warmup_steps5 = warmup_steps4 + steps_per_epoch * warmup_epochs5 + for i in range(total_steps): + if i < warmup_steps1: + lr = lr_init * (warmup_steps1 - i) / (warmup_steps1) + (lr_max * 1e-4) * i / (warmup_steps1 * 3) + elif warmup_steps1 <= i < warmup_steps2: + lr = 1e-5 * (warmup_steps2 - i) / (warmup_steps2 - warmup_steps1) + (lr_max * 1e-3) * ( + i - warmup_steps1) / (warmup_steps2 - warmup_steps1) + elif warmup_steps2 <= i < warmup_steps3: + lr = 1e-4 * (warmup_steps3 - i) / (warmup_steps3 - warmup_steps2) + (lr_max * 1e-2) * ( + i - warmup_steps2) / (warmup_steps3 - warmup_steps2) + elif warmup_steps3 <= i < warmup_steps4: + lr = 1e-3 * (warmup_steps4 - i) / (warmup_steps4 - warmup_steps3) + (lr_max * 1e-1) * ( + i - warmup_steps3) / (warmup_steps4 - warmup_steps3) + elif warmup_steps4 <= i < warmup_steps5: + lr = 1e-2 * (warmup_steps5 - i) / (warmup_steps5 - warmup_steps4) + lr_max * (i - warmup_steps4) / ( + warmup_steps5 - warmup_steps4) + else: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps5) / (total_steps - warmup_steps5))) / 2. + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate diff --git a/model_zoo/research/cv/retinanet_resnet101/src/retinahead.py b/model_zoo/research/cv/retinanet_resnet101/src/retinahead.py new file mode 100644 index 00000000000..b62bc8a6ac1 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/src/retinahead.py @@ -0,0 +1,286 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""retinanet based resnet.""" + +import mindspore.common.dtype as mstype +import mindspore as ms +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.context import ParallelMode +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import get_group_size +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C + +from .bottleneck import FPN + + +class FlattenConcat(nn.Cell): + """ + Concatenate predictions into a single tensor. + + Args: + config (dict): The default config of retinanet. + + Returns: + Tensor, flatten predictions. + """ + + def __init__(self, config): + super(FlattenConcat, self).__init__() + self.num_retinanet_boxes = config.num_retinanet_boxes + self.concat = P.Concat(axis=1) + self.transpose = P.Transpose() + + def construct(self, inputs): + output = () + batch_size = F.shape(inputs[0])[0] + for x in inputs: + x = self.transpose(x, (0, 2, 3, 1)) + output += (F.reshape(x, (batch_size, -1)),) + res = self.concat(output) + return F.reshape(res, (batch_size, self.num_retinanet_boxes, -1)) + + +def ClassificationModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', num_classes=81, + feature_size=256): + conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same') + conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same') + conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same') + conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same') + conv5 = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, pad_mode='same') + return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5]) + + +def RegressionModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', feature_size=256): + conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same') + conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same') + conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same') + conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same') + conv5 = nn.Conv2d(feature_size, num_anchors * 4, kernel_size=3, pad_mode='same') + return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5]) + + +class MultiBox(nn.Cell): + """ + Multibox conv layers. Each multibox layer contains class conf scores and localization predictions. + + Args: + config (dict): The default config of retinanet. + + Returns: + Tensor, localization predictions. + Tensor, class conf scores. + """ + + def __init__(self, config): + super(MultiBox, self).__init__() + + out_channels = config.extras_out_channels + num_default = config.num_default + loc_layers = [] + cls_layers = [] + for k, out_channel in enumerate(out_channels): + loc_layers += [RegressionModel(in_channel=out_channel, num_anchors=num_default[k])] + cls_layers += [ClassificationModel(in_channel=out_channel, num_anchors=num_default[k])] + + self.multi_loc_layers = nn.layer.CellList(loc_layers) + self.multi_cls_layers = nn.layer.CellList(cls_layers) + self.flatten_concat = FlattenConcat(config) + + def construct(self, inputs): + loc_outputs = () + cls_outputs = () + for i in range(len(self.multi_loc_layers)): + loc_outputs += (self.multi_loc_layers[i](inputs[i]),) + cls_outputs += (self.multi_cls_layers[i](inputs[i]),) + return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) + + +class SigmoidFocalClassificationLoss(nn.Cell): + """" + Sigmoid focal-loss for classification. + + Args: + gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0 + alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25 + + Returns: + Tensor, the focal loss. + """ + + def __init__(self, gamma=2.0, alpha=0.25): + super(SigmoidFocalClassificationLoss, self).__init__() + self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.sigmoid = P.Sigmoid() + self.pow = P.Pow() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.gamma = gamma + self.alpha = alpha + + def construct(self, logits, label): + label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value) + sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label) + sigmoid = self.sigmoid(logits) + label = F.cast(label, mstype.float32) + p_t = label * sigmoid + (1 - label) * (1 - sigmoid) + modulating_factor = self.pow(1 - p_t, self.gamma) + alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha) + focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy + return focal_loss + + +class retinahead(nn.Cell): + """retinahead""" + def __init__(self, backbone, config, is_training=True): + super(retinahead, self).__init__() + + self.fpn = FPN(backbone=backbone, config=config) + self.multi_box = MultiBox(config) + self.is_training = is_training + if not is_training: + self.activation = P.Sigmoid() + + def construct(self, inputs): + features = self.fpn(inputs) + pred_loc, pred_label = self.multi_box(features) + return pred_loc, pred_label + + +class retinanetWithLossCell(nn.Cell): + """" + Provide retinanet training loss through network. + + Args: + network (Cell): The training network. + config (dict): retinanet config. + + Returns: + Tensor, the loss of the network. + """ + + def __init__(self, network, config): + super(retinanetWithLossCell, self).__init__() + self.network = network + self.less = P.Less() + self.tile = P.Tile() + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.expand_dims = P.ExpandDims() + self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha) + self.loc_loss = nn.SmoothL1Loss() + + def construct(self, x, gt_loc, gt_label, num_matched_boxes): + """construct""" + pred_loc, pred_label = self.network(x) + mask = F.cast(self.less(0, gt_label), mstype.float32) + num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32)) + + # Localization Loss + mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4)) + smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc + loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1) + + # Classification Loss + loss_cls = self.class_loss(pred_label, gt_label) + loss_cls = self.reduce_sum(loss_cls, (1, 2)) + + return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) + + +class TrainingWrapper(nn.Cell): + """ + Encapsulation class of retinanet network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + + def __init__(self, network, optimizer, sens=1.0): + super(TrainingWrapper, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = ms.ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + if self.reducer_flag: + mean = context.get_auto_parallel_context("gradients_mean") + if auto_parallel_context().get_device_num_is_set(): + degree = context.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, *args): + weights = self.weights + loss = self.network(*args) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*args, sens) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) + + +class retinanetInferWithDecoder(nn.Cell): + """ + retinanet Infer wrapper to decode the bbox locations. + + Args: + network (Cell): the origin retinanet infer network without bbox decoder. + default_boxes (Tensor): the default_boxes from anchor generator + config (dict): retinanet config + Returns: + Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1) + Tensor, the prediction labels. + + """ + def __init__(self, network, default_boxes, config): + super(retinanetInferWithDecoder, self).__init__() + self.network = network + self.default_boxes = default_boxes + self.prior_scaling_xy = config.prior_scaling[0] + self.prior_scaling_wh = config.prior_scaling[1] + + def construct(self, x): + """construct""" + pred_loc, pred_label = self.network(x) + + default_bbox_xy = self.default_boxes[..., :2] + default_bbox_wh = self.default_boxes[..., 2:] + pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy + pred_wh = P.Exp()(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh + + pred_xy_0 = pred_xy - pred_wh / 2.0 + pred_xy_1 = pred_xy + pred_wh / 2.0 + pred_xy = P.Concat(-1)((pred_xy_0, pred_xy_1)) + pred_xy = P.Maximum()(pred_xy, 0) + pred_xy = P.Minimum()(pred_xy, 1) + return pred_xy, pred_label diff --git a/model_zoo/research/cv/retinanet_resnet101/train.py b/model_zoo/research/cv/retinanet_resnet101/train.py new file mode 100644 index 00000000000..6146d0e55c5 --- /dev/null +++ b/model_zoo/research/cv/retinanet_resnet101/train.py @@ -0,0 +1,154 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Train retinanet and get checkpoint files.""" + +import os +import argparse +import ast +import mindspore +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.communication.management import init, get_rank +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor, Callback +from mindspore.train import Model +from mindspore.context import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import set_seed +from src.retinahead import retinanetWithLossCell, TrainingWrapper, retinahead +from src.backbone import resnet101 +from src.config import config +from src.dataset import create_retinanet_dataset, create_mindrecord +from src.lr_schedule import get_lr +from src.init_params import init_net_param, filter_checkpoint_parameter + + +set_seed(1) +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + def step_end(self, run_context): + cb_params = run_context.original_args() + print("lr:[{:8.6f}]".format(self.lr_init[cb_params.cur_step_num-1]), flush=True) + +def main(): + parser = argparse.ArgumentParser(description="retinanet training") + parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, + help="If set it true, only create Mindrecord, default is False.") + parser.add_argument("--distribute", type=ast.literal_eval, default=False, + help="Run distribute, default is False.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.") + parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") + parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") + parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") + parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") + parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") + parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 1.") + parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") + parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, + help="Filter weight parameters, default is False.") + parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"), + help="run platform, only support Ascend.") + args_opt = parser.parse_args() + + if args_opt.run_platform == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if args_opt.distribute: + if os.getenv("DEVICE_ID", "not_set").isdigit(): + context.set_context(device_id=int(os.getenv("DEVICE_ID"))) + init() + device_num = args_opt.device_num + rank = get_rank() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, + device_num=device_num) + else: + rank = 0 + device_num = 1 + context.set_context(device_id=args_opt.device_id) + + else: + raise ValueError("Unsupported platform.") + + mindrecord_file = create_mindrecord(args_opt.dataset, "retina2.mindrecord", True) + + if not args_opt.only_create_dataset: + loss_scale = float(args_opt.loss_scale) + + # When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0. + dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1, + batch_size=args_opt.batch_size, device_num=device_num, rank=rank) + + dataset_size = dataset.get_dataset_size() + print("Create dataset done!") + + + backbone = resnet101(config.num_classes) + retinanet = retinahead(backbone, config) + net = retinanetWithLossCell(retinanet, config) + net.to_float(mindspore.float16) + init_net_param(net) + + if args_opt.pre_trained: + if args_opt.pre_trained_epoch_size <= 0: + raise KeyError("pre_trained_epoch_size must be greater than 0.") + param_dict = load_checkpoint(args_opt.pre_trained) + if args_opt.filter_weight: + filter_checkpoint_parameter(param_dict) + load_param_into_net(net, param_dict) + + lr = Tensor(get_lr(global_step=config.global_step, + lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, + warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2, + warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4, + warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size, + steps_per_epoch=dataset_size)) + opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, + config.momentum, config.weight_decay, loss_scale) + net = TrainingWrapper(net, opt, loss_scale) + model = Model(net) + print("Start train retinanet, the first epoch will be slower because of the graph compilation.") + cb = [TimeMonitor(), LossMonitor()] + cb += [Monitor(lr_init=lr.asnumpy())] + config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck) + if args_opt.distribute: + if rank == 0: + cb += [ckpt_cb] + model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True) + else: + cb += [ckpt_cb] + model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True) + +if __name__ == '__main__': + main()