wideresnet
This commit is contained in:
parent
36820f38dc
commit
e8a96b9b81
|
@ -0,0 +1,258 @@
|
||||||
|
# 目录
|
||||||
|
|
||||||
|
<!-- TOC -->
|
||||||
|
|
||||||
|
- [目录](#目录)
|
||||||
|
- [WideResNet描述](#wideresnet描述)
|
||||||
|
- [模型架构](#模型架构)
|
||||||
|
- [数据集](#数据集)
|
||||||
|
- [环境要求](#环境要求)
|
||||||
|
- [快速入门](#快速入门)
|
||||||
|
- [脚本说明](#脚本说明)
|
||||||
|
- [脚本及样例代码](#脚本及样例代码)
|
||||||
|
- [脚本参数](#脚本参数)
|
||||||
|
- [训练过程](#训练过程)
|
||||||
|
- [用法](#用法)
|
||||||
|
- [Ascend处理器环境运行](#ascend处理器环境运行)
|
||||||
|
- [结果](#结果)
|
||||||
|
- [评估过程](#评估过程)
|
||||||
|
- [用法](#用法)
|
||||||
|
- [Ascend处理器环境运行](#ascend处理器环境运行)
|
||||||
|
- [结果](#结果)
|
||||||
|
- [模型描述](#模型描述)
|
||||||
|
- [性能](#性能)
|
||||||
|
- [评估性能](#评估性能)
|
||||||
|
- [cifar10上的WideResNet](#cifar10上的wideresnet)
|
||||||
|
- [随机情况说明](#随机情况说明)
|
||||||
|
- [ModelZoo主页](#modelzoo主页)
|
||||||
|
|
||||||
|
<!-- /TOC -->
|
||||||
|
|
||||||
|
# WideResNet描述
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
szagoruyko在ResNet基础上提出WideResNet,用于解决网络模型瘦长时,只有有限层学到了有用的知识,更多的层对最终结果只做出了很少的贡献。这个问题也被称为diminishing feature reuse,WideResNet作者加宽了残差块,将训练速度提升几倍,精度也有明显改善。
|
||||||
|
|
||||||
|
如下为MindSpore使用cifar10数据集对WideResNet进行训练的示例。
|
||||||
|
|
||||||
|
## 论文
|
||||||
|
|
||||||
|
1. [论文](https://arxiv.org/abs/1605.07146): Sergey Zagoruyko."Wide Residual Netwoks"
|
||||||
|
|
||||||
|
# 模型架构
|
||||||
|
|
||||||
|
WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.07146)
|
||||||
|
|
||||||
|
# 数据集
|
||||||
|
|
||||||
|
使用的数据集:[cifar10](http://www.cs.toronto.edu/~kriz/cifar.html)
|
||||||
|
|
||||||
|
- 数据集大小:共10个类、32*32彩色图像
|
||||||
|
- 训练集:共50,000张图像
|
||||||
|
- 测试集:共10,000张图像
|
||||||
|
- 注:数据在dataset.py中处理。
|
||||||
|
- 下载数据集,目录结构如下:
|
||||||
|
|
||||||
|
```text
|
||||||
|
└─cifar-10-batches-bin
|
||||||
|
├─data_batch_1.bin # 训练数据集
|
||||||
|
├─data_batch_2.bin # 训练数据集
|
||||||
|
├─data_batch_3.bin # 训练数据集
|
||||||
|
├─data_batch_4.bin # 训练数据集
|
||||||
|
├─data_batch_5.bin # 训练数据集
|
||||||
|
└─test_batch.bin # 评估数据集
|
||||||
|
```
|
||||||
|
|
||||||
|
# 环境要求
|
||||||
|
|
||||||
|
- 硬件
|
||||||
|
- 准备Ascend处理器搭建硬件环境。
|
||||||
|
- 框架
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/)
|
||||||
|
- 如需查看详情,请参见如下资源:
|
||||||
|
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
# 快速入门
|
||||||
|
|
||||||
|
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||||
|
|
||||||
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# 分布式训练
|
||||||
|
用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||||
|
|
||||||
|
# 单机训练
|
||||||
|
用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||||
|
|
||||||
|
# 运行评估示例
|
||||||
|
用法:sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
# 脚本说明
|
||||||
|
|
||||||
|
## 脚本及样例代码
|
||||||
|
|
||||||
|
```text
|
||||||
|
└──wideresnet
|
||||||
|
├── README.md
|
||||||
|
├── scripts
|
||||||
|
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||||
|
├── run_eval.sh # 启动Ascend评估
|
||||||
|
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
|
||||||
|
├── src
|
||||||
|
├── config.py # 参数配置
|
||||||
|
├── dataset.py # 数据预处理
|
||||||
|
├── cross_entropy_smooth.py # cifar10数据集的损失定义
|
||||||
|
├── generator_lr.py # 生成每个步骤的学习率
|
||||||
|
├── save_callback.py # 自定义回调函数保存最优ckpt
|
||||||
|
└── wide_resnet.py # WideResNet网络结构
|
||||||
|
├── eval.py # 评估网络
|
||||||
|
├── export.py # 导出网络
|
||||||
|
└── train.py # 训练网络
|
||||||
|
```
|
||||||
|
|
||||||
|
# 脚本参数
|
||||||
|
|
||||||
|
在config.py中可以同时配置训练参数和评估参数。
|
||||||
|
|
||||||
|
- 配置WideResNet和cifar10数据集。
|
||||||
|
|
||||||
|
```Python
|
||||||
|
"num_classes":10, # 数据集类数
|
||||||
|
"batch_size":32, # 输入张量的批次大小
|
||||||
|
"epoch_size":300, # 训练周期大小
|
||||||
|
"save_checkpoint_path":"./", # 检查点相对执行路劲的保存路径
|
||||||
|
"repeat_num":1, # 数据集重复次数
|
||||||
|
"widen_factor":10, # 网络宽度
|
||||||
|
"depth":40, # 网络深度
|
||||||
|
"lr_init":0.1, # 初始学习率
|
||||||
|
"weight_decay":5e-4, # 权重衰减
|
||||||
|
"momentum":0.9, # 动量优化器
|
||||||
|
"loss_scale":32, # 损失等级
|
||||||
|
"save_checkpoint":True, # 是否保存检查点
|
||||||
|
"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
|
||||||
|
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
|
||||||
|
"use_label_smooth":True, # 标签平滑
|
||||||
|
"label_smooth_factor":0.1, # 标签平滑因子
|
||||||
|
"pretrain_epoch_size":0, # 预训练周期
|
||||||
|
"warmup_epochs":5, # 热身周期
|
||||||
|
```
|
||||||
|
|
||||||
|
# 训练过程
|
||||||
|
|
||||||
|
## 用法
|
||||||
|
|
||||||
|
## Ascend处理器环境运行
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# 分布式训练
|
||||||
|
用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||||
|
|
||||||
|
# 单机训练
|
||||||
|
用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||||
|
|
||||||
|
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
|
||||||
|
|
||||||
|
训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。
|
||||||
|
|
||||||
|
## 结果
|
||||||
|
|
||||||
|
- 使用cifar10数据集训练WideResNet
|
||||||
|
|
||||||
|
```text
|
||||||
|
# 分布式训练结果(8P)
|
||||||
|
epoch: 2 step: 195, loss is 1.4352043
|
||||||
|
epoch: 2 step: 195, loss is 1.4611206
|
||||||
|
epoch: 2 step: 195, loss is 1.2635705
|
||||||
|
epoch: 2 step: 195, loss is 1.3457444
|
||||||
|
epoch: 2 step: 195, loss is 1.4664338
|
||||||
|
epoch: 2 step: 195, loss is 1.3559061
|
||||||
|
epoch: 2 step: 195, loss is 1.5225968
|
||||||
|
epoch: 2 step: 195, loss is 1.246567
|
||||||
|
epoch: 3 step: 195, loss is 1.0763402
|
||||||
|
epoch: 3 step: 195, loss is 1.3007892
|
||||||
|
epoch: 3 step: 195, loss is 1.2473519
|
||||||
|
epoch: 3 step: 195, loss is 1.3249974
|
||||||
|
epoch: 3 step: 195, loss is 1.3388557
|
||||||
|
epoch: 3 step: 195, loss is 1.2402486
|
||||||
|
epoch: 3 step: 195, loss is 1.2878766
|
||||||
|
epoch: 3 step: 195, loss is 1.1507874
|
||||||
|
epoch: 4 step: 195, loss is 1.014946
|
||||||
|
epoch: 4 step: 195, loss is 1.1934564
|
||||||
|
epoch: 4 step: 195, loss is 0.9506259
|
||||||
|
epoch: 4 step: 195, loss is 1.2101849
|
||||||
|
epoch: 4 step: 195, loss is 1.0160742
|
||||||
|
epoch: 4 step: 195, loss is 1.2643425
|
||||||
|
epoch: 4 step: 195, loss is 1.3422029
|
||||||
|
epoch: 4 step: 195, loss is 1.221174
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
# 评估过程
|
||||||
|
|
||||||
|
## 用法
|
||||||
|
|
||||||
|
### Ascend处理器环境运行
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# 评估
|
||||||
|
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# 评估示例
|
||||||
|
sh run_eval.sh /cifar10 WideResNet_best.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
训练过程中可以生成检查点。
|
||||||
|
|
||||||
|
## 结果
|
||||||
|
|
||||||
|
评估结果保存在示例路径中,文件夹名为“eval”。您可在此路径下的日志找到如下结果:
|
||||||
|
|
||||||
|
- 使用cifar10数据集评估WideResNet
|
||||||
|
|
||||||
|
```text
|
||||||
|
result: {'top_1_accuracy': 0.9622395833333334}
|
||||||
|
```
|
||||||
|
|
||||||
|
# 模型描述
|
||||||
|
|
||||||
|
## 性能
|
||||||
|
|
||||||
|
### 评估性能
|
||||||
|
|
||||||
|
#### cifar10上的WideResNet
|
||||||
|
|
||||||
|
| 参数 | Ascend 910 |
|
||||||
|
|---|---|
|
||||||
|
| 模型版本 | WideResNet |
|
||||||
|
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G |
|
||||||
|
| 上传日期 |2021-05-20 ; |
|
||||||
|
| MindSpore版本 | 1.2.1 |
|
||||||
|
| 数据集 | cifar10 |
|
||||||
|
| 训练参数 | epoch=300, steps per epoch=195, batch_size = 32 |
|
||||||
|
| 优化器 | Momentum |
|
||||||
|
| 损失函数 |Softmax交叉熵 |
|
||||||
|
| 输出 | 概率 |
|
||||||
|
| 损失 | 0.545541 |
|
||||||
|
|速度|65.2毫秒/步(8卡) |
|
||||||
|
|总时长 | 70分钟 |
|
||||||
|
|参数(M) | 52.1 |
|
||||||
|
| 微调检查点 | 426.49M(.ckpt文件) |
|
||||||
|
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/wideresnet) |
|
||||||
|
|
||||||
|
# 随机情况说明
|
||||||
|
|
||||||
|
dataset.py中设置了“create_dataset”函数内的种子,同时还使用了train.py中的随机种子。
|
||||||
|
|
||||||
|
# ModelZoo主页
|
||||||
|
|
||||||
|
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
##############test WideResNet example on cifar10#################
|
||||||
|
python eval.py
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.cross_entropy_smooth import CrossEntropySmooth
|
||||||
|
from src.wide_resnet import wideresnet
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.config import config_WideResnet as cfg
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Ascend WideResNet CIFAR10 Eval')
|
||||||
|
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
|
||||||
|
parser.add_argument('--ckpt_url', type=str, default=None, help='location of ckpt')
|
||||||
|
parser.add_argument('--modelart', required=True, type=ast.literal_eval, default=False,
|
||||||
|
help='training on modelart or not, default is False')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
target = 'Ascend'
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False,
|
||||||
|
device_id=int(os.environ["DEVICE_ID"]))
|
||||||
|
|
||||||
|
data_path = '/cache/data_path'
|
||||||
|
|
||||||
|
if args.modelart:
|
||||||
|
import moxing as mox
|
||||||
|
mox.file.copy_parallel(src_url=args.data_url, dst_url=data_path)
|
||||||
|
else:
|
||||||
|
data_path = args.data_url
|
||||||
|
|
||||||
|
ds_eval = create_dataset(dataset_path=data_path,
|
||||||
|
do_train=False,
|
||||||
|
repeat_num=cfg.repeat_num,
|
||||||
|
batch_size=cfg.batch_size)
|
||||||
|
|
||||||
|
net = wideresnet()
|
||||||
|
|
||||||
|
ckpt_path = '/cache/ckpt_path/'
|
||||||
|
if args.modelart:
|
||||||
|
import moxing as mox
|
||||||
|
mox.file.copy_parallel(args.ckpt_url, dst_url=ckpt_path)
|
||||||
|
param_dict = load_checkpoint('/cache/ckpt_path/WideResNet_best.ckpt')
|
||||||
|
else:
|
||||||
|
param_dict = load_checkpoint(args.ckpt_url)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net.set_train(False)
|
||||||
|
|
||||||
|
if not cfg.use_label_smooth:
|
||||||
|
cfg.label_smooth_factor = 0.0
|
||||||
|
loss = CrossEntropySmooth(sparse=True, reduction='mean',
|
||||||
|
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||||
|
|
||||||
|
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy'})
|
||||||
|
|
||||||
|
output = model.eval(ds_eval)
|
||||||
|
|
||||||
|
print("result:", output)
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
##############export checkpoint file into air, onnx, mindir models#################
|
||||||
|
python export.py
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||||
|
|
||||||
|
from src.wide_resnet import wideresnet
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='WideResNet export')
|
||||||
|
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
||||||
|
parser.add_argument('--data_url', default=None, help='Directory contains cifar10 dataset.')
|
||||||
|
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file')
|
||||||
|
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file name.")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||||
|
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
context.set_context(device_id=int(os.environ["DEVICE_ID"]))
|
||||||
|
|
||||||
|
if args.run_modelart:
|
||||||
|
import moxing as mox
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
local_output_url = '/cache/ckpt' + str(device_id)
|
||||||
|
mox.file.copy_parallel(src_url=os.path.join(args.train_url, args.ckpt_file),
|
||||||
|
dst_url=os.path.join(local_output_url, args.ckpt_file))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
net = wideresnet()
|
||||||
|
|
||||||
|
param_dict = load_checkpoint(os.path.join(local_output_url, args.ckpt_file))
|
||||||
|
print('load ckpt')
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
print('load ckpt to net')
|
||||||
|
net.set_train(False)
|
||||||
|
input_arr = Tensor(np.ones([args.batch_size, 3, 32, 32]), mstype.float32)
|
||||||
|
print('input')
|
||||||
|
export(net, input_arr, file_name="WideResNet", file_format=args.file_format)
|
||||||
|
if args.run_modelart:
|
||||||
|
file_name = "WideResNet." + args.file_format.lower()
|
||||||
|
mox.file.copy_parallel(src_url=file_name,
|
||||||
|
dst_url=os.path.join(args.train_url, file_name))
|
|
@ -0,0 +1,74 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
PATH2=$(get_real_path $2)
|
||||||
|
PATH3=$(get_real_path $3)
|
||||||
|
PATH4=$4
|
||||||
|
echo "$PATH1"
|
||||||
|
echo "$PATH2"
|
||||||
|
echo "$PATH3"
|
||||||
|
echo "$PATH4"
|
||||||
|
|
||||||
|
if [ ! -d $PATH2 ]
|
||||||
|
then
|
||||||
|
echo "error: DATA_URL=$PATH2 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d $PATH3 ]
|
||||||
|
then
|
||||||
|
echo "error: CKPT_URL=$PATH3 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
export RANK_SIZE=8
|
||||||
|
export RANK_TABLE_FILE=$PATH1
|
||||||
|
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
|
||||||
|
|
||||||
|
DATA_URL=$2
|
||||||
|
export DATA_URL=${DATA_URL}
|
||||||
|
|
||||||
|
for((i=0;i<${RANK_SIZE};i++))
|
||||||
|
do
|
||||||
|
rm -rf device$i
|
||||||
|
mkdir device$i
|
||||||
|
cp ../*.py ./device$i
|
||||||
|
cp *.sh ./device$i
|
||||||
|
cp -r ../src ./device$i
|
||||||
|
cd ./device$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
echo "start training for device $i"
|
||||||
|
env > env$i.log
|
||||||
|
|
||||||
|
if [ $# == 4 ]
|
||||||
|
then
|
||||||
|
python train.py --data_url=$PATH2 --ckpt_url=$PATH3 --modelart=$PATH4 &> train.log &
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd ../
|
||||||
|
done
|
|
@ -0,0 +1,69 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
#if [$# != 3]
|
||||||
|
#then
|
||||||
|
#echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
|
||||||
|
#exit 1
|
||||||
|
#fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
PATH2=$(get_real_path $2)
|
||||||
|
PATH3=$3
|
||||||
|
|
||||||
|
echo "$PATH1"
|
||||||
|
echo "$PATH2"
|
||||||
|
echo "$PATH3"
|
||||||
|
|
||||||
|
if [ ! -d $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: DATA_URL=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PATH2 ]
|
||||||
|
then
|
||||||
|
echo "error: CKPT_URL=$PATH2 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp *.sh ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cd ./eval
|
||||||
|
env > env.log
|
||||||
|
echo "start evaluation for device $DEVICE_ID"
|
||||||
|
python eval.py --data_url=$PATH1 --ckpt_url=$PATH2 --modelart=$PATH3 &> eval.log &
|
||||||
|
cd ..
|
|
@ -0,0 +1,77 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 3 ]
|
||||||
|
then
|
||||||
|
echo "Usage: bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
if [ $# == 3 ]
|
||||||
|
then
|
||||||
|
PATH2=$(get_real_path $2)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: DATA_URL=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PATH2 ]
|
||||||
|
then
|
||||||
|
echo "error: CKPT_URL=$PATH2 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
PATH3=$3
|
||||||
|
|
||||||
|
echo "$PATH1"
|
||||||
|
echo "$PATH2"
|
||||||
|
echo "$PATH3"
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=6
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ../*.py ./train
|
||||||
|
cp *.sh ./train
|
||||||
|
cp -r ../src ./train
|
||||||
|
cd ./train
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
|
||||||
|
if [ $# == 3 ]
|
||||||
|
then
|
||||||
|
python train.py --data_url=$PATH1 --ckpt_url=$PATH2 --modelart=$PATH3 &> train.log &
|
||||||
|
fi
|
||||||
|
cd ..
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Config parameters for WideResNet models."""
|
||||||
|
|
||||||
|
|
||||||
|
class Config_WideResNet:
|
||||||
|
"""
|
||||||
|
Config parameters for the WideResNet.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Config_WideResNet()
|
||||||
|
"""
|
||||||
|
num_classes = 10
|
||||||
|
batch_size = 32
|
||||||
|
epoch_size = 300
|
||||||
|
save_checkpoint_path = "./"
|
||||||
|
repeat_num = 1
|
||||||
|
widen_factor = 10
|
||||||
|
depth = 40
|
||||||
|
lr_init = 0.1
|
||||||
|
weight_decay = 5e-4
|
||||||
|
momentum = 0.9
|
||||||
|
loss_scale = 32
|
||||||
|
save_checkpoint = True
|
||||||
|
save_checkpoint_epochs = 5
|
||||||
|
keep_checkpoint_max = 10
|
||||||
|
use_label_smooth = True
|
||||||
|
label_smooth_factor = 0.1
|
||||||
|
pretrain_epoch_size = 0
|
||||||
|
warmup_epochs = 5
|
||||||
|
|
||||||
|
|
||||||
|
config_WideResnet = Config_WideResNet()
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""define loss function for network"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropySmooth(_Loss):
|
||||||
|
"""CrossEntropy"""
|
||||||
|
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||||
|
super(CrossEntropySmooth, self).__init__()
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
self.sparse = sparse
|
||||||
|
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||||
|
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||||
|
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||||
|
|
||||||
|
def construct(self, logit, label):
|
||||||
|
if self.sparse:
|
||||||
|
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||||
|
loss = self.ce(logit, label)
|
||||||
|
return loss
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Data operations, will be used in train.py and eval.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C2
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||||
|
"""
|
||||||
|
create a train or evaluate cifar10 dataset for WideResnet
|
||||||
|
Args:
|
||||||
|
dataset_path(string): the path of dataset.
|
||||||
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1
|
||||||
|
batch_size(int): the batch size of dataset. Default: 32
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
|
||||||
|
if do_train:
|
||||||
|
dataset_path = os.path.join(dataset_path, 'train')
|
||||||
|
else:
|
||||||
|
dataset_path = os.path.join(dataset_path, 'eval')
|
||||||
|
|
||||||
|
if device_num == 1:
|
||||||
|
ds = de.Cifar10Dataset(dataset_path)
|
||||||
|
else:
|
||||||
|
if do_train:
|
||||||
|
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||||
|
num_shards=device_num, shard_id=device_id)
|
||||||
|
else:
|
||||||
|
ds = de.Cifar10Dataset(dataset_path)
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
trans = []
|
||||||
|
if do_train:
|
||||||
|
trans += [
|
||||||
|
C.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||||
|
C.RandomHorizontalFlip(prob=0.5)
|
||||||
|
]
|
||||||
|
|
||||||
|
trans += [
|
||||||
|
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
|
||||||
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||||
|
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||||
|
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
ds = ds.repeat(repeat_num)
|
||||||
|
|
||||||
|
return ds
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""generate learning rate"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(total_epochs,
|
||||||
|
steps_per_epoch,
|
||||||
|
lr_init
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
generate learning rate
|
||||||
|
"""
|
||||||
|
lr_each_step = []
|
||||||
|
total_steps = steps_per_epoch * total_epochs
|
||||||
|
for i in range(int(total_steps)):
|
||||||
|
if i <= int(60 * steps_per_epoch):
|
||||||
|
lr = lr_init
|
||||||
|
elif i <= int(120 * steps_per_epoch):
|
||||||
|
lr = lr_init * 0.1 + 0.01
|
||||||
|
elif i <= int(160 * steps_per_epoch):
|
||||||
|
lr = lr_init * 0.1 * 0.1 + 0.003
|
||||||
|
elif i <= int(200 * steps_per_epoch):
|
||||||
|
lr = 0.001
|
||||||
|
elif i <= int(240 * steps_per_epoch):
|
||||||
|
lr = 0.0008
|
||||||
|
elif i <= int(260 * steps_per_epoch):
|
||||||
|
lr = 0.0006
|
||||||
|
lr_each_step.append(lr)
|
||||||
|
|
||||||
|
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||||
|
|
||||||
|
return lr_each_step
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""save best ckpt"""
|
||||||
|
|
||||||
|
from mindspore.train.serialization import save_checkpoint
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
from src.config import config_WideResnet as cfg
|
||||||
|
|
||||||
|
|
||||||
|
class SaveCallback(Callback):
|
||||||
|
"""
|
||||||
|
save best ckpt
|
||||||
|
"""
|
||||||
|
def __init__(self, model, eval_dataset, ckpt_path, modelart):
|
||||||
|
super(SaveCallback, self).__init__()
|
||||||
|
self.model = model
|
||||||
|
self.eval_dataset = eval_dataset
|
||||||
|
self.cpkt_path = ckpt_path
|
||||||
|
self.acc = 0.96
|
||||||
|
self.cur_acc = 0.0
|
||||||
|
self.modelart = modelart
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""
|
||||||
|
step end
|
||||||
|
"""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
result = self.model.eval(self.eval_dataset)
|
||||||
|
self.cur_acc = result['accuracy']
|
||||||
|
print("cur_acc is", self.cur_acc)
|
||||||
|
|
||||||
|
if result['accuracy'] > self.acc:
|
||||||
|
self.acc = result['accuracy']
|
||||||
|
file_name = "WideResNet_best" + ".ckpt"
|
||||||
|
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
|
||||||
|
if self.modelart:
|
||||||
|
import moxing as mox
|
||||||
|
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.cpkt_path)
|
||||||
|
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
|
|
@ -0,0 +1,124 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""WideResNet"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops as ops
|
||||||
|
|
||||||
|
|
||||||
|
class WideBasic(nn.Cell):
|
||||||
|
"""
|
||||||
|
WideBasic
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, stride=1):
|
||||||
|
super(WideBasic, self).__init__()
|
||||||
|
|
||||||
|
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=True)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||||
|
self.dropout = nn.Dropout(keep_prob=0.7)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, has_bias=True)
|
||||||
|
|
||||||
|
self.shortcut = nn.SequentialCell()
|
||||||
|
|
||||||
|
if in_channels != out_channels or stride != 1:
|
||||||
|
self.shortcut = nn.SequentialCell(
|
||||||
|
[nn.Conv2d(in_channels, out_channels, 1, stride=stride, has_bias=True)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""
|
||||||
|
basic construct
|
||||||
|
"""
|
||||||
|
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
shortcut = self.shortcut(identity)
|
||||||
|
|
||||||
|
return x + shortcut
|
||||||
|
|
||||||
|
|
||||||
|
class WideResNet(nn.Cell):
|
||||||
|
"""
|
||||||
|
WideReNet
|
||||||
|
"""
|
||||||
|
def __init__(self, num_classes, block, depth=50, widen_factor=1):
|
||||||
|
"""
|
||||||
|
classes, block, depth, widen_factor
|
||||||
|
"""
|
||||||
|
super(WideResNet, self).__init__()
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
k = widen_factor
|
||||||
|
n = int((depth - 4) / 6)
|
||||||
|
self.in_channels = 16
|
||||||
|
self.conv1 = nn.Conv2d(3, self.in_channels, 3, 1, padding=0, pad_mode='same')
|
||||||
|
self.conv2 = self._make_layer(block, 16 * k, n, 1)
|
||||||
|
self.conv3 = self._make_layer(block, 32 * k, n, 2)
|
||||||
|
self.conv4 = self._make_layer(block, 64 * k, n, 2)
|
||||||
|
self.bn = nn.BatchNorm2d(64 * k, momentum=0.9)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.mean = ops.ReduceMean(keep_dims=True)
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
self.linear = nn.Dense(64 * k, num_classes, has_bias=True)
|
||||||
|
|
||||||
|
self.bn1 = nn.BatchNorm2d(16)
|
||||||
|
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""
|
||||||
|
WideResNet construct
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = self.conv4(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.mean(x, (2, 3))
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.linear(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||||
|
"""
|
||||||
|
make layer
|
||||||
|
"""
|
||||||
|
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for s in strides:
|
||||||
|
layers.append(block(self.in_channels, out_channels, s))
|
||||||
|
self.in_channels = out_channels
|
||||||
|
|
||||||
|
return nn.SequentialCell(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
def wideresnet(depth=40, widen_factor=10):
|
||||||
|
net = WideResNet(10, WideBasic, depth=depth, widen_factor=widen_factor)
|
||||||
|
return net
|
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
#################train WideResNet example on cifar10########################
|
||||||
|
python train.py
|
||||||
|
"""
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn.optim import Momentum
|
||||||
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
|
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.initializer as weight_init
|
||||||
|
|
||||||
|
from src.wide_resnet import wideresnet
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.config import config_WideResnet as cfg
|
||||||
|
from src.generator_lr import get_lr
|
||||||
|
from src.cross_entropy_smooth import CrossEntropySmooth
|
||||||
|
from src.save_callback import SaveCallback
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
parser = argparse.ArgumentParser(description='Ascend WideResnet+CIFAR10 Training')
|
||||||
|
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
|
||||||
|
parser.add_argument('--ckpt_url', required=True, default=None, help='Location of ckpt.')
|
||||||
|
parser.add_argument('--modelart', required=True, type=ast.literal_eval, default=False,
|
||||||
|
help='training on modelart or not, default is False')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
target = "Ascend"
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False,
|
||||||
|
device_id=device_id)
|
||||||
|
|
||||||
|
if device_num > 1:
|
||||||
|
init()
|
||||||
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||||
|
dataset_sink_mode = True
|
||||||
|
|
||||||
|
if args.modelart:
|
||||||
|
import moxing as mox
|
||||||
|
data_path = '/cache/data_path'
|
||||||
|
mox.file.copy_parallel(src_url=args.data_url, dst_url=data_path)
|
||||||
|
else:
|
||||||
|
data_path = args.data_url
|
||||||
|
|
||||||
|
ds_train = create_dataset(dataset_path=data_path,
|
||||||
|
do_train=True,
|
||||||
|
batch_size=cfg.batch_size)
|
||||||
|
ds_eval = create_dataset(dataset_path=data_path,
|
||||||
|
do_train=False,
|
||||||
|
batch_size=cfg.batch_size)
|
||||||
|
step_size = ds_train.get_dataset_size()
|
||||||
|
|
||||||
|
net = wideresnet()
|
||||||
|
|
||||||
|
for _, cell in net.cells_and_names():
|
||||||
|
if isinstance(cell, nn.Conv2d):
|
||||||
|
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(gain=np.sqrt(2)),
|
||||||
|
cell.weight.shape,
|
||||||
|
cell.weight.dtype))
|
||||||
|
|
||||||
|
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||||
|
smooth_factor=cfg.label_smooth_factor,
|
||||||
|
num_classes=cfg.num_classes)
|
||||||
|
loss_scale = FixedLossScaleManager(loss_scale=cfg.loss_scale, drop_overflow_update=False)
|
||||||
|
|
||||||
|
lr = get_lr(total_epochs=cfg.epoch_size, steps_per_epoch=step_size, lr_init=cfg.lr_init)
|
||||||
|
lr = Tensor(lr)
|
||||||
|
|
||||||
|
decayed_params = []
|
||||||
|
no_decayed_params = []
|
||||||
|
for param in net.trainable_params():
|
||||||
|
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||||
|
decayed_params.append(param)
|
||||||
|
else:
|
||||||
|
no_decayed_params.append(param)
|
||||||
|
|
||||||
|
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||||
|
{'params': no_decayed_params},
|
||||||
|
{'order_params': net.trainable_params()}]
|
||||||
|
opt = Momentum(group_params,
|
||||||
|
learning_rate=lr,
|
||||||
|
momentum=cfg.momentum,
|
||||||
|
loss_scale=cfg.loss_scale,
|
||||||
|
use_nesterov=True,
|
||||||
|
weight_decay=cfg.weight_decay)
|
||||||
|
|
||||||
|
model = Model(net,
|
||||||
|
amp_level="O2",
|
||||||
|
loss_fn=loss,
|
||||||
|
optimizer=opt,
|
||||||
|
loss_scale_manager=loss_scale,
|
||||||
|
metrics={'accuracy'},
|
||||||
|
keep_batchnorm_fp32=False
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
time_cb = TimeMonitor()
|
||||||
|
cb = [loss_cb, time_cb]
|
||||||
|
ckpt_path = args.ckpt_url
|
||||||
|
cb += [SaveCallback(model, ds_eval, ckpt_path, args.modelart)]
|
||||||
|
|
||||||
|
model.train(epoch=cfg.epoch_size, train_dataset=ds_train, callbacks=cb,
|
||||||
|
dataset_sink_mode=dataset_sink_mode)
|
Loading…
Reference in New Issue