add retinanet_resnet101

This commit is contained in:
zcc 2021-03-24 10:31:32 +08:00
parent 7454ac8ecd
commit 04f4423b03
17 changed files with 2420 additions and 0 deletions

View File

@ -0,0 +1,315 @@
# 1. 内容
<!-- TOC -->
- <span id="content">[Retinanet 描述](#-Retinanet-描述)</span>
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [脚本和示例代码](#脚本和示例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#用法)
- [运行](#运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [用法](#usage)
- [运行](#running)
- [结果](#outcome)
- [模型说明](#模型说明)
- [性能](#性能)
- [训练性能](#训练性能)
- [推理性能](#推理性能)
- [随机情况的描述](#随机情况的描述)
- [ModelZoo 主页](#modelzoo-主页)
<!-- /TOC -->
## [Retinanet 描述](#content)
RetinaNet算法源自2018年Facebook AI Research的论文 Focal Loss for Dense Object Detection。该论文最大的贡献在于提出了Focal Loss用于解决类别不均衡问题从而创造了RetinaNetOne Stage目标检测算法这个精度超越经典Two Stage的Faster-RCNN的目标检测网络。
[论文](https://arxiv.org/pdf/1708.02002.pdf)
Lin T Y , Goyal P , Girshick R , et al. Focal Loss for Dense Object Detection[C]// 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, 2017:2999-3007.
## [模型架构](#content)
Retinanet的整体网络架构如下所示
[链接](https://arxiv.org/pdf/1708.02002.pdf)
## [数据集](#content)
数据集可参考文献.
MSCOCO2017
- 数据集大小: 19.3G, 123287张80类彩色图像
- 训练:19.3G, 118287张图片
- 测试:1814.3M, 5000张图片
- 数据格式:RGB图像.
- 注意数据将在src/dataset.py 中被处理
## [环境要求](#content)
- 硬件Ascend
- 使用Ascend处理器准备硬件环境。
- 架构
- [MindSpore](https://www.mindspore.cn/install/en)
- 想要获取更多信息,请检查以下资源:
- [MindSpore 教程](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## [脚本说明](#content)
### [脚本和示例代码](#content)
```shell
.
└─Retinanet_resnet101
├─README.md
├─scripts
├─run_single_train.sh # 使用Ascend环境单卡训练
├─run_distribute_train.sh # 使用Ascend环境八卡并行训练
├─run_eval.sh # 使用Ascend环境运行推理脚本
├─src
├─backbone.py # 网络模型定义
├─bottleneck.py # 网络颈部定义
├─config.py # 参数配置
├─dataset.py # 数据预处理
├─retinahead.py # 网络预测头部定义
├─init_params.py # 参数初始化
├─lr_generator.py # 学习率生成函数
├─coco_eval # coco数据集评估
├─box_utils.py # 先验框设置
├─_init_.py # 初始化
├─train.py # 网络训练脚本
└─eval.py # 网络推理脚本
```
### [脚本参数](#content)
```python
在train.py和config.py脚本中使用到的主要参数是:
"img_shape": [600, 600], # 图像尺寸
"num_retinanet_boxes": 67995, # 设置的先验框总数
"match_thershold": 0.5, # 匹配阈值
"softnms_sigma": 0.5, # softnms算法σ
"nms_thershold": 0.6, # 非极大抑制阈值
"min_score": 0.1, # 最低得分
"max_boxes": 100, # 检测框最大数量
"global_step": 0, # 全局步数
"lr_init": 1e-6, # 初始学习率
"lr_end_rate": 5e-3, # 最终学习率与最大学习率的比值
"warmup_epochs1": 2, # 第一阶段warmup的周期数
"warmup_epochs2": 5, # 第二阶段warmup的周期数
"warmup_epochs3": 23, # 第三阶段warmup的周期数
"warmup_epochs4": 60, # 第四阶段warmup的周期数
"warmup_epochs5": 160, # 第五阶段warmup的周期数
"momentum": 0.9, # momentum
"weight_decay": 1.5e-4, # 权重衰减率
"num_default": [9, 9, 9, 9, 9], # 单个网格中先验框的个数
"extras_out_channels": [256, 256, 256, 256, 256], # 特征层输出通道数
"feature_size": [75, 38, 19, 10, 5], # 特征层尺寸
"aspect_ratios": [(0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0)], # 先验框大小变化比值
"steps": ( 8, 16, 32, 64, 128), # 先验框设置步长
"anchor_size":(32, 64, 128, 256, 512), # 先验框尺寸
"prior_scaling": (0.1, 0.2), # 用于调节回归与回归在loss中占的比值
"gamma": 2.0, # focal loss中的参数
"alpha": 0.75, # focal loss中的参数
"mindrecord_dir": "/opr/root/data/MindRecord_COCO", # mindrecord文件路径
"coco_root": "/opr/root/data/", # coco数据集路径
"train_data_type": "train2017", # train图像的文件夹名
"val_data_type": "val2017", # val图像的文件夹名
"instances_set": "annotations_trainval2017/annotations/instances_{}.json", # 标签文件路径
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', # coco数据集的种类
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81, # 数据集类别数
"voc_root": "", # voc数据集路径
"voc_dir": "",
"image_dir": "", # 图像路径
"anno_path": "", # 标签文件路径
"save_checkpoint": True, # 保存checkpoint
"save_checkpoint_epochs": 1, # 保存checkpoint epoch数
"keep_checkpoint_max":1, # 保存checkpoint的最大数量
"save_checkpoint_path": "./model", # 保存checkpoint的路径
"finish_epoch":0, # 已经运行完成的 epoch 数
"checkpoint_path":"/opr/root/reretina/retinanet2/LOG0/model/retinanet-400_458.ckpt" # 用于验证的checkpoint路径
```
### [训练过程](#content)
#### 用法
您可以使用python或shell脚本进行训练。shell脚本的用法如下:
- Ascend:
```训练
# 八卡并行训练示例:
创建 RANK_TABLE_FILE
sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
# 单卡训练示例:
sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
```
> 注意:
RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
#### 运行
``` 运行
# 训练示例
python:
data和存储mindrecord文件的路径在config里设置
# 单卡训练示例:
python train.py
shell:
Ascend:
# 八卡并行训练示例(在retinanet目录下运行)
sh scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE预训练EPOCH大小
例如sh scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322
# 单卡训练示例(在retinanet目录下运行)
sh scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322
```
#### 结果
训练结果将存储在示例路径中。checkpoint将存储在 `./model` 路径下,训练日志将被记录到 `./log.txt` 中,训练日志部分示例如下:
``` 训练日志
epoch: 397 step: 458, loss is 0.6153226
lr:[0.000598]
epoch time: 313364.642 ms, per step time: 684.202 ms
epoch: 398 step: 458, loss is 0.5491791
lr:[0.000544]
epoch time: 313486.094 ms, per step time: 684.467 ms
epoch: 399 step: 458, loss is 0.51681435
lr:[0.000511]
epoch time: 313514.348 ms, per step time: 684.529 ms
epoch: 400 step: 458, loss is 0.4305706
lr:[0.000500]
epoch time: 314138.455 ms, per step time: 685.892 ms
```
### [评估过程](#content)
#### <span id="usage">用法</span>
您可以使用python或shell脚本进行训练。shell脚本的用法如下:
```eval
sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
```
#### <span id="running">运行</span>
```eval运行
# 验证示例
python:
Ascend: python eval.py
checkpoint 的路径在config里设置
shell:
Ascend: sh scripts/run_eval.sh coco 0
```
> checkpoint 可以在训练过程中产生.
#### <span id="outcome">结果</span>
计算结果将存储在示例路径中,您可以在 `eval.log` 查看.
``` mAP
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.371
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.517
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.408
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.143
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.394
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.547
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.318
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.455
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.464
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.172
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.489
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.680
========================================
mAP: 0.3710347196613514
```
## [模型说明](#content)
### [性能](#content)
#### 训练性能
| 参数 | Ascend |
| -------------------------- | ------------------------------------- |
| 模型名称 | Retinanet |
| 运行环境 | 华为云 Modelarts |
| 上传时间 | 10/03/2021 |
| MindSpore 版本 | 1.0.1 |
| 数据集 | 123287 张图片 |
| Batch_size | 32 |
| 训练参数 | src/config.py |
| 优化器 | Momentum |
| 损失函数 | Focal loss |
| 最终损失 | 0.43 |
| 精确度 (8p) | mAP[0.3710] |
| 训练总时间 (8p) | 34h50m20s |
| 脚本 | [Retianet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/Retinanet_resnet101) |
#### 推理性能
| 参数 | Ascend |
| ------------------- | :-------------------------- |
| 模型名称 | Retinanet |
| 运行环境 | 华为云 Modelarts |
| 上传时间 | 10/03/2021 |
| MindSpore 版本 | 1.0.1 |
| 数据集 | 5k 张图片 |
| Batch_size | 1 |
| 精确度 | mAP[0.3710] |
| 总时间 | 10 mins and 50 seconds |
# [随机情况的描述](#内容)
`dataset.py` 脚本中, 我们在 `create_dataset` 函数中设置了随机种子. 我们在 `train.py` 脚本中也设置了随机种子.
# [ModelZoo 主页](#内容)
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for retinanet"""
import os
import argparse
import time
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.retinahead import retinahead, retinanetInferWithDecoder
from src.backbone import resnet101
from src.dataset import create_retinanet_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
from src.config import config
from src.coco_eval import metrics
from src.box_utils import default_boxes
def retinanet_eval(dataset_path, ckpt_path):
"""retinanet evaluation."""
batch_size = 1
ds = create_retinanet_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False)
backbone = resnet101(config.num_classes)
net = retinahead(backbone, config)
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
print("Load Checkpoint!")
param_dict = load_checkpoint(ckpt_path)
net.init_parameters_data()
load_param_into_net(net, param_dict)
net.set_train(False)
i = batch_size
total = ds.get_dataset_size() * batch_size
start = time.time()
pred_data = []
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
for data in ds.create_dict_iterator(output_numpy=True):
img_id = data['img_id']
img_np = data['image']
image_shape = data['image_shape']
output = net(Tensor(img_np))
for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx],
"img_id": int(np.squeeze(img_id[batch_idx])),
"image_shape": image_shape[batch_idx]})
percent = round(i / total * 100., 2)
print(f' {str(percent)} [{i}/{total}]', end='\r')
i += batch_size
cost_time = int((time.time() - start) * 1000)
print(f' 100% [{total}/{total}] cost {cost_time} ms')
mAP = metrics(pred_data)
print("\n========================================\n")
print(f"mAP: {mAP}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='retinanet evaluation')
parser.add_argument("--device_id", type=int, default=3, help="Device id, default is 0.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"),
help="run platform, only support Ascend and GPU.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
prefix = "retinanet_eval.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if args_opt.dataset == "voc":
config.coco_root = config.voc_root
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.coco_root):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
elif args_opt.dataset == "voc":
if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root):
print("Create Mindrecord.")
voc_data_to_mindrecord(mindrecord_dir, False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("voc_root or voc_dir not exits.")
else:
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
print("Start Eval!")
retinanet_eval(mindrecord_file, config.checkpoint_path)

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export file"""
import argparse
import numpy as np
from mindspore import dtype as mstype
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.retinahead import retinahead
from src.backbone import resnet101
from src.config import config
parser = argparse.ArgumentParser(description="retinanet_resnet101 export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="retinanet_resnet101", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == "__main__":
network = retinahead(backbone=resnet101(80), config=config, is_training=False)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, param_dict)
network.set_train(False)
shape = [args.batch_size, 3] + [600, 600]
input_data = Tensor(np.zeros(shape), mstype.float32)
export(network, input_data, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,83 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
echo "for example: sh run_distribute_train.sh 8 500 0.1 coco /data/hccl.json /opt/retinanet-500_458.ckpt(optional) 200(optional)"
echo "It is better to use absolute path."
echo "================================================================================================================="
if [ $# != 5 ] && [ $# != 7 ]
then
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \
[RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
exit 1
fi
# Before start distribute train, first create mindrecord files.
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
python train.py --only_create_dataset=True
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
EPOCH_SIZE=$2
LR=$3
DATASET=$4
PRE_TRAINED=$6
PRE_TRAINED_EPOCH_SIZE=$7
export RANK_TABLE_FILE=$5
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
rm -rf LOG$i
mkdir ./LOG$i
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cp -r ./scripts ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
if [ $# == 5 ]
then
python train.py \
--distribute=True \
--lr=$LR \
--dataset=$DATASET \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
fi
if [ $# == 7 ]
then
python train.py \
--distribute=True \
--lr=$LR \
--dataset=$DATASET \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--pre_trained=$PRE_TRAINED \
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
fi
cd ../
done

View File

@ -0,0 +1,49 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [DATASET] [DEVICE_ID]"
exit 1
fi
DATASET=$1
echo $DATASET
export DEVICE_NUM=1
export DEVICE_ID=$2
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
if [ -d "eval$2" ];
then
rm -rf ./eval$2
fi
mkdir ./eval$2
cp ./*.py ./eval$2
cp -r ./src ./eval$2
cd ./eval$2 || exit
env > env.log
echo "start inferring for device $DEVICE_ID"
python eval.py \
--dataset=$DATASET \
--device_id=$2 > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
echo "for example: sh run_single_train.sh 0 500 0.1 coco /opt/retinanet-500_458.ckpt(optional) 200(optional)"
echo "It is better to use absolute path."
echo "================================================================================================================="
if [ $# != 4 ] && [ $# != 6 ]
then
echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] [DATASET] \
[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
exit 1
fi
# Before start single train, first create mindrecord files.
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
python train.py --only_create_dataset=True
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
export DEVICE_ID=$1
EPOCH_SIZE=$2
LR=$3
DATASET=$4
rm -rf LOG$1
mkdir ./LOG$1
cp ./*.py ./LOG$1
cp -r ./src ./LOG$1
cd ./LOG$1 || exit
echo "start training for device $1"
env > env.log
python train.py \
--distribute=False \
--lr=$LR \
--dataset=$DATASET \
--device_num=1 \
--device_id=$DEVICE_ID \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
cd ../

View File

@ -0,0 +1,226 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BackBone file"""
import mindspore.nn as nn
from mindspore.ops import operations as P
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.97,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = 0
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same',
padding=padding)
layers = [conv, _bn(out_planes), nn.ReLU()]
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', padding=0,
has_bn=True, activation='relu')
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel,
kernel_size=1, stride=stride,
pad_mode='same', padding=0, has_bn=True, activation='relu')
self.add = P.TensorAdd()
self.relu = P.ReLU()
def construct(self, x):
"""construct"""
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class resnet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes):
super(resnet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = ResidualBlock(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = ResidualBlock(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
C1 = self.maxpool(x)
C2 = self.layer1(C1)
C3 = self.layer2(C2)
C4 = self.layer3(C3)
C5 = self.layer4(C4)
return C3, C4, C5
def resnet101(num_classes):
return resnet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
num_classes)
def resnet152(num_classes):
return resnet(ResidualBlock,
[3, 8, 36, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
num_classes)

View File

@ -0,0 +1,71 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bottleneck"""
import mindspore.nn as nn
from mindspore.ops import operations as P
class FPN(nn.Cell):
"""FPN"""
def __init__(self, config, backbone, is_training=True):
super(FPN, self).__init__()
self.backbone = backbone
feature_size = config.feature_size
self.P5_1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, pad_mode='same')
self.P_upsample1 = P.ResizeNearestNeighbor((feature_size[1], feature_size[1]))
self.P5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
self.P4_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, pad_mode='same')
self.P_upsample2 = P.ResizeNearestNeighbor((feature_size[0], feature_size[0]))
self.P4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
self.P3_1 = nn.Conv2d(512, 256, kernel_size=1, stride=1, pad_mode='same')
self.P3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
self.P6_0 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, pad_mode='same')
self.P7_1 = nn.ReLU()
self.P7_2 = nn.Conv2d(256, 256, kernel_size=3, stride=2, pad_mode='same')
self.is_training = is_training
if not is_training:
self.activation = P.Sigmoid()
def construct(self, x):
"""construct"""
C3, C4, C5 = self.backbone(x)
P5 = self.P5_1(C5)
P5_upsampled = self.P_upsample1(P5)
P5 = self.P5_2(P5)
P4 = self.P4_1(C4)
P4 = P5_upsampled + P4
P4_upsampled = self.P_upsample2(P4)
P4 = self.P4_2(P4)
P3 = self.P3_1(C3)
P3 = P4_upsampled + P3
P3 = self.P3_2(P3)
P6 = self.P6_0(C5)
P7 = self.P7_1(P6)
P7 = self.P7_2(P7)
multi_feature = (P3, P4, P5, P6, P7)
return multi_feature

View File

@ -0,0 +1,166 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bbox utils"""
import math
import itertools as it
import numpy as np
from .config import config
class GeneratDefaultBoxes():
"""
Generate Default boxes for retinanet, follows the order of (W, H, archor_sizes).
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
"""
def __init__(self):
fk = config.img_shape[0] / np.array(config.steps)
scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
anchor_size = np.array(config.anchor_size)
self.default_boxes = []
for idex, feature_size in enumerate(config.feature_size):
base_size = anchor_size[idex] / config.img_shape[0]
size1 = base_size * scales[0]
size2 = base_size * scales[1]
size3 = base_size * scales[2]
all_sizes = []
for aspect_ratio in config.aspect_ratios[idex]:
w1, h1 = size1 * math.sqrt(aspect_ratio), size1 / math.sqrt(aspect_ratio)
all_sizes.append((h1, w1))
w2, h2 = size2 * math.sqrt(aspect_ratio), size2 / math.sqrt(aspect_ratio)
all_sizes.append((h2, w2))
w3, h3 = size3 * math.sqrt(aspect_ratio), size3 / math.sqrt(aspect_ratio)
all_sizes.append((h3, w3))
assert len(all_sizes) == config.num_default[idex]
for i, j in it.product(range(feature_size), repeat=2):
for h, w in all_sizes:
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
self.default_boxes.append([cy, cx, h, w])
def to_ltrb(cy, cx, h, w):
return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2
# For IoU calculation
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
self.default_boxes = np.array(self.default_boxes, dtype='float32')
default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
default_boxes = GeneratDefaultBoxes().default_boxes
y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
vol_anchors = (x2 - x1) * (y2 - y1)
matching_threshold = config.match_thershold
def retinanet_bboxes_encode(boxes):
"""
Labels anchors with ground truth inputs.
Args:
boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls].
Returns:
gt_loc: location ground truth with shape [num_anchors, 4].
gt_label: class ground truth with shape [num_anchors, 1].
num_matched_boxes: number of positives in an image.
"""
def jaccard_with_anchors(bbox):
"""Compute jaccard score a box and the anchors."""
# Intersection bbox and volume.
ymin = np.maximum(y1, bbox[0])
xmin = np.maximum(x1, bbox[1])
ymax = np.minimum(y2, bbox[2])
xmax = np.minimum(x2, bbox[3])
w = np.maximum(xmax - xmin, 0.)
h = np.maximum(ymax - ymin, 0.)
# Volumes.
inter_vol = h * w
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
jaccard = inter_vol / union_vol
return np.squeeze(jaccard)
pre_scores = np.zeros((config.num_retinanet_boxes), dtype=np.float32)
t_boxes = np.zeros((config.num_retinanet_boxes, 4), dtype=np.float32)
t_label = np.zeros((config.num_retinanet_boxes), dtype=np.int64)
for bbox in boxes:
label = int(bbox[4])
scores = jaccard_with_anchors(bbox)
idx = np.argmax(scores)
scores[idx] = 2.0
mask = (scores > matching_threshold)
mask = mask & (scores > pre_scores)
pre_scores = np.maximum(pre_scores, scores * mask)
t_label = mask * label + (1 - mask) * t_label
for i in range(4):
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
index = np.nonzero(t_label)
# Transform to ltrb.
bboxes = np.zeros((config.num_retinanet_boxes, 4), dtype=np.float32)
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
# Encode features.
bboxes_t = bboxes[index]
default_boxes_t = default_boxes[index]
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.prior_scaling[0])
tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001)
bboxes_t[:, 2:4] = np.log(tmp) / config.prior_scaling[1]
bboxes[index] = bboxes_t
num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
return bboxes, t_label.astype(np.int32), num_match
def retinanet_bboxes_decode(boxes):
"""Decode predict boxes to [y, x, h, w]"""
boxes_t = boxes.copy()
default_boxes_t = default_boxes.copy()
boxes_t[:, :2] = boxes_t[:, :2] * config.prior_scaling[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.prior_scaling[1]) * default_boxes_t[:, 2:4]
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
return np.clip(bboxes, 0, 1)
def intersect(box_a, box_b):
"""Compute the intersect of two sets of boxes."""
max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
min_yx = np.maximum(box_a[:, :2], box_b[:2])
inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
return inter[:, 0] * inter[:, 1]
def jaccard_numpy(box_a, box_b):
"""Compute the jaccard overlap of two sets of boxes."""
inter = intersect(box_a, box_b)
area_a = ((box_a[:, 2] - box_a[:, 0]) *
(box_a[:, 3] - box_a[:, 1]))
area_b = ((box_b[2] - box_b[0]) *
(box_b[3] - box_b[1]))
union = area_a + area_b - inter
return inter / union

View File

@ -0,0 +1,196 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Coco metrics utils"""
import os
import json
import numpy as np
from .config import config
def apply_softnms(dets, scores, sigma=0.5, method=2, thresh=0.001, Nt=0.1):
'''
the soft nms implement using python
:param dets: the pred_bboxes
:param method: the policy of decay pred_bbox score in soft nms
:param thresh: the threshold
:param Nt: Nt
:return: the index of pred_bbox after soft nms
'''
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
areas = (y2 - y1 + 1.) * (x2 - x1 + 1.)
orders = scores.argsort()[::-1]
keep = []
while orders.size > 0:
i = orders[0]
keep.append(i)
for j in orders[1:]:
xx1 = np.maximum(x1[i], x1[j])
yy1 = np.maximum(y1[i], y1[j])
xx2 = np.minimum(x2[i], x2[j])
yy2 = np.minimum(y2[i], y2[j])
w = np.maximum(xx2 - xx1 + 1., 0.)
h = np.maximum(yy2 - yy1 + 1., 0.)
inter = w * h
overlap = inter / (areas[i] + areas[j] - inter)
if method == 1: # linear
if overlap > Nt:
weight = 1 - overlap
else:
weight = 1
elif method == 2: # gaussian
weight = np.exp(-(overlap * overlap) / sigma)
else: # original NMS
if overlap > Nt:
weight = 0
else:
weight = 1
scores[j] = weight * scores[j]
if scores[j] < thresh:
orders = np.delete(orders, np.where(orders == j))
orders = np.delete(orders, 0)
return keep
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""
y1 = all_boxes[:, 0]
x1 = all_boxes[:, 1]
y2 = all_boxes[:, 2]
x2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = all_scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if len(keep) >= max_boxes:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def metrics(pred_data):
"""Calculate mAP of predicted bboxes."""
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
num_classes = config.num_classes
coco_root = config.coco_root
data_type = config.val_data_type
# Classes need to train or test.
val_cls = config.coco_classes
val_cls_dict = {}
for i, cls in enumerate(val_cls):
val_cls_dict[i] = cls
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
coco_gt = COCO(anno_json)
classs_dict = {}
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
for cat in cat_ids:
classs_dict[cat["name"]] = cat["id"]
predictions = []
img_ids = []
for sample in pred_data:
pred_boxes = sample['boxes']
box_scores = sample['box_scores']
img_id = sample['img_id']
h, w = sample['image_shape']
final_boxes = []
final_label = []
final_score = []
img_ids.append(img_id)
for c in range(1, num_classes):
class_box_scores = box_scores[:, c]
score_mask = class_box_scores > config.min_score
class_box_scores = class_box_scores[score_mask]
class_boxes = pred_boxes[score_mask] * [h, w, h, w]
if score_mask.any():
# nms_index = apply_nms(class_boxes, class_box_scores, config.nms_thershold, config.max_boxes)
# apply_softnms( dets, scores,method=2, thresh=0.001, Nt=0.1, sigma=0.5 )
nms_index = apply_softnms(class_boxes, class_box_scores, config.softnms_sigma)
class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index]
final_boxes += class_boxes.tolist()
final_score += class_box_scores.tolist()
final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores)
for loc, label, score in zip(final_boxes, final_label, final_score):
res = {}
res['image_id'] = img_id
res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]
res['score'] = score
res['category_id'] = label
predictions.append(res)
with open('predictions.json', 'w') as f:
json.dump(predictions, f)
coco_dt = coco_gt.loadRes('predictions.json')
E = COCOeval(coco_gt, coco_dt, iouType='bbox')
E.params.imgIds = img_ids
E.evaluate()
E.accumulate()
E.summarize()
return E.stats[0]

View File

@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# " ============================================================================
"""Config parameters for retinanet models."""
from easydict import EasyDict as ed
config = ed({
"img_shape": [600, 600],
"num_retinanet_boxes": 67995,
"match_thershold": 0.5,
"softnms_sigma": 0.5,
"nms_thershold": 0.6,
"min_score": 0.1,
"max_boxes": 100,
# learing rate settings
"global_step": 0,
"lr_init": 1e-6,
"lr_end_rate": 5e-3,
"warmup_epochs1": 2,
"warmup_epochs2": 5,
"warmup_epochs3": 23,
"warmup_epochs4": 60,
"warmup_epochs5": 160,
"momentum": 0.9,
"weight_decay": 1.5e-4,
# network
"num_default": [9, 9, 9, 9, 9],
"extras_out_channels": [256, 256, 256, 256, 256],
"feature_size": [75, 38, 19, 10, 5],
"aspect_ratios": [(0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0)],
"steps": (8, 16, 32, 64, 128),
"anchor_size": (32, 64, 128, 256, 512),
"prior_scaling": (0.1, 0.2),
"gamma": 2.0,
"alpha": 0.75,
# `mindrecord_dir` and `coco_root` are better to use absolute path.
"mindrecord_dir": "/opr/root/data/MindRecord_COCO",
"coco_root": "/opr/root/data/",
"train_data_type": "train2017",
"val_data_type": "val2017",
"instances_set": "anno/instances_{}.json",
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81,
# The annotation.json position of voc validation dataset.
"voc_root": "",
# voc original dataset.
"voc_dir": "",
# if coco or voc used, `image_dir` and `anno_path` are useless.
"image_dir": "",
"anno_path": "",
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 30,
"save_checkpoint_path": "./model",
"finish_epoch": 0,
"checkpoint_path": "/opr/root/reretina/retinanet2/LOG0/model/retinanet-400_458.ckpt"
})

View File

@ -0,0 +1,454 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""retinanet dataset"""
from __future__ import division
import os
import json
import xml.etree.ElementTree as et
import numpy as np
import cv2
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from mindspore.mindrecord import FileWriter
from .config import config
from .box_utils import jaccard_numpy, retinanet_bboxes_encode
def _rand(a=0., b=1.):
"""Generate random."""
return np.random.rand() * (b - a) + a
def get_imageId_from_fileName(filename):
"""Get imageID from fileName"""
filename = os.path.splitext(filename)[0]
if filename.isdigit():
return int(filename)
return id_iter
def random_sample_crop(image, boxes):
"""Random Crop the image and boxes"""
height, width, _ = image.shape
min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9])
if min_iou is None:
return image, boxes
# max trails (50)
for _ in range(50):
image_t = image
w = _rand(0.3, 1.0) * width
h = _rand(0.3, 1.0) * height
# aspect ratio constraint b/t .5 & 2
if h / w < 0.5 or h / w > 2:
continue
left = _rand() * (width - w)
top = _rand() * (height - h)
rect = np.array([int(top), int(left), int(top + h), int(left + w)])
overlap = jaccard_numpy(boxes, rect)
# dropout some boxes
drop_mask = overlap > 0
if not drop_mask.any():
continue
if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2):
continue
image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :]
centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0
m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
# mask in that both m1 and m2 are true
mask = m1 * m2 * drop_mask
# have any valid boxes? try again if not
if not mask.any():
continue
# take only matching gt boxes
boxes_t = boxes[mask, :].copy()
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2])
boxes_t[:, :2] -= rect[:2]
boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4])
boxes_t[:, 2:4] -= rect[:2]
return image_t, boxes_t
return image, boxes
def preprocess_fn(img_id, image, box, is_training):
"""Preprocess function for dataset."""
cv2.setNumThreads(2)
def _infer_data(image, input_shape):
img_h, img_w, _ = image.shape
input_h, input_w = input_shape
image = cv2.resize(image, (input_w, input_h))
# When the channels of image is 1
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1)
return img_id, image, np.array((img_h, img_w), np.float32)
def _data_aug(image, box, is_training, image_size=(600, 600)):
"""Data augmentation function."""
ih, iw, _ = image.shape
w, h = image_size
if not is_training:
return _infer_data(image, image_size)
# Random crop
box = box.astype(np.float32)
image, box = random_sample_crop(image, box)
ih, iw, _ = image.shape
# Resize image
image = cv2.resize(image, (w, h))
# Flip image or not
flip = _rand() < .5
if flip:
image = cv2.flip(image, 1, dst=None)
# When the channels of image is 1
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1)
box[:, [0, 2]] = box[:, [0, 2]] / ih
box[:, [1, 3]] = box[:, [1, 3]] / iw
if flip:
box[:, [1, 3]] = 1 - box[:, [3, 1]]
box, label, num_match = retinanet_bboxes_encode(box)
return image, box, label, num_match
return _data_aug(image, box, is_training, image_size=config.img_shape)
def create_voc_label(is_training):
"""Get image path and annotation from VOC."""
voc_dir = config.voc_dir
cls_map = {name: i for i, name in enumerate(config.coco_classes)}
sub_dir = 'train' if is_training else 'eval'
voc_dir = os.path.join(voc_dir, sub_dir)
if not os.path.isdir(voc_dir):
raise ValueError(f'Cannot find {sub_dir} dataset path.')
image_dir = anno_dir = voc_dir
if os.path.isdir(os.path.join(voc_dir, 'Images')):
image_dir = os.path.join(voc_dir, 'Images')
if os.path.isdir(os.path.join(voc_dir, 'Annotations')):
anno_dir = os.path.join(voc_dir, 'Annotations')
if not is_training:
data_dir = config.voc_root
json_file = os.path.join(data_dir, config.instances_set.format(sub_dir))
file_dir = os.path.split(json_file)[0]
if not os.path.isdir(file_dir):
os.makedirs(file_dir)
json_dict = {"images": [], "type": "instances", "annotations": [],
"categories": []}
bnd_id = 1
image_files_dict = {}
image_anno_dict = {}
images = []
for anno_file in os.listdir(anno_dir):
print(anno_file)
if not anno_file.endswith('xml'):
continue
tree = et.parse(os.path.join(anno_dir, anno_file))
root_node = tree.getroot()
file_name = root_node.find('filename').text
img_id = get_imageId_from_fileName(file_name)
image_path = os.path.join(image_dir, file_name)
print(image_path)
if not os.path.isfile(image_path):
print(f'Cannot find image {file_name} according to annotations.')
continue
labels = []
for obj in root_node.iter('object'):
cls_name = obj.find('name').text
if cls_name not in cls_map:
print(f'Label "{cls_name}" not in "{config.coco_classes}"')
continue
bnd_box = obj.find('bndbox')
x_min = int(bnd_box.find('xmin').text) - 1
y_min = int(bnd_box.find('ymin').text) - 1
x_max = int(bnd_box.find('xmax').text) - 1
y_max = int(bnd_box.find('ymax').text) - 1
labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]])
if not is_training:
o_width = abs(x_max - x_min)
o_height = abs(y_max - y_min)
ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id': \
img_id, 'bbox': [x_min, y_min, o_width, o_height], \
'category_id': cls_map[cls_name], 'id': bnd_id, \
'ignore': 0, \
'segmentation': []}
json_dict['annotations'].append(ann)
bnd_id = bnd_id + 1
if labels:
images.append(img_id)
image_files_dict[img_id] = image_path
image_anno_dict[img_id] = np.array(labels)
if not is_training:
size = root_node.find("size")
width = int(size.find('width').text)
height = int(size.find('height').text)
image = {'file_name': file_name, 'height': height, 'width': width,
'id': img_id}
json_dict['images'].append(image)
if not is_training:
for cls_name, cid in cls_map.items():
cat = {'supercategory': 'none', 'id': cid, 'name': cls_name}
json_dict['categories'].append(cat)
json_fp = open(json_file, 'w')
json_str = json.dumps(json_dict)
json_fp.write(json_str)
json_fp.close()
return images, image_files_dict, image_anno_dict
def create_coco_label(is_training):
"""Get image path and annotation from COCO."""
from pycocotools.coco import COCO
coco_root = config.coco_root
data_type = config.val_data_type
if is_training:
data_type = config.train_data_type
# Classes need to train or test.
train_cls = config.coco_classes
train_cls_dict = {}
for i, cls in enumerate(train_cls):
train_cls_dict[cls] = i
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
coco = COCO(anno_json)
classs_dict = {}
cat_ids = coco.loadCats(coco.getCatIds())
for cat in cat_ids:
classs_dict[cat["id"]] = cat["name"]
image_ids = coco.getImgIds()
images = []
image_path_dict = {}
image_anno_dict = {}
for img_id in image_ids:
image_info = coco.loadImgs(img_id)
file_name = image_info[0]["file_name"]
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = coco.loadAnns(anno_ids)
image_path = os.path.join(coco_root, data_type, file_name)
annos = []
iscrowd = False
for label in anno:
bbox = label["bbox"]
class_name = classs_dict[label["category_id"]]
iscrowd = iscrowd or label["iscrowd"]
if class_name in train_cls:
x_min, x_max = bbox[0], bbox[0] + bbox[2]
y_min, y_max = bbox[1], bbox[1] + bbox[3]
annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]])
if not is_training and iscrowd:
continue
if len(annos) >= 1:
images.append(img_id)
image_path_dict[img_id] = image_path
image_anno_dict[img_id] = np.array(annos)
return images, image_path_dict, image_anno_dict
def anno_parser(annos_str):
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos
def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path."""
images = []
image_path_dict = {}
image_anno_dict = {}
if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.")
if not os.path.isfile(anno_path):
raise RuntimeError("Annotation file is not valid.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for img_id, line in enumerate(lines):
line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ')
file_name = line_split[0]
image_path = os.path.join(image_dir, file_name)
if os.path.isfile(image_path):
images.append(img_id)
image_path_dict[img_id] = image_path
image_anno_dict[img_id] = anno_parser(line_split[1:])
return images, image_path_dict, image_anno_dict
def voc_data_to_mindrecord(mindrecord_dir, is_training, prefix="retinanet.mindrecord", file_num=8):
"""Create MindRecord file by image_dir and anno_path."""
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
images, image_path_dict, image_anno_dict = create_voc_label(is_training)
retinanet_json = {
"img_id": {"type": "int32", "shape": [1]},
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 5]},
}
writer.add_schema(retinanet_json, "retinanet_json")
for img_id in images:
image_path = image_path_dict[img_id]
with open(image_path, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[img_id], dtype=np.int32)
img_id = np.array([img_id], dtype=np.int32)
row = {"img_id": img_id, "image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="retina2.mindrecord", file_num=4):
"""Create MindRecord file."""
mindrecord_dir = config.mindrecord_dir
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
if dataset == "coco":
images, image_path_dict, image_anno_dict = create_coco_label(is_training)
else:
images, image_path_dict, image_anno_dict = filter_valid_data(config.image_dir, config.anno_path)
retinanet_json = {
"img_id": {"type": "int32", "shape": [1]},
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 5]},
}
writer.add_schema(retinanet_json, "retinanet_json")
for img_id in images:
image_path = image_path_dict[img_id]
with open(image_path, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[img_id], dtype=np.int32)
img_id = np.array([img_id], dtype=np.int32)
row = {"img_id": img_id, "image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def create_retinanet_dataset(mindrecord_file, batch_size, repeat_num, device_num=1, rank=0,
is_training=True, num_parallel_workers=64):
"""Creatr retinanet dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
decode = C.Decode()
ds = ds.map(operations=decode, input_columns=["image"])
change_swap_op = C.HWC2CHW()
normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training))
if is_training:
output_columns = ["image", "box", "label", "num_match"]
trans = [color_adjust_op, normalize_op, change_swap_op]
else:
output_columns = ["img_id", "image", "image_shape"]
trans = [normalize_op, change_swap_op]
ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"],
output_columns=output_columns, column_order=output_columns,
python_multiprocessing=is_training,
num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=is_training,
num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds
def create_mindrecord(dataset="coco", prefix="retinanet.mindrecord", is_training=True):
"""create_mindrecord"""
print("Start create dataset!")
# It will generate mindrecord file in config.mindrecord_dir,
# and the file name is retinanet.mindrecord0, 1, ... file_num.
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if dataset == "coco":
if os.path.isdir(config.coco_root):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", is_training, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
elif dataset == "voc":
if os.path.isdir(config.voc_dir):
print("Create Mindrecord.")
voc_data_to_mindrecord(mindrecord_dir, is_training, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("voc_dir not exits.")
else:
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", is_training, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("image_dir or anno_path not exits.")
return mindrecord_file

View File

@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parameters utils"""
from mindspore.common.initializer import initializer, TruncatedNormal
def init_net_param(network, initialize_mode='TruncatedNormal'):
"""Init the parameters in net."""
params = network.trainable_params()
for p in params:
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
if initialize_mode == 'TruncatedNormal':
p.set_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype))
else:
p.set_data(initialize_mode, p.data.shape, p.data.dtype)
def filter_checkpoint_parameter(param_dict):
"""remove useless parameters"""
for key in list(param_dict.keys()):
if 'multi_loc_layers' in key or 'multi_cls_layers' in key:
del param_dict[key]

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate schedule"""
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs1, warmup_epochs2, warmup_epochs3, warmup_epochs4,
warmup_epochs5, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(float): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps1 = steps_per_epoch * warmup_epochs1
warmup_steps2 = warmup_steps1 + steps_per_epoch * warmup_epochs2
warmup_steps3 = warmup_steps2 + steps_per_epoch * warmup_epochs3
warmup_steps4 = warmup_steps3 + steps_per_epoch * warmup_epochs4
warmup_steps5 = warmup_steps4 + steps_per_epoch * warmup_epochs5
for i in range(total_steps):
if i < warmup_steps1:
lr = lr_init * (warmup_steps1 - i) / (warmup_steps1) + (lr_max * 1e-4) * i / (warmup_steps1 * 3)
elif warmup_steps1 <= i < warmup_steps2:
lr = 1e-5 * (warmup_steps2 - i) / (warmup_steps2 - warmup_steps1) + (lr_max * 1e-3) * (
i - warmup_steps1) / (warmup_steps2 - warmup_steps1)
elif warmup_steps2 <= i < warmup_steps3:
lr = 1e-4 * (warmup_steps3 - i) / (warmup_steps3 - warmup_steps2) + (lr_max * 1e-2) * (
i - warmup_steps2) / (warmup_steps3 - warmup_steps2)
elif warmup_steps3 <= i < warmup_steps4:
lr = 1e-3 * (warmup_steps4 - i) / (warmup_steps4 - warmup_steps3) + (lr_max * 1e-1) * (
i - warmup_steps3) / (warmup_steps4 - warmup_steps3)
elif warmup_steps4 <= i < warmup_steps5:
lr = 1e-2 * (warmup_steps5 - i) / (warmup_steps5 - warmup_steps4) + lr_max * (i - warmup_steps4) / (
warmup_steps5 - warmup_steps4)
else:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps5) / (total_steps - warmup_steps5))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

View File

@ -0,0 +1,286 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""retinanet based resnet."""
import mindspore.common.dtype as mstype
import mindspore as ms
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from .bottleneck import FPN
class FlattenConcat(nn.Cell):
"""
Concatenate predictions into a single tensor.
Args:
config (dict): The default config of retinanet.
Returns:
Tensor, flatten predictions.
"""
def __init__(self, config):
super(FlattenConcat, self).__init__()
self.num_retinanet_boxes = config.num_retinanet_boxes
self.concat = P.Concat(axis=1)
self.transpose = P.Transpose()
def construct(self, inputs):
output = ()
batch_size = F.shape(inputs[0])[0]
for x in inputs:
x = self.transpose(x, (0, 2, 3, 1))
output += (F.reshape(x, (batch_size, -1)),)
res = self.concat(output)
return F.reshape(res, (batch_size, self.num_retinanet_boxes, -1))
def ClassificationModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', num_classes=81,
feature_size=256):
conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same')
conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
conv5 = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, pad_mode='same')
return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5])
def RegressionModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', feature_size=256):
conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same')
conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
conv5 = nn.Conv2d(feature_size, num_anchors * 4, kernel_size=3, pad_mode='same')
return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5])
class MultiBox(nn.Cell):
"""
Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.
Args:
config (dict): The default config of retinanet.
Returns:
Tensor, localization predictions.
Tensor, class conf scores.
"""
def __init__(self, config):
super(MultiBox, self).__init__()
out_channels = config.extras_out_channels
num_default = config.num_default
loc_layers = []
cls_layers = []
for k, out_channel in enumerate(out_channels):
loc_layers += [RegressionModel(in_channel=out_channel, num_anchors=num_default[k])]
cls_layers += [ClassificationModel(in_channel=out_channel, num_anchors=num_default[k])]
self.multi_loc_layers = nn.layer.CellList(loc_layers)
self.multi_cls_layers = nn.layer.CellList(cls_layers)
self.flatten_concat = FlattenConcat(config)
def construct(self, inputs):
loc_outputs = ()
cls_outputs = ()
for i in range(len(self.multi_loc_layers)):
loc_outputs += (self.multi_loc_layers[i](inputs[i]),)
cls_outputs += (self.multi_cls_layers[i](inputs[i]),)
return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs)
class SigmoidFocalClassificationLoss(nn.Cell):
""""
Sigmoid focal-loss for classification.
Args:
gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0
alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25
Returns:
Tensor, the focal loss.
"""
def __init__(self, gamma=2.0, alpha=0.25):
super(SigmoidFocalClassificationLoss, self).__init__()
self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.sigmoid = P.Sigmoid()
self.pow = P.Pow()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.gamma = gamma
self.alpha = alpha
def construct(self, logits, label):
label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label)
sigmoid = self.sigmoid(logits)
label = F.cast(label, mstype.float32)
p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
modulating_factor = self.pow(1 - p_t, self.gamma)
alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy
return focal_loss
class retinahead(nn.Cell):
"""retinahead"""
def __init__(self, backbone, config, is_training=True):
super(retinahead, self).__init__()
self.fpn = FPN(backbone=backbone, config=config)
self.multi_box = MultiBox(config)
self.is_training = is_training
if not is_training:
self.activation = P.Sigmoid()
def construct(self, inputs):
features = self.fpn(inputs)
pred_loc, pred_label = self.multi_box(features)
return pred_loc, pred_label
class retinanetWithLossCell(nn.Cell):
""""
Provide retinanet training loss through network.
Args:
network (Cell): The training network.
config (dict): retinanet config.
Returns:
Tensor, the loss of the network.
"""
def __init__(self, network, config):
super(retinanetWithLossCell, self).__init__()
self.network = network
self.less = P.Less()
self.tile = P.Tile()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.expand_dims = P.ExpandDims()
self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha)
self.loc_loss = nn.SmoothL1Loss()
def construct(self, x, gt_loc, gt_label, num_matched_boxes):
"""construct"""
pred_loc, pred_label = self.network(x)
mask = F.cast(self.less(0, gt_label), mstype.float32)
num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))
# Localization Loss
mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4))
smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc
loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1)
# Classification Loss
loss_cls = self.class_loss(pred_label, gt_label)
loss_cls = self.reduce_sum(loss_cls, (1, 2))
return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes)
class TrainingWrapper(nn.Cell):
"""
Encapsulation class of retinanet network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class retinanetInferWithDecoder(nn.Cell):
"""
retinanet Infer wrapper to decode the bbox locations.
Args:
network (Cell): the origin retinanet infer network without bbox decoder.
default_boxes (Tensor): the default_boxes from anchor generator
config (dict): retinanet config
Returns:
Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1)
Tensor, the prediction labels.
"""
def __init__(self, network, default_boxes, config):
super(retinanetInferWithDecoder, self).__init__()
self.network = network
self.default_boxes = default_boxes
self.prior_scaling_xy = config.prior_scaling[0]
self.prior_scaling_wh = config.prior_scaling[1]
def construct(self, x):
"""construct"""
pred_loc, pred_label = self.network(x)
default_bbox_xy = self.default_boxes[..., :2]
default_bbox_wh = self.default_boxes[..., 2:]
pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy
pred_wh = P.Exp()(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh
pred_xy_0 = pred_xy - pred_wh / 2.0
pred_xy_1 = pred_xy + pred_wh / 2.0
pred_xy = P.Concat(-1)((pred_xy_0, pred_xy_1))
pred_xy = P.Maximum()(pred_xy, 0)
pred_xy = P.Minimum()(pred_xy, 1)
return pred_xy, pred_label

View File

@ -0,0 +1,154 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train retinanet and get checkpoint files."""
import os
import argparse
import ast
import mindspore
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.communication.management import init, get_rank
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor, Callback
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.retinahead import retinanetWithLossCell, TrainingWrapper, retinahead
from src.backbone import resnet101
from src.config import config
from src.dataset import create_retinanet_dataset, create_mindrecord
from src.lr_schedule import get_lr
from src.init_params import init_net_param, filter_checkpoint_parameter
set_seed(1)
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
def step_end(self, run_context):
cb_params = run_context.original_args()
print("lr:[{:8.6f}]".format(self.lr_init[cb_params.cur_step_num-1]), flush=True)
def main():
parser = argparse.ArgumentParser(description="retinanet training")
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False,
help="If set it true, only create Mindrecord, default is False.")
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
help="Run distribute, default is False.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.")
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 1.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
help="Filter weight parameters, default is False.")
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
help="run platform, only support Ascend.")
args_opt = parser.parse_args()
if args_opt.run_platform == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if args_opt.distribute:
if os.getenv("DEVICE_ID", "not_set").isdigit():
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
init()
device_num = args_opt.device_num
rank = get_rank()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
else:
rank = 0
device_num = 1
context.set_context(device_id=args_opt.device_id)
else:
raise ValueError("Unsupported platform.")
mindrecord_file = create_mindrecord(args_opt.dataset, "retina2.mindrecord", True)
if not args_opt.only_create_dataset:
loss_scale = float(args_opt.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0.
dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
backbone = resnet101(config.num_classes)
retinanet = retinahead(backbone, config)
net = retinanetWithLossCell(retinanet, config)
net.to_float(mindspore.float16)
init_net_param(net)
if args_opt.pre_trained:
if args_opt.pre_trained_epoch_size <= 0:
raise KeyError("pre_trained_epoch_size must be greater than 0.")
param_dict = load_checkpoint(args_opt.pre_trained)
if args_opt.filter_weight:
filter_checkpoint_parameter(param_dict)
load_param_into_net(net, param_dict)
lr = Tensor(get_lr(global_step=config.global_step,
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2,
warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4,
warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size,
steps_per_epoch=dataset_size))
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
config.momentum, config.weight_decay, loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
model = Model(net)
print("Start train retinanet, the first epoch will be slower because of the graph compilation.")
cb = [TimeMonitor(), LossMonitor()]
cb += [Monitor(lr_init=lr.asnumpy())]
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck)
if args_opt.distribute:
if rank == 0:
cb += [ckpt_cb]
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
else:
cb += [ckpt_cb]
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
if __name__ == '__main__':
main()