!15989 Add resnetv2_50 gpu

From: @brandonye
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-28 19:36:46 +08:00 committed by Gitee
commit bf65427154
15 changed files with 1626 additions and 0 deletions

View File

@ -0,0 +1,304 @@
# Resnetv2描述
## 概述
ResNet系列模型是在2015年提出的该网络创新性的提出了残差结构通过堆叠多个残差结构从而构建了ResNet网络。ResNet一定程度上解决了传统的卷积网络或全连接网络或多或少存在信息丢失的问题。通过将输入信息传递给输出确保信息完整性使得网络深度得以不断加深的同时避免了梯度消失或爆炸的影响。ResNetv2是何凯明团队在ResNet发表后又进一步对其网络结构进行了改进和优化通过推导证明了前向参数和反向梯度如果直接从Residual Block传递到下一个Residual Block而不用经过ReLU等操作效果会更好。因此调整了激活层和BN层与卷积层的运算先后顺序并经过实验验证在深度网络中ResNetv2会有更好的收敛效果。
如下为MindSpore使用Cifar10/ImageNet2012数据集对ResNetv2_50/ResNetv2_101/ResNetv2_152进行训练的示例。
## 论文
1. [论文](https://arxiv.org/pdf/1603.05027.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
# 模型架构
[ResNetv2_50](https://arxiv.org/pdf/1603.05027.pdf)的整体网络架构和[Resnet50](https://arxiv.org/pdf/1512.03385.pdf)的架构相仿仅调整了激活层和BN层与卷积层的先后顺序。
# 数据集
使用的数据集:[Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)
- 数据集大小共10个类、60,000个32*32彩色图像
- 训练集50,000个图像
- 测试集10,000个图像
- 数据格式:二进制文件
- 注数据在dataset.py中处理。
- 下载数据集。目录结构如下:
```text
├─cifar-10-batches-bin
└─cifar-10-verify-bin
```
使用的数据集:[ImageNet2012](http://www.image-net.org/)
- 数据集大小共1000个类、224*224彩色图像
- 训练集共1,281,167张图像
- 测试集共50,000张图像
- 数据格式JPEG
- 注数据在dataset.py中处理。
- 下载数据集,目录结构如下:
```text
└─dataset
├─ilsvrc # 训练数据集
└─validation_preprocess # 评估数据集
```
# 环境要求
- 硬件
- 准备Ascend处理器搭建硬件环境。如需试用昇腾处理器请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com审核通过即可获得资源。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```Shell
# 分布式训练
用法sh run_distribute_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]
# 单机训练
用法sh run_standalone_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]
# 运行评估示例
用法sh run_eval.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
- GPU处理器环境运行
```shell
# 分布式训练
用法sh run_distribute_train_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]
# 单机训练
用法sh run_standalone_train_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]
# 运行评估示例
用法sh run_eval_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
# 脚本说明
## 脚本及样例代码
```text
└──resnetv2
├── README.md
├── scripts
├── run_distribute_train_gpu.sh # 启动gpu分布式训练8卡
├── run_eval_gpu.sh # 启动gpu评估
├── run_standalone_train_gpu.sh # 启动gpu单机训练单卡
├── run_distribute_train.sh # 启动Ascend分布式训练8卡
├── run_eval.sh # 启动Ascend评估
└── run_standalone_train.sh # 启动Ascend单机训练单卡
├── src
├── config.py # 参数配置
├── dataset.py # 数据预处理
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
├── lr_generator.py # 生成每个步骤的学习率
└── resnetv2.py # ResNet骨干网络
├── eval.py # 评估网络
└── train.py # 训练网络
└── export.py # 导出网络
```
# 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置ResNetv2_50和cifar10数据集。
```Python
"class_num":10, # 数据集类数
"batch_size":32, # 输入张量的批次大小
"loss_scale":1024, # 损失等级
"momentum":0.9, # 动量优化器
"weight_decay":5e-4, # 权重衰减
"epoch_size":100, # 训练周期大小
"save_checkpoint":True, # 是否保存检查点
"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
"save_checkpoint_path":"./checkpoint", # 检查点相对于执行路径的保存路径
"low_memory": False, # 显存不足时可设置为Ture
"warmup_epochs":5, # 热身周期数
"lr_decay_mode":"cosine", # 用于生成学习率的衰减模式
"lr_init":0.1, # 基础学习率
"lr_end":0.0000000005, # 最终学习率
"lr_max":0.1, # 最大学习率
```
- 配置ResNetv2_50和imagenet2012数据集。
```python
"class_num":1001, # 数据集类数
"batch_size":64, # 输入张量的批次大小
"loss_scale":1024, # 损失等级
"momentum":0.9, # 动量优化器
"weight_decay":1e-4, # 权重衰减
"epoch_size":100, # 训练周期大小
"save_checkpoint":True, # 是否保存检查点
"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
"save_checkpoint_path":"./checkpoint", # 检查点相对于执行路径的保存路径
"low_memory": True, # 显存不足时可设置为Ture默认为False
"warmup_epochs":5, # 热身周期数
"use_label_smooth":True, # 标签平滑
"label_smooth_factor":0.1, # 标签平滑因子
"lr_decay_mode":"cosine", # 用于生成学习率的衰减模式
"lr_init":0.05, # 基础学习率
"lr_end":0.0000001, # 最终学习率
"lr_max":0.05, # 最大学习率
```
# 训练过程
## 用法
### Ascend处理器环境运行
```Shell
# 分布式训练
用法sh run_distribute_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]
# 单机训练
用法sh run_standalone_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]
```
分布式训练需要提前创建JSON格式的HCCL配置文件。
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
### GPU处理器环境运行
```shell
# 分布式训练
用法sh run_distribute_train_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]
# 单机训练
用法sh run_standalone_train_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]
```
## 结果
- 使用cifar10数据集训练ResNetv2_50
```text
# Ascend分布式训练结果8P
epoch: 41 step: 195, loss is 0.17125674
epoch time: 4733.000 ms, per step time: 24.143 ms
epoch: 42 step: 195, loss is 0.0011220031
epoch time: 4735.284 ms, per step time: 24.135 ms
epoch: 43 step: 195, loss is 0.105422504
epoch time: 4737.401 ms, per step time: 24.166 ms
...
```
- 使用imagenet2012数据集训练ResNetv2_50
```text
# Ascend分布式训练结果 (8P)
epoch: 61 step: 2502, loss is 2.4235027
epoch time: 813367.327 ms, per step time: 325.087 ms
epoch: 62 step: 2502, loss is 2.0396166
epoch time: 813387.109 ms, per step time: 325.095 ms
epoch: 63 step: 2502, loss is 1.7643375
epoch time: 813347.102 ms, per step time: 325.075 ms
...
```
# 评估过程
## 用法
### Ascend处理器环境运行
```Shell
# 评估
用法sh run_eval.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
### GPU处理器环境运行
```shell
# 运行评估示例
用法sh run_eval_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
## 结果
评估结果可以在当前脚本路径下的日志中找到如下结果:
- 使用cifar10数据集评估ResNetv2_50
```text
result: {'top_5_accuracy': 0.9988982371794872, 'top_1_accuracy': 0.9502283653846154}
```
- 使用imagenet2012数据集评估ResNetv2_50
```text
result: {'top_1_accuracy': 0.7606515786082474, 'top_5_accuracy': 0.9271504510309279}
```
# 模型描述
## 性能
### 评估性能
#### Cifar10上的ResNetv2_50
| 参数 | Ascend 910 |
|---|---|
| 模型版本 | ResNetv2_50 |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 |2021-03-24 ; |
| MindSpore版本 | 1.2.0 |
| 数据集 | Cifar10 |
| 训练参数 | epoch=135, steps per epoch=195, batch_size=32 |
| 优化器 | Momentum |
| 损失函数 |Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.0007279 |
|速度|24.3毫秒/步8卡 |
|总时长 | 10分钟 |
| 微调检查点 | 188.36M.ckpt文件 |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnetv2) |
#### ImageNet2012上的Resnetv2_50
| 参数 | Ascend 910 |
| ------------- | ------------------------------------------------------------ |
| 模型版本 | ResNetv2_50 |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 | 2021-05-6 ; |
| MindSpore版本 | 1.2.0 |
| 数据集 | ImageNet2012 |
| 训练参数 | epoch=90, steps per epoch=2502, batch_size=64 |
| 优化器 | Momentum |
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 1.8290355 |
| 速度 | 325毫秒/步8卡 |
| 总时长 | 20.3小时 |
| 微调检查点 | 195.9M.ckpt文件 |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnetv2) |
# 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子同时还使用了train.py中的随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" eval.py """
import os
import argparse
from mindspore import context
from mindspore.common import set_seed
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.CrossEntropySmooth import CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default='resnetv2_50',
help='Resnetv2 Model, resnetv2_50, resnetv2_101, resnetv2_152')
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset, cifar10, imagenet2012')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default="../cifar-10/cifar-10-verify-bin",
help='Dataset path.')
parser.add_argument('--checkpoint_path', type=str, default="./checkpoint/train_resnetv2_cifar10-100_1562.ckpt",
help='Checkpoint file path.')
args_opt = parser.parse_args()
# import net
if args_opt.net == "resnetv2_50":
from src.resnetv2 import PreActResNet50 as resnetv2
elif args_opt.net == 'resnetv2_101':
from src.resnetv2 import PreActResNet101 as resnetv2
elif args_opt.net == 'resnetv2_152':
from src.resnetv2 import PreActResNet152 as resnetv2
# import dataset
if args_opt.dataset == "cifar10":
from src.dataset import create_dataset1 as create_dataset
elif args_opt.dataset == "cifar100":
from src.dataset import create_dataset2 as create_dataset
elif args_opt.dataset == 'imagenet2012':
from src.dataset import create_dataset3 as create_dataset
# import config
if args_opt.net == "resnetv2_50" or args_opt.net == "resnetv2_101" or args_opt.net == "resnetv2_152":
if args_opt.dataset == "cifar10":
from src.config import config1 as config
elif args_opt.dataset == 'cifar100':
from src.config import config2 as config
elif args_opt.dataset == 'imagenet2012':
from src.config import config3 as config
set_seed(1)
try:
device_id = int(os.getenv('DEVICE_ID'))
except TypeError:
device_id = 0
context.set_context(device_id=device_id)
if __name__ == '__main__':
print("============== Starting Evaluating ==============")
print(f"start evaluating {args_opt.net} on device {device_id}")
# init context
target = args_opt.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False)
step_size = dataset.get_dataset_size()
# define net
net = resnetv2(class_num=config.class_num)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction='mean',
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
# eval model
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,72 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert ckpt to air."""
import argparse
import numpy as np
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
parser = argparse.ArgumentParser(description='resnet export')
parser.add_argument('--net', type=str, default='resnetv2_50',
help='Resnetv2 Model, resnetv2_50, resnetv2_101, resnetv2_152')
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset, cifar10, cifar100, imagenet2012')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="resnetv2", help="output file name.")
parser.add_argument('--width', type=int, default=32, help='input width')
parser.add_argument('--height', type=int, default=32, help='input height')
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
# import net
if args.net == "resnetv2_50":
from src.resnetv2 import PreActResNet50 as resnetv2
elif args.net == 'resnetv2_101':
from src.resnetv2 import PreActResNet101 as resnetv2
elif args.net == 'resnetv2_152':
from src.resnetv2 import PreActResNet152 as resnetv2
else:
raise ValueError("network is not support.")
# import config
if args.net == "resnetv2_50" or args.net == "resnetv2_101" or args.net == "resnetv2_152":
if args.dataset == "cifar10":
from src.config import config1 as config
elif args.dataset == 'cifar100':
from src.config import config2 as config
elif args.dataset == 'imagenet2012':
raise ValueError("ImageNet2012 dataset not yet supported")
else:
raise ValueError("dataset is not support.")
else:
raise ValueError("network is not support.")
net = resnetv2(config.class_num)
assert args.ckpt_file is not None, "checkpoint_path is None."
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]
then
echo "Usage: bash run_distribute_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
then
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
if [ ! -f $3 ]
then
echo "error: RANK_TABLE_FILE=$3 is not a file"
exit 1
fi
if [ ! -d $4 ]
then
echo "error: DATASET_PATH=$4 is not a directory"
exit 1
fi
ulimit -u unlimited
export RANK_TABLE_FILE=$3
export DEVICE_NUM=8
export RANK_SIZE=8
for((i=0; i<${RANK_SIZE}; i++))
do
export DEVICE_ID=${i}
export RANK_ID=${i}
echo "start distributed training for rank $RANK_ID, device $DEVICE_ID"
python train.py --net $1 --dataset $2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path $4 &> log.$i &
done

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash run_distribute_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]"
exit 1
fi
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
then
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
if [ ! -d $3 ]
then
echo "error: DATASET_PATH=$3 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --net=$1 --dataset=$2 --run_distribute=True \
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$3 &> log &

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
then
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
if [ ! -d $3 ]
then
echo "error: DATASET_PATH=$3 is not a directory"
exit 1
fi
if [ ! -f $4 ]
then
echo "error: CHECKPOINT_PATH=$4 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
echo "start evaluation for device $DEVICE_ID"
python eval.py --net=$1 --dataset=$2 --dataset_path=$3 --checkpoint_path=$4 &> eval.log &

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
then
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
if [ ! -d $3 ]
then
echo "error: DATASET_PATH=$3 is not a directory"
exit 1
fi
if [ ! -f $4 ]
then
echo "error: CHECKPOINT_PATH=$4 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
echo "start evaluation for device $DEVICE_ID"
python eval.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$3 --checkpoint_path=$4 &> eval.log &

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash run_standalone_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]"
exit 1
fi
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
then
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
if [ ! -d $3 ]
then
echo "error: DATASET_PATH=$4 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_ID=0
export DEVICE_NUM=1
export RANK_ID=0
export RANK_SIZE=1
echo "start training for device $DEVICE_ID"
python train.py --net $1 --dataset $2 --device_num=$DEVICE_NUM --dataset_path $3 &> log.$DEVICE_ID &

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash run_standalone_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]"
exit 1
fi
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
then
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
if [ ! -d $3 ]
then
echo "error: DATASET_PATH=$4 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_ID=2
export DEVICE_NUM=1
export RANK_ID=0
export RANK_SIZE=1
echo "start training for device $DEVICE_ID"
python train.py --net $1 --dataset $2 --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path $3 &> log.$DEVICE_ID &

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" CrossEntropySmooth.py """
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -0,0 +1,81 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" config.py """
from easydict import EasyDict as ed
# config for ResNetv2, cifar10
config1 = ed({
"class_num": 10,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 200,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./checkpoint",
"low_memory": False,
"warmup_epochs": 5,
"lr_decay_mode": "cosine",
"lr_init": 0.1,
"lr_end": 0.0000000005,
"lr_max": 0.1,
})
# config for ResNetv2, cifar100
config2 = ed({
"class_num": 100,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 100,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./checkpoint",
"low_memory": False,
"warmup_epochs": 5,
"lr_decay_mode": "cosine",
"lr_init": 0.1,
"lr_end": 0.0000000005,
"lr_max": 0.1,
})
# config for ResNetv2, imagenet2012
config3 = ed({
"class_num": 1001,
"batch_size": 64,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 90,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./checkpoint",
"low_memory": True,
"warmup_epochs": 0,
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_decay_mode": "cosine",
"lr_init": 0.05,
"lr_end": 0.0000001,
"lr_max": 0.05,
})

