forked from mindspore-Ecosystem/mindspore
add retinanet_resnet101
This commit is contained in:
parent
7454ac8ecd
commit
04f4423b03
|
@ -0,0 +1,315 @@
|
|||
# 1. 内容
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- <span id="content">[Retinanet 描述](#-Retinanet-描述)</span>
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本和示例代码](#脚本和示例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [用法](#用法)
|
||||
- [运行](#运行)
|
||||
- [结果](#结果)
|
||||
- [评估过程](#评估过程)
|
||||
- [用法](#usage)
|
||||
- [运行](#running)
|
||||
- [结果](#outcome)
|
||||
- [模型说明](#模型说明)
|
||||
- [性能](#性能)
|
||||
- [训练性能](#训练性能)
|
||||
- [推理性能](#推理性能)
|
||||
- [随机情况的描述](#随机情况的描述)
|
||||
- [ModelZoo 主页](#modelzoo-主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## [Retinanet 描述](#content)
|
||||
|
||||
RetinaNet算法源自2018年Facebook AI Research的论文 Focal Loss for Dense Object Detection。该论文最大的贡献在于提出了Focal Loss用于解决类别不均衡问题,从而创造了RetinaNet(One Stage目标检测算法)这个精度超越经典Two Stage的Faster-RCNN的目标检测网络。
|
||||
|
||||
[论文](https://arxiv.org/pdf/1708.02002.pdf)
|
||||
Lin T Y , Goyal P , Girshick R , et al. Focal Loss for Dense Object Detection[C]// 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, 2017:2999-3007.
|
||||
|
||||
## [模型架构](#content)
|
||||
|
||||
Retinanet的整体网络架构如下所示:
|
||||
|
||||
[链接](https://arxiv.org/pdf/1708.02002.pdf)
|
||||
|
||||
## [数据集](#content)
|
||||
|
||||
数据集可参考文献.
|
||||
|
||||
MSCOCO2017
|
||||
|
||||
- 数据集大小: 19.3G, 123287张80类彩色图像
|
||||
|
||||
- 训练:19.3G, 118287张图片
|
||||
|
||||
- 测试:1814.3M, 5000张图片
|
||||
|
||||
- 数据格式:RGB图像.
|
||||
|
||||
- 注意:数据将在src/dataset.py 中被处理
|
||||
|
||||
## [环境要求](#content)
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend处理器准备硬件环境。
|
||||
- 架构
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 想要获取更多信息,请检查以下资源:
|
||||
- [MindSpore 教程](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## [脚本说明](#content)
|
||||
|
||||
### [脚本和示例代码](#content)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─Retinanet_resnet101
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_single_train.sh # 使用Ascend环境单卡训练
|
||||
├─run_distribute_train.sh # 使用Ascend环境八卡并行训练
|
||||
├─run_eval.sh # 使用Ascend环境运行推理脚本
|
||||
├─src
|
||||
├─backbone.py # 网络模型定义
|
||||
├─bottleneck.py # 网络颈部定义
|
||||
├─config.py # 参数配置
|
||||
├─dataset.py # 数据预处理
|
||||
├─retinahead.py # 网络预测头部定义
|
||||
├─init_params.py # 参数初始化
|
||||
├─lr_generator.py # 学习率生成函数
|
||||
├─coco_eval # coco数据集评估
|
||||
├─box_utils.py # 先验框设置
|
||||
├─_init_.py # 初始化
|
||||
├─train.py # 网络训练脚本
|
||||
└─eval.py # 网络推理脚本
|
||||
|
||||
```
|
||||
|
||||
### [脚本参数](#content)
|
||||
|
||||
```python
|
||||
在train.py和config.py脚本中使用到的主要参数是:
|
||||
"img_shape": [600, 600], # 图像尺寸
|
||||
"num_retinanet_boxes": 67995, # 设置的先验框总数
|
||||
"match_thershold": 0.5, # 匹配阈值
|
||||
"softnms_sigma": 0.5, # softnms算法σ值
|
||||
"nms_thershold": 0.6, # 非极大抑制阈值
|
||||
"min_score": 0.1, # 最低得分
|
||||
"max_boxes": 100, # 检测框最大数量
|
||||
"global_step": 0, # 全局步数
|
||||
"lr_init": 1e-6, # 初始学习率
|
||||
"lr_end_rate": 5e-3, # 最终学习率与最大学习率的比值
|
||||
"warmup_epochs1": 2, # 第一阶段warmup的周期数
|
||||
"warmup_epochs2": 5, # 第二阶段warmup的周期数
|
||||
"warmup_epochs3": 23, # 第三阶段warmup的周期数
|
||||
"warmup_epochs4": 60, # 第四阶段warmup的周期数
|
||||
"warmup_epochs5": 160, # 第五阶段warmup的周期数
|
||||
"momentum": 0.9, # momentum
|
||||
"weight_decay": 1.5e-4, # 权重衰减率
|
||||
"num_default": [9, 9, 9, 9, 9], # 单个网格中先验框的个数
|
||||
"extras_out_channels": [256, 256, 256, 256, 256], # 特征层输出通道数
|
||||
"feature_size": [75, 38, 19, 10, 5], # 特征层尺寸
|
||||
"aspect_ratios": [(0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0)], # 先验框大小变化比值
|
||||
"steps": ( 8, 16, 32, 64, 128), # 先验框设置步长
|
||||
"anchor_size":(32, 64, 128, 256, 512), # 先验框尺寸
|
||||
"prior_scaling": (0.1, 0.2), # 用于调节回归与回归在loss中占的比值
|
||||
"gamma": 2.0, # focal loss中的参数
|
||||
"alpha": 0.75, # focal loss中的参数
|
||||
"mindrecord_dir": "/opr/root/data/MindRecord_COCO", # mindrecord文件路径
|
||||
"coco_root": "/opr/root/data/", # coco数据集路径
|
||||
"train_data_type": "train2017", # train图像的文件夹名
|
||||
"val_data_type": "val2017", # val图像的文件夹名
|
||||
"instances_set": "annotations_trainval2017/annotations/instances_{}.json", # 标签文件路径
|
||||
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', # coco数据集的种类
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
"num_classes": 81, # 数据集类别数
|
||||
"voc_root": "", # voc数据集路径
|
||||
"voc_dir": "",
|
||||
"image_dir": "", # 图像路径
|
||||
"anno_path": "", # 标签文件路径
|
||||
"save_checkpoint": True, # 保存checkpoint
|
||||
"save_checkpoint_epochs": 1, # 保存checkpoint epoch数
|
||||
"keep_checkpoint_max":1, # 保存checkpoint的最大数量
|
||||
"save_checkpoint_path": "./model", # 保存checkpoint的路径
|
||||
"finish_epoch":0, # 已经运行完成的 epoch 数
|
||||
"checkpoint_path":"/opr/root/reretina/retinanet2/LOG0/model/retinanet-400_458.ckpt" # 用于验证的checkpoint路径
|
||||
```
|
||||
|
||||
### [训练过程](#content)
|
||||
|
||||
#### 用法
|
||||
|
||||
您可以使用python或shell脚本进行训练。shell脚本的用法如下:
|
||||
|
||||
- Ascend:
|
||||
|
||||
```训练
|
||||
# 八卡并行训练示例:
|
||||
|
||||
创建 RANK_TABLE_FILE
|
||||
sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
|
||||
# 单卡训练示例:
|
||||
|
||||
sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
|
||||
```
|
||||
|
||||
> 注意:
|
||||
|
||||
RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
#### 运行
|
||||
|
||||
``` 运行
|
||||
# 训练示例
|
||||
|
||||
python:
|
||||
data和存储mindrecord文件的路径在config里设置
|
||||
|
||||
# 单卡训练示例:
|
||||
|
||||
python train.py
|
||||
shell:
|
||||
Ascend:
|
||||
|
||||
# 八卡并行训练示例(在retinanet目录下运行):
|
||||
|
||||
sh scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE(预训练EPOCH大小)
|
||||
例如:sh scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322
|
||||
|
||||
# 单卡训练示例(在retinanet目录下运行):
|
||||
|
||||
sh scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322
|
||||
|
||||
```
|
||||
|
||||
#### 结果
|
||||
|
||||
训练结果将存储在示例路径中。checkpoint将存储在 `./model` 路径下,训练日志将被记录到 `./log.txt` 中,训练日志部分示例如下:
|
||||
|
||||
``` 训练日志
|
||||
epoch: 397 step: 458, loss is 0.6153226
|
||||
lr:[0.000598]
|
||||
epoch time: 313364.642 ms, per step time: 684.202 ms
|
||||
epoch: 398 step: 458, loss is 0.5491791
|
||||
lr:[0.000544]
|
||||
epoch time: 313486.094 ms, per step time: 684.467 ms
|
||||
epoch: 399 step: 458, loss is 0.51681435
|
||||
lr:[0.000511]
|
||||
epoch time: 313514.348 ms, per step time: 684.529 ms
|
||||
epoch: 400 step: 458, loss is 0.4305706
|
||||
lr:[0.000500]
|
||||
epoch time: 314138.455 ms, per step time: 685.892 ms
|
||||
```
|
||||
|
||||
### [评估过程](#content)
|
||||
|
||||
#### <span id="usage">用法</span>
|
||||
|
||||
您可以使用python或shell脚本进行训练。shell脚本的用法如下:
|
||||
|
||||
```eval
|
||||
sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
|
||||
```
|
||||
|
||||
#### <span id="running">运行</span>
|
||||
|
||||
```eval运行
|
||||
# 验证示例
|
||||
|
||||
python:
|
||||
Ascend: python eval.py
|
||||
checkpoint 的路径在config里设置
|
||||
shell:
|
||||
Ascend: sh scripts/run_eval.sh coco 0
|
||||
```
|
||||
|
||||
> checkpoint 可以在训练过程中产生.
|
||||
|
||||
#### <span id="outcome">结果</span>
|
||||
|
||||
计算结果将存储在示例路径中,您可以在 `eval.log` 查看.
|
||||
|
||||
``` mAP
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.371
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.517
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.408
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.143
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.394
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.547
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.318
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.455
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.464
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.172
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.489
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.680
|
||||
|
||||
========================================
|
||||
|
||||
mAP: 0.3710347196613514
|
||||
```
|
||||
|
||||
## [模型说明](#content)
|
||||
|
||||
### [性能](#content)
|
||||
|
||||
#### 训练性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ------------------------------------- |
|
||||
| 模型名称 | Retinanet |
|
||||
| 运行环境 | 华为云 Modelarts |
|
||||
| 上传时间 | 10/03/2021 |
|
||||
| MindSpore 版本 | 1.0.1 |
|
||||
| 数据集 | 123287 张图片 |
|
||||
| Batch_size | 32 |
|
||||
| 训练参数 | src/config.py |
|
||||
| 优化器 | Momentum |
|
||||
| 损失函数 | Focal loss |
|
||||
| 最终损失 | 0.43 |
|
||||
| 精确度 (8p) | mAP[0.3710] |
|
||||
| 训练总时间 (8p) | 34h50m20s |
|
||||
| 脚本 | [Retianet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/Retinanet_resnet101) |
|
||||
|
||||
#### 推理性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------------- | :-------------------------- |
|
||||
| 模型名称 | Retinanet |
|
||||
| 运行环境 | 华为云 Modelarts |
|
||||
| 上传时间 | 10/03/2021 |
|
||||
| MindSpore 版本 | 1.0.1 |
|
||||
| 数据集 | 5k 张图片 |
|
||||
| Batch_size | 1 |
|
||||
| 精确度 | mAP[0.3710] |
|
||||
| 总时间 | 10 mins and 50 seconds |
|
||||
|
||||
# [随机情况的描述](#内容)
|
||||
|
||||
在 `dataset.py` 脚本中, 我们在 `create_dataset` 函数中设置了随机种子. 我们在 `train.py` 脚本中也设置了随机种子.
|
||||
|
||||
# [ModelZoo 主页](#内容)
|
||||
|
||||
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for retinanet"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.retinahead import retinahead, retinanetInferWithDecoder
|
||||
from src.backbone import resnet101
|
||||
from src.dataset import create_retinanet_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
|
||||
from src.config import config
|
||||
from src.coco_eval import metrics
|
||||
from src.box_utils import default_boxes
|
||||
|
||||
|
||||
def retinanet_eval(dataset_path, ckpt_path):
|
||||
"""retinanet evaluation."""
|
||||
batch_size = 1
|
||||
ds = create_retinanet_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False)
|
||||
backbone = resnet101(config.num_classes)
|
||||
net = retinahead(backbone, config)
|
||||
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
|
||||
print("Load Checkpoint!")
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
net.init_parameters_data()
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net.set_train(False)
|
||||
i = batch_size
|
||||
total = ds.get_dataset_size() * batch_size
|
||||
start = time.time()
|
||||
pred_data = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
for data in ds.create_dict_iterator(output_numpy=True):
|
||||
img_id = data['img_id']
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
|
||||
output = net(Tensor(img_np))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"img_id": int(np.squeeze(img_id[batch_idx])),
|
||||
"image_shape": image_shape[batch_idx]})
|
||||
percent = round(i / total * 100., 2)
|
||||
|
||||
print(f' {str(percent)} [{i}/{total}]', end='\r')
|
||||
i += batch_size
|
||||
cost_time = int((time.time() - start) * 1000)
|
||||
print(f' 100% [{total}/{total}] cost {cost_time} ms')
|
||||
mAP = metrics(pred_data)
|
||||
print("\n========================================\n")
|
||||
print(f"mAP: {mAP}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='retinanet evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=3, help="Device id, default is 0.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"),
|
||||
help="run platform, only support Ascend and GPU.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
|
||||
|
||||
prefix = "retinanet_eval.mindrecord"
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if args_opt.dataset == "voc":
|
||||
config.coco_root = config.voc_root
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if args_opt.dataset == "coco":
|
||||
if os.path.isdir(config.coco_root):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("coco", False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
elif args_opt.dataset == "voc":
|
||||
if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root):
|
||||
print("Create Mindrecord.")
|
||||
voc_data_to_mindrecord(mindrecord_dir, False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("voc_root or voc_dir not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("other", False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("IMAGE_DIR or ANNO_PATH not exits.")
|
||||
|
||||
print("Start Eval!")
|
||||
retinanet_eval(mindrecord_file, config.checkpoint_path)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Export file"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.retinahead import retinahead
|
||||
from src.backbone import resnet101
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description="retinanet_resnet101 export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="retinanet_resnet101", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
network = retinahead(backbone=resnet101(80), config=config, is_training=False)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
network.set_train(False)
|
||||
|
||||
shape = [args.batch_size, 3] + [600, 600]
|
||||
input_data = Tensor(np.zeros(shape), mstype.float32)
|
||||
|
||||
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,83 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh run_distribute_train.sh 8 500 0.1 coco /data/hccl.json /opt/retinanet-500_458.ckpt(optional) 200(optional)"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 5 ] && [ $# != 7 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \
|
||||
[RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Before start distribute train, first create mindrecord files.
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
python train.py --only_create_dataset=True
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
DATASET=$4
|
||||
PRE_TRAINED=$6
|
||||
PRE_TRAINED_EPOCH_SIZE=$7
|
||||
export RANK_TABLE_FILE=$5
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cp -r ./scripts ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 7 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,49 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATASET] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET=$1
|
||||
echo $DATASET
|
||||
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=$2
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
|
||||
if [ -d "eval$2" ];
|
||||
then
|
||||
rm -rf ./eval$2
|
||||
fi
|
||||
|
||||
mkdir ./eval$2
|
||||
cp ./*.py ./eval$2
|
||||
cp -r ./src ./eval$2
|
||||
cd ./eval$2 || exit
|
||||
env > env.log
|
||||
echo "start inferring for device $DEVICE_ID"
|
||||
python eval.py \
|
||||
--dataset=$DATASET \
|
||||
--device_id=$2 > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,57 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh run_single_train.sh 0 500 0.1 coco /opt/retinanet-500_458.ckpt(optional) 200(optional)"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 4 ] && [ $# != 6 ]
|
||||
then
|
||||
echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] [DATASET] \
|
||||
[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Before start single train, first create mindrecord files.
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
python train.py --only_create_dataset=True
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
DATASET=$4
|
||||
|
||||
rm -rf LOG$1
|
||||
mkdir ./LOG$1
|
||||
cp ./*.py ./LOG$1
|
||||
cp -r ./src ./LOG$1
|
||||
cd ./LOG$1 || exit
|
||||
echo "start training for device $1"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--distribute=False \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=1 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
cd ../
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""BackBone file"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.97,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Cell):
|
||||
"""
|
||||
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
out_planes (int): Output channel.
|
||||
kernel_size (int): Input kernel size.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
|
||||
"""
|
||||
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = 0
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same',
|
||||
padding=padding)
|
||||
layers = [conv, _bn(out_planes), nn.ReLU()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, stride=2)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
|
||||
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
|
||||
self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', padding=0,
|
||||
has_bn=True, activation='relu')
|
||||
|
||||
self.down_sample = False
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0, has_bn=True, activation='relu')
|
||||
self.add = P.TensorAdd()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class resnet(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
strides (list): Stride size in each layer.
|
||||
num_classes (int): The number of classes that the training images are belonging to.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> [1, 2, 2, 2],
|
||||
>>> 10)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes):
|
||||
super(resnet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1])
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2])
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3])
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
layers = []
|
||||
|
||||
resnet_block = ResidualBlock(in_channel, out_channel, stride=stride)
|
||||
layers.append(resnet_block)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = ResidualBlock(out_channel, out_channel, stride=1)
|
||||
layers.append(resnet_block)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
C1 = self.maxpool(x)
|
||||
|
||||
C2 = self.layer1(C1)
|
||||
C3 = self.layer2(C2)
|
||||
C4 = self.layer3(C3)
|
||||
C5 = self.layer4(C4)
|
||||
return C3, C4, C5
|
||||
|
||||
|
||||
def resnet101(num_classes):
|
||||
return resnet(ResidualBlock,
|
||||
[3, 4, 23, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
num_classes)
|
||||
|
||||
|
||||
def resnet152(num_classes):
|
||||
return resnet(ResidualBlock,
|
||||
[3, 8, 36, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
num_classes)
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bottleneck"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class FPN(nn.Cell):
|
||||
"""FPN"""
|
||||
def __init__(self, config, backbone, is_training=True):
|
||||
super(FPN, self).__init__()
|
||||
|
||||
self.backbone = backbone
|
||||
feature_size = config.feature_size
|
||||
self.P5_1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, pad_mode='same')
|
||||
self.P_upsample1 = P.ResizeNearestNeighbor((feature_size[1], feature_size[1]))
|
||||
self.P5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.P4_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, pad_mode='same')
|
||||
self.P_upsample2 = P.ResizeNearestNeighbor((feature_size[0], feature_size[0]))
|
||||
self.P4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.P3_1 = nn.Conv2d(512, 256, kernel_size=1, stride=1, pad_mode='same')
|
||||
self.P3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.P6_0 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, pad_mode='same')
|
||||
|
||||
self.P7_1 = nn.ReLU()
|
||||
self.P7_2 = nn.Conv2d(256, 256, kernel_size=3, stride=2, pad_mode='same')
|
||||
|
||||
self.is_training = is_training
|
||||
if not is_training:
|
||||
self.activation = P.Sigmoid()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
C3, C4, C5 = self.backbone(x)
|
||||
|
||||
P5 = self.P5_1(C5)
|
||||
P5_upsampled = self.P_upsample1(P5)
|
||||
P5 = self.P5_2(P5)
|
||||
|
||||
P4 = self.P4_1(C4)
|
||||
P4 = P5_upsampled + P4
|
||||
P4_upsampled = self.P_upsample2(P4)
|
||||
P4 = self.P4_2(P4)
|
||||
|
||||
P3 = self.P3_1(C3)
|
||||
P3 = P4_upsampled + P3
|
||||
P3 = self.P3_2(P3)
|
||||
|
||||
P6 = self.P6_0(C5)
|
||||
|
||||
P7 = self.P7_1(P6)
|
||||
P7 = self.P7_2(P7)
|
||||
multi_feature = (P3, P4, P5, P6, P7)
|
||||
|
||||
return multi_feature
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Bbox utils"""
|
||||
|
||||
import math
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
from .config import config
|
||||
|
||||
|
||||
class GeneratDefaultBoxes():
|
||||
"""
|
||||
Generate Default boxes for retinanet, follows the order of (W, H, archor_sizes).
|
||||
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
|
||||
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
fk = config.img_shape[0] / np.array(config.steps)
|
||||
scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
|
||||
anchor_size = np.array(config.anchor_size)
|
||||
self.default_boxes = []
|
||||
for idex, feature_size in enumerate(config.feature_size):
|
||||
base_size = anchor_size[idex] / config.img_shape[0]
|
||||
size1 = base_size * scales[0]
|
||||
size2 = base_size * scales[1]
|
||||
size3 = base_size * scales[2]
|
||||
all_sizes = []
|
||||
for aspect_ratio in config.aspect_ratios[idex]:
|
||||
w1, h1 = size1 * math.sqrt(aspect_ratio), size1 / math.sqrt(aspect_ratio)
|
||||
all_sizes.append((h1, w1))
|
||||
w2, h2 = size2 * math.sqrt(aspect_ratio), size2 / math.sqrt(aspect_ratio)
|
||||
all_sizes.append((h2, w2))
|
||||
w3, h3 = size3 * math.sqrt(aspect_ratio), size3 / math.sqrt(aspect_ratio)
|
||||
all_sizes.append((h3, w3))
|
||||
|
||||
assert len(all_sizes) == config.num_default[idex]
|
||||
|
||||
for i, j in it.product(range(feature_size), repeat=2):
|
||||
for h, w in all_sizes:
|
||||
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
|
||||
self.default_boxes.append([cy, cx, h, w])
|
||||
|
||||
def to_ltrb(cy, cx, h, w):
|
||||
return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2
|
||||
|
||||
# For IoU calculation
|
||||
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
|
||||
self.default_boxes = np.array(self.default_boxes, dtype='float32')
|
||||
|
||||
|
||||
default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
|
||||
default_boxes = GeneratDefaultBoxes().default_boxes
|
||||
y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
|
||||
vol_anchors = (x2 - x1) * (y2 - y1)
|
||||
matching_threshold = config.match_thershold
|
||||
|
||||
|
||||
def retinanet_bboxes_encode(boxes):
|
||||
"""
|
||||
Labels anchors with ground truth inputs.
|
||||
|
||||
Args:
|
||||
boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls].
|
||||
|
||||
Returns:
|
||||
gt_loc: location ground truth with shape [num_anchors, 4].
|
||||
gt_label: class ground truth with shape [num_anchors, 1].
|
||||
num_matched_boxes: number of positives in an image.
|
||||
"""
|
||||
|
||||
def jaccard_with_anchors(bbox):
|
||||
"""Compute jaccard score a box and the anchors."""
|
||||
# Intersection bbox and volume.
|
||||
ymin = np.maximum(y1, bbox[0])
|
||||
xmin = np.maximum(x1, bbox[1])
|
||||
ymax = np.minimum(y2, bbox[2])
|
||||
xmax = np.minimum(x2, bbox[3])
|
||||
w = np.maximum(xmax - xmin, 0.)
|
||||
h = np.maximum(ymax - ymin, 0.)
|
||||
|
||||
# Volumes.
|
||||
inter_vol = h * w
|
||||
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
|
||||
jaccard = inter_vol / union_vol
|
||||
return np.squeeze(jaccard)
|
||||
|
||||
pre_scores = np.zeros((config.num_retinanet_boxes), dtype=np.float32)
|
||||
t_boxes = np.zeros((config.num_retinanet_boxes, 4), dtype=np.float32)
|
||||
t_label = np.zeros((config.num_retinanet_boxes), dtype=np.int64)
|
||||
for bbox in boxes:
|
||||
label = int(bbox[4])
|
||||
scores = jaccard_with_anchors(bbox)
|
||||
idx = np.argmax(scores)
|
||||
scores[idx] = 2.0
|
||||
mask = (scores > matching_threshold)
|
||||
mask = mask & (scores > pre_scores)
|
||||
pre_scores = np.maximum(pre_scores, scores * mask)
|
||||
t_label = mask * label + (1 - mask) * t_label
|
||||
for i in range(4):
|
||||
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
|
||||
|
||||
index = np.nonzero(t_label)
|
||||
|
||||
# Transform to ltrb.
|
||||
bboxes = np.zeros((config.num_retinanet_boxes, 4), dtype=np.float32)
|
||||
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
|
||||
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
|
||||
|
||||
# Encode features.
|
||||
bboxes_t = bboxes[index]
|
||||
default_boxes_t = default_boxes[index]
|
||||
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.prior_scaling[0])
|
||||
tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001)
|
||||
bboxes_t[:, 2:4] = np.log(tmp) / config.prior_scaling[1]
|
||||
bboxes[index] = bboxes_t
|
||||
|
||||
num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
|
||||
return bboxes, t_label.astype(np.int32), num_match
|
||||
|
||||
|
||||
def retinanet_bboxes_decode(boxes):
|
||||
"""Decode predict boxes to [y, x, h, w]"""
|
||||
boxes_t = boxes.copy()
|
||||
default_boxes_t = default_boxes.copy()
|
||||
boxes_t[:, :2] = boxes_t[:, :2] * config.prior_scaling[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
|
||||
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.prior_scaling[1]) * default_boxes_t[:, 2:4]
|
||||
|
||||
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
|
||||
|
||||
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
|
||||
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
|
||||
|
||||
return np.clip(bboxes, 0, 1)
|
||||
|
||||
|
||||
def intersect(box_a, box_b):
|
||||
"""Compute the intersect of two sets of boxes."""
|
||||
max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
|
||||
min_yx = np.maximum(box_a[:, :2], box_b[:2])
|
||||
inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
|
||||
return inter[:, 0] * inter[:, 1]
|
||||
|
||||
|
||||
def jaccard_numpy(box_a, box_b):
|
||||
"""Compute the jaccard overlap of two sets of boxes."""
|
||||
inter = intersect(box_a, box_b)
|
||||
area_a = ((box_a[:, 2] - box_a[:, 0]) *
|
||||
(box_a[:, 3] - box_a[:, 1]))
|
||||
area_b = ((box_b[2] - box_b[0]) *
|
||||
(box_b[3] - box_b[1]))
|
||||
union = area_a + area_b - inter
|
||||
return inter / union
|
|
@ -0,0 +1,196 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Coco metrics utils"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from .config import config
|
||||
|
||||
|
||||
def apply_softnms(dets, scores, sigma=0.5, method=2, thresh=0.001, Nt=0.1):
|
||||
'''
|
||||
the soft nms implement using python
|
||||
:param dets: the pred_bboxes
|
||||
:param method: the policy of decay pred_bbox score in soft nms
|
||||
:param thresh: the threshold
|
||||
:param Nt: Nt
|
||||
:return: the index of pred_bbox after soft nms
|
||||
'''
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = dets[:, 2]
|
||||
y2 = dets[:, 3]
|
||||
|
||||
areas = (y2 - y1 + 1.) * (x2 - x1 + 1.)
|
||||
orders = scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while orders.size > 0:
|
||||
|
||||
i = orders[0]
|
||||
keep.append(i)
|
||||
|
||||
for j in orders[1:]:
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[j])
|
||||
yy1 = np.maximum(y1[i], y1[j])
|
||||
xx2 = np.minimum(x2[i], x2[j])
|
||||
yy2 = np.minimum(y2[i], y2[j])
|
||||
w = np.maximum(xx2 - xx1 + 1., 0.)
|
||||
h = np.maximum(yy2 - yy1 + 1., 0.)
|
||||
|
||||
inter = w * h
|
||||
overlap = inter / (areas[i] + areas[j] - inter)
|
||||
|
||||
if method == 1: # linear
|
||||
|
||||
if overlap > Nt:
|
||||
|
||||
weight = 1 - overlap
|
||||
|
||||
else:
|
||||
|
||||
weight = 1
|
||||
|
||||
elif method == 2: # gaussian
|
||||
|
||||
weight = np.exp(-(overlap * overlap) / sigma)
|
||||
|
||||
else: # original NMS
|
||||
|
||||
if overlap > Nt:
|
||||
|
||||
weight = 0
|
||||
|
||||
else:
|
||||
|
||||
weight = 1
|
||||
|
||||
scores[j] = weight * scores[j]
|
||||
|
||||
if scores[j] < thresh:
|
||||
orders = np.delete(orders, np.where(orders == j))
|
||||
|
||||
orders = np.delete(orders, 0)
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
y1 = all_boxes[:, 0]
|
||||
x1 = all_boxes[:, 1]
|
||||
y2 = all_boxes[:, 2]
|
||||
x2 = all_boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
order = all_scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
if len(keep) >= max_boxes:
|
||||
break
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def metrics(pred_data):
|
||||
"""Calculate mAP of predicted bboxes."""
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
num_classes = config.num_classes
|
||||
|
||||
coco_root = config.coco_root
|
||||
data_type = config.val_data_type
|
||||
|
||||
# Classes need to train or test.
|
||||
val_cls = config.coco_classes
|
||||
val_cls_dict = {}
|
||||
for i, cls in enumerate(val_cls):
|
||||
val_cls_dict[i] = cls
|
||||
|
||||
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
|
||||
coco_gt = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["name"]] = cat["id"]
|
||||
|
||||
predictions = []
|
||||
img_ids = []
|
||||
|
||||
for sample in pred_data:
|
||||
pred_boxes = sample['boxes']
|
||||
box_scores = sample['box_scores']
|
||||
img_id = sample['img_id']
|
||||
h, w = sample['image_shape']
|
||||
|
||||
final_boxes = []
|
||||
final_label = []
|
||||
final_score = []
|
||||
img_ids.append(img_id)
|
||||
|
||||
for c in range(1, num_classes):
|
||||
class_box_scores = box_scores[:, c]
|
||||
score_mask = class_box_scores > config.min_score
|
||||
class_box_scores = class_box_scores[score_mask]
|
||||
class_boxes = pred_boxes[score_mask] * [h, w, h, w]
|
||||
|
||||
if score_mask.any():
|
||||
# nms_index = apply_nms(class_boxes, class_box_scores, config.nms_thershold, config.max_boxes)
|
||||
# apply_softnms( dets, scores,method=2, thresh=0.001, Nt=0.1, sigma=0.5 )
|
||||
nms_index = apply_softnms(class_boxes, class_box_scores, config.softnms_sigma)
|
||||
class_boxes = class_boxes[nms_index]
|
||||
class_box_scores = class_box_scores[nms_index]
|
||||
|
||||
final_boxes += class_boxes.tolist()
|
||||
final_score += class_box_scores.tolist()
|
||||
final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores)
|
||||
|
||||
for loc, label, score in zip(final_boxes, final_label, final_score):
|
||||
res = {}
|
||||
res['image_id'] = img_id
|
||||
res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]
|
||||
res['score'] = score
|
||||
res['category_id'] = label
|
||||
predictions.append(res)
|
||||
with open('predictions.json', 'w') as f:
|
||||
json.dump(predictions, f)
|
||||
|
||||
coco_dt = coco_gt.loadRes('predictions.json')
|
||||
E = COCOeval(coco_gt, coco_dt, iouType='bbox')
|
||||
E.params.imgIds = img_ids
|
||||
E.evaluate()
|
||||
E.accumulate()
|
||||
E.summarize()
|
||||
return E.stats[0]
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# " ============================================================================
|
||||
|
||||
"""Config parameters for retinanet models."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"img_shape": [600, 600],
|
||||
"num_retinanet_boxes": 67995,
|
||||
"match_thershold": 0.5,
|
||||
"softnms_sigma": 0.5,
|
||||
"nms_thershold": 0.6,
|
||||
"min_score": 0.1,
|
||||
"max_boxes": 100,
|
||||
|
||||
# learing rate settings
|
||||
"global_step": 0,
|
||||
"lr_init": 1e-6,
|
||||
"lr_end_rate": 5e-3,
|
||||
"warmup_epochs1": 2,
|
||||
"warmup_epochs2": 5,
|
||||
"warmup_epochs3": 23,
|
||||
"warmup_epochs4": 60,
|
||||
"warmup_epochs5": 160,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1.5e-4,
|
||||
|
||||
# network
|
||||
"num_default": [9, 9, 9, 9, 9],
|
||||
"extras_out_channels": [256, 256, 256, 256, 256],
|
||||
"feature_size": [75, 38, 19, 10, 5],
|
||||
"aspect_ratios": [(0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0)],
|
||||
"steps": (8, 16, 32, 64, 128),
|
||||
"anchor_size": (32, 64, 128, 256, 512),
|
||||
"prior_scaling": (0.1, 0.2),
|
||||
"gamma": 2.0,
|
||||
"alpha": 0.75,
|
||||
|
||||
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||
"mindrecord_dir": "/opr/root/data/MindRecord_COCO",
|
||||
"coco_root": "/opr/root/data/",
|
||||
"train_data_type": "train2017",
|
||||
"val_data_type": "val2017",
|
||||
"instances_set": "anno/instances_{}.json",
|
||||
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
"num_classes": 81,
|
||||
# The annotation.json position of voc validation dataset.
|
||||
"voc_root": "",
|
||||
# voc original dataset.
|
||||
"voc_dir": "",
|
||||
# if coco or voc used, `image_dir` and `anno_path` are useless.
|
||||
"image_dir": "",
|
||||
"anno_path": "",
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 5,
|
||||
"keep_checkpoint_max": 30,
|
||||
"save_checkpoint_path": "./model",
|
||||
"finish_epoch": 0,
|
||||
"checkpoint_path": "/opr/root/reretina/retinanet2/LOG0/model/retinanet-400_458.ckpt"
|
||||
})
|
|
@ -0,0 +1,454 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""retinanet dataset"""
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import json
|
||||
import xml.etree.ElementTree as et
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from .config import config
|
||||
from .box_utils import jaccard_numpy, retinanet_bboxes_encode
|
||||
|
||||
|
||||
def _rand(a=0., b=1.):
|
||||
"""Generate random."""
|
||||
return np.random.rand() * (b - a) + a
|
||||
|
||||
|
||||
def get_imageId_from_fileName(filename):
|
||||
"""Get imageID from fileName"""
|
||||
filename = os.path.splitext(filename)[0]
|
||||
if filename.isdigit():
|
||||
return int(filename)
|
||||
return id_iter
|
||||
|
||||
|
||||
def random_sample_crop(image, boxes):
|
||||
"""Random Crop the image and boxes"""
|
||||
height, width, _ = image.shape
|
||||
min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9])
|
||||
|
||||
if min_iou is None:
|
||||
return image, boxes
|
||||
|
||||
# max trails (50)
|
||||
for _ in range(50):
|
||||
image_t = image
|
||||
|
||||
w = _rand(0.3, 1.0) * width
|
||||
h = _rand(0.3, 1.0) * height
|
||||
|
||||
# aspect ratio constraint b/t .5 & 2
|
||||
if h / w < 0.5 or h / w > 2:
|
||||
continue
|
||||
|
||||
left = _rand() * (width - w)
|
||||
top = _rand() * (height - h)
|
||||
|
||||
rect = np.array([int(top), int(left), int(top + h), int(left + w)])
|
||||
overlap = jaccard_numpy(boxes, rect)
|
||||
|
||||
# dropout some boxes
|
||||
drop_mask = overlap > 0
|
||||
if not drop_mask.any():
|
||||
continue
|
||||
|
||||
if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2):
|
||||
continue
|
||||
|
||||
image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :]
|
||||
|
||||
centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0
|
||||
|
||||
m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
|
||||
m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
|
||||
|
||||
# mask in that both m1 and m2 are true
|
||||
mask = m1 * m2 * drop_mask
|
||||
|
||||
# have any valid boxes? try again if not
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
# take only matching gt boxes
|
||||
boxes_t = boxes[mask, :].copy()
|
||||
|
||||
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2])
|
||||
boxes_t[:, :2] -= rect[:2]
|
||||
boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4])
|
||||
boxes_t[:, 2:4] -= rect[:2]
|
||||
|
||||
return image_t, boxes_t
|
||||
return image, boxes
|
||||
|
||||
|
||||
def preprocess_fn(img_id, image, box, is_training):
|
||||
"""Preprocess function for dataset."""
|
||||
cv2.setNumThreads(2)
|
||||
|
||||
def _infer_data(image, input_shape):
|
||||
img_h, img_w, _ = image.shape
|
||||
input_h, input_w = input_shape
|
||||
|
||||
image = cv2.resize(image, (input_w, input_h))
|
||||
|
||||
# When the channels of image is 1
|
||||
if len(image.shape) == 2:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
image = np.concatenate([image, image, image], axis=-1)
|
||||
|
||||
return img_id, image, np.array((img_h, img_w), np.float32)
|
||||
|
||||
def _data_aug(image, box, is_training, image_size=(600, 600)):
|
||||
"""Data augmentation function."""
|
||||
ih, iw, _ = image.shape
|
||||
w, h = image_size
|
||||
|
||||
if not is_training:
|
||||
return _infer_data(image, image_size)
|
||||
|
||||
# Random crop
|
||||
box = box.astype(np.float32)
|
||||
image, box = random_sample_crop(image, box)
|
||||
ih, iw, _ = image.shape
|
||||
|
||||
# Resize image
|
||||
image = cv2.resize(image, (w, h))
|
||||
|
||||
# Flip image or not
|
||||
flip = _rand() < .5
|
||||
if flip:
|
||||
image = cv2.flip(image, 1, dst=None)
|
||||
|
||||
# When the channels of image is 1
|
||||
if len(image.shape) == 2:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
image = np.concatenate([image, image, image], axis=-1)
|
||||
|
||||
box[:, [0, 2]] = box[:, [0, 2]] / ih
|
||||
box[:, [1, 3]] = box[:, [1, 3]] / iw
|
||||
|
||||
if flip:
|
||||
box[:, [1, 3]] = 1 - box[:, [3, 1]]
|
||||
|
||||
box, label, num_match = retinanet_bboxes_encode(box)
|
||||
return image, box, label, num_match
|
||||
|
||||
return _data_aug(image, box, is_training, image_size=config.img_shape)
|
||||
|
||||
|
||||
def create_voc_label(is_training):
|
||||
"""Get image path and annotation from VOC."""
|
||||
voc_dir = config.voc_dir
|
||||
cls_map = {name: i for i, name in enumerate(config.coco_classes)}
|
||||
sub_dir = 'train' if is_training else 'eval'
|
||||
voc_dir = os.path.join(voc_dir, sub_dir)
|
||||
if not os.path.isdir(voc_dir):
|
||||
raise ValueError(f'Cannot find {sub_dir} dataset path.')
|
||||
|
||||
image_dir = anno_dir = voc_dir
|
||||
if os.path.isdir(os.path.join(voc_dir, 'Images')):
|
||||
image_dir = os.path.join(voc_dir, 'Images')
|
||||
if os.path.isdir(os.path.join(voc_dir, 'Annotations')):
|
||||
anno_dir = os.path.join(voc_dir, 'Annotations')
|
||||
|
||||
if not is_training:
|
||||
data_dir = config.voc_root
|
||||
json_file = os.path.join(data_dir, config.instances_set.format(sub_dir))
|
||||
file_dir = os.path.split(json_file)[0]
|
||||
if not os.path.isdir(file_dir):
|
||||
os.makedirs(file_dir)
|
||||
json_dict = {"images": [], "type": "instances", "annotations": [],
|
||||
"categories": []}
|
||||
bnd_id = 1
|
||||
|
||||
image_files_dict = {}
|
||||
image_anno_dict = {}
|
||||
images = []
|
||||
for anno_file in os.listdir(anno_dir):
|
||||
print(anno_file)
|
||||
if not anno_file.endswith('xml'):
|
||||
continue
|
||||
tree = et.parse(os.path.join(anno_dir, anno_file))
|
||||
root_node = tree.getroot()
|
||||
file_name = root_node.find('filename').text
|
||||
img_id = get_imageId_from_fileName(file_name)
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
print(image_path)
|
||||
if not os.path.isfile(image_path):
|
||||
print(f'Cannot find image {file_name} according to annotations.')
|
||||
continue
|
||||
|
||||
labels = []
|
||||
for obj in root_node.iter('object'):
|
||||
cls_name = obj.find('name').text
|
||||
if cls_name not in cls_map:
|
||||
print(f'Label "{cls_name}" not in "{config.coco_classes}"')
|
||||
continue
|
||||
bnd_box = obj.find('bndbox')
|
||||
x_min = int(bnd_box.find('xmin').text) - 1
|
||||
y_min = int(bnd_box.find('ymin').text) - 1
|
||||
x_max = int(bnd_box.find('xmax').text) - 1
|
||||
y_max = int(bnd_box.find('ymax').text) - 1
|
||||
labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]])
|
||||
|
||||
if not is_training:
|
||||
o_width = abs(x_max - x_min)
|
||||
o_height = abs(y_max - y_min)
|
||||
ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id': \
|
||||
img_id, 'bbox': [x_min, y_min, o_width, o_height], \
|
||||
'category_id': cls_map[cls_name], 'id': bnd_id, \
|
||||
'ignore': 0, \
|
||||
'segmentation': []}
|
||||
json_dict['annotations'].append(ann)
|
||||
bnd_id = bnd_id + 1
|
||||
|
||||
if labels:
|
||||
images.append(img_id)
|
||||
image_files_dict[img_id] = image_path
|
||||
image_anno_dict[img_id] = np.array(labels)
|
||||
|
||||
if not is_training:
|
||||
size = root_node.find("size")
|
||||
width = int(size.find('width').text)
|
||||
height = int(size.find('height').text)
|
||||
image = {'file_name': file_name, 'height': height, 'width': width,
|
||||
'id': img_id}
|
||||
json_dict['images'].append(image)
|
||||
|
||||
if not is_training:
|
||||
for cls_name, cid in cls_map.items():
|
||||
cat = {'supercategory': 'none', 'id': cid, 'name': cls_name}
|
||||
json_dict['categories'].append(cat)
|
||||
json_fp = open(json_file, 'w')
|
||||
json_str = json.dumps(json_dict)
|
||||
json_fp.write(json_str)
|
||||
json_fp.close()
|
||||
|
||||
return images, image_files_dict, image_anno_dict
|
||||
|
||||
|
||||
def create_coco_label(is_training):
|
||||
"""Get image path and annotation from COCO."""
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
coco_root = config.coco_root
|
||||
data_type = config.val_data_type
|
||||
if is_training:
|
||||
data_type = config.train_data_type
|
||||
|
||||
# Classes need to train or test.
|
||||
train_cls = config.coco_classes
|
||||
train_cls_dict = {}
|
||||
for i, cls in enumerate(train_cls):
|
||||
train_cls_dict[cls] = i
|
||||
|
||||
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
|
||||
|
||||
coco = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco.loadCats(coco.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["id"]] = cat["name"]
|
||||
|
||||
image_ids = coco.getImgIds()
|
||||
images = []
|
||||
image_path_dict = {}
|
||||
image_anno_dict = {}
|
||||
|
||||
for img_id in image_ids:
|
||||
image_info = coco.loadImgs(img_id)
|
||||
file_name = image_info[0]["file_name"]
|
||||
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = coco.loadAnns(anno_ids)
|
||||
image_path = os.path.join(coco_root, data_type, file_name)
|
||||
annos = []
|
||||
iscrowd = False
|
||||
for label in anno:
|
||||
bbox = label["bbox"]
|
||||
class_name = classs_dict[label["category_id"]]
|
||||
iscrowd = iscrowd or label["iscrowd"]
|
||||
if class_name in train_cls:
|
||||
x_min, x_max = bbox[0], bbox[0] + bbox[2]
|
||||
y_min, y_max = bbox[1], bbox[1] + bbox[3]
|
||||
annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]])
|
||||
|
||||
if not is_training and iscrowd:
|
||||
continue
|
||||
if len(annos) >= 1:
|
||||
images.append(img_id)
|
||||
image_path_dict[img_id] = image_path
|
||||
image_anno_dict[img_id] = np.array(annos)
|
||||
|
||||
return images, image_path_dict, image_anno_dict
|
||||
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
annos.append(anno)
|
||||
return annos
|
||||
|
||||
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
images = []
|
||||
image_path_dict = {}
|
||||
image_anno_dict = {}
|
||||
if not os.path.isdir(image_dir):
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
if not os.path.isfile(anno_path):
|
||||
raise RuntimeError("Annotation file is not valid.")
|
||||
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for img_id, line in enumerate(lines):
|
||||
line_str = line.decode("utf-8").strip()
|
||||
line_split = str(line_str).split(' ')
|
||||
file_name = line_split[0]
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
if os.path.isfile(image_path):
|
||||
images.append(img_id)
|
||||
image_path_dict[img_id] = image_path
|
||||
image_anno_dict[img_id] = anno_parser(line_split[1:])
|
||||
|
||||
return images, image_path_dict, image_anno_dict
|
||||
|
||||
|
||||
def voc_data_to_mindrecord(mindrecord_dir, is_training, prefix="retinanet.mindrecord", file_num=8):
|
||||
"""Create MindRecord file by image_dir and anno_path."""
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
images, image_path_dict, image_anno_dict = create_voc_label(is_training)
|
||||
|
||||
retinanet_json = {
|
||||
"img_id": {"type": "int32", "shape": [1]},
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(retinanet_json, "retinanet_json")
|
||||
|
||||
for img_id in images:
|
||||
image_path = image_path_dict[img_id]
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[img_id], dtype=np.int32)
|
||||
img_id = np.array([img_id], dtype=np.int32)
|
||||
row = {"img_id": img_id, "image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="retina2.mindrecord", file_num=4):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
if dataset == "coco":
|
||||
images, image_path_dict, image_anno_dict = create_coco_label(is_training)
|
||||
else:
|
||||
images, image_path_dict, image_anno_dict = filter_valid_data(config.image_dir, config.anno_path)
|
||||
|
||||
retinanet_json = {
|
||||
"img_id": {"type": "int32", "shape": [1]},
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(retinanet_json, "retinanet_json")
|
||||
|
||||
for img_id in images:
|
||||
image_path = image_path_dict[img_id]
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[img_id], dtype=np.int32)
|
||||
img_id = np.array([img_id], dtype=np.int32)
|
||||
row = {"img_id": img_id, "image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def create_retinanet_dataset(mindrecord_file, batch_size, repeat_num, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=64):
|
||||
"""Creatr retinanet dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
|
||||
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(operations=decode, input_columns=["image"])
|
||||
change_swap_op = C.HWC2CHW()
|
||||
normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
||||
color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||
compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training))
|
||||
if is_training:
|
||||
output_columns = ["image", "box", "label", "num_match"]
|
||||
trans = [color_adjust_op, normalize_op, change_swap_op]
|
||||
else:
|
||||
output_columns = ["img_id", "image", "image_shape"]
|
||||
trans = [normalize_op, change_swap_op]
|
||||
ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"],
|
||||
output_columns=output_columns, column_order=output_columns,
|
||||
python_multiprocessing=is_training,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=is_training,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
||||
|
||||
|
||||
def create_mindrecord(dataset="coco", prefix="retinanet.mindrecord", is_training=True):
|
||||
"""create_mindrecord"""
|
||||
print("Start create dataset!")
|
||||
|
||||
# It will generate mindrecord file in config.mindrecord_dir,
|
||||
# and the file name is retinanet.mindrecord0, 1, ... file_num.
|
||||
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if dataset == "coco":
|
||||
if os.path.isdir(config.coco_root):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("coco", is_training, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
elif dataset == "voc":
|
||||
if os.path.isdir(config.voc_dir):
|
||||
print("Create Mindrecord.")
|
||||
voc_data_to_mindrecord(mindrecord_dir, is_training, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("voc_dir not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("other", is_training, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("image_dir or anno_path not exits.")
|
||||
return mindrecord_file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parameters utils"""
|
||||
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
|
||||
|
||||
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
||||
"""Init the parameters in net."""
|
||||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
if initialize_mode == 'TruncatedNormal':
|
||||
p.set_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype))
|
||||
else:
|
||||
p.set_data(initialize_mode, p.data.shape, p.data.dtype)
|
||||
|
||||
|
||||
def filter_checkpoint_parameter(param_dict):
|
||||
"""remove useless parameters"""
|
||||
for key in list(param_dict.keys()):
|
||||
if 'multi_loc_layers' in key or 'multi_cls_layers' in key:
|
||||
del param_dict[key]
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Learning rate schedule"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs1, warmup_epochs2, warmup_epochs3, warmup_epochs4,
|
||||
warmup_epochs5, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): total steps of the training
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(float): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps1 = steps_per_epoch * warmup_epochs1
|
||||
warmup_steps2 = warmup_steps1 + steps_per_epoch * warmup_epochs2
|
||||
warmup_steps3 = warmup_steps2 + steps_per_epoch * warmup_epochs3
|
||||
warmup_steps4 = warmup_steps3 + steps_per_epoch * warmup_epochs4
|
||||
warmup_steps5 = warmup_steps4 + steps_per_epoch * warmup_epochs5
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps1:
|
||||
lr = lr_init * (warmup_steps1 - i) / (warmup_steps1) + (lr_max * 1e-4) * i / (warmup_steps1 * 3)
|
||||
elif warmup_steps1 <= i < warmup_steps2:
|
||||
lr = 1e-5 * (warmup_steps2 - i) / (warmup_steps2 - warmup_steps1) + (lr_max * 1e-3) * (
|
||||
i - warmup_steps1) / (warmup_steps2 - warmup_steps1)
|
||||
elif warmup_steps2 <= i < warmup_steps3:
|
||||
lr = 1e-4 * (warmup_steps3 - i) / (warmup_steps3 - warmup_steps2) + (lr_max * 1e-2) * (
|
||||
i - warmup_steps2) / (warmup_steps3 - warmup_steps2)
|
||||
elif warmup_steps3 <= i < warmup_steps4:
|
||||
lr = 1e-3 * (warmup_steps4 - i) / (warmup_steps4 - warmup_steps3) + (lr_max * 1e-1) * (
|
||||
i - warmup_steps3) / (warmup_steps4 - warmup_steps3)
|
||||
elif warmup_steps4 <= i < warmup_steps5:
|
||||
lr = 1e-2 * (warmup_steps5 - i) / (warmup_steps5 - warmup_steps4) + lr_max * (i - warmup_steps4) / (
|
||||
warmup_steps5 - warmup_steps4)
|
||||
else:
|
||||
lr = lr_end + \
|
||||
(lr_max - lr_end) * \
|
||||
(1. + math.cos(math.pi * (i - warmup_steps5) / (total_steps - warmup_steps5))) / 2.
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
|
@ -0,0 +1,286 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""retinanet based resnet."""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
from .bottleneck import FPN
|
||||
|
||||
|
||||
class FlattenConcat(nn.Cell):
|
||||
"""
|
||||
Concatenate predictions into a single tensor.
|
||||
|
||||
Args:
|
||||
config (dict): The default config of retinanet.
|
||||
|
||||
Returns:
|
||||
Tensor, flatten predictions.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(FlattenConcat, self).__init__()
|
||||
self.num_retinanet_boxes = config.num_retinanet_boxes
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, inputs):
|
||||
output = ()
|
||||
batch_size = F.shape(inputs[0])[0]
|
||||
for x in inputs:
|
||||
x = self.transpose(x, (0, 2, 3, 1))
|
||||
output += (F.reshape(x, (batch_size, -1)),)
|
||||
res = self.concat(output)
|
||||
return F.reshape(res, (batch_size, self.num_retinanet_boxes, -1))
|
||||
|
||||
|
||||
def ClassificationModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', num_classes=81,
|
||||
feature_size=256):
|
||||
conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same')
|
||||
conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
|
||||
conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
|
||||
conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
|
||||
conv5 = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, pad_mode='same')
|
||||
return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5])
|
||||
|
||||
|
||||
def RegressionModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', feature_size=256):
|
||||
conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same')
|
||||
conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
|
||||
conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
|
||||
conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
|
||||
conv5 = nn.Conv2d(feature_size, num_anchors * 4, kernel_size=3, pad_mode='same')
|
||||
return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5])
|
||||
|
||||
|
||||
class MultiBox(nn.Cell):
|
||||
"""
|
||||
Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.
|
||||
|
||||
Args:
|
||||
config (dict): The default config of retinanet.
|
||||
|
||||
Returns:
|
||||
Tensor, localization predictions.
|
||||
Tensor, class conf scores.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(MultiBox, self).__init__()
|
||||
|
||||
out_channels = config.extras_out_channels
|
||||
num_default = config.num_default
|
||||
loc_layers = []
|
||||
cls_layers = []
|
||||
for k, out_channel in enumerate(out_channels):
|
||||
loc_layers += [RegressionModel(in_channel=out_channel, num_anchors=num_default[k])]
|
||||
cls_layers += [ClassificationModel(in_channel=out_channel, num_anchors=num_default[k])]
|
||||
|
||||
self.multi_loc_layers = nn.layer.CellList(loc_layers)
|
||||
self.multi_cls_layers = nn.layer.CellList(cls_layers)
|
||||
self.flatten_concat = FlattenConcat(config)
|
||||
|
||||
def construct(self, inputs):
|
||||
loc_outputs = ()
|
||||
cls_outputs = ()
|
||||
for i in range(len(self.multi_loc_layers)):
|
||||
loc_outputs += (self.multi_loc_layers[i](inputs[i]),)
|
||||
cls_outputs += (self.multi_cls_layers[i](inputs[i]),)
|
||||
return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs)
|
||||
|
||||
|
||||
class SigmoidFocalClassificationLoss(nn.Cell):
|
||||
""""
|
||||
Sigmoid focal-loss for classification.
|
||||
|
||||
Args:
|
||||
gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0
|
||||
alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25
|
||||
|
||||
Returns:
|
||||
Tensor, the focal loss.
|
||||
"""
|
||||
|
||||
def __init__(self, gamma=2.0, alpha=0.25):
|
||||
super(SigmoidFocalClassificationLoss, self).__init__()
|
||||
self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.pow = P.Pow()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
|
||||
def construct(self, logits, label):
|
||||
label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
|
||||
sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label)
|
||||
sigmoid = self.sigmoid(logits)
|
||||
label = F.cast(label, mstype.float32)
|
||||
p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
|
||||
modulating_factor = self.pow(1 - p_t, self.gamma)
|
||||
alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
|
||||
focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy
|
||||
return focal_loss
|
||||
|
||||
|
||||
class retinahead(nn.Cell):
|
||||
"""retinahead"""
|
||||
def __init__(self, backbone, config, is_training=True):
|
||||
super(retinahead, self).__init__()
|
||||
|
||||
self.fpn = FPN(backbone=backbone, config=config)
|
||||
self.multi_box = MultiBox(config)
|
||||
self.is_training = is_training
|
||||
if not is_training:
|
||||
self.activation = P.Sigmoid()
|
||||
|
||||
def construct(self, inputs):
|
||||
features = self.fpn(inputs)
|
||||
pred_loc, pred_label = self.multi_box(features)
|
||||
return pred_loc, pred_label
|
||||
|
||||
|
||||
class retinanetWithLossCell(nn.Cell):
|
||||
""""
|
||||
Provide retinanet training loss through network.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
config (dict): retinanet config.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
|
||||
def __init__(self, network, config):
|
||||
super(retinanetWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.less = P.Less()
|
||||
self.tile = P.Tile()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha)
|
||||
self.loc_loss = nn.SmoothL1Loss()
|
||||
|
||||
def construct(self, x, gt_loc, gt_label, num_matched_boxes):
|
||||
"""construct"""
|
||||
pred_loc, pred_label = self.network(x)
|
||||
mask = F.cast(self.less(0, gt_label), mstype.float32)
|
||||
num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))
|
||||
|
||||
# Localization Loss
|
||||
mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4))
|
||||
smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc
|
||||
loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1)
|
||||
|
||||
# Classification Loss
|
||||
loss_cls = self.class_loss(pred_label, gt_label)
|
||||
loss_cls = self.reduce_sum(loss_cls, (1, 2))
|
||||
|
||||
return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes)
|
||||
|
||||
|
||||
class TrainingWrapper(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of retinanet network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = ms.ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, *args):
|
||||
weights = self.weights
|
||||
loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class retinanetInferWithDecoder(nn.Cell):
|
||||
"""
|
||||
retinanet Infer wrapper to decode the bbox locations.
|
||||
|
||||
Args:
|
||||
network (Cell): the origin retinanet infer network without bbox decoder.
|
||||
default_boxes (Tensor): the default_boxes from anchor generator
|
||||
config (dict): retinanet config
|
||||
Returns:
|
||||
Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1)
|
||||
Tensor, the prediction labels.
|
||||
|
||||
"""
|
||||
def __init__(self, network, default_boxes, config):
|
||||
super(retinanetInferWithDecoder, self).__init__()
|
||||
self.network = network
|
||||
self.default_boxes = default_boxes
|
||||
self.prior_scaling_xy = config.prior_scaling[0]
|
||||
self.prior_scaling_wh = config.prior_scaling[1]
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
pred_loc, pred_label = self.network(x)
|
||||
|
||||
default_bbox_xy = self.default_boxes[..., :2]
|
||||
default_bbox_wh = self.default_boxes[..., 2:]
|
||||
pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy
|
||||
pred_wh = P.Exp()(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh
|
||||
|
||||
pred_xy_0 = pred_xy - pred_wh / 2.0
|
||||
pred_xy_1 = pred_xy + pred_wh / 2.0
|
||||
pred_xy = P.Concat(-1)((pred_xy_0, pred_xy_1))
|
||||
pred_xy = P.Maximum()(pred_xy, 0)
|
||||
pred_xy = P.Minimum()(pred_xy, 1)
|
||||
return pred_xy, pred_label
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Train retinanet and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor, Callback
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.retinahead import retinanetWithLossCell, TrainingWrapper, retinahead
|
||||
from src.backbone import resnet101
|
||||
from src.config import config
|
||||
from src.dataset import create_retinanet_dataset, create_mindrecord
|
||||
from src.lr_schedule import get_lr
|
||||
from src.init_params import init_net_param, filter_checkpoint_parameter
|
||||
|
||||
|
||||
set_seed(1)
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
|
||||
"""
|
||||
|
||||
def __init__(self, lr_init=None):
|
||||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
print("lr:[{:8.6f}]".format(self.lr_init[cb_params.cur_step_num-1]), flush=True)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="retinanet training")
|
||||
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False,
|
||||
help="If set it true, only create Mindrecord, default is False.")
|
||||
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is False.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.")
|
||||
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
|
||||
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 1.")
|
||||
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||
help="Filter weight parameters, default is False.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform, only support Ascend.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.run_platform == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if args_opt.distribute:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
init()
|
||||
device_num = args_opt.device_num
|
||||
rank = get_rank()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
mindrecord_file = create_mindrecord(args_opt.dataset, "retina2.mindrecord", True)
|
||||
|
||||
if not args_opt.only_create_dataset:
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0.
|
||||
dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
||||
|
||||
backbone = resnet101(config.num_classes)
|
||||
retinanet = retinahead(backbone, config)
|
||||
net = retinanetWithLossCell(retinanet, config)
|
||||
net.to_float(mindspore.float16)
|
||||
init_net_param(net)
|
||||
|
||||
if args_opt.pre_trained:
|
||||
if args_opt.pre_trained_epoch_size <= 0:
|
||||
raise KeyError("pre_trained_epoch_size must be greater than 0.")
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if args_opt.filter_weight:
|
||||
filter_checkpoint_parameter(param_dict)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
lr = Tensor(get_lr(global_step=config.global_step,
|
||||
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
|
||||
warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2,
|
||||
warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4,
|
||||
warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size,
|
||||
steps_per_epoch=dataset_size))
|
||||
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
|
||||
config.momentum, config.weight_decay, loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
model = Model(net)
|
||||
print("Start train retinanet, the first epoch will be slower because of the graph compilation.")
|
||||
cb = [TimeMonitor(), LossMonitor()]
|
||||
cb += [Monitor(lr_init=lr.asnumpy())]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck)
|
||||
if args_opt.distribute:
|
||||
if rank == 0:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
else:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue