wideresnet
This commit is contained in:
parent
36820f38dc
commit
e8a96b9b81
|
@ -0,0 +1,258 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [WideResNet描述](#wideresnet描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [用法](#用法)
|
||||
- [Ascend处理器环境运行](#ascend处理器环境运行)
|
||||
- [结果](#结果)
|
||||
- [评估过程](#评估过程)
|
||||
- [用法](#用法)
|
||||
- [Ascend处理器环境运行](#ascend处理器环境运行)
|
||||
- [结果](#结果)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [cifar10上的WideResNet](#cifar10上的wideresnet)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# WideResNet描述
|
||||
|
||||
## 概述
|
||||
|
||||
szagoruyko在ResNet基础上提出WideResNet,用于解决网络模型瘦长时,只有有限层学到了有用的知识,更多的层对最终结果只做出了很少的贡献。这个问题也被称为diminishing feature reuse,WideResNet作者加宽了残差块,将训练速度提升几倍,精度也有明显改善。
|
||||
|
||||
如下为MindSpore使用cifar10数据集对WideResNet进行训练的示例。
|
||||
|
||||
## 论文
|
||||
|
||||
1. [论文](https://arxiv.org/abs/1605.07146): Sergey Zagoruyko."Wide Residual Netwoks"
|
||||
|
||||
# 模型架构
|
||||
|
||||
WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.07146)
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[cifar10](http://www.cs.toronto.edu/~kriz/cifar.html)
|
||||
|
||||
- 数据集大小:共10个类、32*32彩色图像
|
||||
- 训练集:共50,000张图像
|
||||
- 测试集:共10,000张图像
|
||||
- 注:数据在dataset.py中处理。
|
||||
- 下载数据集,目录结构如下:
|
||||
|
||||
```text
|
||||
└─cifar-10-batches-bin
|
||||
├─data_batch_1.bin # 训练数据集
|
||||
├─data_batch_2.bin # 训练数据集
|
||||
├─data_batch_3.bin # 训练数据集
|
||||
├─data_batch_4.bin # 训练数据集
|
||||
├─data_batch_5.bin # 训练数据集
|
||||
└─test_batch.bin # 评估数据集
|
||||
```
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件
|
||||
- 准备Ascend处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
# 运行评估示例
|
||||
用法:sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```text
|
||||
└──wideresnet
|
||||
├── README.md
|
||||
├── scripts
|
||||
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||
├── run_eval.sh # 启动Ascend评估
|
||||
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
|
||||
├── src
|
||||
├── config.py # 参数配置
|
||||
├── dataset.py # 数据预处理
|
||||
├── cross_entropy_smooth.py # cifar10数据集的损失定义
|
||||
├── generator_lr.py # 生成每个步骤的学习率
|
||||
├── save_callback.py # 自定义回调函数保存最优ckpt
|
||||
└── wide_resnet.py # WideResNet网络结构
|
||||
├── eval.py # 评估网络
|
||||
├── export.py # 导出网络
|
||||
└── train.py # 训练网络
|
||||
```
|
||||
|
||||
# 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 配置WideResNet和cifar10数据集。
|
||||
|
||||
```Python
|
||||
"num_classes":10, # 数据集类数
|
||||
"batch_size":32, # 输入张量的批次大小
|
||||
"epoch_size":300, # 训练周期大小
|
||||
"save_checkpoint_path":"./", # 检查点相对执行路劲的保存路径
|
||||
"repeat_num":1, # 数据集重复次数
|
||||
"widen_factor":10, # 网络宽度
|
||||
"depth":40, # 网络深度
|
||||
"lr_init":0.1, # 初始学习率
|
||||
"weight_decay":5e-4, # 权重衰减
|
||||
"momentum":0.9, # 动量优化器
|
||||
"loss_scale":32, # 损失等级
|
||||
"save_checkpoint":True, # 是否保存检查点
|
||||
"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
|
||||
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
|
||||
"use_label_smooth":True, # 标签平滑
|
||||
"label_smooth_factor":0.1, # 标签平滑因子
|
||||
"pretrain_epoch_size":0, # 预训练周期
|
||||
"warmup_epochs":5, # 热身周期
|
||||
```
|
||||
|
||||
# 训练过程
|
||||
|
||||
## 用法
|
||||
|
||||
## Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
```
|
||||
|
||||
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||
|
||||
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
|
||||
|
||||
训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。
|
||||
|
||||
## 结果
|
||||
|
||||
- 使用cifar10数据集训练WideResNet
|
||||
|
||||
```text
|
||||
# 分布式训练结果(8P)
|
||||
epoch: 2 step: 195, loss is 1.4352043
|
||||
epoch: 2 step: 195, loss is 1.4611206
|
||||
epoch: 2 step: 195, loss is 1.2635705
|
||||
epoch: 2 step: 195, loss is 1.3457444
|
||||
epoch: 2 step: 195, loss is 1.4664338
|
||||
epoch: 2 step: 195, loss is 1.3559061
|
||||
epoch: 2 step: 195, loss is 1.5225968
|
||||
epoch: 2 step: 195, loss is 1.246567
|
||||
epoch: 3 step: 195, loss is 1.0763402
|
||||
epoch: 3 step: 195, loss is 1.3007892
|
||||
epoch: 3 step: 195, loss is 1.2473519
|
||||
epoch: 3 step: 195, loss is 1.3249974
|
||||
epoch: 3 step: 195, loss is 1.3388557
|
||||
epoch: 3 step: 195, loss is 1.2402486
|
||||
epoch: 3 step: 195, loss is 1.2878766
|
||||
epoch: 3 step: 195, loss is 1.1507874
|
||||
epoch: 4 step: 195, loss is 1.014946
|
||||
epoch: 4 step: 195, loss is 1.1934564
|
||||
epoch: 4 step: 195, loss is 0.9506259
|
||||
epoch: 4 step: 195, loss is 1.2101849
|
||||
epoch: 4 step: 195, loss is 1.0160742
|
||||
epoch: 4 step: 195, loss is 1.2643425
|
||||
epoch: 4 step: 195, loss is 1.3422029
|
||||
epoch: 4 step: 195, loss is 1.221174
|
||||
...
|
||||
```
|
||||
|
||||
# 评估过程
|
||||
|
||||
## 用法
|
||||
|
||||
### Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 评估
|
||||
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
```Shell
|
||||
# 评估示例
|
||||
sh run_eval.sh /cifar10 WideResNet_best.ckpt
|
||||
```
|
||||
|
||||
训练过程中可以生成检查点。
|
||||
|
||||
## 结果
|
||||
|
||||
评估结果保存在示例路径中,文件夹名为“eval”。您可在此路径下的日志找到如下结果:
|
||||
|
||||
- 使用cifar10数据集评估WideResNet
|
||||
|
||||
```text
|
||||
result: {'top_1_accuracy': 0.9622395833333334}
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### cifar10上的WideResNet
|
||||
|
||||
| 参数 | Ascend 910 |
|
||||
|---|---|
|
||||
| 模型版本 | WideResNet |
|
||||
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G |
|
||||
| 上传日期 |2021-05-20 ; |
|
||||
| MindSpore版本 | 1.2.1 |
|
||||
| 数据集 | cifar10 |
|
||||
| 训练参数 | epoch=300, steps per epoch=195, batch_size = 32 |
|
||||
| 优化器 | Momentum |
|
||||
| 损失函数 |Softmax交叉熵 |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 0.545541 |
|
||||
|速度|65.2毫秒/步(8卡) |
|
||||
|总时长 | 70分钟 |
|
||||
|参数(M) | 52.1 |
|
||||
| 微调检查点 | 426.49M(.ckpt文件) |
|
||||
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/wideresnet) |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
dataset.py中设置了“create_dataset”函数内的种子,同时还使用了train.py中的随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############test WideResNet example on cifar10#################
|
||||
python eval.py
|
||||
"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.cross_entropy_smooth import CrossEntropySmooth
|
||||
from src.wide_resnet import wideresnet
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config_WideResnet as cfg
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Ascend WideResNet CIFAR10 Eval')
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
|
||||
parser.add_argument('--ckpt_url', type=str, default=None, help='location of ckpt')
|
||||
parser.add_argument('--modelart', required=True, type=ast.literal_eval, default=False,
|
||||
help='training on modelart or not, default is False')
|
||||
args = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
target = 'Ascend'
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False,
|
||||
device_id=int(os.environ["DEVICE_ID"]))
|
||||
|
||||
data_path = '/cache/data_path'
|
||||
|
||||
if args.modelart:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=data_path)
|
||||
else:
|
||||
data_path = args.data_url
|
||||
|
||||
ds_eval = create_dataset(dataset_path=data_path,
|
||||
do_train=False,
|
||||
repeat_num=cfg.repeat_num,
|
||||
batch_size=cfg.batch_size)
|
||||
|
||||
net = wideresnet()
|
||||
|
||||
ckpt_path = '/cache/ckpt_path/'
|
||||
if args.modelart:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(args.ckpt_url, dst_url=ckpt_path)
|
||||
param_dict = load_checkpoint('/cache/ckpt_path/WideResNet_best.ckpt')
|
||||
else:
|
||||
param_dict = load_checkpoint(args.ckpt_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction='mean',
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy'})
|
||||
|
||||
output = model.eval(ds_eval)
|
||||
|
||||
print("result:", output)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.wide_resnet import wideresnet
|
||||
|
||||
parser = argparse.ArgumentParser(description='WideResNet export')
|
||||
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
||||
parser.add_argument('--data_url', default=None, help='Directory contains cifar10 dataset.')
|
||||
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file')
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file name.")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=int(os.environ["DEVICE_ID"]))
|
||||
|
||||
if args.run_modelart:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
local_output_url = '/cache/ckpt' + str(device_id)
|
||||
mox.file.copy_parallel(src_url=os.path.join(args.train_url, args.ckpt_file),
|
||||
dst_url=os.path.join(local_output_url, args.ckpt_file))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = wideresnet()
|
||||
|
||||
param_dict = load_checkpoint(os.path.join(local_output_url, args.ckpt_file))
|
||||
print('load ckpt')
|
||||
load_param_into_net(net, param_dict)
|
||||
print('load ckpt to net')
|
||||
net.set_train(False)
|
||||
input_arr = Tensor(np.ones([args.batch_size, 3, 32, 32]), mstype.float32)
|
||||
print('input')
|
||||
export(net, input_arr, file_name="WideResNet", file_format=args.file_format)
|
||||
if args.run_modelart:
|
||||
file_name = "WideResNet." + args.file_format.lower()
|
||||
mox.file.copy_parallel(src_url=file_name,
|
||||
dst_url=os.path.join(args.train_url, file_name))
|
|
@ -0,0 +1,74 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================================================================
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
PATH3=$(get_real_path $3)
|
||||
PATH4=$4
|
||||
echo "$PATH1"
|
||||
echo "$PATH2"
|
||||
echo "$PATH3"
|
||||
echo "$PATH4"
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: DATA_URL=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH3 ]
|
||||
then
|
||||
echo "error: CKPT_URL=$PATH3 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
|
||||
|
||||
DATA_URL=$2
|
||||
export DATA_URL=${DATA_URL}
|
||||
|
||||
for((i=0;i<${RANK_SIZE};i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cp ../*.py ./device$i
|
||||
cp *.sh ./device$i
|
||||
cp -r ../src ./device$i
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
python train.py --data_url=$PATH2 --ckpt_url=$PATH3 --modelart=$PATH4 &> train.log &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,69 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
#if [$# != 3]
|
||||
#then
|
||||
#echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
|
||||
#exit 1
|
||||
#fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
PATH3=$3
|
||||
|
||||
echo "$PATH1"
|
||||
echo "$PATH2"
|
||||
echo "$PATH3"
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATA_URL=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CKPT_URL=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --data_url=$PATH1 --ckpt_url=$PATH2 --modelart=$PATH3 &> eval.log &
|
||||
cd ..
|
|
@ -0,0 +1,77 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PATH2=$(get_real_path $2)
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATA_URL=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CKPT_URL=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH3=$3
|
||||
|
||||
echo "$PATH1"
|
||||
echo "$PATH2"
|
||||
echo "$PATH3"
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=6
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py --data_url=$PATH1 --ckpt_url=$PATH2 --modelart=$PATH3 &> train.log &
|
||||
fi
|
||||
cd ..
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Config parameters for WideResNet models."""
|
||||
|
||||
|
||||
class Config_WideResNet:
|
||||
"""
|
||||
Config parameters for the WideResNet.
|
||||
|
||||
Examples:
|
||||
Config_WideResNet()
|
||||
"""
|
||||
num_classes = 10
|
||||
batch_size = 32
|
||||
epoch_size = 300
|
||||
save_checkpoint_path = "./"
|
||||
repeat_num = 1
|
||||
widen_factor = 10
|
||||
depth = 40
|
||||
lr_init = 0.1
|
||||
weight_decay = 5e-4
|
||||
momentum = 0.9
|
||||
loss_scale = 32
|
||||
save_checkpoint = True
|
||||
save_checkpoint_epochs = 5
|
||||
keep_checkpoint_max = 10
|
||||
use_label_smooth = True
|
||||
label_smooth_factor = 0.1
|
||||
pretrain_epoch_size = 0
|
||||
warmup_epochs = 5
|
||||
|
||||
|
||||
config_WideResnet = Config_WideResNet()
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CrossEntropySmooth(_Loss):
|
||||
"""CrossEntropy"""
|
||||
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropySmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.sparse = sparse
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||
|
||||
def construct(self, logit, label):
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, label)
|
||||
return loss
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||
"""
|
||||
create a train or evaluate cifar10 dataset for WideResnet
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
|
||||
if do_train:
|
||||
dataset_path = os.path.join(dataset_path, 'train')
|
||||
else:
|
||||
dataset_path = os.path.join(dataset_path, 'eval')
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.Cifar10Dataset(dataset_path)
|
||||
else:
|
||||
if do_train:
|
||||
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=device_id)
|
||||
else:
|
||||
ds = de.Cifar10Dataset(dataset_path)
|
||||
|
||||
# define map operations
|
||||
trans = []
|
||||
if do_train:
|
||||
trans += [
|
||||
C.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||
C.RandomHorizontalFlip(prob=0.5)
|
||||
]
|
||||
|
||||
trans += [
|
||||
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""generate learning rate"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(total_epochs,
|
||||
steps_per_epoch,
|
||||
lr_init
|
||||
):
|
||||
"""
|
||||
generate learning rate
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
for i in range(int(total_steps)):
|
||||
if i <= int(60 * steps_per_epoch):
|
||||
lr = lr_init
|
||||
elif i <= int(120 * steps_per_epoch):
|
||||
lr = lr_init * 0.1 + 0.01
|
||||
elif i <= int(160 * steps_per_epoch):
|
||||
lr = lr_init * 0.1 * 0.1 + 0.003
|
||||
elif i <= int(200 * steps_per_epoch):
|
||||
lr = 0.001
|
||||
elif i <= int(240 * steps_per_epoch):
|
||||
lr = 0.0008
|
||||
elif i <= int(260 * steps_per_epoch):
|
||||
lr = 0.0006
|
||||
lr_each_step.append(lr)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
return lr_each_step
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""save best ckpt"""
|
||||
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore.train.callback import Callback
|
||||
from src.config import config_WideResnet as cfg
|
||||
|
||||
|
||||
class SaveCallback(Callback):
|
||||
"""
|
||||
save best ckpt
|
||||
"""
|
||||
def __init__(self, model, eval_dataset, ckpt_path, modelart):
|
||||
super(SaveCallback, self).__init__()
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.cpkt_path = ckpt_path
|
||||
self.acc = 0.96
|
||||
self.cur_acc = 0.0
|
||||
self.modelart = modelart
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
step end
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
result = self.model.eval(self.eval_dataset)
|
||||
self.cur_acc = result['accuracy']
|
||||
print("cur_acc is", self.cur_acc)
|
||||
|
||||
if result['accuracy'] > self.acc:
|
||||
self.acc = result['accuracy']
|
||||
file_name = "WideResNet_best" + ".ckpt"
|
||||
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
|
||||
if self.modelart:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.cpkt_path)
|
||||
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""WideResNet"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class WideBasic(nn.Cell):
|
||||
"""
|
||||
WideBasic
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super(WideBasic, self).__init__()
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=True)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
self.dropout = nn.Dropout(keep_prob=0.7)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, has_bias=True)
|
||||
|
||||
self.shortcut = nn.SequentialCell()
|
||||
|
||||
if in_channels != out_channels or stride != 1:
|
||||
self.shortcut = nn.SequentialCell(
|
||||
[nn.Conv2d(in_channels, out_channels, 1, stride=stride, has_bias=True)]
|
||||
)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
basic construct
|
||||
"""
|
||||
|
||||
identity = x
|
||||
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.dropout(x)
|
||||
x = self.conv1(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
shortcut = self.shortcut(identity)
|
||||
|
||||
return x + shortcut
|
||||
|
||||
|
||||
class WideResNet(nn.Cell):
|
||||
"""
|
||||
WideReNet
|
||||
"""
|
||||
def __init__(self, num_classes, block, depth=50, widen_factor=1):
|
||||
"""
|
||||
classes, block, depth, widen_factor
|
||||
"""
|
||||
super(WideResNet, self).__init__()
|
||||
|
||||
self.depth = depth
|
||||
k = widen_factor
|
||||
n = int((depth - 4) / 6)
|
||||
self.in_channels = 16
|
||||
self.conv1 = nn.Conv2d(3, self.in_channels, 3, 1, padding=0, pad_mode='same')
|
||||
self.conv2 = self._make_layer(block, 16 * k, n, 1)
|
||||
self.conv3 = self._make_layer(block, 32 * k, n, 2)
|
||||
self.conv4 = self._make_layer(block, 64 * k, n, 2)
|
||||
self.bn = nn.BatchNorm2d(64 * k, momentum=0.9)
|
||||
self.relu = nn.ReLU()
|
||||
self.mean = ops.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear = nn.Dense(64 * k, num_classes, has_bias=True)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
WideResNet construct
|
||||
"""
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
x = self.mean(x, (2, 3))
|
||||
x = self.flatten(x)
|
||||
x = self.linear(x)
|
||||
|
||||
return x
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
"""
|
||||
make layer
|
||||
"""
|
||||
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for s in strides:
|
||||
layers.append(block(self.in_channels, out_channels, s))
|
||||
self.in_channels = out_channels
|
||||
|
||||
return nn.SequentialCell(*layers)
|
||||
|
||||
|
||||
def wideresnet(depth=40, widen_factor=10):
|
||||
net = WideResNet(10, WideBasic, depth=depth, widen_factor=widen_factor)
|
||||
return net
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train WideResNet example on cifar10########################
|
||||
python train.py
|
||||
"""
|
||||
import ast
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
|
||||
from src.wide_resnet import wideresnet
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config_WideResnet as cfg
|
||||
from src.generator_lr import get_lr
|
||||
from src.cross_entropy_smooth import CrossEntropySmooth
|
||||
from src.save_callback import SaveCallback
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
parser = argparse.ArgumentParser(description='Ascend WideResnet+CIFAR10 Training')
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
|
||||
parser.add_argument('--ckpt_url', required=True, default=None, help='Location of ckpt.')
|
||||
parser.add_argument('--modelart', required=True, type=ast.literal_eval, default=False,
|
||||
help='training on modelart or not, default is False')
|
||||
args = parser.parse_args()
|
||||
|
||||
target = "Ascend"
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False,
|
||||
device_id=device_id)
|
||||
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
dataset_sink_mode = True
|
||||
|
||||
if args.modelart:
|
||||
import moxing as mox
|
||||
data_path = '/cache/data_path'
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=data_path)
|
||||
else:
|
||||
data_path = args.data_url
|
||||
|
||||
ds_train = create_dataset(dataset_path=data_path,
|
||||
do_train=True,
|
||||
batch_size=cfg.batch_size)
|
||||
ds_eval = create_dataset(dataset_path=data_path,
|
||||
do_train=False,
|
||||
batch_size=cfg.batch_size)
|
||||
step_size = ds_train.get_dataset_size()
|
||||
|
||||
net = wideresnet()
|
||||
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(gain=np.sqrt(2)),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor,
|
||||
num_classes=cfg.num_classes)
|
||||
loss_scale = FixedLossScaleManager(loss_scale=cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
lr = get_lr(total_epochs=cfg.epoch_size, steps_per_epoch=step_size, lr_init=cfg.lr_init)
|
||||
lr = Tensor(lr)
|
||||
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
|
||||
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
opt = Momentum(group_params,
|
||||
learning_rate=lr,
|
||||
momentum=cfg.momentum,
|
||||
loss_scale=cfg.loss_scale,
|
||||
use_nesterov=True,
|
||||
weight_decay=cfg.weight_decay)
|
||||
|
||||
model = Model(net,
|
||||
amp_level="O2",
|
||||
loss_fn=loss,
|
||||
optimizer=opt,
|
||||
loss_scale_manager=loss_scale,
|
||||
metrics={'accuracy'},
|
||||
keep_batchnorm_fp32=False
|
||||
)
|
||||
|
||||
loss_cb = LossMonitor()
|
||||
time_cb = TimeMonitor()
|
||||
cb = [loss_cb, time_cb]
|
||||
ckpt_path = args.ckpt_url
|
||||
cb += [SaveCallback(model, ds_eval, ckpt_path, args.modelart)]
|
||||
|
||||
model.train(epoch=cfg.epoch_size, train_dataset=ds_train, callbacks=cb,
|
||||
dataset_sink_mode=dataset_sink_mode)
|
Loading…
Reference in New Issue