forked from mindspore-Ecosystem/mindspore
parent
83b56cac85
commit
4f777e2ac5
|
@ -0,0 +1,239 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [glore_res50描述](#glore_res50描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [特性](#特性)
|
||||
- [混合精度](#混合精度)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [评估过程](#评估过程)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#ModelZoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# glore_res描述
|
||||
|
||||
## 概述
|
||||
|
||||
卷积神经网络擅长提取局部关系,但是在处理全局上的区域间关系时显得低效,且需要堆叠很多层才可能完成,而在区域之间进行全局建模和推理对很多计算机视觉任务有益。为了进行全局推理,facebook research、新加坡国立大学和360 AI研究所提出了基于图的全局推理模块-Global Reasoning Unit,可以被插入到很多任务的网络模型中。glore_res200是在ResNet200的Stage2, Stage3中分别均匀地插入了2和3个全局推理模块的用于图像分类任务的网络模型。
|
||||
|
||||
如下为MindSpore使用ImageNet2012数据集对glore_res50进行训练的示例。glore_res50可参考[论文1](https://arxiv.org/pdf/1811.12814v1.pdf)
|
||||
|
||||
## 论文
|
||||
|
||||
1. [论文](https://arxiv.org/pdf/1811.12814v1.pdf):Yupeng Chen, Marcus Rohrbach, Zhicheng Yan, Shuicheng Yan,
|
||||
Jiashi Feng, Yannis Kalantidis."Deep Residual Learning for Image Recognition"
|
||||
|
||||
# 模型架构
|
||||
|
||||
glore_res的总体网络架构如下:
|
||||
[链接](https://arxiv.org/pdf/1811.12814v1.pdf)
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[ImageNet2012](http://www.image-net.org/)
|
||||
|
||||
- 数据集大小:共1000个类、224*224彩色图像
|
||||
- 训练集:共1,281,167张图像
|
||||
- 测试集:共50,000张图像
|
||||
- 数据格式:JPEG
|
||||
- 注:数据在dataset.py中处理。
|
||||
- 下载数据集,目录结构如下:
|
||||
|
||||
```text
|
||||
└─imagenet_original
|
||||
├─train # 训练数据集
|
||||
└─val # 评估数据集
|
||||
```
|
||||
|
||||
# 特性
|
||||
|
||||
## 混合精度
|
||||
|
||||
采用[混合精度](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```text
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train.sh [DATA_PATH] [DEVICE_NUM]
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train.sh [DATA_PATH] [DEVICE_ID]
|
||||
|
||||
# 运行评估示例
|
||||
用法:sh run_eval.sh [DATA_PATH] [DEVICE_ID] [CKPT_PATH]
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```shell
|
||||
.
|
||||
└──glore_res50
|
||||
├── README.md
|
||||
├── scripts
|
||||
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||
├── run_eval.sh # 启动Ascend评估
|
||||
├── run_standalone_train.sh # 启动Ascend单机训练(单卡)
|
||||
├── src
|
||||
├── __init__.py
|
||||
├── autoaugment.py # AutoAugment组件与类
|
||||
├── config.py # 参数配置
|
||||
├── dataset.py # 数据预处理
|
||||
├── glore_res50.py # glore_res50网络定义
|
||||
├── loss.py # ImageNet2012数据集的损失定义
|
||||
├── save_callback.py # 训练时推理并保存最优精度下的参数
|
||||
└── lr_generator.py # 生成每个步骤的学习率
|
||||
├── eval.py # 评估网络
|
||||
└── train.py # 训练网络
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 配置glore_res50和ImageNet2012数据集。
|
||||
|
||||
```text
|
||||
"class_num":1000, # 数据集类数
|
||||
"batch_size":128, # 输入张量的批次大小
|
||||
"loss_scale":1024, # 损失等级
|
||||
"momentum":0.9, # 动量优化器
|
||||
"weight_decay":1e-4, # 权重衰减
|
||||
"epoch_size":120, # 此值仅适用于训练;应用于推理时固定为1
|
||||
"pretrained": False, # 加载预训练权重
|
||||
"pretrain_epoch_size": 0, # 加载预训练检查点之前已经训练好的模型的周期大小;实际训练周期大小等于epoch_size减去pretrain_epoch_size
|
||||
"save_checkpoint":True, # 是否保存检查点
|
||||
"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
|
||||
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
|
||||
"save_checkpoint_path":"./", # 检查点相对于执行路径的保存路径
|
||||
"warmup_epochs":0, # 热身周期数
|
||||
"lr_decay_mode":"Linear", # 用于生成学习率的衰减模式
|
||||
"use_label_smooth":True, # 标签平滑
|
||||
"label_smooth_factor":0.05, # 标签平滑因子
|
||||
"weight_init": "xavier_uniform", # 权重初始化方式,可选"he_normal", "he_uniform", "xavier_uniform"
|
||||
"use_autoaugment": True, # 是否应用AutoAugment方法
|
||||
"lr_init":0, # 初始学习率
|
||||
"lr_max":0.8, # 最大学习率
|
||||
"lr_end":0.0, # 最小学习率
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 用法
|
||||
|
||||
#### Ascend处理器环境运行
|
||||
|
||||
```text
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train.sh [DATA_PATH] [DEVICE_NUM]
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train.sh [DATA_PATH] [DEVICE_ID]
|
||||
|
||||
```
|
||||
|
||||
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||
|
||||
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
|
||||
|
||||
### 结果
|
||||
|
||||
- 使用ImageNet2012数据集训练glore_res50
|
||||
|
||||
```text
|
||||
# 分布式训练结果(8P)
|
||||
epoch:1 step:1251, loss is 5.721338
|
||||
epoch:2 step:1251, loss is 4.8941164
|
||||
epoch:3 step:1251, loss is 4.3002024
|
||||
epoch:4 step:1251, loss is 3.862403
|
||||
epoch:5 step:1251, loss is 3.5204496
|
||||
...
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 用法
|
||||
|
||||
#### Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# 评估
|
||||
用法:sh run_eval.sh [DATA_PATH] [DEVICE_ID] [CKPT_PATH]
|
||||
```
|
||||
|
||||
```bash
|
||||
# 评估示例
|
||||
sh run_eval.sh ~/dataset/imagenet 0 ~/ckpt/glore_res50_120-1251.ckpt
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
评估结果保存在示例路径中,文件夹名为“eval”。您可在此路径下的日志找到如下结果:
|
||||
|
||||
- 使用ImageNet2012数据集评估glore_res50
|
||||
|
||||
```text
|
||||
{'Accuracy': 0.7844638020833334}
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### ImageNet2012上的glore_res50
|
||||
|
||||
| 参数 | Ascend 910
|
||||
| -------------------------- | -------------------------------------- |
|
||||
| 模型版本 | glore_res50
|
||||
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G |
|
||||
| 上传日期 | 2021-03-21 |
|
||||
| MindSpore版本 | r1.1 |
|
||||
| 数据集 | ImageNet2012 |
|
||||
| 训练参数 | epoch=120, steps per epoch=1251, batch_size = 128 |
|
||||
| 优化器 | Momentum |
|
||||
| 损失函数 | Softmax交叉熵 |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 1.8464266 |
|
||||
| 速度 | 263.483毫秒/步(8卡)|
|
||||
| 总时长 | 10.98小时 |
|
||||
| 参数(M) | 30.5 |
|
||||
| 微调检查点| 233.46M(.ckpt文件)|
|
||||
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/glore_res50_r1.1/model_zoo/research/cv/glore_res50) |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
使用了train.py中的随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
################################eval glore_resnet50################################
|
||||
python eval.py
|
||||
"""
|
||||
import os
|
||||
import ast
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.glore_resnet50 import glore_resnet50
|
||||
from src.dataset import create_eval_dataset
|
||||
from src.loss import CrossEntropySmooth, SoftmaxCrossEntropyExpand
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification with glore_resnet50')
|
||||
parser.add_argument('--use_glore', type=ast.literal_eval, default=True, help='Enable GloreUnit')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, help='Train output in modelarts')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
parser.add_argument('--ckpt_url', type=str, default=None)
|
||||
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=True)
|
||||
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.is_modelarts:
|
||||
import moxing as mox
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
# init context
|
||||
device_id = args_opt.device_id
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False,
|
||||
device_id=device_id)
|
||||
|
||||
# dataset
|
||||
eval_dataset_path = os.path.join(args_opt.data_url, 'val')
|
||||
if args_opt.is_modelarts:
|
||||
mox.file.copy_parallel(src_url=args_opt.data_url, dst_url='/cache/dataset')
|
||||
eval_dataset_path = '/cache/dataset/'
|
||||
predict_data = create_eval_dataset(dataset_path=eval_dataset_path, repeat_num=1, batch_size=config.batch_size)
|
||||
step_size = predict_data.get_dataset_size()
|
||||
if step_size == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
# define net
|
||||
net = glore_resnet50(class_num=config.class_num, use_glore=args_opt.use_glore)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.ckpt_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# define loss, model
|
||||
if config.use_label_smooth:
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=config.label_smooth_factor,
|
||||
num_classes=config.class_num)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
print("============== Starting Testing ==============")
|
||||
print("ckpt path : {}".format(args_opt.ckpt_url))
|
||||
print("data path : {}".format(eval_dataset_path))
|
||||
acc = model.eval(predict_data)
|
||||
print("==============Acc: {} ==============".format(acc))
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.config import config
|
||||
from src.glore_resnet50 import glore_resnet50
|
||||
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--file_name", type=str, default="googlenet", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
parser.add_argument("--ckpt_url", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = glore_resnet50(class_num=config.class_num)
|
||||
|
||||
assert config.checkpoint_path is not None, "arg.ckpt_url is None."
|
||||
param_dict = load_checkpoint(args.ckpt_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.ones([args.batch_size, 3, 224, 224]), mstype.float32)
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,79 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH RANK_SIZE"
|
||||
echo "For example: bash run.sh /path/dataset 8"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
DATA_PATH=$1
|
||||
export DATA_PATH=${DATA_PATH}
|
||||
RANK_SIZE=$2
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
for((i=1;i<${RANK_SIZE};i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cd ./device$i
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ../*.py ./device$i
|
||||
cp ../src/*.py ./device$i/src
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
python train.py --data_url $1 --is_modelarts False --run_distribute True > train$i.log 2>&1 &
|
||||
echo "$i finish"
|
||||
cd ../
|
||||
done
|
||||
rm -rf device0
|
||||
mkdir device0
|
||||
cd ./device0
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ../*.py ./device0
|
||||
cp ../src/*.py ./device0/src
|
||||
cd ./device0
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
echo "start training for device 0"
|
||||
env > env0.log
|
||||
python train.py --data_url $1 --is_modelarts False --run_distribute True > train0.log 2>&1 &
|
||||
|
||||
echo "training in the background."
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH DEVICE_ID CKPT_PATH"
|
||||
echo "For example: bash run.sh /path/dataset 0 /path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
DATA_PATH=$1
|
||||
DEVICE_ID=$2
|
||||
export DATA_PATH=${DATA_PATH}
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf eval/
|
||||
mkdir eval
|
||||
export DEVICE_ID=$2
|
||||
export RANK_ID=$2
|
||||
env > env0.log
|
||||
python3 eval.py --data_url $1 --is_modelarts False --device_id $2 --ckpt_url $3> ./eval/eval.log 2>&1
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "evaling success"
|
||||
else
|
||||
echo "evaling failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH DEVICE_ID"
|
||||
echo "For example: bash run.sh /path/dataset 0"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
DATA_PATH=$1
|
||||
DEVICE_ID=$2
|
||||
export DATA_PATH=${DATA_PATH}
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
cd ../
|
||||
rm -rf train
|
||||
mkdir train
|
||||
export DEVICE_ID=$2
|
||||
export RANK_ID=$2
|
||||
env > env0.log
|
||||
echo "Standalone train begin."
|
||||
python3 train.py --data_url $1 --is_modelarts False --run_distribute False --device_id $2 > ./train/train_alone.log 2>&1
|
||||
if [ $? -eq 0 ];then
|
||||
echo "training success"
|
||||
else
|
||||
echo "training failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define autoaugment"""
|
||||
import os
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
# define Auto Augmentation operators
|
||||
PARAMETER_MAX = 10
|
||||
|
||||
|
||||
def float_parameter(level, maxval):
|
||||
return float(level) * maxval / PARAMETER_MAX
|
||||
|
||||
|
||||
def int_parameter(level, maxval):
|
||||
return int(level * maxval / PARAMETER_MAX)
|
||||
|
||||
|
||||
def shear_x(level):
|
||||
v = float_parameter(level, 0.3)
|
||||
return c_transforms.RandomChoice(
|
||||
[c_vision.RandomAffine(degrees=0, shear=(-v, -v)), c_vision.RandomAffine(degrees=0, shear=(v, v))])
|
||||
|
||||
|
||||
def shear_y(level):
|
||||
v = float_parameter(level, 0.3)
|
||||
return c_transforms.RandomChoice(
|
||||
[c_vision.RandomAffine(degrees=0, shear=(0, 0, -v, -v)), c_vision.RandomAffine(degrees=0, shear=(0, 0, v, v))])
|
||||
|
||||
|
||||
def translate_x(level):
|
||||
v = float_parameter(level, 150 / 331)
|
||||
return c_transforms.RandomChoice(
|
||||
[c_vision.RandomAffine(degrees=0, translate=(-v, -v)), c_vision.RandomAffine(degrees=0, translate=(v, v))])
|
||||
|
||||
|
||||
def translate_y(level):
|
||||
v = float_parameter(level, 150 / 331)
|
||||
return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(0, 0, -v, -v)),
|
||||
c_vision.RandomAffine(degrees=0, translate=(0, 0, v, v))])
|
||||
|
||||
|
||||
def color_impl(level):
|
||||
v = float_parameter(level, 1.8) + 0.1
|
||||
return c_vision.RandomColor(degrees=(v, v))
|
||||
|
||||
|
||||
def rotate_impl(level):
|
||||
v = int_parameter(level, 30)
|
||||
return c_transforms.RandomChoice(
|
||||
[c_vision.RandomRotation(degrees=(-v, -v)), c_vision.RandomRotation(degrees=(v, v))])
|
||||
|
||||
|
||||
def solarize_impl(level):
|
||||
level = int_parameter(level, 256)
|
||||
v = 256 - level
|
||||
return c_vision.RandomSolarize(threshold=(0, v))
|
||||
|
||||
|
||||
def posterize_impl(level):
|
||||
level = int_parameter(level, 4)
|
||||
v = 4 - level
|
||||
return c_vision.RandomPosterize(bits=(v, v))
|
||||
|
||||
|
||||
def contrast_impl(level):
|
||||
v = float_parameter(level, 1.8) + 0.1
|
||||
return c_vision.RandomColorAdjust(contrast=(v, v))
|
||||
|
||||
|
||||
def autocontrast_impl(level):
|
||||
return c_vision.AutoContrast()
|
||||
|
||||
|
||||
def sharpness_impl(level):
|
||||
v = float_parameter(level, 1.8) + 0.1
|
||||
return c_vision.RandomSharpness(degrees=(v, v))
|
||||
|
||||
|
||||
def brightness_impl(level):
|
||||
v = float_parameter(level, 1.8) + 0.1
|
||||
return c_vision.RandomColorAdjust(brightness=(v, v))
|
||||
|
||||
|
||||
# define the Auto Augmentation policy
|
||||
imagenet_policy = [
|
||||
[(posterize_impl(8), 0.4), (rotate_impl(9), 0.6)],
|
||||
[(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
|
||||
[(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
|
||||
[(posterize_impl(7), 0.6), (posterize_impl(6), 0.6)],
|
||||
[(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
|
||||
|
||||
[(c_vision.Equalize(), 0.4), (rotate_impl(8), 0.8)],
|
||||
[(solarize_impl(3), 0.6), (c_vision.Equalize(), 0.6)],
|
||||
[(posterize_impl(5), 0.8), (c_vision.Equalize(), 1.0)],
|
||||
[(rotate_impl(3), 0.2), (solarize_impl(8), 0.6)],
|
||||
[(c_vision.Equalize(), 0.6), (posterize_impl(6), 0.4)],
|
||||
|
||||
[(rotate_impl(8), 0.8), (color_impl(0), 0.4)],
|
||||
[(rotate_impl(9), 0.4), (c_vision.Equalize(), 0.6)],
|
||||
[(c_vision.Equalize(), 0.0), (c_vision.Equalize(), 0.8)],
|
||||
[(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
|
||||
[(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
|
||||
|
||||
[(rotate_impl(8), 0.8), (color_impl(2), 1.0)],
|
||||
[(color_impl(8), 0.8), (solarize_impl(7), 0.8)],
|
||||
[(sharpness_impl(7), 0.4), (c_vision.Invert(), 0.6)],
|
||||
[(shear_x(5), 0.6), (c_vision.Equalize(), 1.0)],
|
||||
[(color_impl(0), 0.4), (c_vision.Equalize(), 0.6)],
|
||||
|
||||
[(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
|
||||
[(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
|
||||
[(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
|
||||
[(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
|
||||
[(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
|
||||
]
|
||||
|
||||
|
||||
def autoaugment(dataset_path, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
define dataset with autoaugment
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
trans = [
|
||||
c_vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
]
|
||||
|
||||
post_trans = [
|
||||
c_vision.RandomHorizontalFlip(prob=0.5),
|
||||
c_vision.Normalize(mean=mean, std=std),
|
||||
c_vision.HWC2CHW()
|
||||
]
|
||||
dataset = ds.map(operations=trans, input_columns="image")
|
||||
dataset = dataset.map(operations=c_vision.RandomSelectSubpolicy(imagenet_policy), input_columns=["image"])
|
||||
dataset = dataset.map(operations=post_trans, input_columns="image")
|
||||
|
||||
type_cast_op = c_transforms.TypeCast(mstype.int32)
|
||||
dataset = dataset.map(operations=type_cast_op, input_columns="label")
|
||||
# apply the batch operation
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
# apply the repeat operation
|
||||
dataset = dataset.repeat(repeat_num)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = int(os.environ.get("RANK_SIZE"))
|
||||
rank_id = int(os.environ.get("RANK_ID"))
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""config for train or evaluation"""
|
||||
from easydict import EasyDict
|
||||
|
||||
config = EasyDict({
|
||||
"class_num": 1000,
|
||||
"batch_size": 128,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 120,
|
||||
"pretrained": False,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 5,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 5,
|
||||
"lr_decay_mode": "poly",
|
||||
"use_label_smooth": True,
|
||||
"use_autoaugment": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"weight_init": "xavier_uniform",
|
||||
"lr_init": 0,
|
||||
"lr_max": 0.6,
|
||||
"lr_end": 0.0
|
||||
})
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
|
||||
def cifar10(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or evaluate cifar10 dataset for resnet50
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
else:
|
||||
device_num = 1
|
||||
if device_num == 1:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
# define map operations
|
||||
trans = []
|
||||
if do_train:
|
||||
trans += [
|
||||
C.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||
C.RandomHorizontalFlip(prob=0.5)
|
||||
]
|
||||
|
||||
trans += [
|
||||
C.Resize((224, 224)),
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def create_train_dataset(dataset_path, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# std = [127.5, 127.5, 127.5]
|
||||
# mean = [127.5, 127.5, 127.5]
|
||||
|
||||
# define map operations
|
||||
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def create_eval_dataset(dataset_path, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# std = [127.5, 127.5, 127.5]
|
||||
# mean = [127.5, 127.5, 127.5]
|
||||
|
||||
# define map operations
|
||||
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,394 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create glore_resnet50
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.08,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.08,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel):
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
|
||||
class GCN(nn.Cell):
|
||||
"""
|
||||
Graph convolution unit (single layer)
|
||||
"""
|
||||
|
||||
def __init__(self, num_state, num_mode, bias=False):
|
||||
super(GCN, self).__init__()
|
||||
self.conv1 = nn.Conv1d(num_mode, num_mode, kernel_size=1)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, has_bias=bias)
|
||||
self.transpose = P.Transpose()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct GCN"""
|
||||
identity = x
|
||||
# (n, num_state, num_node) -> (n, num_node, num_state)
|
||||
# -> (n, num_state, num_node)
|
||||
out = self.transpose(x, (0, 2, 1))
|
||||
out = self.conv1(out)
|
||||
out = self.transpose(out, (0, 2, 1))
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
return out
|
||||
|
||||
|
||||
class GloreUnit(nn.Cell):
|
||||
"""
|
||||
Graph-based Global Reasoning Unit
|
||||
Parameter:
|
||||
'normalize' is not necessary if the input size is fixed
|
||||
Args:
|
||||
num_in: Input channel
|
||||
num_mid:
|
||||
"""
|
||||
|
||||
def __init__(self, num_in, num_mid,
|
||||
normalize=False):
|
||||
super(GloreUnit, self).__init__()
|
||||
self.normalize = normalize
|
||||
self.num_s = int(2 * num_mid) # 512 num_in = 1024
|
||||
self.num_n = int(1 * num_mid) # 256
|
||||
# reduce dim
|
||||
self.conv_state = nn.SequentialCell([_bn(num_in),
|
||||
nn.ReLU(),
|
||||
_conv1x1(num_in, self.num_s, stride=1)])
|
||||
# projection map
|
||||
self.conv_proj = nn.SequentialCell([_bn(num_in),
|
||||
nn.ReLU(),
|
||||
_conv1x1(num_in, self.num_n, stride=1)])
|
||||
|
||||
self.gcn = GCN(num_state=self.num_s, num_mode=self.num_n)
|
||||
|
||||
self.conv_extend = nn.SequentialCell([_bn_last(self.num_s),
|
||||
nn.ReLU(),
|
||||
_conv1x1(self.num_s, num_in, stride=1)])
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.BatchMatMul()
|
||||
self.transpose = P.Transpose()
|
||||
self.add = P.Add()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct Graph-based Global Reasoning Unit"""
|
||||
n = x.shape[0]
|
||||
identity = x
|
||||
# (n, num_in, h, w) --> (n, num_state, h, w)
|
||||
# --> (n, num_state, h*w)
|
||||
x_conv_state = self.conv_state(x)
|
||||
x_state_reshaped = self.reshape(x_conv_state, (n, self.num_s, -1))
|
||||
|
||||
# (n, num_in, h, w) --> (n, num_node, h, w)
|
||||
# --> (n, num_node, h*w)
|
||||
x_conv_proj = self.conv_proj(x)
|
||||
x_proj_reshaped = self.reshape(x_conv_proj, (n, self.num_n, -1))
|
||||
|
||||
# (n, num_in, h, w) --> (n, num_node, h, w)
|
||||
# --> (n, num_node, h*w)
|
||||
x_rproj_reshaped = x_proj_reshaped
|
||||
|
||||
# projection: coordinate space -> interaction space
|
||||
# (n, num_state, h*w) x (n, num_node, h*w)T --> (n, num_state, num_node)
|
||||
x_proj_reshaped = self.transpose(x_proj_reshaped, (0, 2, 1))
|
||||
|
||||
x_state_reshaped_fp16 = self.cast(x_state_reshaped, mstype.float16)
|
||||
x_proj_reshaped_fp16 = self.cast(x_proj_reshaped, mstype.float16)
|
||||
x_n_state_fp16 = self.matmul(x_state_reshaped_fp16, x_proj_reshaped_fp16)
|
||||
x_n_state = self.cast(x_n_state_fp16, mstype.float32)
|
||||
|
||||
if self.normalize:
|
||||
x_n_state = x_n_state * (1. / x_state_reshaped.shape[2])
|
||||
|
||||
# reasoning: (n, num_state, num_node) -> (n, num_state, num_node)
|
||||
x_n_rel = self.gcn(x_n_state)
|
||||
|
||||
# reverse projection: interaction space -> coordinate space
|
||||
# (n, num_state, num_node) x (n, num_node, h*w) --> (n, num_state, h*w)
|
||||
x_n_rel_fp16 = self.cast(x_n_rel, mstype.float16)
|
||||
x_rproj_reshaped_fp16 = self.cast(x_rproj_reshaped, mstype.float16)
|
||||
x_state_reshaped_fp16 = self.matmul(x_n_rel_fp16, x_rproj_reshaped_fp16)
|
||||
x_state_reshaped = self.cast(x_state_reshaped_fp16, mstype.float32)
|
||||
|
||||
# (n, num_state, h*w) --> (n, num_state, h, w)
|
||||
x_state = self.reshape(x_state_reshaped, (n, self.num_s, identity.shape[2], identity.shape[3]))
|
||||
|
||||
# (n, num_state, h, w) -> (n, num_in, h, w)
|
||||
x_conv_extend = self.conv_extend(x_state)
|
||||
out = self.add(x_conv_extend, identity)
|
||||
return out
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, stride=2)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.stride = stride
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1)
|
||||
self.bn1 = _bn(channel)
|
||||
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride)
|
||||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1)
|
||||
self.bn3 = _bn_last(out_channel)
|
||||
self.relu = nn.ReLU()
|
||||
self.down_sample = False
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell(
|
||||
[
|
||||
nn.AvgPool2d(kernel_size=stride, stride=stride),
|
||||
_conv1x1(in_channel, out_channel, stride=1),
|
||||
_bn(out_channel)
|
||||
])
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct ResidualBlock"""
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
strides (list): Stride size in each layer.
|
||||
num_classes (int): The number of classes that the training images are belonging to.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> [1, 2, 2, 2],
|
||||
>>> 10)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes,
|
||||
use_glore=False):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
|
||||
self.conv1 = nn.SequentialCell(OrderedDict([
|
||||
('conv_1', _conv3x3(3, 32, stride=2)),
|
||||
('bn1', _bn(32)),
|
||||
('relu1', nn.ReLU()),
|
||||
('conv_2', _conv3x3(32, 32, stride=1)),
|
||||
('bn2', _bn(32)),
|
||||
('relu2', nn.ReLU()),
|
||||
('conv_3', _conv3x3(32, 64, stride=1)),
|
||||
]))
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1])
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2],
|
||||
use_glore=use_glore)
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3])
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(out_channels[3], num_classes)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride,
|
||||
use_glore=False, glore_pos=None):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
if use_glore and glore_pos is None:
|
||||
glore_pos = [1, 3, 5]
|
||||
|
||||
layers = []
|
||||
for i in range(1, layer_num + 1):
|
||||
resnet_block = block(in_channel=(in_channel if i == 1 else out_channel),
|
||||
out_channel=out_channel,
|
||||
stride=(stride if i == 1 else 1))
|
||||
layers.append(resnet_block)
|
||||
if use_glore and i in glore_pos:
|
||||
glore_unit = GloreUnit(out_channel, int(out_channel / 4))
|
||||
layers.append(glore_unit)
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct ResNet"""
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
out = self.mean(c5, (2, 3))
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def glore_resnet50(class_num=1000, use_glore=True):
|
||||
"""
|
||||
Get ResNet50 with GloreUnit neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
use_glore (bool)
|
||||
Returns:
|
||||
Cell, cell instance of ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = glore_resnet50(10)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num,
|
||||
use_glore=use_glore)
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss for glore_resnet50"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyExpand(nn.Cell): # pylint: disable=missing-docstring
|
||||
def __init__(self, sparse=False):
|
||||
super(SoftmaxCrossEntropyExpand, self).__init__()
|
||||
self.exp = ops.Exp()
|
||||
self.sum = ops.ReduceSum(keep_dims=True)
|
||||
self.onehot = ops.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.div = ops.RealDiv()
|
||||
self.log = ops.Log()
|
||||
self.sum_cross_entropy = ops.ReduceSum(keep_dims=False)
|
||||
self.mul = ops.Mul()
|
||||
self.mul2 = ops.Mul()
|
||||
self.mean = ops.ReduceMean(keep_dims=False)
|
||||
self.sparse = sparse
|
||||
self.max = ops.ReduceMax(keep_dims=True)
|
||||
self.sub = ops.Sub()
|
||||
self.eps = Tensor(1e-24, mstype.float32)
|
||||
|
||||
def construct(self, logit, label): # pylint: disable=missing-docstring
|
||||
logit_max = self.max(logit, -1)
|
||||
exp = self.exp(self.sub(logit, logit_max))
|
||||
exp_sum = self.sum(exp, -1)
|
||||
softmax_result = self.div(exp, exp_sum)
|
||||
if self.sparse:
|
||||
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
|
||||
|
||||
softmax_result_log = self.log(softmax_result + self.eps)
|
||||
loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
|
||||
loss = self.mul2(ops.scalar_to_array(-1.0), loss)
|
||||
loss = self.mean(loss, -1)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class CrossEntropySmooth(_Loss):
|
||||
"""CrossEntropy"""
|
||||
|
||||
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropySmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.sparse = sparse
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||
|
||||
def construct(self, logit, label):
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, label)
|
||||
return loss
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'poly':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
decay_steps = total_steps - warmup_steps
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
lr_each_step.append(lr)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):
|
||||
"""
|
||||
generate learning rate array with cosine
|
||||
|
||||
Args:
|
||||
lr(float): base learning rate
|
||||
steps_per_epoch(int): steps size of one epoch
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
max_epoch(int): total epochs of training
|
||||
global_step(int): the current start index of lr array
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
decay_steps = total_steps - warmup_steps
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = base_lr * decayed
|
||||
lr_each_step.append(lr)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[global_step:]
|
||||
return learning_rate
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define savecallback, save best model while training."""
|
||||
from mindspore import save_checkpoint
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
class SaveCallback(Callback):
|
||||
"""
|
||||
define savecallback, save best model while training.
|
||||
"""
|
||||
def __init__(self, model_save, eval_dataset_save, save_file_path):
|
||||
super(SaveCallback, self).__init__()
|
||||
self.model = model_save
|
||||
self.eval_dataset = eval_dataset_save
|
||||
self.acc = 0.78
|
||||
self.save_path = save_file_path
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
eval and save model while training.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
result = self.model.eval(self.eval_dataset)
|
||||
print(result)
|
||||
if result['Accuracy'] > self.acc:
|
||||
self.acc = result['Accuracy']
|
||||
file_name = self.save_path + str(self.acc) + ".ckpt"
|
||||
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
|
||||
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
|
|
@ -0,0 +1,181 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
################################train glore_resnet50################################
|
||||
python train.py
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore import dataset as de
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.communication.management import init
|
||||
import mindspore.common.initializer as weight_init
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
|
||||
from src.config import config
|
||||
from src.lr_generator import get_lr
|
||||
from src.dataset import create_train_dataset, create_eval_dataset, _get_rank_info
|
||||
from src.save_callback import SaveCallback
|
||||
from src.glore_resnet50 import glore_resnet50
|
||||
from src.loss import SoftmaxCrossEntropyExpand, CrossEntropySmooth
|
||||
from src.autoaugment import autoaugment
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification with glore_resnet50')
|
||||
parser.add_argument('--use_glore', type=ast.literal_eval, default=True, help='Enable GloreUnit')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
|
||||
parser.add_argument('--data_url', type=str, default=None,
|
||||
help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str)
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=True)
|
||||
parser.add_argument('--pretrained_ckpt', type=str, default=None, help='Pretrained ckpt path')
|
||||
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.is_modelarts:
|
||||
import moxing as mox
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
if args_opt.run_distribute:
|
||||
if target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
if target == "Ascend":
|
||||
device_id = args_opt.device_id
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False, device_id=device_id)
|
||||
# create dataset
|
||||
train_dataset_path = os.path.join(args_opt.data_url, 'train')
|
||||
eval_dataset_path = os.path.join(args_opt.data_url, 'val')
|
||||
|
||||
# download dataset from obs to cache if train on ModelArts
|
||||
if args_opt.is_modelarts:
|
||||
mox.file.copy_parallel(src_url=args_opt.data_url, dst_url='/cache/dataset/device_' + os.getenv('DEVICE_ID'))
|
||||
train_dataset_path = '/cache/dataset/device_' + os.getenv('DEVICE_ID') + '/train'
|
||||
eval_dataset_path = '/cache/dataset/device_' + os.getenv('DEVICE_ID') + '/val'
|
||||
if config.use_autoaugment:
|
||||
print("===========Use autoaugment==========")
|
||||
train_dataset = autoaugment(dataset_path=train_dataset_path, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target)
|
||||
else:
|
||||
train_dataset = create_train_dataset(dataset_path=train_dataset_path, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target)
|
||||
|
||||
eval_dataset = create_eval_dataset(dataset_path=eval_dataset_path, repeat_num=1, batch_size=config.batch_size)
|
||||
|
||||
step_size = train_dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
|
||||
net = glore_resnet50(class_num=config.class_num, use_glore=args_opt.use_glore)
|
||||
|
||||
# init weight
|
||||
if config.pretrained:
|
||||
param_dict = load_checkpoint(args_opt.pretrained_ckpt)
|
||||
load_param_into_net(net, param_dict)
|
||||
else:
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv1d)):
|
||||
if config.weight_init == 'xavier_uniform':
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype)
|
||||
elif config.weight_init == 'he_uniform':
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.HeUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype)
|
||||
else: # config.weight_init == 'he_normal' or the others
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.HeNormal(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype)
|
||||
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype)
|
||||
|
||||
# init lr
|
||||
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
|
||||
total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode=config.lr_decay_mode)
|
||||
lr = Tensor(lr)
|
||||
|
||||
#
|
||||
# define opt
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
|
||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
net_opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||
# define loss, model
|
||||
if config.use_label_smooth:
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=config.label_smooth_factor,
|
||||
num_classes=config.class_num)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=net_opt, loss_scale_manager=loss_scale, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
device_num, device_id = _get_rank_info()
|
||||
if config.save_checkpoint:
|
||||
if args_opt.is_modelarts:
|
||||
save_checkpoint_path = '/cache/train_output/device_' + os.getenv('DEVICE_ID') + '/'
|
||||
else:
|
||||
save_checkpoint_path = config.save_checkpoint_path
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="glore_resnet50", directory=save_checkpoint_path, config=config_ck)
|
||||
save_cb = SaveCallback(model, eval_dataset, save_checkpoint_path)
|
||||
cb += [ckpt_cb, save_cb]
|
||||
|
||||
# train model
|
||||
print("=======Training Begin========")
|
||||
model.train(config.epoch_size - config.pretrain_epoch_size, train_dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
# copy train result from cache to obs
|
||||
if args_opt.is_modelarts:
|
||||
mox.file.copy_parallel(src_url='/cache/train_output', dst_url=args_opt.train_url)
|
Loading…
Reference in New Issue