forked from mindspore-Ecosystem/mindspore
add resnetv2_50 gpu
This commit is contained in:
parent
32c8a1529e
commit
fff351e00b
|
@ -0,0 +1,304 @@
|
|||
|
||||
# Resnetv2描述
|
||||
|
||||
## 概述
|
||||
|
||||
ResNet系列模型是在2015年提出的,该网络创新性的提出了残差结构,通过堆叠多个残差结构从而构建了ResNet网络。ResNet一定程度上解决了传统的卷积网络或全连接网络或多或少存在信息丢失的问题。通过将输入信息传递给输出,确保信息完整性,使得网络深度得以不断加深的同时避免了梯度消失或爆炸的影响。ResNetv2是何凯明团队在ResNet发表后,又进一步对其网络结构进行了改进和优化,通过推导证明了前向参数和反向梯度如果直接从Residual Block传递到下一个Residual Block而不用经过ReLU等操作,效果会更好。因此调整了激活层和BN层与卷积层的运算先后顺序,并经过实验验证在深度网络中ResNetv2会有更好的收敛效果。
|
||||
|
||||
如下为MindSpore使用Cifar10/ImageNet2012数据集对ResNetv2_50/ResNetv2_101/ResNetv2_152进行训练的示例。
|
||||
|
||||
## 论文
|
||||
|
||||
1. [论文](https://arxiv.org/pdf/1603.05027.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
|
||||
|
||||
# 模型架构
|
||||
|
||||
[ResNetv2_50](https://arxiv.org/pdf/1603.05027.pdf)的整体网络架构和[Resnet50](https://arxiv.org/pdf/1512.03385.pdf)的架构相仿,仅调整了激活层和BN层与卷积层的先后顺序。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)
|
||||
|
||||
- 数据集大小:共10个类、60,000个32*32彩色图像
|
||||
- 训练集:50,000个图像
|
||||
- 测试集:10,000个图像
|
||||
- 数据格式:二进制文件
|
||||
- 注:数据在dataset.py中处理。
|
||||
- 下载数据集。目录结构如下:
|
||||
|
||||
```text
|
||||
├─cifar-10-batches-bin
|
||||
│
|
||||
└─cifar-10-verify-bin
|
||||
```
|
||||
|
||||
使用的数据集:[ImageNet2012](http://www.image-net.org/)
|
||||
|
||||
- 数据集大小:共1000个类、224*224彩色图像
|
||||
- 训练集:共1,281,167张图像
|
||||
- 测试集:共50,000张图像
|
||||
- 数据格式:JPEG
|
||||
- 注:数据在dataset.py中处理。
|
||||
- 下载数据集,目录结构如下:
|
||||
|
||||
```text
|
||||
└─dataset
|
||||
├─ilsvrc # 训练数据集
|
||||
└─validation_preprocess # 评估数据集
|
||||
```
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件
|
||||
- 准备Ascend处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,审核通过即可获得资源。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]
|
||||
|
||||
# 运行评估示例
|
||||
用法:sh run_eval.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
```shell
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]
|
||||
|
||||
# 运行评估示例
|
||||
用法:sh run_eval_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```text
|
||||
└──resnetv2
|
||||
├── README.md
|
||||
├── scripts
|
||||
├── run_distribute_train_gpu.sh # 启动gpu分布式训练(8卡)
|
||||
├── run_eval_gpu.sh # 启动gpu评估
|
||||
├── run_standalone_train_gpu.sh # 启动gpu单机训练(单卡)
|
||||
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||
├── run_eval.sh # 启动Ascend评估
|
||||
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
|
||||
├── src
|
||||
├── config.py # 参数配置
|
||||
├── dataset.py # 数据预处理
|
||||
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
|
||||
├── lr_generator.py # 生成每个步骤的学习率
|
||||
└── resnetv2.py # ResNet骨干网络
|
||||
├── eval.py # 评估网络
|
||||
└── train.py # 训练网络
|
||||
└── export.py # 导出网络
|
||||
```
|
||||
|
||||
# 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 配置ResNetv2_50和cifar10数据集。
|
||||
|
||||
```Python
|
||||
"class_num":10, # 数据集类数
|
||||
"batch_size":32, # 输入张量的批次大小
|
||||
"loss_scale":1024, # 损失等级
|
||||
"momentum":0.9, # 动量优化器
|
||||
"weight_decay":5e-4, # 权重衰减
|
||||
"epoch_size":100, # 训练周期大小
|
||||
"save_checkpoint":True, # 是否保存检查点
|
||||
"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
|
||||
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
|
||||
"save_checkpoint_path":"./checkpoint", # 检查点相对于执行路径的保存路径
|
||||
"low_memory": False, # 显存不足时可设置为Ture
|
||||
"warmup_epochs":5, # 热身周期数
|
||||
"lr_decay_mode":"cosine", # 用于生成学习率的衰减模式
|
||||
"lr_init":0.1, # 基础学习率
|
||||
"lr_end":0.0000000005, # 最终学习率
|
||||
"lr_max":0.1, # 最大学习率
|
||||
```
|
||||
|
||||
- 配置ResNetv2_50和imagenet2012数据集。
|
||||
|
||||
```python
|
||||
"class_num":1001, # 数据集类数
|
||||
"batch_size":64, # 输入张量的批次大小
|
||||
"loss_scale":1024, # 损失等级
|
||||
"momentum":0.9, # 动量优化器
|
||||
"weight_decay":1e-4, # 权重衰减
|
||||
"epoch_size":100, # 训练周期大小
|
||||
"save_checkpoint":True, # 是否保存检查点
|
||||
"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
|
||||
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
|
||||
"save_checkpoint_path":"./checkpoint", # 检查点相对于执行路径的保存路径
|
||||
"low_memory": True, # 显存不足时可设置为Ture,默认为False
|
||||
"warmup_epochs":5, # 热身周期数
|
||||
"use_label_smooth":True, # 标签平滑
|
||||
"label_smooth_factor":0.1, # 标签平滑因子
|
||||
"lr_decay_mode":"cosine", # 用于生成学习率的衰减模式
|
||||
"lr_init":0.05, # 基础学习率
|
||||
"lr_end":0.0000001, # 最终学习率
|
||||
"lr_max":0.05, # 最大学习率
|
||||
```
|
||||
|
||||
# 训练过程
|
||||
|
||||
## 用法
|
||||
|
||||
### Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]
|
||||
```
|
||||
|
||||
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||
|
||||
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
|
||||
|
||||
### GPU处理器环境运行
|
||||
|
||||
```shell
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]
|
||||
```
|
||||
|
||||
## 结果
|
||||
|
||||
- 使用cifar10数据集训练ResNetv2_50
|
||||
|
||||
```text
|
||||
# Ascend分布式训练结果(8P)
|
||||
epoch: 41 step: 195, loss is 0.17125674
|
||||
epoch time: 4733.000 ms, per step time: 24.143 ms
|
||||
epoch: 42 step: 195, loss is 0.0011220031
|
||||
epoch time: 4735.284 ms, per step time: 24.135 ms
|
||||
epoch: 43 step: 195, loss is 0.105422504
|
||||
epoch time: 4737.401 ms, per step time: 24.166 ms
|
||||
...
|
||||
```
|
||||
|
||||
- 使用imagenet2012数据集训练ResNetv2_50
|
||||
|
||||
```text
|
||||
# Ascend分布式训练结果 (8P)
|
||||
epoch: 61 step: 2502, loss is 2.4235027
|
||||
epoch time: 813367.327 ms, per step time: 325.087 ms
|
||||
epoch: 62 step: 2502, loss is 2.0396166
|
||||
epoch time: 813387.109 ms, per step time: 325.095 ms
|
||||
epoch: 63 step: 2502, loss is 1.7643375
|
||||
epoch time: 813347.102 ms, per step time: 325.075 ms
|
||||
...
|
||||
```
|
||||
|
||||
# 评估过程
|
||||
|
||||
## 用法
|
||||
|
||||
### Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 评估
|
||||
用法:sh run_eval.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
### GPU处理器环境运行
|
||||
|
||||
```shell
|
||||
# 运行评估示例
|
||||
用法:sh run_eval_gpu.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
## 结果
|
||||
|
||||
评估结果可以在当前脚本路径下的日志中找到如下结果:
|
||||
|
||||
- 使用cifar10数据集评估ResNetv2_50
|
||||
|
||||
```text
|
||||
result: {'top_5_accuracy': 0.9988982371794872, 'top_1_accuracy': 0.9502283653846154}
|
||||
```
|
||||
|
||||
- 使用imagenet2012数据集评估ResNetv2_50
|
||||
|
||||
```text
|
||||
result: {'top_1_accuracy': 0.7606515786082474, 'top_5_accuracy': 0.9271504510309279}
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### Cifar10上的ResNetv2_50
|
||||
|
||||
| 参数 | Ascend 910 |
|
||||
|---|---|
|
||||
| 模型版本 | ResNetv2_50 |
|
||||
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G |
|
||||
| 上传日期 |2021-03-24 ; |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | Cifar10 |
|
||||
| 训练参数 | epoch=135, steps per epoch=195, batch_size=32 |
|
||||
| 优化器 | Momentum |
|
||||
| 损失函数 |Softmax交叉熵 |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 0.0007279 |
|
||||
|速度|24.3毫秒/步(8卡) |
|
||||
|总时长 | 10分钟 |
|
||||
| 微调检查点 | 188.36M(.ckpt文件) |
|
||||
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnetv2) |
|
||||
|
||||
#### ImageNet2012上的Resnetv2_50
|
||||
|
||||
| 参数 | Ascend 910 |
|
||||
| ------------- | ------------------------------------------------------------ |
|
||||
| 模型版本 | ResNetv2_50 |
|
||||
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G |
|
||||
| 上传日期 | 2021-05-6 ; |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | ImageNet2012 |
|
||||
| 训练参数 | epoch=90, steps per epoch=2502, batch_size=64 |
|
||||
| 优化器 | Momentum |
|
||||
| 损失函数 | Softmax交叉熵 |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 1.8290355 |
|
||||
| 速度 | 325毫秒/步(8卡) |
|
||||
| 总时长 | 20.3小时 |
|
||||
| 微调检查点 | 195.9M(.ckpt文件) |
|
||||
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnetv2) |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
dataset.py中设置了“create_dataset”函数内的种子,同时还使用了train.py中的随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" eval.py """
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--net', type=str, default='resnetv2_50',
|
||||
help='Resnetv2 Model, resnetv2_50, resnetv2_101, resnetv2_152')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset, cifar10, imagenet2012')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument('--dataset_path', type=str, default="../cifar-10/cifar-10-verify-bin",
|
||||
help='Dataset path.')
|
||||
parser.add_argument('--checkpoint_path', type=str, default="./checkpoint/train_resnetv2_cifar10-100_1562.ckpt",
|
||||
help='Checkpoint file path.')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
# import net
|
||||
if args_opt.net == "resnetv2_50":
|
||||
from src.resnetv2 import PreActResNet50 as resnetv2
|
||||
elif args_opt.net == 'resnetv2_101':
|
||||
from src.resnetv2 import PreActResNet101 as resnetv2
|
||||
elif args_opt.net == 'resnetv2_152':
|
||||
from src.resnetv2 import PreActResNet152 as resnetv2
|
||||
|
||||
# import dataset
|
||||
if args_opt.dataset == "cifar10":
|
||||
from src.dataset import create_dataset1 as create_dataset
|
||||
elif args_opt.dataset == "cifar100":
|
||||
from src.dataset import create_dataset2 as create_dataset
|
||||
elif args_opt.dataset == 'imagenet2012':
|
||||
from src.dataset import create_dataset3 as create_dataset
|
||||
|
||||
# import config
|
||||
if args_opt.net == "resnetv2_50" or args_opt.net == "resnetv2_101" or args_opt.net == "resnetv2_152":
|
||||
if args_opt.dataset == "cifar10":
|
||||
from src.config import config1 as config
|
||||
elif args_opt.dataset == 'cifar100':
|
||||
from src.config import config2 as config
|
||||
elif args_opt.dataset == 'imagenet2012':
|
||||
from src.config import config3 as config
|
||||
|
||||
set_seed(1)
|
||||
|
||||
try:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
except TypeError:
|
||||
device_id = 0
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("============== Starting Evaluating ==============")
|
||||
print(f"start evaluating {args_opt.net} on device {device_id}")
|
||||
|
||||
# init context
|
||||
target = args_opt.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
net = resnetv2(class_num=config.class_num)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# define loss, model
|
||||
if args_opt.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction='mean',
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
# eval model
|
||||
res = model.eval(dataset)
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ckpt to air."""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
parser = argparse.ArgumentParser(description='resnet export')
|
||||
parser.add_argument('--net', type=str, default='resnetv2_50',
|
||||
help='Resnetv2 Model, resnetv2_50, resnetv2_101, resnetv2_152')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset, cifar10, cifar100, imagenet2012')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="resnetv2", help="output file name.")
|
||||
parser.add_argument('--width', type=int, default=32, help='input width')
|
||||
parser.add_argument('--height', type=int, default=32, help='input height')
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# import net
|
||||
if args.net == "resnetv2_50":
|
||||
from src.resnetv2 import PreActResNet50 as resnetv2
|
||||
elif args.net == 'resnetv2_101':
|
||||
from src.resnetv2 import PreActResNet101 as resnetv2
|
||||
elif args.net == 'resnetv2_152':
|
||||
from src.resnetv2 import PreActResNet152 as resnetv2
|
||||
else:
|
||||
raise ValueError("network is not support.")
|
||||
|
||||
# import config
|
||||
if args.net == "resnetv2_50" or args.net == "resnetv2_101" or args.net == "resnetv2_152":
|
||||
if args.dataset == "cifar10":
|
||||
from src.config import config1 as config
|
||||
elif args.dataset == 'cifar100':
|
||||
from src.config import config2 as config
|
||||
elif args.dataset == 'imagenet2012':
|
||||
raise ValueError("ImageNet2012 dataset not yet supported")
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
else:
|
||||
raise ValueError("network is not support.")
|
||||
|
||||
net = resnetv2(config.class_num)
|
||||
|
||||
assert args.ckpt_file is not None, "checkpoint_path is None."
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,58 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: bash run_distribute_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
|
||||
then
|
||||
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $3 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $4 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$4 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export RANK_TABLE_FILE=$3
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
||||
for((i=0; i<${RANK_SIZE}; i++))
|
||||
do
|
||||
export DEVICE_ID=${i}
|
||||
export RANK_ID=${i}
|
||||
echo "start distributed training for rank $RANK_ID, device $DEVICE_ID"
|
||||
python train.py --net $1 --dataset $2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path $4 &> log.$i &
|
||||
done
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_distribute_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
|
||||
then
|
||||
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $3 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$3 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py --net=$1 --dataset=$2 --run_distribute=True \
|
||||
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$3 &> log &
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
|
||||
then
|
||||
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $3 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$3 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $4 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$4 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --net=$1 --dataset=$2 --dataset_path=$3 --checkpoint_path=$4 &> eval.log &
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
|
||||
then
|
||||
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $3 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$3 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $4 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$4 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$3 --checkpoint_path=$4 &> eval.log &
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_standalone_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
|
||||
then
|
||||
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $3 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$4 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_ID=0
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
python train.py --net $1 --dataset $2 --device_num=$DEVICE_NUM --dataset_path $3 &> log.$DEVICE_ID &
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_standalone_train.sh [resnetv2_50|resnetv2_101|resnetv2_152] [cifar10|imagenet2012] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $1 != "resnetv2_50" ] && [ $1 != "resnetv2_101" ] && [ $1 != "resnetv2_152" ]
|
||||
then
|
||||
echo "error: the selected net is neither resnetv2_50 nor resnetv2_101 and resnetv2_152"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $3 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$4 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_ID=2
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
python train.py --net $1 --dataset $2 --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path $3 &> log.$DEVICE_ID &
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" CrossEntropySmooth.py """
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CrossEntropySmooth(_Loss):
|
||||
"""CrossEntropy"""
|
||||
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropySmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.sparse = sparse
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||
|
||||
def construct(self, logit, label):
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, label)
|
||||
return loss
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" config.py """
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
# config for ResNetv2, cifar10
|
||||
config1 = ed({
|
||||
"class_num": 10,
|
||||
"batch_size": 32,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 5e-4,
|
||||
"epoch_size": 200,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 5,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"low_memory": False,
|
||||
"warmup_epochs": 5,
|
||||
"lr_decay_mode": "cosine",
|
||||
"lr_init": 0.1,
|
||||
"lr_end": 0.0000000005,
|
||||
"lr_max": 0.1,
|
||||
})
|
||||
|
||||
# config for ResNetv2, cifar100
|
||||
config2 = ed({
|
||||
"class_num": 100,
|
||||
"batch_size": 32,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 5e-4,
|
||||
"epoch_size": 100,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 5,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"low_memory": False,
|
||||
"warmup_epochs": 5,
|
||||
"lr_decay_mode": "cosine",
|
||||
"lr_init": 0.1,
|
||||
"lr_end": 0.0000000005,
|
||||
"lr_max": 0.1,
|
||||
})
|
||||
|
||||
# config for ResNetv2, imagenet2012
|
||||
config3 = ed({
|
||||
"class_num": 1001,
|
||||
"batch_size": 64,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 90,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 5,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"low_memory": True,
|
||||
"warmup_epochs": 0,
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_decay_mode": "cosine",
|
||||
"lr_init": 0.05,
|
||||
"lr_end": 0.0000001,
|
||||
"lr_max": 0.05,
|
||||
})
|
|
@ -0,0 +1,213 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" dataset.py """
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
def create_dataset1(dataset_path, do_train=True, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or evaluate cifar10 dataset for PreActResnet
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
rank_size, rank_id = _get_rank_info()
|
||||
else:
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
else:
|
||||
rank_size = 1
|
||||
if rank_size == 1:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=rank_size, shard_id=rank_id)
|
||||
|
||||
# define map operations
|
||||
trans = []
|
||||
if do_train:
|
||||
trans += [
|
||||
C.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||
C.RandomHorizontalFlip()
|
||||
]
|
||||
|
||||
trans += [
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def create_dataset2(dataset_path, do_train=True, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or evaluate cifar100 dataset for PreActResnet
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
device_num = 1
|
||||
if device_num == 1:
|
||||
data_set = ds.Cifar100Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
data_set = ds.Cifar100Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
# define map operations
|
||||
trans = []
|
||||
if do_train:
|
||||
trans += [
|
||||
C.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||
C.RandomHorizontalFlip()
|
||||
]
|
||||
|
||||
trans += [
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="fine_label", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
|
||||
def del_column(col1, col2, col3, batch_info):
|
||||
return(col1, col2,)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, per_batch_map=del_column,
|
||||
input_columns=['image', 'fine_label', 'coarse_label'],
|
||||
output_columns=['image', 'label'],
|
||||
drop_remainder=True)
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def create_dataset3(dataset_path, do_train=True, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for PreActResnet
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
else:
|
||||
device_num = 1
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
""" lr_generator.py """
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies three steps decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies polynomial decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies cosine decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_steps = total_steps - warmup_steps
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies liner decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default)
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
|
||||
if lr_decay_mode == 'steps':
|
||||
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'poly':
|
||||
lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
else:
|
||||
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):
|
||||
"""
|
||||
generate learning rate array with cosine
|
||||
|
||||
Args:
|
||||
lr(float): base learning rate
|
||||
steps_per_epoch(int): steps size of one epoch
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
max_epoch(int): total epochs of training
|
||||
global_step(int): the current start index of lr array
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
decay_steps = total_steps - warmup_steps
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = base_lr * decayed
|
||||
lr_each_step.append(lr)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[global_step:]
|
||||
return learning_rate
|
|
@ -0,0 +1,160 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
""" resnetv2.py """
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class PreActBottleNeck(nn.Cell):
|
||||
""" PreActBottleNeck """
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
planes,
|
||||
stride=1):
|
||||
super(PreActBottleNeck, self).__init__()
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_planes, eps=1e-5, momentum=0.9)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1)
|
||||
self.bn2 = nn.BatchNorm2d(planes, eps=1e-5, momentum=0.9)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, pad_mode='pad')
|
||||
self.bn3 = nn.BatchNorm2d(planes, eps=1e-5, momentum=0.9)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, stride=1)
|
||||
|
||||
self.downtown = False
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.downtown = True
|
||||
self.shortcut = nn.SequentialCell([nn.Conv2d(in_planes, self.expansion*planes,
|
||||
kernel_size=1, stride=stride)])
|
||||
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
""" construct network """
|
||||
out = self.bn1(x)
|
||||
out = self.relu(out)
|
||||
if self.downtown:
|
||||
identity = self.shortcut(out)
|
||||
else:
|
||||
identity = x
|
||||
out = self.conv1(out)
|
||||
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
out = self.bn3(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
|
||||
out = self.add(out, identity)
|
||||
return out
|
||||
|
||||
|
||||
class PreActResNet(nn.Cell):
|
||||
""" PreActResNet """
|
||||
def __init__(self,
|
||||
block,
|
||||
num_blocks,
|
||||
in_planes,
|
||||
planes,
|
||||
strides,
|
||||
low_memory,
|
||||
num_classes=10):
|
||||
super(PreActResNet, self).__init__()
|
||||
self.in_planes = in_planes
|
||||
self.low_memory = low_memory
|
||||
|
||||
self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2)
|
||||
self.conv2 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, pad_mode='pad', padding=1)
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
planes=planes[0],
|
||||
num_blocks=num_blocks[0],
|
||||
stride=strides[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
planes=planes[1],
|
||||
num_blocks=num_blocks[1],
|
||||
stride=strides[1])
|
||||
self.layer3 = self._make_layer(block,
|
||||
planes=planes[2],
|
||||
num_blocks=num_blocks[2],
|
||||
stride=strides[2])
|
||||
self.layer4 = self._make_layer(block,
|
||||
planes=planes[3],
|
||||
num_blocks=num_blocks[3],
|
||||
stride=strides[3])
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear = nn.Dense(planes[3]*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
layers = []
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
|
||||
for s in strides:
|
||||
layers.append(block(self.in_planes, planes, s))
|
||||
self.in_planes = planes * block.expansion
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
""" construct network """
|
||||
if self.low_memory:
|
||||
out = self.conv1(x)
|
||||
else:
|
||||
out = self.conv2(x)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
|
||||
out = self.mean(out, (2, 3))
|
||||
out = out.view(out.shape[0], -1)
|
||||
out = self.linear(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def PreActResNet50(class_num=10, low_memory=False):
|
||||
return PreActResNet(PreActBottleNeck,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
in_planes=64,
|
||||
planes=[64, 128, 256, 512],
|
||||
strides=[1, 2, 2, 2],
|
||||
low_memory=low_memory,
|
||||
num_classes=class_num)
|
||||
|
||||
|
||||
def PreActResNet101(class_num=10, low_memory=False):
|
||||
return PreActResNet(PreActBottleNeck,
|
||||
num_blocks=[3, 4, 23, 3],
|
||||
in_planes=64,
|
||||
planes=[64, 128, 256, 512],
|
||||
strides=[1, 2, 2, 2],
|
||||
low_memory=low_memory,
|
||||
num_classes=class_num)
|
||||
|
||||
|
||||
def PreActResNet152(class_num=10, low_memory=False):
|
||||
return PreActResNet(PreActBottleNeck,
|
||||
num_blocks=[3, 8, 36, 3],
|
||||
in_planes=64,
|
||||
planes=[64, 128, 256, 512],
|
||||
strides=[1, 2, 2, 2],
|
||||
low_memory=low_memory,
|
||||
num_classes=class_num)
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" train.py """
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context, Model, load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.nn import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
|
||||
from src.lr_generator import get_lr
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification.')
|
||||
parser.add_argument('--net', type=str, default='resnetv2_50',
|
||||
help='Resnetv2 Model, resnetv2_50, resnetv2_101, resnetv2_152')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10',
|
||||
help='Dataset, cifar10, imagenet2012')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||
parser.add_argument('--dataset_path', type=str, default="../cifar-10/cifar-10-batches-bin",
|
||||
help='Dataset path.')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
# import net
|
||||
if args_opt.net == "resnetv2_50":
|
||||
from src.resnetv2 import PreActResNet50 as resnetv2
|
||||
elif args_opt.net == 'resnetv2_101':
|
||||
from src.resnetv2 import PreActResNet101 as resnetv2
|
||||
elif args_opt.net == 'resnetv2_152':
|
||||
from src.resnetv2 import PreActResNet152 as resnetv2
|
||||
|
||||
# import dataset
|
||||
if args_opt.dataset == "cifar10":
|
||||
from src.dataset import create_dataset1 as create_dataset
|
||||
elif args_opt.dataset == "cifar100":
|
||||
from src.dataset import create_dataset2 as create_dataset
|
||||
elif args_opt.dataset == 'imagenet2012':
|
||||
from src.dataset import create_dataset3 as create_dataset
|
||||
|
||||
# import config
|
||||
if args_opt.net == "resnetv2_50" or args_opt.net == "resnetv2_101" or args_opt.net == "resnetv2_152":
|
||||
if args_opt.dataset == "cifar10":
|
||||
from src.config import config1 as config
|
||||
elif args_opt.dataset == 'cifar100':
|
||||
from src.config import config2 as config
|
||||
elif args_opt.dataset == 'imagenet2012':
|
||||
from src.config import config3 as config
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("============== Starting Training ==============")
|
||||
target = args_opt.device_target
|
||||
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
if args_opt.run_distribute:
|
||||
if target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
||||
# init parallel training parameters
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
# init HCCL
|
||||
init()
|
||||
else:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
try:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
except TypeError:
|
||||
device_id = 0
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target, distribute=args_opt.run_distribute)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
epoch_size = config.epoch_size
|
||||
net = resnetv2(config.class_num, config.low_memory)
|
||||
|
||||
# init weight
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# init lr
|
||||
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
|
||||
lr_decay_mode=config.lr_decay_mode)
|
||||
lr = Tensor(lr)
|
||||
|
||||
# define loss, opt, model
|
||||
if args_opt.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay, config.loss_scale)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
|
||||
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"train_{args_opt.net}_{args_opt.dataset}",
|
||||
directory=ckpt_save_dir, config=config_ck)
|
||||
|
||||
# train
|
||||
if args_opt.run_distribute:
|
||||
callbacks = [time_cb, loss_cb]
|
||||
if target == "GPU" and str(get_rank()) == '0':
|
||||
callbacks = [time_cb, loss_cb, ckpoint_cb]
|
||||
elif target == "Ascend" and device_id == 0:
|
||||
callbacks = [time_cb, loss_cb, ckpoint_cb]
|
||||
else:
|
||||
callbacks = [time_cb, loss_cb, ckpoint_cb]
|
||||
|
||||
model.train(epoch_size, dataset, callbacks=callbacks)
|
Loading…
Reference in New Issue