View File

@ -0,0 +1,213 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" dataset.py """
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset1(dataset_path, do_train=True, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
"""
create a train or evaluate cifar10 dataset for PreActResnet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
distribute(bool): data for distribute or not. Default: False
Returns:
dataset
"""
if target == "Ascend":
rank_size, rank_id = _get_rank_info()
else:
if distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
else:
rank_size = 1
if rank_size == 1:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
# define map operations
trans = []
if do_train:
trans += [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip()
]
trans += [
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def create_dataset2(dataset_path, do_train=True, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
"""
create a train or evaluate cifar100 dataset for PreActResnet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
distribute(bool): data for distribute or not. Default: False
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
device_num = 1
if device_num == 1:
data_set = ds.Cifar100Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.Cifar100Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
# define map operations
trans = []
if do_train:
trans += [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip()
]
trans += [
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=type_cast_op, input_columns="fine_label", num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
def del_column(col1, col2, col3, batch_info):
return(col1, col2,)
# apply batch operations
data_set = data_set.batch(batch_size, per_batch_map=del_column,
input_columns=['image', 'fine_label', 'coarse_label'],
output_columns=['image', 'label'],
drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def create_dataset3(dataset_path, do_train=True, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
"""
create a train or eval imagenet2012 dataset for PreActResnet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
distribute(bool): data for distribute or not. Default: False
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
if distribute:
init()
rank_id = get_rank()
device_num = get_group_size()
else:
device_num = 1
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

View File

@ -0,0 +1,206 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
""" lr_generator.py """
import math
import numpy as np
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
"""
Applies three steps decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
return lr_each_step
def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies polynomial decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
lr_each_step = []
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
return lr_each_step
def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies cosine decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
return lr_each_step
def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies liner decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
return lr_each_step
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default)
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps)
elif lr_decay_mode == 'poly':
lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
elif lr_decay_mode == 'cosine':
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
else:
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):
"""
generate learning rate array with cosine
Args:
lr(float): base learning rate
steps_per_epoch(int): steps size of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epochs of training
global_step(int): the current start index of lr array
Returns:
np.array, learning rate array
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = base_lr * decayed
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[global_step:]
return learning_rate

View File

@ -0,0 +1,160 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
""" resnetv2.py """
import mindspore.nn as nn
from mindspore.ops import operations as P
class PreActBottleNeck(nn.Cell):
""" PreActBottleNeck """
expansion = 4
def __init__(self,
in_planes,
planes,
stride=1):
super(PreActBottleNeck, self).__init__()
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2d(in_planes, eps=1e-5, momentum=0.9)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-5, momentum=0.9)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, pad_mode='pad')
self.bn3 = nn.BatchNorm2d(planes, eps=1e-5, momentum=0.9)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, stride=1)
self.downtown = False
if stride != 1 or in_planes != self.expansion*planes:
self.downtown = True
self.shortcut = nn.SequentialCell([nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride)])
self.add = P.TensorAdd()
def construct(self, x):
""" construct network """
out = self.bn1(x)
out = self.relu(out)
if self.downtown:
identity = self.shortcut(out)
else:
identity = x
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
out = self.add(out, identity)
return out
class PreActResNet(nn.Cell):
""" PreActResNet """
def __init__(self,
block,
num_blocks,
in_planes,
planes,
strides,
low_memory,
num_classes=10):
super(PreActResNet, self).__init__()
self.in_planes = in_planes
self.low_memory = low_memory
self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2)
self.conv2 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, pad_mode='pad', padding=1)
self.layer1 = self._make_layer(block,
planes=planes[0],
num_blocks=num_blocks[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
planes=planes[1],
num_blocks=num_blocks[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
planes=planes[2],
num_blocks=num_blocks[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
planes=planes[3],
num_blocks=num_blocks[3],
stride=strides[3])
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.linear = nn.Dense(planes[3]*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
layers = []
strides = [stride] + [1]*(num_blocks-1)
for s in strides:
layers.append(block(self.in_planes, planes, s))
self.in_planes = planes * block.expansion
return nn.SequentialCell(layers)
def construct(self, x):
""" construct network """
if self.low_memory:
out = self.conv1(x)
else:
out = self.conv2(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.mean(out, (2, 3))
out = out.view(out.shape[0], -1)
out = self.linear(out)
return out
def PreActResNet50(class_num=10, low_memory=False):
return PreActResNet(PreActBottleNeck,
num_blocks=[3, 4, 6, 3],
in_planes=64,
planes=[64, 128, 256, 512],
strides=[1, 2, 2, 2],
low_memory=low_memory,
num_classes=class_num)
def PreActResNet101(class_num=10, low_memory=False):
return PreActResNet(PreActBottleNeck,
num_blocks=[3, 4, 23, 3],
in_planes=64,
planes=[64, 128, 256, 512],
strides=[1, 2, 2, 2],
low_memory=low_memory,
num_classes=class_num)
def PreActResNet152(class_num=10, low_memory=False):
return PreActResNet(PreActBottleNeck,
num_blocks=[3, 8, 36, 3],
in_planes=64,
planes=[64, 128, 256, 512],
strides=[1, 2, 2, 2],
low_memory=low_memory,
num_classes=class_num)

View File

@ -0,0 +1,152 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" train.py """
import os
import argparse
from mindspore.nn import Momentum
from mindspore.context import ParallelMode
from mindspore import context, Model, load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.common import set_seed
from mindspore.common.tensor import Tensor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification.')
parser.add_argument('--net', type=str, default='resnetv2_50',
help='Resnetv2 Model, resnetv2_50, resnetv2_101, resnetv2_152')
parser.add_argument('--dataset', type=str, default='cifar10',
help='Dataset, cifar10, imagenet2012')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--dataset_path', type=str, default="../cifar-10/cifar-10-batches-bin",
help='Dataset path.')
args_opt = parser.parse_args()
# import net
if args_opt.net == "resnetv2_50":
from src.resnetv2 import PreActResNet50 as resnetv2
elif args_opt.net == 'resnetv2_101':
from src.resnetv2 import PreActResNet101 as resnetv2
elif args_opt.net == 'resnetv2_152':
from src.resnetv2 import PreActResNet152 as resnetv2
# import dataset
if args_opt.dataset == "cifar10":
from src.dataset import create_dataset1 as create_dataset
elif args_opt.dataset == "cifar100":
from src.dataset import create_dataset2 as create_dataset
elif args_opt.dataset == 'imagenet2012':
from src.dataset import create_dataset3 as create_dataset
# import config
if args_opt.net == "resnetv2_50" or args_opt.net == "resnetv2_101" or args_opt.net == "resnetv2_152":
if args_opt.dataset == "cifar10":
from src.config import config1 as config
elif args_opt.dataset == 'cifar100':
from src.config import config2 as config
elif args_opt.dataset == 'imagenet2012':
from src.config import config3 as config
set_seed(1)
if __name__ == '__main__':
print("============== Starting Training ==============")
target = args_opt.device_target
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if args_opt.run_distribute:
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
# init parallel training parameters
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
# init HCCL
init()
else:
init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
try:
device_id = int(os.getenv('DEVICE_ID'))
except TypeError:
device_id = 0
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target, distribute=args_opt.run_distribute)
step_size = dataset.get_dataset_size()
# define net
epoch_size = config.epoch_size
net = resnetv2(config.class_num, config.low_memory)
# init weight
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
# init lr
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
lr = Tensor(lr)
# define loss, opt, model
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_save_dir = config.save_checkpoint_path
ckpoint_cb = ModelCheckpoint(prefix=f"train_{args_opt.net}_{args_opt.dataset}",
directory=ckpt_save_dir, config=config_ck)
# train
if args_opt.run_distribute:
callbacks = [time_cb, loss_cb]
if target == "GPU" and str(get_rank()) == '0':
callbacks = [time_cb, loss_cb, ckpoint_cb]
elif target == "Ascend" and device_id == 0:
callbacks = [time_cb, loss_cb, ckpoint_cb]
else:
callbacks = [time_cb, loss_cb, ckpoint_cb]
model.train(epoch_size, dataset, callbacks=callbacks)