wideresnet

This commit is contained in:
yangwm 2021-07-20 20:54:19 +08:00
parent 36820f38dc
commit e8a96b9b81
13 changed files with 1134 additions and 0 deletions

View File

@ -0,0 +1,258 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [WideResNet描述](#wideresnet描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#用法)
- [Ascend处理器环境运行](#ascend处理器环境运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [用法](#用法)
- [Ascend处理器环境运行](#ascend处理器环境运行)
- [结果](#结果)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [cifar10上的WideResNet](#cifar10上的wideresnet)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# WideResNet描述
## 概述
szagoruyko在ResNet基础上提出WideResNet用于解决网络模型瘦长时只有有限层学到了有用的知识更多的层对最终结果只做出了很少的贡献。这个问题也被称为diminishing feature reuseWideResNet作者加宽了残差块将训练速度提升几倍精度也有明显改善。
如下为MindSpore使用cifar10数据集对WideResNet进行训练的示例。
## 论文
1. [论文](https://arxiv.org/abs/1605.07146): Sergey Zagoruyko."Wide Residual Netwoks"
# 模型架构
WideResNet的总体网络架构如下[链接](https://arxiv.org/abs/1605.07146)
# 数据集
使用的数据集:[cifar10](http://www.cs.toronto.edu/~kriz/cifar.html)
- 数据集大小共10个类、32*32彩色图像
- 训练集共50,000张图像
- 测试集共10,000张图像
- 注数据在dataset.py中处理。
- 下载数据集,目录结构如下:
```text
└─cifar-10-batches-bin
├─data_batch_1.bin # 训练数据集
├─data_batch_2.bin # 训练数据集
├─data_batch_3.bin # 训练数据集
├─data_batch_4.bin # 训练数据集
├─data_batch_5.bin # 训练数据集
└─test_batch.bin # 评估数据集
```
# 环境要求
- 硬件
- 准备Ascend处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```Shell
# 分布式训练
用法sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练
用法sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 运行评估示例
用法sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
# 脚本说明
## 脚本及样例代码
```text
└──wideresnet
├── README.md
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练8卡
├── run_eval.sh # 启动Ascend评估
└── run_standalone_train.sh # 启动Ascend单机训练单卡
├── src
├── config.py # 参数配置
├── dataset.py # 数据预处理
├── cross_entropy_smooth.py # cifar10数据集的损失定义
├── generator_lr.py # 生成每个步骤的学习率
├── save_callback.py # 自定义回调函数保存最优ckpt
└── wide_resnet.py # WideResNet网络结构
├── eval.py # 评估网络
├── export.py # 导出网络
└── train.py # 训练网络
```
# 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置WideResNet和cifar10数据集。
```Python
"num_classes":10, # 数据集类数
"batch_size":32, # 输入张量的批次大小
"epoch_size":300, # 训练周期大小
"save_checkpoint_path":"./", # 检查点相对执行路劲的保存路径
"repeat_num":1, # 数据集重复次数
"widen_factor":10, # 网络宽度
"depth":40, # 网络深度
"lr_init":0.1, # 初始学习率
"weight_decay":5e-4, # 权重衰减
"momentum":0.9, # 动量优化器
"loss_scale":32, # 损失等级
"save_checkpoint":True, # 是否保存检查点
"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
"use_label_smooth":True, # 标签平滑
"label_smooth_factor":0.1, # 标签平滑因子
"pretrain_epoch_size":0, # 预训练周期
"warmup_epochs":5, # 热身周期
```
# 训练过程
## 用法
## Ascend处理器环境运行
```Shell
# 分布式训练
用法sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练
用法sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
```
分布式训练需要提前创建JSON格式的HCCL配置文件。
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
训练结果保存在示例路径中文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果如下所示。
## 结果
- 使用cifar10数据集训练WideResNet
```text
# 分布式训练结果8P
epoch: 2 step: 195, loss is 1.4352043
epoch: 2 step: 195, loss is 1.4611206
epoch: 2 step: 195, loss is 1.2635705
epoch: 2 step: 195, loss is 1.3457444
epoch: 2 step: 195, loss is 1.4664338
epoch: 2 step: 195, loss is 1.3559061
epoch: 2 step: 195, loss is 1.5225968
epoch: 2 step: 195, loss is 1.246567
epoch: 3 step: 195, loss is 1.0763402
epoch: 3 step: 195, loss is 1.3007892
epoch: 3 step: 195, loss is 1.2473519
epoch: 3 step: 195, loss is 1.3249974
epoch: 3 step: 195, loss is 1.3388557
epoch: 3 step: 195, loss is 1.2402486
epoch: 3 step: 195, loss is 1.2878766
epoch: 3 step: 195, loss is 1.1507874
epoch: 4 step: 195, loss is 1.014946
epoch: 4 step: 195, loss is 1.1934564
epoch: 4 step: 195, loss is 0.9506259
epoch: 4 step: 195, loss is 1.2101849
epoch: 4 step: 195, loss is 1.0160742
epoch: 4 step: 195, loss is 1.2643425
epoch: 4 step: 195, loss is 1.3422029
epoch: 4 step: 195, loss is 1.221174
...
```
# 评估过程
## 用法
### Ascend处理器环境运行
```Shell
# 评估
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
```Shell
# 评估示例
sh run_eval.sh /cifar10 WideResNet_best.ckpt
```
训练过程中可以生成检查点。
## 结果
评估结果保存在示例路径中文件夹名为“eval”。您可在此路径下的日志找到如下结果
- 使用cifar10数据集评估WideResNet
```text
result: {'top_1_accuracy': 0.9622395833333334}
```
# 模型描述
## 性能
### 评估性能
#### cifar10上的WideResNet
| 参数 | Ascend 910 |
|---|---|
| 模型版本 | WideResNet |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 |2021-05-20 ; |
| MindSpore版本 | 1.2.1 |
| 数据集 | cifar10 |
| 训练参数 | epoch=300, steps per epoch=195, batch_size = 32 |
| 优化器 | Momentum |
| 损失函数 |Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.545541 |
|速度|65.2毫秒/步8卡 |
|总时长 | 70分钟 |
|参数(M) | 52.1 |
| 微调检查点 | 426.49M.ckpt文件 |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/wideresnet) |
# 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子同时还使用了train.py中的随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,83 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test WideResNet example on cifar10#################
python eval.py
"""
import os
import ast
import argparse
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.cross_entropy_smooth import CrossEntropySmooth
from src.wide_resnet import wideresnet
from src.dataset import create_dataset
from src.config import config_WideResnet as cfg
parser = argparse.ArgumentParser(description='Ascend WideResNet CIFAR10 Eval')
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
parser.add_argument('--ckpt_url', type=str, default=None, help='location of ckpt')
parser.add_argument('--modelart', required=True, type=ast.literal_eval, default=False,
help='training on modelart or not, default is False')
args = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
if __name__ == '__main__':
target = 'Ascend'
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False,
device_id=int(os.environ["DEVICE_ID"]))
data_path = '/cache/data_path'
if args.modelart:
import moxing as mox
mox.file.copy_parallel(src_url=args.data_url, dst_url=data_path)
else:
data_path = args.data_url
ds_eval = create_dataset(dataset_path=data_path,
do_train=False,
repeat_num=cfg.repeat_num,
batch_size=cfg.batch_size)
net = wideresnet()
ckpt_path = '/cache/ckpt_path/'
if args.modelart:
import moxing as mox
mox.file.copy_parallel(args.ckpt_url, dst_url=ckpt_path)
param_dict = load_checkpoint('/cache/ckpt_path/WideResNet_best.ckpt')
else:
param_dict = load_checkpoint(args.ckpt_url)
load_param_into_net(net, param_dict)
net.set_train(False)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction='mean',
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy'})
output = model.eval(ds_eval)
print("result:", output)

View File

@ -0,0 +1,62 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import os
import ast
import argparse
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.wide_resnet import wideresnet
parser = argparse.ArgumentParser(description='WideResNet export')
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
parser.add_argument('--data_url', default=None, help='Directory contains cifar10 dataset.')
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file')
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file name.")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=int(os.environ["DEVICE_ID"]))
if args.run_modelart:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
local_output_url = '/cache/ckpt' + str(device_id)
mox.file.copy_parallel(src_url=os.path.join(args.train_url, args.ckpt_file),
dst_url=os.path.join(local_output_url, args.ckpt_file))
if __name__ == '__main__':
net = wideresnet()
param_dict = load_checkpoint(os.path.join(local_output_url, args.ckpt_file))
print('load ckpt')
load_param_into_net(net, param_dict)
print('load ckpt to net')
net.set_train(False)
input_arr = Tensor(np.ones([args.batch_size, 3, 32, 32]), mstype.float32)
print('input')
export(net, input_arr, file_name="WideResNet", file_format=args.file_format)
if args.run_modelart:
file_name = "WideResNet." + args.file_format.lower()
mox.file.copy_parallel(src_url=file_name,
dst_url=os.path.join(args.train_url, file_name))

View File

@ -0,0 +1,74 @@
#!/bin/bash
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==========================================================================
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PATH3=$(get_real_path $3)
PATH4=$4
echo "$PATH1"
echo "$PATH2"
echo "$PATH3"
echo "$PATH4"
if [ ! -d $PATH2 ]
then
echo "error: DATA_URL=$PATH2 is not a directory"
exit 1
fi
if [ ! -d $PATH3 ]
then
echo "error: CKPT_URL=$PATH3 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
DATA_URL=$2
export DATA_URL=${DATA_URL}
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cp ../*.py ./device$i
cp *.sh ./device$i
cp -r ../src ./device$i
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
if [ $# == 4 ]
then
python train.py --data_url=$PATH2 --ckpt_url=$PATH3 --modelart=$PATH4 &> train.log &
fi
cd ../
done

View File

@ -0,0 +1,69 @@
#!/bin/bash
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#if [$# != 3]
#then
#echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
#exit 1
#fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PATH3=$3
echo "$PATH1"
echo "$PATH2"
echo "$PATH3"
if [ ! -d $PATH1 ]
then
echo "error: DATA_URL=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CKPT_URL=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --data_url=$PATH1 --ckpt_url=$PATH2 --modelart=$PATH3 &> eval.log &
cd ..

View File

@ -0,0 +1,77 @@
#!/bin/bash
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ $# == 3 ]
then
PATH2=$(get_real_path $2)
fi
if [ ! -d $PATH1 ]
then
echo "error: DATA_URL=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CKPT_URL=$PATH2 is not a file"
exit 1
fi
PATH3=$3
echo "$PATH1"
echo "$PATH2"
echo "$PATH3"
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=6
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 3 ]
then
python train.py --data_url=$PATH1 --ckpt_url=$PATH2 --modelart=$PATH3 &> train.log &
fi
cd ..

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for WideResNet models."""
class Config_WideResNet:
"""
Config parameters for the WideResNet.
Examples:
Config_WideResNet()
"""
num_classes = 10
batch_size = 32
epoch_size = 300
save_checkpoint_path = "./"
repeat_num = 1
widen_factor = 10
depth = 40
lr_init = 0.1
weight_decay = 5e-4
momentum = 0.9
loss_scale = 32
save_checkpoint = True
save_checkpoint_epochs = 5
keep_checkpoint_max = 10
use_label_smooth = True
label_smooth_factor = 0.1
pretrain_epoch_size = 0
warmup_epochs = 5
config_WideResnet = Config_WideResNet()

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -0,0 +1,78 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
create a train or evaluate cifar10 dataset for WideResnet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
Returns:
dataset
"""
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
if do_train:
dataset_path = os.path.join(dataset_path, 'train')
else:
dataset_path = os.path.join(dataset_path, 'eval')
if device_num == 1:
ds = de.Cifar10Dataset(dataset_path)
else:
if do_train:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=device_id)
else:
ds = de.Cifar10Dataset(dataset_path)
# define map operations
trans = []
if do_train:
trans += [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(prob=0.5)
]
trans += [
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate learning rate"""
import numpy as np
def get_lr(total_epochs,
steps_per_epoch,
lr_init
):
"""
generate learning rate
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
for i in range(int(total_steps)):
if i <= int(60 * steps_per_epoch):
lr = lr_init
elif i <= int(120 * steps_per_epoch):
lr = lr_init * 0.1 + 0.01
elif i <= int(160 * steps_per_epoch):
lr = lr_init * 0.1 * 0.1 + 0.003
elif i <= int(200 * steps_per_epoch):
lr = 0.001
elif i <= int(240 * steps_per_epoch):
lr = 0.0008
elif i <= int(260 * steps_per_epoch):
lr = 0.0006
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""save best ckpt"""
from mindspore.train.serialization import save_checkpoint
from mindspore.train.callback import Callback
from src.config import config_WideResnet as cfg
class SaveCallback(Callback):
"""
save best ckpt
"""
def __init__(self, model, eval_dataset, ckpt_path, modelart):
super(SaveCallback, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.cpkt_path = ckpt_path
self.acc = 0.96
self.cur_acc = 0.0
self.modelart = modelart
def step_end(self, run_context):
"""
step end
"""
cb_params = run_context.original_args()
result = self.model.eval(self.eval_dataset)
self.cur_acc = result['accuracy']
print("cur_acc is", self.cur_acc)
if result['accuracy'] > self.acc:
self.acc = result['accuracy']
file_name = "WideResNet_best" + ".ckpt"
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
if self.modelart:
import moxing as mox
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.cpkt_path)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)

View File

@ -0,0 +1,124 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""WideResNet"""
import mindspore.nn as nn
import mindspore.ops as ops
class WideBasic(nn.Cell):
"""
WideBasic
"""
def __init__(self, in_channels, out_channels, stride=1):
super(WideBasic, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=True)
self.bn2 = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout(keep_prob=0.7)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, has_bias=True)
self.shortcut = nn.SequentialCell()
if in_channels != out_channels or stride != 1:
self.shortcut = nn.SequentialCell(
[nn.Conv2d(in_channels, out_channels, 1, stride=stride, has_bias=True)]
)
def construct(self, x):
"""
basic construct
"""
identity = x
x = self.bn1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv2(x)
shortcut = self.shortcut(identity)
return x + shortcut
class WideResNet(nn.Cell):
"""
WideReNet
"""
def __init__(self, num_classes, block, depth=50, widen_factor=1):
"""
classes, block, depth, widen_factor
"""
super(WideResNet, self).__init__()
self.depth = depth
k = widen_factor
n = int((depth - 4) / 6)
self.in_channels = 16
self.conv1 = nn.Conv2d(3, self.in_channels, 3, 1, padding=0, pad_mode='same')
self.conv2 = self._make_layer(block, 16 * k, n, 1)
self.conv3 = self._make_layer(block, 32 * k, n, 2)
self.conv4 = self._make_layer(block, 64 * k, n, 2)
self.bn = nn.BatchNorm2d(64 * k, momentum=0.9)
self.relu = nn.ReLU()
self.mean = ops.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.linear = nn.Dense(64 * k, num_classes, has_bias=True)
self.bn1 = nn.BatchNorm2d(16)
def construct(self, x):
"""
WideResNet construct
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.bn(x)
x = self.relu(x)
x = self.mean(x, (2, 3))
x = self.flatten(x)
x = self.linear(x)
return x
def _make_layer(self, block, out_channels, num_blocks, stride):
"""
make layer
"""
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for s in strides:
layers.append(block(self.in_channels, out_channels, s))
self.in_channels = out_channels
return nn.SequentialCell(*layers)
def wideresnet(depth=40, widen_factor=10):
net = WideResNet(10, WideBasic, depth=depth, widen_factor=widen_factor)
return net

View File

@ -0,0 +1,129 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train WideResNet example on cifar10########################
python train.py
"""
import ast
import os
import argparse
import numpy as np
from mindspore.common import set_seed
from mindspore import context
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore import Tensor
from mindspore.nn.optim import Momentum
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.model import Model
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from src.wide_resnet import wideresnet
from src.dataset import create_dataset
from src.config import config_WideResnet as cfg
from src.generator_lr import get_lr
from src.cross_entropy_smooth import CrossEntropySmooth
from src.save_callback import SaveCallback
set_seed(1)
if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
parser = argparse.ArgumentParser(description='Ascend WideResnet+CIFAR10 Training')
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
parser.add_argument('--ckpt_url', required=True, default=None, help='Location of ckpt.')
parser.add_argument('--modelart', required=True, type=ast.literal_eval, default=False,
help='training on modelart or not, default is False')
args = parser.parse_args()
target = "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False,
device_id=device_id)
if device_num > 1:
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
dataset_sink_mode = True
if args.modelart:
import moxing as mox
data_path = '/cache/data_path'
mox.file.copy_parallel(src_url=args.data_url, dst_url=data_path)
else:
data_path = args.data_url
ds_train = create_dataset(dataset_path=data_path,
do_train=True,
batch_size=cfg.batch_size)
ds_eval = create_dataset(dataset_path=data_path,
do_train=False,
batch_size=cfg.batch_size)
step_size = ds_train.get_dataset_size()
net = wideresnet()
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(gain=np.sqrt(2)),
cell.weight.shape,
cell.weight.dtype))
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor,
num_classes=cfg.num_classes)
loss_scale = FixedLossScaleManager(loss_scale=cfg.loss_scale, drop_overflow_update=False)
lr = get_lr(total_epochs=cfg.epoch_size, steps_per_epoch=step_size, lr_init=cfg.lr_init)
lr = Tensor(lr)
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
opt = Momentum(group_params,
learning_rate=lr,
momentum=cfg.momentum,
loss_scale=cfg.loss_scale,
use_nesterov=True,
weight_decay=cfg.weight_decay)
model = Model(net,
amp_level="O2",
loss_fn=loss,
optimizer=opt,
loss_scale_manager=loss_scale,
metrics={'accuracy'},
keep_batchnorm_fp32=False
)
loss_cb = LossMonitor()
time_cb = TimeMonitor()
cb = [loss_cb, time_cb]
ckpt_path = args.ckpt_url
cb += [SaveCallback(model, ds_eval, ckpt_path, args.modelart)]
model.train(epoch=cfg.epoch_size, train_dataset=ds_train, callbacks=cb,
dataset_sink_mode=dataset_sink_mode)