From: @sanshenghua2
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-04-28 19:45:31 +08:00 committed by Gitee
commit fca07903e1
30 changed files with 1972 additions and 0 deletions

View File

View File

View File

View File

View File

View File

@ -0,0 +1,298 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [VGG描述](#vgg描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练](#训练)
- [评估](#评估)
- [参数配置](#参数配置)
- [训练过程](#训练过程)
- [训练](#训练-1)
- [GPU处理器环境运行VGG19](#gpu处理器环境运行vgg19)
- [评估过程](#评估过程)
- [评估](#评估-1)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# VGG描述
于2014年提出的VGG是用于大规模图像识别的非常深的卷积网络。它在ImageNet大型视觉识别大赛2014ILSVRC14中获得了目标定位第一名和图像分类第二名。
[论文](https://arxiv.org/abs/1409.1556): Simonyan K, zisserman A. Very Deep Convolutional Networks for Large-Scale Image Recognition[J]. arXiv preprint arXiv:1409.1556, 2014.
# 模型架构
VGG 19网络主要由几个基本模块包括卷积层和池化层和三个连续密集层组成。
这里的基本模块主要包括以下基本操作: **3×3卷积**和**2×2最大池化**
# 数据集
## 使用的数据集:[ImageNet2012](http://www.image-net.org/)
- 数据集大小约146 GB共1000个类、128万张彩色图像
- 训练集140 GB1,281,167张图像
- 测试集6.4 GB50, 000张图像
- 数据格式RGB图像。
- 注数据在src/dataset.py中处理。
## 数据集组织方式
ImageNet2012
> 将ImageNet2012数据集解压到任意路径文件夹结构应包含训练数据集和评估数据集如下所示
>
> ```bash
> .
> └─dataset
> ├─ilsvrc # 训练数据集
> └─validation_preprocess # 评估数据集
> ```
# 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# 环境要求
- 硬件Ascend或GPU
- 准备Ascend或GPU处理器搭建硬件环境。如需试用昇腾处理器请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com审核通过即可获得资源。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```python
# 训练示例
python train.py --data_path=[DATA_PATH] --device_id=[DEVICE_ID] > output.train.log 2>&1 &
# 分布式训练示例
sh run_distribute_train.sh [RANL_TABLE_JSON] [DATA_PATH]
# 评估示例
python eval.py --data_path=[DATA_PATH] --pre_trained=[PRE_TRAINED] > output.eval.log 2>&1 &
```
分布式训练需要提前创建JSON格式的HCCL配置文件。
具体操作,参见:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>
- GPU处理器环境运行
```python
# 训练示例
python train.py --device_target="GPU" --device_id=[DEVICE_ID] --dataset=[DATASET_TYPE] --data_path=[DATA_PATH] > output.train.log 2>&1 &
# 分布式训练示例
sh run_distribute_train_gpu.sh [DATA_PATH]
# 评估示例
python eval.py --device_target="GPU" --device_id=[DEVICE_ID] --dataset=[DATASET_TYPE] --data_path=[DATA_PATH] --pre_trained=[PRE_TRAINED] > output.eval.log 2>&1 &
```
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── README.md // 所有模型相关说明
├── vgg19
├── README.md // GoogLeNet相关说明
├── scripts
│ ├── run_distribute_train.sh // Ascend分布式训练shell脚本
│ ├── run_distribute_train_gpu.sh // GPU分布式训练shell脚本
├── src
│ ├── utils
│ │ ├── logging.py // 日志格式设置
│ │ ├── sampler.py // 为数据集创建采样器
│ │ ├── util.py // 工具函数
│ │ ├── var_init.py // 网络参数init方法
│ ├── config.py // 参数配置
│ ├── crossentropy.py // 损失计算
│ ├── dataset.py // 创建数据集
│ ├── linear_warmup.py // 线性学习率
│ ├── warmup_cosine_annealing_lr.py // 余弦退火学习率
│ ├── warmup_step_lr.py // 单次或多次迭代学习率
│ ├──vgg.py // VGG架构
├── train.py // 训练脚本
├── eval.py // 评估脚本
```
## 脚本参数
### 训练
```bash
用法train.py [--device_target TARGET][--data_path DATA_PATH]
[--dataset DATASET_TYPE][--is_distributed VALUE]
[--device_id DEVICE_ID][--pre_trained PRE_TRAINED]
[--ckpt_path CHECKPOINT_PATH][--ckpt_interval INTERVAL_STEP]
选项:
--device_target 训练后端类型Ascend或GPU默认为Ascend。
--dataset 数据集类型cifar10或imagenet2012。
--is_distributed 训练方式是否为分布式训练值可以是0或1。
--data_path 数据集存储路径
--device_id 用于训练模型的设备。
--pre_trained 预训练检查点文件路径。
--ckpt_path 存放检查点的路径。
--ckpt_interval 保存检查点的轮次间隔。
```
### 评估
```bash
用法eval.py [--device_target TARGET][--data_path DATA_PATH]
[--dataset DATASET_TYPE][--pre_trained PRE_TRAINED]
[--device_id DEVICE_ID]
选项:
--device_target 评估后端类型Ascend或GPU默认为Ascend。
--dataset 数据集类型cifar10或imagenet2012。
--data_path 数据集存储路径。
--device_id 用于评估模型的设备。
--pre_trained 用于评估模型的检查点文件路径。
```
## 参数配置
在config.py中可以同时配置训练参数和评估参数。
- VGG19配置ImageNet2012数据集
```bash
"num_classes": 1000, # 数据集类数
"lr": 0.01, # 学习率
"lr_init": 0.01, # 初始学习率
"lr_max": 0.1, # 最大学习率
"lr_epochs": '30,60,90,120', # 基于变化lr的轮次
"lr_scheduler": "cosine_annealing", # 学习率模式
"warmup_epochs": 0, # 热身轮次数
"batch_size": 32, # 输入张量的批次大小
"max_epoch": 150, # 只对训练有效推理固定值为1
"momentum": 0.9, # 动量
"weight_decay": 1e-4, # 权重衰减
"loss_scale": 1024, # 损失放大
"label_smooth": 1, # 标签平滑
"label_smooth_factor": 0.1, # 标签平滑因子
"buffer_size": 10, # 混洗缓冲区大小
"image_size": '224,224', # 图像大小
"pad_mode": 'pad', # conv2d的填充方式
"padding": 1, # conv2d的填充值
"has_bias": True, # conv2d是否有偏差
"batch_norm": False, # 在conv2d中是否有batch_norm
"keep_checkpoint_max": 10, # 只保留最后一个keep_checkpoint_max检查点
"initialize_mode": "KaimingNormal", # conv2d init模式
"has_dropout": True # 是否使用Dropout层
```
## 训练过程
### 训练
#### GPU处理器环境运行VGG19
- 单设备训练1p
```bash
python train.py --device_target="GPU" --dataset="imagenet2012" --is_distributed=0 --data_path=$DATA_PATH > output.train.log 2>&1 &
```
- 分布式训练
```bash
# 分布式训练8p
bash scripts/run_distribute_train_gpu.sh /path/ImageNet2012/train"
```
## 评估过程
### 评估
- 评估过程如下需要指定数据集类型为“cifar10”或“imagenet2012”。
```bash
# 使用ImageNet2012数据集
python eval.py --data_path=your_data_path --dataset="imagenet2012" --device_target="GPU" --pre_trained=./*-150-5004.ckpt > output.eval.log 2>&1 &
```
- 上述python命令在后台运行可通过`output.eval.log`文件查看结果。准确率如下:
```bash
# 使用ImageNet2012数据集
after allreduce eval: top1_correct=37101, tot=49984,acc=74.23%
after allreduce eval: top5_correct=46007, tot=49984,acc=92.04%
```
# 模型描述
## 性能
### 训练性能
| 参数 | VGG19(Ascend) |
| -------------------------- | ---------------------------------------------- |
| 模型版本 | VGG19 |
| 资源 | Ascend 910CPU2.60GHz192核内存755 GB |
| 上传日期 | 2021-03-18 |
| MindSpore版本 | 1.1.1-alpha |
| 数据集 |ImageNet2012 |
| 训练参数 |epoch=90, steps=2502, batch_size = 64, lr=0.1 |
| 优化器 | Momentum |
| 损失函数 | SoftmaxCrossEntropy |
| 输出 | 概率 |
| 损失 |1.5~2.0 |
| 速度 | 8卡97.4毫秒/步 |
| 总时长 | 8卡6.1小时 |
| 调优检查点 | 1.1 GB.ckpt 文件) |
| 脚本 |[VGG19](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/vgg19) | |
### 评估性能
| 参数 | VGG19(Ascend) |
| ------------------- | --------------------------- |
| 模型版本 | VGG19 |
| 资源 | Ascend 910 |
| 上传日期 | 2021-03-18 |
| MindSpore版本 | 1.1.1-alpha |
| 数据集 | ImageNet20125000张图像 |
| batch_size | 64 |
| 输出 | 概率 |
| 准确率 | 8卡top1_correct 74.23%,top5_correct 92.04%; |
# 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子同时还使用了train.py中的随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,212 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval"""
import os
import time
import argparse
import datetime
import glob
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from src.utils.logging import get_logger
from src.vgg import vgg19
from src.dataset import vgg_create_dataset
from src.dataset import classification_dataset
class ParameterReduce(nn.Cell):
"""ParameterReduce"""
def __init__(self):
super(ParameterReduce, self).__init__()
self.cast = P.Cast()
self.reduce = P.AllReduce()
def construct(self, x):
one = self.cast(F.scalar_to_array(1.0), mstype.float32)
out = x * one
ret = self.reduce(out)
return ret
def parse_args(cloud_args=None):
"""parse_args"""
parser = argparse.ArgumentParser('mindspore classification test')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
# dataset related
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10")
parser.add_argument('--data_path', type=str, default='', help='eval data dir')
parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu')
# network related
parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt')
parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. '
'If it is a direction, it will test all ckpt')
# logging related
parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args_opt = parser.parse_args()
args_opt = merge_args(args_opt, cloud_args)
if args_opt.dataset == "cifar10":
from src.config import cifar_cfg as cfg
else:
from src.config import imagenet_cfg as cfg
args_opt.image_size = cfg.image_size
args_opt.num_classes = cfg.num_classes
args_opt.per_batch_size = cfg.batch_size
args_opt.momentum = cfg.momentum
args_opt.weight_decay = cfg.weight_decay
args_opt.buffer_size = cfg.buffer_size
args_opt.pad_mode = cfg.pad_mode
args_opt.padding = cfg.padding
args_opt.has_bias = cfg.has_bias
args_opt.batch_norm = cfg.batch_norm
args_opt.initialize_mode = cfg.initialize_mode
args_opt.has_dropout = cfg.has_dropout
args_opt.image_size = list(map(int, args_opt.image_size.split(',')))
return args_opt
def get_top5_acc(top5_arg, gt_class):
sub_count = 0
for top5, gt in zip(top5_arg, gt_class):
if gt in top5:
sub_count += 1
return sub_count
def merge_args(args, cloud_args):
"""merge_args"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def test(cloud_args=None):
"""test"""
args = parse_args(cloud_args)
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=False)
if os.getenv('DEVICE_ID', "not_set").isdigit() and args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
if args.dataset == "cifar10":
net = vgg19(num_classes=args.num_classes, args=args)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum,
weight_decay=args.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
param_dict = load_checkpoint(args.pre_trained)
load_param_into_net(net, param_dict)
net.set_train(False)
dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, training=False)
res = model.eval(dataset)
print("result: ", res)
else:
# network
args.logger.important_info('start create network')
if os.path.isdir(args.pre_trained):
models = list(glob.glob(os.path.join(args.pre_trained, '*.ckpt')))
print(models)
if args.graph_ckpt:
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0])
else:
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
args.models = sorted(models, key=f)
else:
args.models = [args.pre_trained,]
for model in args.models:
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size, mode='eval')
eval_dataloader = dataset.create_tuple_iterator(output_numpy=True, num_epochs=1)
network = vgg19(args.num_classes, args, phase="test")
# pre_trained
load_param_into_net(network, load_checkpoint(model))
network.add_flags_recursive(fp16=True)
img_tot = 0
top1_correct = 0
top5_correct = 0
network.set_train(False)
t_end = time.time()
it = 0
for data, gt_classes in eval_dataloader:
output = network(Tensor(data, mstype.float32))
output = output.asnumpy()
top1_output = np.argmax(output, (-1))
top5_output = np.argsort(output)[:, -5:]
t1_correct = np.equal(top1_output, gt_classes).sum()
top1_correct += t1_correct
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += args.per_batch_size
if args.rank == 0 and it == 0:
t_end = time.time()
it = 1
if args.rank == 0:
time_used = time.time() - t_end
fps = (img_tot - args.per_batch_size) * args.group_size / time_used
args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
results = [[top1_correct], [top5_correct], [img_tot]]
args.logger.info('before results={}'.format(results))
results = np.array(results)
args.logger.info('after results={}'.format(results))
top1_correct = results[0, 0]
top5_correct = results[1, 0]
img_tot = results[2, 0]
acc1 = 100.0 * top1_correct / img_tot
acc5 = 100.0 * top5_correct / img_tot
args.logger.info('after allreduce eval: top1_correct={}, tot={},'
'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1))
args.logger.info('after allreduce eval: top5_correct={}, tot={},'
'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5))
if __name__ == "__main__":
test()

View File

@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into models"""
import argparse
import numpy as np
from mindspore import Tensor, context
import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, export
from src.vgg import vgg19
parser = argparse.ArgumentParser(description='VGG19 export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10", help='ckpt file')
parser.add_argument('--ckpt_file', type=str, required=True, help='vgg19 ckpt file.')
parser.add_argument('--file_name', type=str, default='vgg19', help='vgg19 output file name.')
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
if args.dataset == "cifar10":
from src.config import cifar_cfg as cfg
else:
from src.config import imagenet_cfg as cfg
args.num_classes = cfg.num_classes
args.pad_mode = cfg.pad_mode
args.padding = cfg.padding
args.has_bias = cfg.has_bias
args.initialize_mode = cfg.initialize_mode
args.batch_norm = cfg.batch_norm
args.has_dropout = cfg.has_dropout
args.image_size = list(map(int, cfg.image_size.split(',')))
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
if args.dataset == "cifar10":
net = vgg19(num_classes=args.num_classes, args=args)
else:
net = vgg19(args.num_classes, args, phase="test")
net.add_flags_recursive(fp19=True)
load_checkpoint(args.ckpt_file, net=net)
net.set_train(False)
input_data = Tensor(np.zeros([cfg.batch_size, 3, args.image_size[0], args.image_size[1]]), mstype.float32)
export(net, input_data, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.vgg import vgg19 as VGG19
def vgg19(*args, **kwargs):
return VGG19(*args, **kwargs)
def create_network(name, *args, **kwargs):
if name == "vgg19":
return vgg19(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,76 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
if [ ! -d $2 ]
then
echo "error: DATA_PATH=$2 is not a directory"
exit 1
fi
dataset_type='cifar10'
if [ $# == 3 ]
then
if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
dataset_type=$3
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$1
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
avg=`expr $cpus \/ $RANK_SIZE`
gap=`expr $avg \- 1`
script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
src_dir=$script_dir/..
start_idx=0
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg`
end=`expr $start \+ $gap`
cmdopt=$start"-"$end
export DEVICE_ID=`expr $i \+ $start_idx`
export RANK_ID=$i
rm -rf ./train_parallel$DEVICE_ID
mkdir ./train_parallel$DEVICE_ID
cp $src_dir/*.py ./train_parallel$DEVICE_ID
cp -r $src_dir/src ./train_parallel$DEVICE_ID
cd ./train_parallel$DEVICE_ID || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
env > env.log
taskset -c $cmdopt python train.py --data_path=$2 --device_target="Ascend" --device_id=$DEVICE_ID --is_distributed=1 --dataset=$dataset_type &> log &
cd ..
done

View File

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train_gpu.sh DATA_PATH"
echo "for example: bash run_distribute_train_gpu.sh /path/ImageNet2012/train"
echo "=============================================================================================================="
DATA_PATH=$1
mpirun -n 8 --output-filename log_output --merge-stderr-to-stdout \
python train.py \
--device_target="GPU" \
--dataset="imagenet2012" \
--is_distributed=1 \
--data_path=$DATA_PATH > output.train.log 2>&1 &

View File

@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval.sh DATA_PATH DATASET_TYPE DEVICE_TYPE CHECKPOINT_PATH"
echo "for example: bash run_eval.sh /path/ImageNet2012/train cifar10 Ascend /path/a.ckpt "
echo "=============================================================================================================="
DATA_PATH=$1
DATASET_TYPE=$2
DEVICE_TYPE=$3
CHECKPOINT_PATH=$4
python eval.py \
--data_path=$DATA_PATH \
--dataset=$DATASET_TYPE \
--device_target=$DEVICE_TYPE \
--pre_trained=$CHECKPOINT_PATH > output.eval.log 2>&1 &

View File

@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as edict
# config for vgg19, cifar10
cifar_cfg = edict({
"num_classes": 10,
"lr": 0.01,
"lr_init": 0.01,
"lr_max": 0.1,
"lr_epochs": '30,60,90,120',
"lr_scheduler": "step",
"warmup_epochs": 5,
"batch_size": 64,
"max_epoch": 70,
"momentum": 0.9,
"weight_decay": 5e-4,
"loss_scale": 1.0,
"label_smooth": 0,
"label_smooth_factor": 0,
"buffer_size": 10,
"image_size": '224,224',
"pad_mode": 'same',
"padding": 0,
"has_bias": False,
"batch_norm": True,
"keep_checkpoint_max": 10,
"initialize_mode": "XavierUniform",
"has_dropout": False
})
# config for vgg19, imagenet2012
imagenet_cfg = edict({
"num_classes": 1000,
"lr": 0.04,
"lr_init": 0.01,
"lr_max": 0.1,
"lr_epochs": '30,60,90,120',
"lr_scheduler": 'cosine_annealing',
"warmup_epochs": 0,
"batch_size": 64,
"max_epoch": 90,
"momentum": 0.9,
"weight_decay": 1e-4,
"loss_scale": 1024,
"label_smooth": 1,
"label_smooth_factor": 0.1,
"buffer_size": 10,
"image_size": '224,224',
"pad_mode": 'pad',
"padding": 1,
"has_bias": False,
"batch_norm": False,
"keep_checkpoint_max": 10,
"initialize_mode": "KaimingNormal",
"has_dropout": True
})

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

View File

@ -0,0 +1,195 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset processing.
"""
import os
from mindspore.common import dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
from PIL import Image, ImageFile
from src.utils.sampler import DistributedSampler
ImageFile.LOAD_TRUNCATED_IMAGES = True
def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True):
"""Data operations."""
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
data_set = de.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id)
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize(image_size) # interpolation default BILINEAR
rescale_op = vision.Rescale(rescale, shift)
normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
changeswap_op = vision.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op,
changeswap_op]
# apply map operations on images
data_set = data_set.map(operations=type_cast_op, input_columns="label")
data_set = data_set.map(operations=c_trans, input_columns="image")
# apply repeat operations
data_set = data_set.repeat(repeat_num)
# apply shuffle operations
data_set = data_set.shuffle(buffer_size=10)
# apply batch operations
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
return data_set
def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_size=1,
mode='train',
input_mode='folder',
root='',
num_parallel_workers=None,
shuffle=None,
sampler=None,
repeat_num=1,
class_indexing=None,
drop_remainder=True,
transform=None,
target_transform=None):
"""
A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
are written into a textfile.
Args:
data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
Or path of the textfile that contains every image's path of the dataset.
image_size (Union(int, sequence)): Size of the input images.
per_batch_size (int): the batch size of evey step during training.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided
into (default=None).
mode (str): "train" or others. Default: " train".
input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
root (str): the images path for "input_mode="txt"". Default: " ".
num_parallel_workers (int): Number of workers to read the data. Default: None.
shuffle (bool): Whether or not to perform shuffle on the dataset
(default=None, performs shuffle).
sampler (Sampler): Object used to choose samples from the dataset. Default: None.
repeat_num (int): the num of repeat dataset.
class_indexing (dict): A str-to-int mapping from folder name to index
(default=None, the folder names will be sorted
alphabetically and each class will be given a
unique index starting from 0).
Examples:
>>> from src.dataset import classification_dataset
>>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
>>> data_dir = "/path/to/imagefolder_directory"
>>> de_dataset = classification_dataset(data_dir, image_size=[224, 244],
>>> per_batch_size=64, rank=0, group_size=4)
>>> # Path of the textfile that contains every image's path of the dataset.
>>> data_dir = "/path/to/dataset/images/train.txt"
>>> images_dir = "/path/to/dataset/images"
>>> de_dataset = classification_dataset(data_dir, image_size=[224, 244],
>>> per_batch_size=64, rank=0, group_size=4,
>>> input_mode="txt", root=images_dir)
"""
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if transform is None:
if mode == 'train':
transform_img = [
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transform_img = [
vision.Decode(),
vision.Resize((256, 256)),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transform_img = transform
if target_transform is None:
transform_label = [C.TypeCast(mstype.int32)]
else:
transform_label = target_transform
if input_mode == 'folder':
de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
num_shards=group_size, shard_id=rank)
else:
dataset = TxtDataset(root, data_dir)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
de_dataset = de_dataset.map(operations=transform_img, input_columns="image", num_parallel_workers=8)
de_dataset = de_dataset.map(operations=transform_label, input_columns="label", num_parallel_workers=8)
columns_to_project = ["image", "label"]
de_dataset = de_dataset.project(columns=columns_to_project)
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
de_dataset = de_dataset.repeat(repeat_num)
return de_dataset
class TxtDataset:
"""
create txt dataset.
Args:
Returns:
de_dataset.
"""
def __init__(self, root, txt_name):
super(TxtDataset, self).__init__()
self.imgs = []
self.labels = []
fin = open(txt_name, "r")
for line in fin:
img_name, label = line.strip().split(' ')
self.imgs.append(os.path.join(root, img_name))
self.labels.append(int(label))
fin.close()
def __getitem__(self, index):
img = Image.open(self.imgs[index]).convert('RGB')
return img, self.labels[index]
def __len__(self):
return len(self.imgs)

View File

@ -0,0 +1,23 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
linear warm up learning rate.
"""
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get logger.
"""
import logging
import os
import sys
from datetime import datetime
class LOGGER(logging.Logger):
"""
set up logging file.
Args:
logger_name (string): logger name.
log_dir (string): path of logger.
Returns:
string, logger path
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""set up log file"""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER("mindversion", rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
choose samples from the dataset
"""
import math
import numpy as np
class DistributedSampler():
"""
sampling the dataset.
Args:
Returns:
num_samples, number of samples.
"""
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank
self.group_size = group_size
self.dataset_length = len(self.dataset)
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
self.total_size = self.num_samples * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_length).tolist()
else:
indices = list(range(len(self.dataset_length)))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
def get_param_groups(network):
"""Param groups for optimizer."""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

View File

@ -0,0 +1,210 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Initialize.
"""
import math
from functools import reduce
import numpy as np
import mindspore.nn as nn
from mindspore.common import initializer as init
def _calculate_gain(nonlinearity, param=None):
r"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function
param: optional parameter for the non-linear function
Examples:
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of `num` to `arr`."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_in_and_out(arr):
"""
Calculate n_in and n_out.
Args:
arr (Array): Input array.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dim = len(arr.shape)
if dim < 2:
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
n_in = arr.shape[1]
n_out = arr.shape[0]
if dim > 2:
counter = reduce(lambda x, y: x * y, arr.shape[2:])
n_in *= counter
n_out *= counter
return n_in, n_out
def _select_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_in_and_out(array)
return fan_in if mode == 'fan_in' else fan_out
class KaimingInit(init.Initializer):
r"""
Base Class. Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingInit, self).__init__()
self.mode = mode
self.gain = _calculate_gain(nonlinearity, a)
def _initialize(self, arr):
pass
class KaimingUniform(KaimingInit):
r"""
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
class KaimingNormal(KaimingInit):
r"""
Initialize the array with He kaiming normal algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
std = self.gain / math.sqrt(fan)
data = np.random.normal(0, std, arr.shape)
_assignment(arr, data)
def default_recurisive_init(custom_cell):
"""default_recurisive_init"""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass

View File

@ -0,0 +1,150 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Image classifiation.
"""
import math
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.common import initializer as init
from mindspore.common.initializer import initializer
from .utils.var_init import default_recurisive_init, KaimingNormal
def _make_layer(base, args, batch_norm):
"""Make stage network of VGG."""
layers = []
in_channels = 3
for v in base:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
weight = 'ones'
if args.initialize_mode == "XavierUniform":
weight_shape = (v, in_channels, 3, 3)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32)
conv2d = nn.Conv2d(in_channels=in_channels,
out_channels=v,
kernel_size=3,
padding=args.padding,
pad_mode=args.pad_mode,
has_bias=args.has_bias,
weight_init=weight)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
else:
layers += [conv2d, nn.ReLU()]
in_channels = v
return nn.SequentialCell(layers)
class Vgg(nn.Cell):
"""
VGG network definition.
Args:
base (list): Configuration for different layers, mainly the channel number of Conv layer.
num_classes (int): Class numbers. Default: 1000.
batch_norm (bool): Whether to do the batchnorm. Default: False.
batch_size (int): Batch size. Default: 1.
include_top(bool): Whether to include the 3 fully-connected layers at the top of the network. Default: True.
Returns:
Tensor, infer output tensor.
Examples:
>>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
>>> num_classes=1000, batch_norm=False, batch_size=1)
"""
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train",
include_top=True):
super(Vgg, self).__init__()
_ = batch_size
self.layers = _make_layer(base, args, batch_norm=batch_norm)
self.include_top = include_top
self.flatten = nn.Flatten()
dropout_ratio = 0.5
if not args.has_dropout or phase == "test":
dropout_ratio = 1.0
self.classifier = nn.SequentialCell([
nn.Dense(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dropout(dropout_ratio),
nn.Dense(4096, 4096),
nn.ReLU(),
nn.Dropout(dropout_ratio),
nn.Dense(4096, num_classes)])
if args.initialize_mode == "KaimingNormal":
default_recurisive_init(self)
self.custom_init_weight()
def construct(self, x):
x = self.layers(x)
if self.include_top:
x = self.flatten(x)
x = self.classifier(x)
return x
def custom_init_weight(self):
"""
Init the weight of Conv2d and Dense in the net.
"""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(
KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(init.initializer(
'zeros', cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(init.initializer(
init.Normal(0.01), cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(init.initializer(
'zeros', cell.bias.shape, cell.bias.dtype))
cfg = {
'11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
def vgg19(num_classes=1000, args=None, phase="train"):
"""
Get Vgg19 neural network with batch normalization.
Args:
num_classes (int): Class numbers. Default: 1000.
args(namespace): param for net init.
phase(str): train or test mode.
Returns:
Cell, cell instance of Vgg19 neural network with batch normalization.
Examples:
>>> vgg19(num_classes=1000, args=args)
"""
if args is None:
from .config import cifar_cfg
args = cifar_cfg
net = Vgg(cfg['19'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
return net

View File

@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
warm up cosine annealing learning rate.
"""
import math
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""warm up cosine annealing learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

View File

@ -0,0 +1,84 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
warm up step learning rate.
"""
from collections import Counter
import numpy as np
from .linear_warmup import linear_warmup_lr
def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""Set learning rate."""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr_value = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr_value = float(lr_max) * base * base
if lr_value < 0.0:
lr_value = 0.0
lr_each_step.append(lr_value)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""warmup_step_lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)

View File

@ -0,0 +1,235 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train vgg19 example on cifar10########################
"""
import argparse
import datetime
import os
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
from src.dataset import vgg_create_dataset
from src.dataset import classification_dataset
from src.crossentropy import CrossEntropy
from src.warmup_step_lr import warmup_step_lr
from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from src.warmup_step_lr import lr_steps
from src.utils.logging import get_logger
from src.utils.util import get_param_groups
from src.vgg import vgg19
set_seed(1)
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore classification training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
# dataset related
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10")
parser.add_argument('--data_path', type=str, default='', help='train data dir')
# network related
parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load')
parser.add_argument('--lr_gamma', type=float, default=0.1,
help='decrease lr by a factor of exponential lr_scheduler')
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
parser.add_argument('--T_max', type=int, default=90, help='T-max in cosine_annealing scheduler')
# logging and checkpoint related
parser.add_argument('--log_interval', type=int, default=1, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=1, help='ckpt_interval')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
# distributed related
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args_opt = parser.parse_args()
args_opt = merge_args(args_opt, cloud_args)
if args_opt.dataset == "cifar10":
from src.config import cifar_cfg as cfg
else:
from src.config import imagenet_cfg as cfg
args_opt.label_smooth = cfg.label_smooth
args_opt.label_smooth_factor = cfg.label_smooth_factor
args_opt.lr_scheduler = cfg.lr_scheduler
args_opt.loss_scale = cfg.loss_scale
args_opt.max_epoch = cfg.max_epoch
args_opt.warmup_epochs = cfg.warmup_epochs
args_opt.lr = cfg.lr
args_opt.lr_init = cfg.lr_init
args_opt.lr_max = cfg.lr_max
args_opt.momentum = cfg.momentum
args_opt.weight_decay = cfg.weight_decay
args_opt.per_batch_size = cfg.batch_size
args_opt.num_classes = cfg.num_classes
args_opt.buffer_size = cfg.buffer_size
args_opt.ckpt_save_max = cfg.keep_checkpoint_max
args_opt.pad_mode = cfg.pad_mode
args_opt.padding = cfg.padding
args_opt.has_bias = cfg.has_bias
args_opt.batch_norm = cfg.batch_norm
args_opt.initialize_mode = cfg.initialize_mode
args_opt.has_dropout = cfg.has_dropout
args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(',')))
args_opt.image_size = list(map(int, cfg.image_size.split(',')))
return args_opt
def merge_args(args_opt, cloud_args):
"""dictionary"""
args_dict = vars(args_opt)
if isinstance(cloud_args, dict):
for key_arg in cloud_args.keys():
val = cloud_args[key_arg]
if key_arg in args_dict and val:
arg_type = type(args_dict[key_arg])
if arg_type is not None:
val = arg_type(val)
args_dict[key_arg] = val
return args_opt
if __name__ == '__main__':
args = parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
device_num = int(os.environ.get("DEVICE_NUM", 1))
if args.is_distributed:
if args.device_target == "Ascend":
init()
context.set_context(device_id=args.device_id)
elif args.device_target == "GPU":
init()
args.rank = get_rank()
args.group_size = get_group_size()
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[2, 18])
else:
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
if args.dataset == "cifar10":
dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size)
else:
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size,
args.rank, args.group_size)
batch_num = dataset.get_dataset_size()
args.steps_per_epoch = dataset.get_dataset_size()
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
# get network and init
network = vgg19(args.num_classes, args)
# pre_trained
if args.pre_trained:
load_param_into_net(network, load_checkpoint(args.pre_trained))
# lr scheduler
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.T_max,
args.eta_min)
elif args.lr_scheduler == 'step':
lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs,
total_epochs=args.max_epoch, steps_per_epoch=batch_num)
else:
raise NotImplementedError(args.lr_scheduler)
# optimizer
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
if args.dataset == "cifar10":
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
else:
if not args.label_smooth:
args.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
# define callbacks
time_cb = TimeMonitor(data_size=batch_num)
loss_cb = LossMonitor(per_print_times=batch_num)
callbacks = [time_cb, loss_cb]
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
keep_checkpoint_max=args.ckpt_save_max)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='{}'.format(args.rank))
callbacks.append(ckpt_cb)
model.train(args.max_epoch, dataset, callbacks=callbacks)