!22675 update ghostnet retinanet_resnet101/152 with master in r1.3

Merge pull request !22675 from Shawny/add_1.3
This commit is contained in:
i-robot 2021-09-01 01:52:00 +00:00 committed by Gitee
commit 5bc6619f4f
37 changed files with 11213 additions and 435 deletions

View File

@ -0,0 +1,256 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [概述](#概述)
- [论文](#论文)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本结构与说明](#脚本结构与说明)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#用法)
- [Ascend处理器环境运行](#Ascend处理器环境运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [用法](#用法-1)
- [Ascend处理器环境运行](#Ascend处理器环境运行-1)
- [结果](#结果-1)
- [推理过程](#推理过程)
- [导出MindIR](#导出MindIR)
- [结果](#结果)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# GhostNet描述
## 概述
GhostNet由华为诺亚方舟实验室在2020年提出此网络提供了一个全新的Ghost模块旨在通过廉价操作生成更多的特征图。基于一组原始的特征图作者应用一系列线性变换以很小的代价生成许多能从原始特征发掘所需信息的“幻影”特征图Ghost feature maps。该Ghost模块即插即用通过堆叠Ghost模块得出Ghost bottleneck进而搭建轻量级神经网络——GhostNet。该架构可以在同样精度下速度和计算量均少于SOTA算法。
如下为MindSpore使用ImageNet2012数据集对GhostNet进行训练的示例。
## 论文
1. [论文](https://arxiv.org/pdf/1911.11907.pdf): Kai Han, Yunhe Wang, Qi Tian."GhostNet: More Features From Cheap Operations"
# 模型架构
GhostNet的总体网络架构如下[链接](https://arxiv.org/pdf/1911.11907.pdf)
# 数据集
使用的数据集:[ImageNet2012](http://www.image-net.org/)
- 数据集大小共1000个类、224*224彩色图像
- 训练集共1,281,167张图像
- 测试集共50,000张图像
- 数据格式JPEG
- 注数据在dataset.py中处理。
- 下载数据集,目录结构如下:
```text
└─dataset
├─ilsvrc # 训练数据集
└─validation_preprocess # 评估数据集
```
# 环境要求
- 硬件
- 准备Ascend处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```Shell
# 分布式训练
用法sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练
用法sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 运行评估示例
用法sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
# 脚本说明
## 脚本结构与说明
```text
└──ghostnet
├── README.md
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练8卡
├── run_eval.sh # 启动Ascend评估
└── run_standalone_train.sh # 启动Ascend单机训练单卡
├── src
├── config.py # 参数配置
├── dataset.py # 数据预处理
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
├── lr_generator.py # 生成每个步骤的学习率
└── ghostnet.py # ghostnet网络
├── eval.py # 评估网络
└── train.py # 训练网络
```
# 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置GhostNet和ImageNet2012数据集。
```Python
"num_classes": 1000, # 数据集类数
"batch_size": 128, # 输入张量的批次大小
"epoch_size": 500, # 训练周期大小
"warmup_epochs": 20, # 热身周期数
"lr_init": 0.1, # 基础学习率
"lr_max": 0.4, # 最大学习率
'lr_end': 1e-6, # 最终学习率
'lr_decay_mode': 'cosine', # 用于生成学习率的衰减模式
"momentum": 0.9, # 动量优化器
"weight_decay": 4e-5, # 权重衰减
"label_smooth": 0.1, # 标签平滑因子
"loss_scale": 128, # 损失等级
"use_label_smooth": True, # 标签平滑
"label_smooth_factor": 0.1, # 标签平滑因子
"save_checkpoint": True, # 是否保存检查点
"save_checkpoint_epochs": 20, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
"keep_checkpoint_max": 10, # 只保存最后一个keep_checkpoint_max检查点
"save_checkpoint_path": "./", # 检查点相对于执行路径的保存路径
```
# 训练过程
## 用法
### Ascend处理器环境运行
```Shell
# 分布式训练
用法sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练
用法sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
```
分布式训练需要提前创建JSON格式的HCCL配置文件。
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
训练结果保存在示例路径中文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果如下所示。
## 结果
- 使用ImageNet2012数据集训练GhostNet
```text
# 分布式训练结果8P
epoch: 1 step: 1251, loss is 5.001419
epoch time: 457012.100 ms, per step time: 365.317 ms
epoch: 2 step: 1251, loss is 4.275552
epoch time: 280175.784 ms, per step time: 223.961 ms
epoch: 3 step: 1251, loss is 4.0788813
epoch time: 280134.943 ms, per step time: 223.929 ms
epoch: 4 step: 1251, loss is 4.0310946
epoch time: 280161.342 ms, per step time: 223.950 ms
epoch: 5 step: 1251, loss is 3.7326777
epoch time: 280178.602 ms, per step time: 223.964 ms
...
```
# 评估过程
## 用法
### Ascend处理器环境运行
```Shell
# 评估
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
```Shell
# 评估示例
sh run_eval.sh /data/dataset/ImageNet/imagenet_original ghostnet-500_1251.ckpt
```
训练过程中可以生成检查点。
## 结果
评估结果保存在示例路径中文件夹名为“eval”。您可在此路径下的日志找到如下结果
- 使用ImageNet2012数据集评估GhostNet
```text
result: {'top_5_accuracy': 0.9162371134020618, 'top_1_accuracy': 0.739368556701031}
ckpt = /home/lzu/ghost_Mindspore/scripts/device0/ghostnet-500_1251.ckpt
```
# 推理过程
## [导出MindIR](#contents)
```shell
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
参数ckpt_file为必填项
`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。
## 结果
导出“.mindir”文件可在当前目录查看
# 模型描述
## 性能
### 评估性能
| 参数 | Ascend 910 |
|---|---|
| 模型版本 | GhostNet |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 |2021-06-22 ; |
| MindSpore版本 | 1.2.0 |
| 数据集 | ImageNet2012 |
| 训练参数 | epoch=500, steps per epoch=1251, batch_size = 128 |
| 优化器 | Momentum |
| 损失函数 |Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 1.7887309 |
|速度|223.92毫秒/步8卡 |
|总时长 | 39小时 |
|参数(M) | 5.18 |
| 微调检查点 | 42.05M.ckpt文件 |
| 脚本 | [链接](https://gitee.com/alreadyhad/mindspore/tree/master/model_zoo/research/cv/ghostnet) |
# 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子同时还使用了train.py中的随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -1,145 +0,0 @@
# Contents
- [GhostNet Description](#ghostnet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
## [GhostNet Description](#contents)
The GhostNet architecture is based on an Ghost module structure which generate more features from cheap operations. Based on a set of intrinsic feature maps, a series of cheap operations are applied to generate many ghost feature maps that could fully reveal information underlying intrinsic features.
[Paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Han_GhostNet_More_Features_From_Cheap_Operations_CVPR_2020_paper.pdf): Kai Han, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, Chang Xu. GhostNet: More Features from Cheap Operations. CVPR 2020.
## [Model architecture](#contents)
The overall network architecture of GhostNet is show below:
[Link](https://openaccess.thecvf.com/content_CVPR_2020/papers/Han_GhostNet_More_Features_From_Cheap_Operations_CVPR_2020_paper.pdf)
## [Dataset](#contents)
Dataset used: [Oxford-IIIT Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/)
- Dataset size: 7049 colorful images in 1000 classes
- Train: 3680 images
- Test: 3369 images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
## [Environment Requirements](#contents)
- HardwareAscend/GPU)
- Prepare hardware environment with Ascend or GPU.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/r1.3/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/r1.3/index.html)
## [Script description](#contents)
### [Script and sample code](#contents)
```python
├── GhostNet
├── Readme.md # descriptions about ghostnet # shell script for evaluation with CPU, GPU or Ascend
├── src
│ ├──config.py # parameter configuration
│ ├──dataset.py # creating dataset
│ ├──launch.py # start python script
│ ├──lr_generator.py # learning rate config
│ ├──ghostnet.py # GhostNet architecture
│ ├──ghostnet600.py # GhostNet-600M architecture
├── eval.py # evaluation script
├── mindspore_hub_conf.py # export model for hub
```
## [Training process](#contents)
To Be Done
## [Eval process](#contents)
### Usage
After installing MindSpore via the official website, you can start evaluation as follows:
### Launch
```bash
# infer example
Ascend: python eval.py --model [ghostnet/ghostnet-600] --dataset_path ~/Pets/test.mindrecord --platform Ascend --checkpoint_path [CHECKPOINT_PATH]
GPU: python eval.py --model [ghostnet/ghostnet-600] --dataset_path ~/Pets/test.mindrecord --platform GPU --checkpoint_path [CHECKPOINT_PATH]
```
> checkpoint can be produced in training process.
### Result
```bash
result: {'acc': 0.8113927500681385} ckpt= ./ghostnet_nose_1x_pets.ckpt
result: {'acc': 0.824475333878441} ckpt= ./ghostnet_1x_pets.ckpt
result: {'acc': 0.8691741618969746} ckpt= ./ghostnet600M_pets.ckpt
```
## [Model Description](#contents)
### [Performance](#contents)
#### Evaluation Performance
##### GhostNet on ImageNet2012
| Parameters | | |
| -------------------------- | -------------------------------------- |---------------------------------- |
| Model Version | GhostNet |GhostNet-600|
| uploaded Date | 09/08/2020 (month/day/year) | 09/08/2020 (month/day/year) |
| MindSpore Version | 0.6.0-alpha |0.6.0-alpha |
| Dataset | ImageNet2012 | ImageNet2012|
| Parameters (M) | 5.2 | 11.9 |
| FLOPs (M) | 142 | 591 |
| Accuracy (Top1) | 73.9 |80.2 |
###### GhostNet on Oxford-IIIT Pet
| Parameters | | |
| -------------------------- | -------------------------------------- |---------------------------------- |
| Model Version | GhostNet |GhostNet-600|
| uploaded Date | 09/08/2020 (month/day/year) | 09/08/2020 (month/day/year) |
| MindSpore Version | 0.6.0-alpha |0.6.0-alpha |
| Dataset | Oxford-IIIT Pet | Oxford-IIIT Pet|
| Parameters (M) | 3.9 | 10.6 |
| FLOPs (M) | 140 | 590 |
| Accuracy (Top1) | 82.4 |86.9 |
###### Comparison with other methods on Oxford-IIIT Pet
|Model|FLOPs (M)|Latency (ms)*|Accuracy (Top1)|
|-|-|-|-|
|MobileNetV2-1x|300|28.2|78.5|
|Ghost-1x w\o SE|138|19.1|81.1|
|Ghost-1x|140|25.3|82.4|
|Ghost-600|590|-|86.9|
*The latency is measured on Huawei Kirin 990 chip under single-threaded mode with batch size 1.
## [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,56 +21,25 @@ from mindspore import context
from mindspore import nn
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype
from src.dataset import create_dataset
from src.config import config_ascend, config_gpu
from src.ghostnet import ghostnet_1x, ghostnet_nose_1x
from src.ghostnet600 import ghostnet_600m
from src.ghostnet import ghostnet_1x
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--platform', type=str, default=None, help='run platform')
parser.add_argument('--model', type=str, default=None, help='ghostnet')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
if __name__ == '__main__':
config_platform = None
if args_opt.platform == "Ascend":
config_platform = config_ascend
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
elif args_opt.platform == "GPU":
config_platform = config_gpu
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False)
else:
raise ValueError("Unsupported platform.")
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
if args_opt.model == 'ghostnet':
net = ghostnet_1x(num_classes=config_platform.num_classes)
elif args_opt.model == 'ghostnet_nose':
net = ghostnet_nose_1x(num_classes=config_platform.num_classes)
elif args_opt.model == 'ghostnet-600':
net = ghostnet_600m(num_classes=config_platform.num_classes)
net = ghostnet_1x()
if args_opt.platform == "Ascend":
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,
config=config_platform,
platform=args_opt.platform,
batch_size=config_platform.batch_size,
model=args_opt.model)
dataset = create_dataset(dataset_path=args_opt.data_url, do_train=False)
step_size = dataset.get_dataset_size()
if args_opt.checkpoint_path:
@ -78,6 +47,6 @@ if __name__ == '__main__':
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'acc'})
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" export MINDIR """
import argparse as arg
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, export, load_checkpoint
from src.ghostnet import ghostnet_1x
from src.config import config
if __name__ == '__main__':
parser = arg.ArgumentParser(description='SID export')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device where the code will be implemented')
parser.add_argument('--device_id', type=int, default=0, help='device id')
parser.add_argument('--file_format', type=str, choices=['AIR', 'MINDIR'], default='MINDIR',
help='file format')
parser.add_argument('--checkpoint_path', required=True, default=None, help='ckpt file path')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == 'Ascend':
context.set_context(device_id=args.device_id)
ckpt_dir = args.checkpoint_path
net = ghostnet_1x(num_classes=config.num_classes)
load_checkpoint(ckpt_dir, net=net)
net.set_train(False)
input_data = Tensor(np.zeros([1, 3, 224, 224]), ms.float32)
export(net, input_data, file_name='ghost', file_format=args.file_format)

View File

@ -0,0 +1,90 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH PRETRAINED_CKPT_PATH](optional)"
echo "For example: bash run_distribute_train.sh hccl_8p_01234567_127.0.0.1.json /path/dataset"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ $# == 3 ]
then
PATH3=$(get_real_path $3)
fi
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: DATA_PATH=$PATH2 is not a directory"
exit 1
fi
if [ $# == 3 ] && [ ! -f $PATH3 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
DATA_PATH=$2
export DATA_PATH=${DATA_PATH}
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cp ../*.py ./device$i
cp *.sh ./device$i
cp -r ../src ./device$i
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
if [ $# == 2 ]
then
python train.py --run_distribute=True --data_url=$PATH2 &> train.log &
fi
if [ $# == 3 ]
then
python train.py --run_distribute=True --data_url=$PATH2 --pre_trained=$PATH3 &> train.log &
fi
cd ../
done

View File

@ -0,0 +1,64 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval.sh DATA_PATH CHECKPOINT_PATH "
echo "For example: bash run.sh /path/dataset ghostnet-500_1251.ckpt"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --data_url=$PATH1 --checkpoint_path=$PATH2 &> eval.log &
cd ..

View File

@ -0,0 +1,77 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_standalone_train.sh DATA_PATH PRETRAINED_CKPT_PATH(optional)"
echo "For example: bash run_standalone_train.sh /path/dataset"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ $# == 2 ]
then
PATH2=$(get_real_path $2)
fi
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ $# == 2 ] && [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 1 ]
then
python train.py --data_url=$PATH1 &> train.log &
fi
if [ $# == 2 ]
then
python train.py --data_url=$PATH1 --pre_trained=$PATH2 &> train.log &
fi
cd ..

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,38 +17,23 @@ network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config_ascend = ed({
"num_classes": 37,
"image_height": 224,
"image_width": 224,
"batch_size": 256,
"epoch_size": 200,
"warmup_epochs": 4,
"lr": 0.4,
config = ed({
"num_classes": 1000,
"batch_size": 128,
"epoch_size": 500,
"warmup_epochs": 20,
"lr_init": 0.1,
"lr_max": 0.4,
'lr_end': 1e-6,
'lr_decay_mode': 'cosine',
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"loss_scale": 128,
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
})
config_gpu = ed({
"num_classes": 37,
"image_height": 224,
"image_width": 224,
"batch_size": 3,
"epoch_size": 370,
"warmup_epochs": 4,
"lr": 0.4,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 500,
"save_checkpoint_path": "./checkpoint",
"save_checkpoint_epochs": 20,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
})

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,99 +12,83 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
"""Data operations, will be used in train.py and eval.py"""
import os
from src.config import config
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.vision.py_transforms as P
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.dataset.transforms.vision import Inter
import mindspore.dataset.vision.c_transforms as C
from mindspore.communication.management import get_rank, get_group_size
def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=100, model='ghsotnet'):
def create_dataset(dataset_path, do_train, target="Ascend"):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided into (default=None).
repeat_num(int): the repeat times of dataset. Default: 1.
Returns:
dataset
"""
if platform == "Ascend":
rank_size = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
if rank_size == 1:
data_set = ds.MindDataset(
dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU":
if do_train:
from mindspore.communication.management import get_rank, get_group_size
data_set = ds.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
data_set = ds.MindDataset(
dataset_path, num_parallel_workers=8, shuffle=True)
if not do_train:
dataset_path = os.path.join(dataset_path, 'val')
else:
raise ValueError("Unsupported platform.")
dataset_path = os.path.join(dataset_path, 'train')
if target == "Ascend":
device_num, rank_id = _get_rank_info()
resize_height = config.image_height
buffer_size = 1000
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
resize_crop_op = C.RandomCropDecodeResize(
resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
color_op = C.RandomColorAdjust(
brightness=0.4, contrast=0.4, saturation=0.4)
rescale_op = C.Rescale(1 / 255.0, 0)
normalize_op = C.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
change_swap_op = C.HWC2CHW()
# define python operations
decode_p = P.Decode()
if model == 'ghostnet-600':
s = 274
c = 240
else:
s = 256
c = 224
resize_p = P.Resize(s, interpolation=Inter.BICUBIC)
center_crop_p = P.CenterCrop(c)
totensor = P.ToTensor()
normalize_p = P.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
composeop = P.ComposeOp(
[decode_p, resize_p, center_crop_p, totensor, normalize_p])
if do_train:
trans = [resize_crop_op, horizontal_flip_op, color_op,
rescale_op, normalize_op, change_swap_op]
trans = [
C.RandomCropDecodeResize(224),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
]
else:
trans = composeop()
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(224),
]
trans += [
C.Normalize(mean=mean, std=std),
C.HWC2CHW(),
]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(input_columns="image", operations=trans,
num_parallel_workers=8)
data_set = data_set.map(input_columns="label_list",
operations=type_cast_op, num_parallel_workers=8)
# apply shuffle operations
data_set = data_set.shuffle(buffer_size=buffer_size)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
ds = ds.batch(config.batch_size, drop_remainder=True)
return ds
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -46,6 +46,7 @@ class MyHSigmoid(nn.Cell):
self.relu6 = nn.ReLU6()
def construct(self, x):
""" construct """
return self.relu6(x + 3.) * 0.16666667
@ -74,6 +75,7 @@ class Activation(nn.Cell):
raise NotImplementedError
def construct(self, x):
""" construct """
return self.act(x)
@ -95,6 +97,7 @@ class GlobalAvgPooling(nn.Cell):
self.mean = P.ReduceMean(keep_dims=keep_dims)
def construct(self, x):
""" construct """
x = self.mean(x, (2, 3))
return x
@ -127,6 +130,7 @@ class SE(nn.Cell):
self.mul = P.Mul()
def construct(self, x):
""" construct of SE module """
out = self.pool(x)
out = self.conv_reduce(out)
out = self.act1(out)
@ -173,6 +177,7 @@ class ConvUnit(nn.Cell):
self.act = Activation(act_type) if use_act else None
def construct(self, x):
""" construct of conv unit """
out = self.conv(x)
out = self.bn(out)
if self.use_act:
@ -209,12 +214,14 @@ class GhostModule(nn.Cell):
new_channels = init_channels * (ratio - 1)
self.primary_conv = ConvUnit(num_in, init_channels, kernel_size=kernel_size, stride=stride, padding=padding,
num_groups=1, use_act=use_act, act_type='relu')
self.cheap_operation = ConvUnit(init_channels, new_channels, kernel_size=dw_size, stride=1, padding=dw_size//2,
num_groups=init_channels, use_act=use_act, act_type='relu')
num_groups=1, use_act=use_act, act_type=act_type)
self.cheap_operation = ConvUnit(init_channels, new_channels, kernel_size=dw_size, stride=1,
padding=dw_size // 2, num_groups=init_channels,
use_act=use_act, act_type=act_type)
self.concat = P.Concat(axis=1)
def construct(self, x):
""" ghost module construct """
x1 = self.primary_conv(x)
x2 = self.cheap_operation(x1)
return self.concat((x1, x2))
@ -269,10 +276,10 @@ class GhostBottleneck(nn.Cell):
ConvUnit(num_in, num_out, kernel_size=1, stride=1,
padding=0, num_groups=1, use_act=False),
])
self.add = P.Add()
self.add = P.TensorAdd()
def construct(self, x):
r"""construct of ghostnet"""
""" construct of ghostnet """
shortcut = x
out = self.ghost1(x)
if self.use_dw:
@ -318,7 +325,7 @@ class GhostNet(nn.Cell):
>>> GhostNet(num_classes=1000)
"""
def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8):
def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0.):
super(GhostNet, self).__init__()
self.cfgs = model_cfgs['cfg']
self.inplanes = 16
@ -365,7 +372,7 @@ class GhostNet(nn.Cell):
self._initialize_weights()
def construct(self, x):
r"""construct of GhostNet"""
""" construct of GhostNet """
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
@ -403,21 +410,21 @@ class GhostNet(nn.Cell):
for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape).astype("float32")))
m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(
m.bias.set_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data(
m.gamma.set_data(
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data(
m.beta.set_data(
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal(
m.weight.set_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(
m.bias.set_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,8 +16,7 @@
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
@ -47,9 +46,6 @@ def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, st
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
return lr_each_step

View File

@ -0,0 +1,141 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train.
"""
import os
import argparse
import ast
import mindspore.common.initializer as weight_init
from mindspore import context
from mindspore import nn
from mindspore import Tensor
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype
from mindspore.common import set_seed
from mindspore.nn.optim.momentum import Momentum
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.dataset import create_dataset
from src.config import config
from src.ghostnet import ghostnet_1x
parser = argparse.ArgumentParser(description='Image classification--GhostNet')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
args_opt = parser.parse_args()
set_seed(1)
if __name__ == '__main__':
# init context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
save_graphs=False)
if args_opt.run_distribute:
device_id = int(os.getenv('DEVICE_ID'))
rank_size = int(os.environ.get("RANK_SIZE", 1))
print(rank_size)
device_num = rank_size
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
args_opt.rank = get_rank()
# select for master rank save ckpt or all rank save, compatible for model parallel
args_opt.rank_save_ckpt_flag = 0
if args_opt.is_save_on_master:
if args_opt.rank == 0:
args_opt.rank_save_ckpt_flag = 1
else:
args_opt.rank_save_ckpt_flag = 1
# define net
net = ghostnet_1x(num_classes=config.num_classes)
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
local_data_path = args_opt.data_url
print('Download data:')
dataset = create_dataset(dataset_path=local_data_path,
do_train=True,
target="Ascend")
step_size = dataset.get_dataset_size()
print('steps:', step_size)
# init weight
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
else:
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.HeUniform(),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense):
cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(),
cell.weight.shape,
cell.weight.dtype))
# init lr
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end,
lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size, steps_per_epoch=step_size)
lr = Tensor(lr)
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
opt = Momentum(net.trainable_params(), lr, config.momentum, loss_scale=config.loss_scale,
weight_decay=config.weight_decay)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
metrics={'top_1_accuracy', 'top_5_accuracy'},
amp_level="O3", keep_batchnorm_fp32=False)
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
if args_opt.rank_save_ckpt_flag:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="ghostnet", directory=config.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
# train model
model.train(config.epoch_size, dataset, callbacks=cb,
sink_size=dataset.get_dataset_size())

View File

@ -2,7 +2,7 @@
<!-- TOC -->
- <span id="content">[Retinanet 描述](#-Retinanet-描述)</span>
- [Retinanet 描述](#retinanet描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
@ -14,9 +14,16 @@
- [运行](#运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [用法](#usage)
- [运行](#running)
- [结果](#outcome)
- [用法](#用-法)
- [运行](#运-行)
- [结果](#结-果)
- [模型导出](#模型导出)
- [用法](#具体用法)
- [运行](#运行命令)
- [推理过程](#推理过程)
- [用法](#用途)
- [运行](#运行方式)
- [结果](#运行结果)
- [模型说明](#模型说明)
- [性能](#性能)
- [训练性能](#训练性能)
@ -26,7 +33,7 @@
<!-- /TOC -->
## [Retinanet 描述](#content)
## [Retinanet描述](#content)
RetinaNet算法源自2018年Facebook AI Research的论文 Focal Loss for Dense Object Detection。该论文最大的贡献在于提出了Focal Loss用于解决类别不均衡问题从而创造了RetinaNetOne Stage目标检测算法这个精度超越经典Two Stage的Faster-RCNN的目标检测网络。
@ -60,10 +67,10 @@ MSCOCO2017
- 硬件Ascend
- 使用Ascend处理器准备硬件环境。
- 架构
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install)
- 想要获取更多信息,请检查以下资源:
- [MindSpore 教程](https://www.mindspore.cn/tutorials/zh-CN/r1.3/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/r1.3/index.html)
- [MindSpore 教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
## [脚本说明](#content)
@ -73,6 +80,7 @@ MSCOCO2017
.
└─Retinanet_resnet101
├─README.md
├─ascend310_infer # 实现310推理源代码
├─scripts
├─run_single_train.sh # 使用Ascend环境单卡训练
├─run_distribute_train.sh # 使用Ascend环境八卡并行训练
@ -168,17 +176,17 @@ MSCOCO2017
# 八卡并行训练示例:
创建 RANK_TABLE_FILE
sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
bash run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
# 单卡训练示例:
sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
bash run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
```
> 注意:
RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.3/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
#### 运行
@ -196,12 +204,12 @@ sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional)
# 八卡并行训练示例(在retinanet目录下运行)
sh scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE预训练EPOCH大小
例如sh scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322
bash scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE预训练EPOCH大小
例如:bash scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322
# 单卡训练示例(在retinanet目录下运行)
sh scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322
bash scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322
```
@ -226,15 +234,15 @@ epoch time: 314138.455 ms, per step time: 685.892 ms
### [评估过程](#content)
#### <span id="usage">用法</span>
#### 用 法
您可以使用python或shell脚本进行训练。shell脚本的用法如下:
```eval
sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
bash scripts/run_eval.sh [DATASET] [DEVICE_ID]
```
#### <span id="running">运行</span>
#### 运 行
```eval运行
# 验证示例
@ -243,12 +251,12 @@ sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
Ascend: python eval.py
checkpoint 的路径在config里设置
shell:
Ascend: sh scripts/run_eval.sh coco 0
Ascend: bash scripts/run_eval.sh coco 0
```
> checkpoint 可以在训练过程中产生.
#### <span id="outcome">结果</span>
#### 结 果
计算结果将存储在示例路径中,您可以在 `eval.log` 查看.
@ -271,6 +279,83 @@ sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
mAP: 0.3710347196613514
```
### [模型导出](#content)
#### 具体用法
导出模型前要修改config.py文件中的checkpoint_path配置项值为checkpoint的路径。
```shell
python export.py --file_name [RUN_PLATFORM] --file_format[EXPORT_FORMAT] --checkpoint_path [CHECKPOINT PATH]
```
`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"]
#### 运行命令
```运行
python export.py
```
- 在modelarts上导出MindIR
```Modelarts
在ModelArts上导出MindIR示例
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
# a. 设置 "enable_modelarts=True"
# 设置 "file_name=retinanet"
# 设置 "file_format=MINDIR"
# 设置 "checkpoint_path=/cache/data/checkpoint/checkpoint file name"
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
# 在modelarts的界面上设置方法a所需要的参数
# 注意:路径参数不需要加引号
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
# (3) 在modelarts的界面上设置代码的路径 "/path/retinanet"。
# (4) 在modelarts的界面上设置模型的启动文件 "export.py" 。
# (5) 在modelarts的界面上设置模型的数据路径 ".../MindRecord_COCO"(选择MindRecord_COCO文件夹路径) ,
# MindIR的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
```
### [推理过程](#content)
#### 用途
在推理之前需要在昇腾910环境上完成模型的导出。推理时要将iscrowd为true的图片排除掉。在ascend310_infer目录下保存了去排除后的图片id。
还需要修改config.py文件中的coco_root、val_data_type、instances_set配置项值分别取coco数据集的目录推理所用数据集的目录名称推理完成后计算精度用的annotation文件instances_set是用val_data_type拼接起来的要保证文件正确并且存在。
```shell
# Ascend310 inference
sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANN_FILE] [DEVICE_ID]
```
#### 运行方式
```运行
bash run_infer_310.sh [MINDIR_PATH] [DATASET_NAME] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
```
#### 运行结果
推理的结果保存在当前目录下在acc.log日志文件中可以找到类似以下的结果。
```mAP
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.369
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.520
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.404
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.146
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.391
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.535
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.316
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.431
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.433
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.162
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.459
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.633
mAP: 0.36858371862143824
```
## [模型说明](#content)
### [性能](#content)
@ -312,4 +397,10 @@ mAP: 0.3710347196613514
# [ModelZoo 主页](#内容)
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
# FAQ
优先参考[ModelZoo FAQ](https://gitee.com/mindspore/mindspore/tree/master/model_zoo#FAQ)来查找一些常见的公共问题。
- **Q: 使用PYNATIVE_MODE发生内存溢出怎么办** **A**内存溢出通常是因为PYNATIVE_MODE需要更多的内存 将batch size设置为16降低内存消耗可进行网络训练。

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -1,4 +1,5 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,16 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.ghostnet import ghostnet_1x, ghostnet_nose_1x
from src.ghostnet600 import ghostnet_600m
def create_network(name, *args, **kwargs):
if name == 'ghostnet':
return ghostnet_1x(*args, **kwargs)
if name == 'ghostnet_nose':
return ghostnet_nose_1x(*args, **kwargs)
if name == 'ghostnet-600':
return ghostnet_600m(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")
if [ ! -d out ]; then
mkdir out
fi
cd out || exit
cmake .. \
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif

View File

@ -0,0 +1,153 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include "../inc/utils.h"
#include "include/dataset/execute.h"
#include "include/dataset/transforms.h"
#include "include/dataset/vision.h"
#include "include/dataset/vision_ascend.h"
#include "include/api/types.h"
#include "include/api/model.h"
#include "include/api/serialization.h"
#include "include/api/context.h"
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Context;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::Graph;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::DataType;
using mindspore::dataset::Execute;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::Decode;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::HWC2CHW;
DEFINE_string(model_path, "", "model path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_int32(device_id, 0, "device id");
DEFINE_string(precision_mode, "allow_fp32_to_fp16", "precision mode");
DEFINE_string(op_select_impl_mode, "high_precision", "op impl mode");
DEFINE_string(buffer_optimize_mode, "off_optimize", "buffer optimize mode");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_model_path).empty()) {
std::cout << "Invalid model" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310_info->SetDeviceID(FLAGS_device_id);
ascend310_info->SetPrecisionMode(FLAGS_precision_mode);
ascend310_info->SetOpSelectImplMode(FLAGS_op_select_impl_mode);
ascend310_info->SetBufferOptimizeMode(FLAGS_buffer_optimize_mode);
context->MutableDeviceInfo().push_back(ascend310_info);
Graph graph;
Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph);
if (ret != kSuccess) {
std::cout << "Load model failed." << std::endl;
return 1;
}
Model model;
ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> modelInputs = model.GetInputs();
auto all_files = GetAllFiles(FLAGS_dataset_path);
if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return 1;
}
auto decode = Decode();
auto resize = Resize({600, 600});
auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.12, 57.375});
auto hwc2chw = HWC2CHW();
mindspore::dataset::Execute transform({decode, resize, normalize, hwc2chw});
std::map<double, double> costTime_map;
size_t size = all_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start;
struct timeval end;
double startTime_ms;
double endTime_ms;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i] << std::endl;
mindspore::MSTensor image = ReadFileToTensor(all_files[i]);
transform(image, &image);
inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(),
image.Data().get(), image.DataSize());
gettimeofday(&start, NULL);
model.Predict(inputs, &outputs);
gettimeofday(&end, NULL);
startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
WriteResult(all_files[i], outputs);
}
double average = 0.0;
int infer_cnt = 0;
char tmpCh[256] = {0};
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
infer_cnt++;
}
average = average/infer_cnt;
snprintf(tmpCh, sizeof(tmpCh), "NN inference cost average time: %4.3f ms of infer_count %d\n", average, infer_cnt);
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl;
std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
file_stream << tmpCh;
file_stream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -12,42 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export file"""
"""export for retinanet"""
import argparse
import numpy as np
from mindspore import dtype as mstype
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.retinahead import retinahead
from src.backbone import resnet101
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.retinahead import retinahead, retinanetInferWithDecoder
from src.config import config
from src.box_utils import default_boxes
from src.backbone import resnet101
parser = argparse.ArgumentParser(description="retinanet_resnet101 export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="retinanet_resnet101", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='retinanet evaluation')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
help="run platform, only support Ascend.")
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--file_name", type=str, default="retinanet", help="output file name.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
if __name__ == "__main__":
network = retinahead(backbone=resnet101(80), config=config, is_training=False)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, param_dict)
network.set_train(False)
shape = [args.batch_size, 3] + [600, 600]
backbone = resnet101(config.num_classes)
net = retinahead(backbone, config)
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
param_dict = load_checkpoint(config.checkpoint_path)
net.init_parameters_data()
load_param_into_net(net, param_dict)
net.set_train(False)
shape = [args_opt.batch_size, 3] + config.img_shape
input_data = Tensor(np.zeros(shape), mstype.float32)
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
export(net, input_data, file_name=args_opt.file_name, file_format=args_opt.file_format)

View File

@ -0,0 +1,76 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for retinanet"""
import os
import argparse
import numpy as np
from PIL import Image
from src.coco_eval import metrics
from src.config import config
parser = argparse.ArgumentParser(description='retinanet evaluation')
parser.add_argument("--result_path", type=str, required=True, help="result file path.")
parser.add_argument("--img_path", type=str, required=True, help="image file path.")
parser.add_argument("--img_id_file", type=str, required=True, help="image id file.")
args = parser.parse_args()
def get_pred(result_path, img_id):
boxes_file = os.path.join(result_path, img_id + '_0.bin')
scores_file = os.path.join(result_path, img_id + '_1.bin')
boxes = np.fromfile(boxes_file, dtype=np.float32).reshape(67995, 4)
scores = np.fromfile(scores_file, dtype=np.float32).reshape(67995, config.num_classes)
return boxes, scores
def get_img_size(file_name):
img = Image.open(file_name)
return img.size
def get_img_id(img_id_file):
f = open(img_id_file)
lines = f.readlines()
ids = []
for line in lines:
ids.append(int(line))
return ids
def cal_acc(result_path, img_path, img_id_file):
"""Calculate acc"""
ids = get_img_id(img_id_file)
imgs = os.listdir(img_path)
pred_data = []
for img in imgs:
img_id = img.split('.')[0]
if int(img_id) not in ids:
continue
boxes, box_scores = get_pred(result_path, img_id)
w, h = get_img_size(os.path.join(img_path, img))
img_shape = np.array((h, w), dtype=np.float32)
pred_data.append({"boxes": boxes,
"box_scores": box_scores,
"img_id": int(img_id),
"image_shape": img_shape})
mAP = metrics(pred_data)
print(f"mAP: {mAP}")
if __name__ == '__main__':
cal_acc(args.result_path, args.img_path, args.img_id_file)

View File

@ -1,4 +0,0 @@
numpy
easydict
opencv-python
pycocotools

View File

@ -0,0 +1,103 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
device_id=0
if [ $# == 3 ]; then
device_id=$3
fi
echo $model
echo $data_path
echo $device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function compile_app()
{
cd ../ascend310_infer || exit
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
cd - || exit
}
function infer()
{
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/out/main --model_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log
if [ $? -ne 0 ]; then
echo "execute inference failed"
exit 1
fi
}
function cal_acc()
{
python ../postprocess.py --result_path=result_Files --img_path=$data_path --img_id_file=../ascend310_infer/image_id.txt &> acc.log
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi
}
compile_app
infer
cal_acc

View File

@ -246,7 +246,8 @@ class TrainingWrapper(nn.Cell):
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
self.optimizer(grads)
return loss
class retinanetInferWithDecoder(nn.Cell):

View File

@ -2,7 +2,7 @@
<!-- TOC -->
- <span id="content">[Retinanet 描述](#-Retinanet-描述)</span>
- [Retinanet 描述](#retinanet描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
@ -14,9 +14,16 @@
- [运行](#运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [用法](#usage)
- [运行](#running)
- [结果](#outcome)
- [用法](#用-法)
- [运行](#运-行)
- [结果](#结-果)
- [模型导出](#模型导出)
- [用途](#用途)
- [运行方式](#运行方式)
- [推理过程](#推理过程)
- [用途](#用-途)
- [运行命令](#运行命令)
- [运行结果](#运行结果)
- [模型说明](#模型说明)
- [性能](#性能)
- [训练性能](#训练性能)
@ -26,7 +33,7 @@
<!-- /TOC -->
## [Retinanet 描述](#content)
## [Retinanet描述](#content)
RetinaNet算法源自2018年Facebook AI Research的论文 Focal Loss for Dense Object Detection。该论文最大的贡献在于提出了Focal Loss用于解决类别不均衡问题从而创造了RetinaNetOne Stage目标检测算法这个精度超越经典Two Stage的Faster-RCNN的目标检测网络。
@ -60,10 +67,10 @@ MSCOCO2017
- 硬件Ascend
- 使用Ascend处理器准备硬件环境。
- 架构
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install)
- 想要获取更多信息,请检查以下资源:
- [MindSpore 教程](https://www.mindspore.cn/tutorials/zh-CN/r1.3/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/r1.3/index.html)
- [MindSpore 教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
## [脚本说明](#content)
@ -168,17 +175,17 @@ MSCOCO2017
# 八卡并行训练示例:
创建 RANK_TABLE_FILE
sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
bash run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
# 单卡训练示例:
sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
bash run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
```
> 注意:
RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.3/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
#### 运行
@ -196,12 +203,12 @@ sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional)
# 八卡并行训练示例(在retinanet目录下运行)
sh scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE预训练EPOCH大小
例如sh scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322
bash scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE预训练EPOCH大小
例如:bash scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322
# 单卡训练示例(在retinanet目录下运行)
sh scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322
bash scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322
```
@ -226,15 +233,15 @@ epoch time: 444237.851 ms, per step time: 484.976 ms
### [评估过程](#content)
#### <span id="usage">用法</span>
#### 用 法
您可以使用python或shell脚本进行训练。shell脚本的用法如下:
```eval
sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
bash scripts/run_eval.sh [DATASET] [DEVICE_ID]
```
#### <span id="running">运行</span>
#### 运 行
```eval运行
# 验证示例
@ -243,12 +250,12 @@ sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
Ascend: python eval.py
checkpoint 的路径在config里设置
shell:
Ascend: sh scripts/run_eval.sh coco 0
Ascend: bash scripts/run_eval.sh coco 0
```
> checkpoint 可以在训练过程中产生.
#### <span id="outcome">结果</span>
#### 结 果
计算结果将存储在示例路径中,您可以在 `eval.log` 查看.
@ -271,6 +278,82 @@ sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
mAP: 0.3571988469737286
```
### [模型导出](#content)
#### 用途
导出模型前要修改config.py文件中的checkpoint_path配置项值为checkpoint的路径。
```shell
python export.py --file_name [RUN_PLATFORM] --file_format[EXPORT_FORMAT] --checkpoint_path [CHECKPOINT PATH]
```
`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"]
#### 运行方式
```运行
python export.py
```
- 在modelarts上导出MindIR
```Modelarts
在ModelArts上导出MindIR示例
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
# a. 设置 "enable_modelarts=True"
# 设置 "file_name=retinanet"
# 设置 "file_format=MINDIR"
# 设置 "checkpoint_path=/cache/data/checkpoint/checkpoint file name"
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
# 在modelarts的界面上设置方法a所需要的参数
# 注意:路径参数不需要加引号
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
# (3) 在modelarts的界面上设置代码的路径 "/path/retinanet"。
# (4) 在modelarts的界面上设置模型的启动文件 "export.py" 。
# (5) 在modelarts的界面上设置模型的数据路径 ".../MindRecord_COCO"(选择MindRecord_COCO文件夹路径) ,
# MindIR的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
```
### [推理过程](#content)
#### 用 途
在推理之前需要在昇腾910环境上完成模型的导出。推理时要将iscrowd为true的图片排除掉。在ascend310_infer目录下保存了去排除后的图片id。
还需要修改config.py文件中的coco_root、val_data_type、instances_set配置项值分别取coco数据集的目录推理所用数据集的目录名称推理完成后计算精度用的annotation文件instances_set是用val_data_type拼接起来的要保证文件正确并且存在。
```shell
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANN_FILE] [DEVICE_ID]
```
#### 运行命令
```运行
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
```
#### 运行结果
推理的结果保存在当前目录下在acc.log日志文件中可以找到类似以下的结果。
```mAP
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.356
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.499
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.396
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.145
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.380
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.506
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.308
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.446
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.457
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.179
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.483
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.647
mAP: 0.35625723922139957
```
## [模型说明](#content)
### [性能](#content)
@ -312,4 +395,4 @@ mAP: 0.3571988469737286
# [ModelZoo 主页](#内容)
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,23 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ ! -d out ]; then
mkdir out
fi
cd out || exit
cmake .. \
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif

View File

@ -0,0 +1,153 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include "../inc/utils.h"
#include "include/dataset/execute.h"
#include "include/dataset/transforms.h"
#include "include/dataset/vision.h"
#include "include/dataset/vision_ascend.h"
#include "include/api/types.h"
#include "include/api/model.h"
#include "include/api/serialization.h"
#include "include/api/context.h"
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Context;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::Graph;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::DataType;
using mindspore::dataset::Execute;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::Decode;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::HWC2CHW;
DEFINE_string(model_path, "", "model path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_int32(device_id, 0, "device id");
DEFINE_string(precision_mode, "allow_fp32_to_fp16", "precision mode");
DEFINE_string(op_select_impl_mode, "high_precision", "op impl mode");
DEFINE_string(buffer_optimize_mode, "off_optimize", "buffer optimize mode");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_model_path).empty()) {
std::cout << "Invalid model" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310_info->SetDeviceID(FLAGS_device_id);
ascend310_info->SetPrecisionMode(FLAGS_precision_mode);
ascend310_info->SetOpSelectImplMode(FLAGS_op_select_impl_mode);
ascend310_info->SetBufferOptimizeMode(FLAGS_buffer_optimize_mode);
context->MutableDeviceInfo().push_back(ascend310_info);
Graph graph;
Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph);
if (ret != kSuccess) {
std::cout << "Load model failed." << std::endl;
return 1;
}
Model model;
ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> modelInputs = model.GetInputs();
auto all_files = GetAllFiles(FLAGS_dataset_path);
if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return 1;
}
auto decode = Decode();
auto resize = Resize({640, 640});
auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.12, 57.375});
auto hwc2chw = HWC2CHW();
mindspore::dataset::Execute transform({decode, resize, normalize, hwc2chw});
std::map<double, double> costTime_map;
size_t size = all_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start;
struct timeval end;
double startTime_ms;
double endTime_ms;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i] << std::endl;
mindspore::MSTensor image = ReadFileToTensor(all_files[i]);
transform(image, &image);
inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(),
image.Data().get(), image.DataSize());
gettimeofday(&start, NULL);
model.Predict(inputs, &outputs);
gettimeofday(&end, NULL);
startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
WriteResult(all_files[i], outputs);
}
double average = 0.0;
int infer_cnt = 0;
char tmpCh[256] = {0};
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
infer_cnt++;
}
average = average/infer_cnt;
snprintf(tmpCh, sizeof(tmpCh), "NN inference cost average time: %4.3f ms of infer_count %d\n", average, infer_cnt);
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl;
std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
file_stream << tmpCh;
file_stream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -12,43 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export file"""
"""export for retinanet"""
import argparse
import numpy as np
from mindspore import dtype as mstype
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.retinahead import retinahead
from src.backbone import resnet152
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.retinahead import retinahead, retinanetInferWithDecoder
from src.config import config
from src.box_utils import default_boxes
from src.backbone import resnet152
parser = argparse.ArgumentParser(description="retinanet_resnet152 export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="retinanet_resnet152", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='retinanet evaluation')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
help="run platform, only support Ascend.")
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--file_name", type=str, default="retinanet", help="output file name.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
if __name__ == "__main__":
network = retinahead(backbone=resnet152(80), config=config, is_training=False)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, param_dict)
network.set_train(False)
shape = [args.batch_size, 3] + [640, 640]
backbone = resnet152(config.num_classes)
net = retinahead(backbone, config)
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
param_dict = load_checkpoint(config.checkpoint_path)
net.init_parameters_data()
load_param_into_net(net, param_dict)
net.set_train(False)
shape = [args_opt.batch_size, 3] + config.img_shape
input_data = Tensor(np.zeros(shape), mstype.float32)
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
export(net, input_data, file_name=args_opt.file_name, file_format=args_opt.file_format)

View File

@ -0,0 +1,76 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for retinanet"""
import os
import argparse
import numpy as np
from PIL import Image
from src.coco_eval import metrics
from src.config import config
parser = argparse.ArgumentParser(description='retinanet evaluation')
parser.add_argument("--result_path", type=str, required=True, help="result file path.")
parser.add_argument("--img_path", type=str, required=True, help="image file path.")
parser.add_argument("--img_id_file", type=str, required=True, help="image id file.")
args = parser.parse_args()
def get_pred(result_path, img_id):
boxes_file = os.path.join(result_path, img_id + '_0.bin')
scores_file = os.path.join(result_path, img_id + '_1.bin')
boxes = np.fromfile(boxes_file, dtype=np.float32).reshape(76725, 4)
scores = np.fromfile(scores_file, dtype=np.float32).reshape(76725, config.num_classes)
return boxes, scores
def get_img_size(file_name):
img = Image.open(file_name)
return img.size
def get_img_id(img_id_file):
f = open(img_id_file)
lines = f.readlines()
ids = []
for line in lines:
ids.append(int(line))
return ids
def cal_acc(result_path, img_path, img_id_file):
"""Calculate acc"""
ids = get_img_id(img_id_file)
imgs = os.listdir(img_path)
pred_data = []
for img in imgs:
img_id = img.split('.')[0]
if int(img_id) not in ids:
continue
boxes, box_scores = get_pred(result_path, img_id)
w, h = get_img_size(os.path.join(img_path, img))
img_shape = np.array((h, w), dtype=np.float32)
pred_data.append({"boxes": boxes,
"box_scores": box_scores,
"img_id": int(img_id),
"image_shape": img_shape})
mAP = metrics(pred_data)
print(f"mAP: {mAP}")
if __name__ == '__main__':
cal_acc(args.result_path, args.img_path, args.img_id_file)

View File

@ -1,4 +0,0 @@
numpy
pycocotools
easydict
opencv-python

View File

@ -0,0 +1,104 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
device_id=0
if [ $# == 3 ]; then
device_id=$3
fi
echo $model
echo $data_path
echo $device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function compile_app()
{
cd ../ascend310_infer || exit
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
cd - || exit
}
function infer()
{
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/out/main --model_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log
if [ $? -ne 0 ]; then
echo "execute inference failed"
exit 1
fi
}
function cal_acc()
{
python ../postprocess.py --result_path=result_Files --img_path=$data_path --img_id_file=../ascend310_infer/image_id.txt &> acc.log
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi
}
compile_app
infer
cal_acc

View File

@ -246,7 +246,8 @@ class TrainingWrapper(nn.Cell):
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
self.optimizer(grads)
return loss
class retinanetInferWithDecoder(nn.Cell):