Add autoaugment

This commit is contained in:
lukelook 2021-06-21 15:55:29 +08:00
parent a8553078b9
commit 1440b1d0a0
29 changed files with 2746 additions and 0 deletions

View File

@ -0,0 +1,365 @@
## 目录
- [目录](#目录)
- [AutoAugment描述](#AutoAugment描述)
- [概述](#概述)
- [AutoAugment论文](#AutoAugment论文)
- [模型架构](#模型架构)
- [WideResNet论文](#WideResNet论文)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本参数](#脚本参数)
- [脚本使用](#脚本使用)
- [AutoAugment算子用法](#AutoAugment算子用法)
- [训练脚本用法](#训练脚本用法)
- [评估脚本用法](#评估脚本用法)
- [导出脚本用法](#导出脚本用法)
- [模型描述](#模型描述)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#ModelZoo主页)
## AutoAugment描述
### 概述
数据增广是提升图像分类器准确度和泛化能力的一种重要手段,传统数据增广方法主要依赖人工设计并使用固定的增广流程(例如组合应用`RandomCrop`与`RandomHorizontalFlip`图像变换算子)。
不同于传统方法AutoAugment为数据增广提出了一种有效的策略空间设计使得研究者能够使用不同的搜索算法例如强化学习、进化算法、甚至是随机搜索等来为特定的模型、数据集自动定制增广流程。具体而言AutoAugment提出的策略空间主要涵盖以下概念
| 概念名称 | 英文对照 | 概念简述 |
|:---|:---|:---|
| 算子 | Operation | 图像变换算子例如平移、旋转等AutoAugment选用的算子均不改变输入图片的大小和类型每种算子具有两个可搜索的参数为概率及量级。 |
| 概率 | Probability | 随机应用某一图像变换算子的概率,如不应用,则直接返回输入图片。 |
| 量级 | Magnitude | 应用某一图像变换算子的强度,例如平移的像素数、旋转的角度等。 |
| 子策略 | Subpolicy | 每个子策略包含两个算子;应用子策略时,两个算子依据概率和量级按序变换输入图像。 |
| 策略 | Policy | 每个策略包含若干个子策略,对数据进行增广时,策略为每张图片随机选择一个子策略。 |
由于算子数目是有限的、每个算子的概率和量级参数均可离散化因此AutoAugment提出的策略空间能够引出一个有限状态的离散搜索问题。特别地实验表明AutoAugment提出的策略空间还具有一定的可迁移能力即使用某一模型、数据集组合搜索得到的策略能被迁移到针对同一数据集的其它模型、或使用某一数据集搜索得到的策略能被迁移到其它相似的数据集。
本示例主要针对AutoAugment提出的策略空间进行了实现开发者可以基于本示例使用AutoAugment论文列出的“好策略”对数据集进行增广、或基于本示例设计搜索算法以自动定制增广流程。
### AutoAugment论文
Cubuk, Ekin D., et al. "Autoaugment: Learning augmentation strategies from data." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
## 模型架构
除实现AutoAugment提出的策略空间外本示例还提供了Wide-ResNet模型的简单实现以供开发者参考。
### WideResNet论文
Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).
## 数据集
本示例以Cifar10为例介绍AutoAugment的使用方法并验证本示例的有效性。
本示例使用[CIFAR-10 binary version](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz),其目录结构如下:
```bash
cifar-10-batches-bin
├── batches.meta.txt
├── data_batch_1.bin
├── data_batch_2.bin
├── data_batch_3.bin
├── data_batch_4.bin
├── data_batch_5.bin
├── readme.html
└── test_batch.bin
```
## 环境要求
- 硬件
- 准备Ascend处理器搭建硬件环境。如需试用昇腾处理器请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo体验资源申请表.docx)至ascend@huawei.com审核通过即可获得资源。
- 框架
- [MindSpore](https://www.mindspore.cn/install/)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
## 快速入门
完成计算设备和框架环境的准备后,开发者可以运行如下指令对本示例进行训练和评估。
- Ascend处理器环境运行
```bash
# 8卡分布式训练
用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
# 单卡训练
用法bash run_standalone_train.sh [DATASET_PATH]
# 单卡评估
用法bash run_eval.sh [CHECKPOINT_PATH] [DATASET_PATH]
```
分布式训练需要提前创建JSON格式的HCCL配置文件。
具体操作,请参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
## 脚本说明
```bash
.
├── export.py # 导出网络
├── mindspore_hub_conf.py # MindSpore Hub配置
├── README.md # 说明文档
├── run_distribute_train.sh # Ascend处理器环境多卡训练脚本
├── run_eval.sh # Ascend处理器环境评估脚本
├── run_standalone_train.sh # Ascend处理器环境单卡训练脚本
├── src
│ ├── config.py # 模型训练/测试配置文件
│ ├── dataset
│ │ ├── autoaugment
│ │ │ ├── aug.py # AutoAugment策略
│ │ │ ├── aug_test.py # AutoAugment策略测试及可视化
│ │ │ ├── ops
│ │ │ │ ├── crop.py # RandomCrop算子
│ │ │ │ ├── cutout.py # RandomCutout算子
│ │ │ │ ├── effect.py # 图像特效算子
│ │ │ │ ├── enhance.py # 图像增强算子
│ │ │ │ ├── ops_test.py # 算子测试及可视化
│ │ │ │ └── transform.py # 图像变换算子
│ │ │ └── third_party
│ │ │ └── policies.py # AutoAugment搜索得到的“好策略”
│ │ └── cifar10.py # Cifar10数据集处理
│ ├── network
│ │ └── wrn.py # Wide-ResNet模型定义
│ ├── optim
│ │ └── lr.py # Cosine学习率定义
│ └── utils # 初始化日志格式等
├── test.py # 测试网络
└── train.py # 训练网络
```
## 脚本参数
在[src/config.py](./src/config.py)中可以配置训练参数、数据集路径等参数。
```python
# Set to mute logs with lower levels.
self.log_level = logging.INFO
# Random seed.
self.seed = 1
# Type of device(s) where the model would be deployed to.
# Choices: ['Ascend', 'GPU', 'CPU']
self.device_target = 'Ascend'
# The model to use. Choices: ['wrn']
self.net = 'wrn'
# The dataset to train or test against. Choices: ['cifar10']
self.dataset = 'cifar10'
# The number of classes.
self.class_num = 10
# Path to the folder where the intended dataset is stored.
self.dataset_path = './cifar-10-batches-bin'
# Batch size for both training mode and testing mode.
self.batch_size = 128
# Indicates training or testing mode.
self.training = training
# Testing parameters.
if not self.training:
# The checkpoint to load and test against.
# Example: './checkpoint/train_wrn_cifar10-200_390.ckpt'
self.checkpoint_path = None
# Training parameters.
if self.training:
# Whether to apply auto-augment or not.
self.augment = True
# The number of device(s) to be used for training.
self.device_num = 1
# Whether to train the model in a distributed mode or not.
self.run_distribute = False
# The pre-trained checkpoint to load and train from.
# Example: './checkpoint/train_wrn_cifar10-200_390.ckpt'
self.pre_trained = None
# Number of epochs to train.
self.epoch_size = 200
# Momentum factor.
self.momentum = 0.9
# L2 penalty.
self.weight_decay = 5e-4
# Learning rate decaying mode. Choices: ['cosine']
self.lr_decay_mode = 'cosine'
# The starting learning rate.
self.lr_init = 0.1
# The maximum learning rate.
self.lr_max = 0.1
# The number of warmup epochs. Note that during the warmup period,
# the learning rate grows from `lr_init` to `lr_max` linearly.
self.warmup_epochs = 5
# Loss scaling for mixed-precision training.
self.loss_scale = 1024
# Create a checkpoint per `save_checkpoint_epochs` epochs.
self.save_checkpoint_epochs = 5
# The maximum number of checkpoints to keep.
self.keep_checkpoint_max = 10
# The folder path to save checkpoints.
self.save_checkpoint_path = './checkpoint'
```
## 脚本使用
### AutoAugment算子用法
类似于[src/dataset/cifar10.py](./src/dataset/cifar10.py)为使用AutoAugment算子首先需要引入`Augment`类:
```python
# 开发者需将"src/dataset/autoaugment/"文件夹完整复制到当前目录,或使用软链接。
from autoaugment import Augment
```
AutoAugment算子与MindSpore数据集兼容直接将其用作数据集的变换算子即可
```python
dataset = dataset.map(operations=[Augment(mean=MEAN, std=STD)],
input_columns='image', num_parallel_workers=8)
```
AutoAugment支持的参数如下
```python
Args:
index (int or None): If index is not None, the indexed policy would
always be used. Otherwise, a policy would be randomly chosen from
the policies set for each image.
policies (policies found by AutoAugment or None): A set of policies
to sample from. When the given policies is None, good policies found
on cifar10 would be used.
enable_basic (bool): Whether to apply basic augmentations after
auto-augment or not. Note that basic augmentations
include RandomFlip, RandomCrop, and RandomCutout.
from_pil (bool): Whether the image passed to the operator is already a
PIL image.
as_pil (bool): Whether the returned image should be kept as a PIL image.
mean, std (list): Per-channel mean and std used to normalize the output
image. Only applicable when as_pil is False.
```
### 训练脚本用法
使用AutoAugment算子对数据集进行增广并进行模型训练
```bash
# python train.py -h
usage: train.py [-h] [--device_target {Ascend,GPU,CPU}] [--dataset {cifar10}]
[--dataset_path DATASET_PATH] [--augment AUGMENT]
[--device_num DEVICE_NUM] [--run_distribute RUN_DISTRIBUTE]
[--lr_max LR_MAX] [--pre_trained PRE_TRAINED]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
AutoAugment for image classification.
optional arguments:
-h, --help show this help message and exit
--device_target {Ascend,GPU,CPU}
Type of device(s) where the model would be deployed
to.
--dataset {cifar10} The dataset to train or test against.
--dataset_path DATASET_PATH
Path to the folder where the intended dataset is
stored.
--augment AUGMENT Whether to apply auto-augment or not.
--device_num DEVICE_NUM
The number of device(s) to be used for training.
--run_distribute RUN_DISTRIBUTE
Whether to train the model in distributed mode or not.
--lr_max LR_MAX The maximum learning rate.
--pre_trained PRE_TRAINED
The pre-trained checkpoint to load and train from.
Example: ./checkpoint/train_wrn_cifar10-200_390.ckpt
--save_checkpoint_path SAVE_CHECKPOINT_PATH
The folder path to save checkpoints.
```
### 评估脚本用法
对训练好的模型进行精度评估:
```bash
# python test.py -h
usage: test.py [-h] [--device_target {Ascend,GPU,CPU}] [--dataset {cifar10}]
[--dataset_path DATASET_PATH]
[--checkpoint_path CHECKPOINT_PATH]
AutoAugment for image classification.
optional arguments:
-h, --help show this help message and exit
--device_target {Ascend,GPU,CPU}
Type of device(s) where the model would be deployed
to.
--dataset {cifar10} The dataset to train or test against.
--dataset_path DATASET_PATH
Path to the folder where the intended dataset is
stored.
--checkpoint_path CHECKPOINT_PATH
The checkpoint to load and test against.
Example: ./checkpoint/train_wrn_cifar10-200_390.ckpt
```
### 导出脚本用法
将训练好的模型导出为AIR、ONNX或MINDIR格式
```bash
# python export.py -h
usage: export.py [-h] [--device_id DEVICE_ID] --checkpoint_path
CHECKPOINT_PATH [--file_name FILE_NAME]
[--file_format {AIR,ONNX,MINDIR}]
[--device_target {Ascend,GPU,CPU}]
WRN with AutoAugment export.
optional arguments:
-h, --help show this help message and exit
--device_id DEVICE_ID
Device id.
--checkpoint_path CHECKPOINT_PATH
Checkpoint file path.
--file_name FILE_NAME
Output file name.
--file_format {AIR,ONNX,MINDIR}
Export format.
--device_target {Ascend,GPU,CPU}
Device target.
```
## 模型描述
| 参数 | 单卡Ascend 910 | 八卡Ascend 910 |
|:---|:---|:---|
| 资源 | Ascend 910 | Ascend 910 |
| 上传日期 | 2021.06.21 | 2021.06.24 |
| MindSpore版本 | 1.2.0 | 1.2.0 |
| 训练数据集 | Cifar10 | Cifar10 |
| 训练参数 | epoch=200, batch_size=128 | epoch=200, batch_size=128, lr_max=0.8 |
| 优化器 | Momentum | Momentum |
| 输出 | 损失 | 损失 |
| 准确率 | 97.42% | 97.39% |
| 速度 | 97.73 ms/step | 106.29 ms/step |
| 总时长 | 127 min | 17 min |
| 微调检查点 | 277M.ckpt文件) | 277M.ckpt文件) |
| 脚本 | [autoaugment](./) | [autoaugment](./) |
## 随机情况说明
[train.py](./train.py)中设置了随机种子,以确保训练的可复现性。
## ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export checkpoint file into air, onnx, mindir models."""
import argparse
import numpy as np
from mindspore import (
context,
export,
load_checkpoint,
load_param_into_net,
Tensor,
)
from src.config import Config
from src.network import WRN
parser = argparse.ArgumentParser(description='WRN with AutoAugment export.')
parser.add_argument(
'--device_id', type=int, default=0,
help='Device id.',
)
parser.add_argument(
'--checkpoint_path', type=str, required=True,
help='Checkpoint file path.',
)
parser.add_argument(
'--file_name', type=str, default='wrn-autoaugment',
help='Output file name.',
)
parser.add_argument(
'--file_format', type=str, choices=['AIR', 'ONNX', 'MINDIR'],
default='AIR', help='Export format.',
)
parser.add_argument(
'--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'],
default='Ascend', help='Device target.',
)
if __name__ == '__main__':
args = parser.parse_args()
context.set_context(
mode=context.GRAPH_MODE,
device_target=args.device_target,
)
if args.device_target == 'Ascend':
context.set_context(device_id=args.device_id)
conf = Config(training=False, load_args=False)
net = WRN(160, 3, conf.class_num)
param_dict = load_checkpoint(args.checkpoint_path)
load_param_into_net(net, param_dict)
image = Tensor(np.ones((1, 3, 32, 32), np.float32))
export(
net, image,
file_name=args.file_name,
file_format=args.file_format,
)

View File

@ -0,0 +1,24 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Hub config"""
from src.network import WRN
def create_network(name, *args, **kwargs):
"""Creates a WideResNet."""
if name == 'WRN':
return WRN(*args, **kwargs)
raise NotImplementedError('%s is not implemented in the repo' % name)

View File

@ -0,0 +1,49 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage:
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi
export RANK_TABLE_FILE=$1
export DEVICE_NUM=8
export RANK_SIZE=8
PID_LIST=()
for ((i=0; i<${RANK_SIZE}; i++)); do
export DEVICE_ID=${i}
export RANK_ID=${i}
echo "Start distributed training for rank ${RANK_ID}, device ${DEVICE_ID}"
python ../train.py \
--dataset=cifar10 \
--dataset_path=$2 \
--run_distribute=True \
--lr_max=0.8 \
> train-${i}.log 2>&1 &
pid=$!
PID_LIST+=("${pid}")
done
RUN_BACKGROUND=1
if (( RUN_BACKGROUND == 0 )); then
echo "Waiting for all processes to exit..."
for pid in ${PID_LIST[*]}; do
wait ${pid}
echo "Process ${pid} exited"
done
fi

View File

@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=0
export DEVICE_NUM=1
export RANK_ID=0
export RANK_SIZE=1
if [ $# == 2 ]; then
python ../test.py \
--checkpoint_path $1 \
--dataset cifar10 \
--dataset_path $2 \
> eval.log 2>&1 &
else
echo "Usage: \
bash run_eval.sh [CHECKPOINT_PATH] [DATASET_PATH]"
exit 1
fi

View File

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=0
export DEVICE_NUM=1
export RANK_ID=0
export RANK_SIZE=1
if [ $# == 1 ]; then
python ../train.py \
--dataset cifar10 \
--dataset_path $1 \
> train.log 2>&1 &
else
echo "Usage: \
bash run_standalone_train.sh [DATASET_PATH]"
exit 1
fi

View File

@ -0,0 +1,17 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Package initialization.
"""

View File

@ -0,0 +1,224 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Configurable parameters.
"""
import argparse
import logging
logger = logging.getLogger('config')
class Config:
"""
Define configurable parameters.
Args:
training (bool): Whether it is used in training mode or testing mode.
load_args (bool): Whether to load cli arguments automatically or not.
"""
def __init__(self, training, load_args=True):
# Set to mute logs with lower levels.
self.log_level = logging.INFO
# Random seed.
self.seed = 1
# Type of device(s) where the model would be deployed to.
# Choices: ['Ascend', 'GPU', 'CPU']
self.device_target = 'Ascend'
# The model to use. Choices: ['wrn']
self.net = 'wrn'
# The dataset to train or test against. Choices: ['cifar10']
self.dataset = 'cifar10'
# The number of classes.
self.class_num = 10
# Path to the folder where the intended dataset is stored.
self.dataset_path = './cifar-10-batches-bin'
# Batch size for both training mode and testing mode.
self.batch_size = 128
# Indicates training or testing mode.
self.training = training
# Testing parameters.
if not self.training:
# The checkpoint to load and test against.
# Example: './checkpoint/train_wrn_cifar10-200_390.ckpt'
self.checkpoint_path = None
# Training parameters.
if self.training:
# Whether to apply auto-augment or not.
self.augment = True
# The number of device(s) to be used for training.
self.device_num = 1
# Whether to train the model in a distributed mode or not.
self.run_distribute = False
# The pre-trained checkpoint to load and train from.
# Example: './checkpoint/train_wrn_cifar10-200_390.ckpt'
self.pre_trained = None
# Number of epochs to train.
self.epoch_size = 200
# Momentum factor.
self.momentum = 0.9
# L2 penalty.
self.weight_decay = 5e-4
# Learning rate decaying mode. Choices: ['cosine']
self.lr_decay_mode = 'cosine'
# The starting learning rate.
self.lr_init = 0.1
# The maximum learning rate.
self.lr_max = 0.1
# The number of warmup epochs. Note that during the warmup period,
# the learning rate grows from `lr_init` to `lr_max` linearly.
self.warmup_epochs = 5
# Loss scaling for mixed-precision training.
self.loss_scale = 1024
# Create a checkpoint per `save_checkpoint_epochs` epochs.
self.save_checkpoint_epochs = 5
# The maximum number of checkpoints to keep.
self.keep_checkpoint_max = 10
# The folder path to save checkpoints.
self.save_checkpoint_path = './checkpoint'
# _init is an initialization guard, which helps warn setting attributes
# outside __init__.
self._init = True
if load_args:
self.load_args()
def __setattr__(self, name, value):
"""___setattr__ is customized to warn adding attributes outside
__init__ and encourage declaring configurable parameters explicitly in
__init__."""
if getattr(self, '_init', False) and not hasattr(self, name):
logger.warning('attempting to add an attribute '
'outside __init__: %s=%s', name, value)
object.__setattr__(self, name, value)
def load_args(self):
"""load_args overwrites configurations by cli arguments."""
hooks = {} # hooks are used to assign values.
parser = argparse.ArgumentParser(
description='AutoAugment for image classification.')
parser.add_argument(
'--device_target', type=str, default='Ascend',
choices=['Ascend', 'GPU', 'CPU'],
help='Type of device(s) where the model would be deployed to.',
)
def hook_device_target(x):
"""Sets the device_target value."""
self.device_target = x
hooks['device_target'] = hook_device_target
parser.add_argument(
'--dataset', type=str, default='cifar10',
choices=['cifar10'],
help='The dataset to train or test against.',
)
def hook_dataset(x):
"""Sets the dataset value."""
self.dataset = x
hooks['dataset'] = hook_dataset
parser.add_argument(
'--dataset_path', type=str, default='./cifar-10-batches-bin',
help='Path to the folder where the intended dataset is stored.',
)
def hook_dataset_path(x):
"""Sets the dataset_path value."""
self.dataset_path = x
hooks['dataset_path'] = hook_dataset_path
if not self.training:
parser.add_argument(
'--checkpoint_path', type=str, default=None,
help='The checkpoint to load and test against. '
'Example: ./checkpoint/train_wrn_cifar10-200_390.ckpt',
)
def hook_checkpoint_path(x):
"""Sets the checkpoint_path value."""
self.checkpoint_path = x
hooks['checkpoint_path'] = hook_checkpoint_path
if self.training:
parser.add_argument(
'--augment', type=bool, default=True,
help='Whether to apply auto-augment or not.',
)
def hook_augment(x):
"""Sets the augment value."""
self.augment = x
hooks['augment'] = hook_augment
parser.add_argument(
'--device_num', type=int, default=1,
help='The number of device(s) to be used for training.',
)
def hook_device_num(x):
"""Sets the device_num value."""
self.device_num = x
hooks['device_num'] = hook_device_num
parser.add_argument(
'--run_distribute', type=bool, default=False,
help='Whether to train the model in distributed mode or not.',
)
def hook_distribute(x):
"""Sets the run_distribute value."""
self.run_distribute = x
hooks['run_distribute'] = hook_distribute
parser.add_argument(
'--lr_max', type=float, default=0.1,
help='The maximum learning rate.',
)
def hook_lr_max(x):
"""Sets the lr_max value."""
self.lr_max = x
hooks['lr_max'] = hook_lr_max
parser.add_argument(
'--pre_trained', type=str, default=None,
help='The pre-trained checkpoint to load and train from. '
'Example: ./checkpoint/train_wrn_cifar10-200_390.ckpt',
)
def hook_pre_trained(x):
"""Sets the pre_trained value."""
self.pre_trained = x
hooks['pre_trained'] = hook_pre_trained
parser.add_argument(
'--save_checkpoint_path', type=str, default='./checkpoint',
help='The folder path to save checkpoints.',
)
def hook_save_checkpoint_path(x):
"""Sets the save_checkpoint_path value."""
self.save_checkpoint_path = x
hooks['save_checkpoint_path'] = hook_save_checkpoint_path
# Overwrite default configurations by cli arguments
args_opt = parser.parse_args()
for name, val in args_opt.__dict__.items():
hooks[name](val)

View File

@ -0,0 +1,19 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Package initialization for dataset helpers.
"""
from .cifar10 import create_dataset as create_cifar10_dataset

View File

@ -0,0 +1,2 @@
cifar-10-batches-bin/
*.png

View File

@ -0,0 +1,19 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Package initialization for the AutoAugment operator.
"""
from .aug import Augment

View File

@ -0,0 +1,112 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
The Augment operator.
"""
import random
import mindspore.dataset.vision.py_transforms as py_trans
from .third_party.policies import good_policies
from .ops import OperatorClasses, RandomCutout
class Augment:
"""
Augment acts as a single transformation operator and applies policies found
by AutoAugment.
Args:
index (int or None): If index is not None, the indexed policy would
always be used. Otherwise, a policy would be randomly chosen from
the policies set for each image.
policies (policies found by AutoAugment or None): A set of policies
to sample from. When the given policies is None, good policies found
on cifar10 would be used.
enable_basic (bool): Whether to apply basic augmentations after
auto-augment or not. Note that basic augmentations
include RandomFlip, RandomCrop, and RandomCutout.
from_pil (bool): Whether the image passed to the operator is already a
PIL image.
as_pil (bool): Whether the returned image should be kept as a PIL image.
mean, std (list): Per-channel mean and std used to normalize the output
image. Only applicable when as_pil is False.
"""
def __init__(
self, index=None, policies=None, enable_basic=True,
from_pil=False, as_pil=False,
mean=None, std=None,
):
self.index = index
if policies is None:
self.policies = good_policies()
else:
self.policies = policies
self.oc = OperatorClasses()
self.to_pil = py_trans.ToPIL()
self.to_tensor = py_trans.ToTensor()
self.enable_basic = enable_basic
self.random_crop = self.oc.RandomCrop(None)
self.random_flip = self.oc.RandomHorizontalFlip(None)
self.cutout = RandomCutout(size=16, value=(0, 0, 0))
self.from_pil = from_pil
self.as_pil = as_pil
self.normalize = None
if mean is not None and std is not None:
self.normalize = py_trans.Normalize(mean, std)
def _apply(self, name, prob, level, img):
if random.random() > prob:
# Untouched
return img
# Apply the operator
return getattr(self.oc, name)(level)(img)
def __call__(self, img):
"""
Call method.
Args:
img (raw image or PIL image): Image to be auto-augmented.
Returns:
img (PIL image or Tensor), Auto-augmented image.
"""
if self.index is None:
policy = random.choice(self.policies)
else:
policy = self.policies[self.index]
if not self.from_pil:
img = self.to_pil(img)
for name, prob, level in policy:
img = self._apply(name, prob, level, img)
if self.enable_basic:
img = self.random_crop(img)
img = self.random_flip(img)
img = self.cutout(img)
if not self.as_pil:
img = self.to_tensor(img)
if self.normalize is not None:
img = self.normalize(img)
return img

View File

@ -0,0 +1,106 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Visualization for testing purposes.
"""
import sys
import matplotlib.pyplot as plt
import numpy as np
import mindspore.dataset as ds
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
def compare(data_path, index=None, ops=None, rescale=False):
"""Visualize images before and after applying auto-augment."""
# Load dataset
ds.config.set_seed(8)
dataset_orig = ds.Cifar10Dataset(
data_path,
num_samples=5,
shuffle=True,
)
# Apply transformations
dataset_augmented = dataset_orig.map(
operations=[Augment(index)] if ops is None else ops,
input_columns=['image'],
)
# Collect images
image_orig_list, image_augmented_list, label_list = [], [], []
for data in dataset_orig.create_dict_iterator():
image_orig_list.append(data['image'])
label_list.append(data['label'])
print('Original image: shape {}, label {}'.format(
data['image'].shape, data['label'],
))
for data in dataset_augmented.create_dict_iterator():
image_augmented_list.append(data['image'])
print('Augmented image: shape {}, label {}'.format(
data['image'].shape, data['label'],
))
num_samples = len(image_orig_list)
fig, mesh = plt.subplots(ncols=num_samples, nrows=2, figsize=(5, 3))
axes = mesh[0]
for i in range(num_samples):
axes[i].axis('off')
axes[i].imshow(image_orig_list[i].asnumpy())
axes[i].set_title(label_list[i].asnumpy())
axes = mesh[1]
for i in range(num_samples):
axes[i].axis('off')
img = image_augmented_list[i].asnumpy().transpose((1, 2, 0))
if rescale:
max_val = max(np.abs(img.min()), img.max())
img = (img / max_val + 1) / 2
print('min and max of the transformed image:', img.min(), img.max())
axes[i].imshow(img)
fig.tight_layout()
fig.savefig(
'aug_test.png' if index is None else 'aug_test_{}.png'.format(index),
)
if __name__ == '__main__':
sys.path.append('..')
from autoaugment.third_party.policies import good_policies
from autoaugment import Augment
cifar10_data_path = './cifar-10-batches-bin/'
# Test the feasibility of each policy
for ind, policy in enumerate(good_policies()):
if ind >= 3:
pass
# break
print(policy)
compare(cifar10_data_path, ind)
# Test the random policy selection and the normalize operation
MEAN = [0.49139968, 0.48215841, 0.44653091]
STD = [0.24703223, 0.24348513, 0.26158784]
compare(
cifar10_data_path,
ops=[Augment(mean=MEAN, std=STD, enable_basic=False)],
)
compare(
cifar10_data_path,
ops=[Augment(mean=MEAN, std=STD, enable_basic=False)],
rescale=True,
)

View File

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Package initialization for custom PIL operators.
"""
from mindspore.dataset.vision import py_transforms
from .crop import RandomCrop
from .cutout import RandomCutout
from .effect import (
Posterize,
Solarize,
)
from .enhance import (
Brightness,
Color,
Contrast,
Sharpness,
)
from .transform import (
Rotate,
ShearX,
ShearY,
TranslateX,
TranslateY,
)
class OperatorClasses:
"""OperatorClasses gathers all unary-image transformations listed in the
Table 6 of https://arxiv.org/abs/1805.09501 and uses discrte levels for
these transformations (The Sample Pairing transformation is an
exception, which involes multiple images from a single mini-batch and
is not exploited in this implementation.)
Additionally, there are RandomHorizontalFlip and RandomCrop.
"""
def __init__(self):
self.Rotate = self.decorate(Rotate, max_val=30, rounding=True)
self.ShearX = self.decorate(ShearX, max_val=0.3)
self.ShearY = self.decorate(ShearY, max_val=0.3)
self.TranslateX = self.decorate(TranslateX, max_val=10, rounding=True)
self.TranslateY = self.decorate(TranslateY, max_val=10, rounding=True)
self.AutoContrast = self.decorate(py_transforms.AutoContrast)
self.Invert = self.decorate(py_transforms.Invert)
self.Equalize = self.decorate(py_transforms.Equalize)
self.Solarize = self.decorate(
Solarize, max_val=256, rounding=True, post=lambda x: 256 - x)
self.Posterize = self.decorate(
Posterize, max_val=4, rounding=True, post=lambda x: 4 - x)
def post(x):
"""Post operation to avoid 0 value."""
return x + 0.1
self.Brightness = self.decorate(Brightness, max_val=1.8, post=post)
self.Color = self.decorate(Color, max_val=1.8, post=post)
self.Contrast = self.decorate(Contrast, max_val=1.8, post=post)
self.Sharpness = self.decorate(Sharpness, max_val=1.8, post=post)
self.Cutout = self.decorate(RandomCutout, max_val=20, rounding=True)
self.RandomHorizontalFlip = self.decorate(
py_transforms.RandomHorizontalFlip)
self.RandomCrop = self.decorate(RandomCrop)
def vars(self):
"""vars returns all available operators as a dictionary."""
return vars(self)
def decorate(self, op, max_val=None, rounding=False, post=None):
"""
decorate interprets discrete levels for the given operator when
applicable.
Args:
op (Augmentation Operator): Operator to be decorated.
max_val (int or float): Maximum value level-10 corresponds to.
rounding (bool): Whether the corresponding value should be rounded
to an integer.
post (function): User-defined post-processing value function.
Returns:
Decorated operator.
"""
if max_val is None:
def no_arg_fn(_):
"""Decorates an operator without level parameter."""
return op()
return no_arg_fn
def fn(level):
"""Decorates an operator with level parameter."""
val = max_val * level / 10
if rounding:
val = int(val)
if post is not None:
val = post(val)
return op(val)
return fn

View File

@ -0,0 +1,56 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
RandomCrop operator.
"""
from mindspore.dataset.vision import py_transforms
from mindspore.dataset.vision import py_transforms_util
from mindspore.dataset.vision import utils
class RandomCrop(py_transforms.RandomCrop):
"""
RandomCrop inherits from py_transforms.RandomCrop but derives/uses the
original image size as the output size.
Please refer to py_transforms.RandomCrop for argument specifications.
"""
def __init__(self, padding=4, pad_if_needed=False,
fill_value=0, padding_mode=utils.Border.CONSTANT):
# Note the `1` for the size argument is only set for passing the check.
super(RandomCrop, self).__init__(1, padding=padding, pad_if_needed=pad_if_needed,
fill_value=fill_value, padding_mode=padding_mode)
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be padded and then randomly cropped back
to the same size.
Returns:
img (PIL image), Randomly cropped image.
"""
if not py_transforms_util.is_pil(img):
raise TypeError(
py_transforms_util.augment_error_message.format(type(img)))
return py_transforms_util.random_crop(
img, img.size, self.padding, self.pad_if_needed,
self.fill_value, self.padding_mode,
)

View File

@ -0,0 +1,77 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
RandomCutout operator.
"""
import random
class RandomCutout:
"""
RandomCutout is similar to py_transforms.Cutout but is simplified and
crafted for PIL images.
Args:
size (int): the side-size of each square cutout patch.
num_patches (int): the number of square cutout patches to add.
value (RGB value): the pixel value to fill in each cutout patches.
"""
def __init__(self, size=30, num_patches=1, value=(125, 122, 113)):
self.size = size
self.num_patches = num_patches
self.value = value
@staticmethod
def _clip(x, lower, upper):
"""Clip value to the [lower, upper] range."""
return max(lower, min(x, upper))
@staticmethod
def _get_cutout_area(img_w, img_h, size):
"""Randomly create a cutout area."""
x = random.randint(0, img_w)
y = random.randint(0, img_h)
x1 = x - size // 2
x2 = x1 + size
y1 = y - size // 2
y2 = y1 + size
x1 = RandomCutout._clip(x1, 0, img_w)
x2 = RandomCutout._clip(x2, 0, img_w)
y1 = RandomCutout._clip(y1, 0, img_h)
y2 = RandomCutout._clip(y2, 0, img_h)
return x1, y1, x2, y2
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be cutout.
Returns:
img (PIL image), Randomly cutout image.
"""
img_w, img_h = img.size
pixels = img.load()
for _ in range(self.num_patches):
x1, y1, x2, y2 = self._get_cutout_area(img_w, img_h, self.size)
for i in range(x1, x2): # columns
for j in range(y1, y2): # rows
pixels[i, j] = self.value
return img

View File

@ -0,0 +1,79 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Operators for applying image effects.
"""
from PIL import ImageOps
from mindspore.dataset.vision import py_transforms_util
class Solarize:
"""
Solarize inverts image pixels with values above the configured threshold.
Args:
threshold (int): All pixels above the threshold would be inverted.
Ranging within [0, 255].
"""
def __init__(self, threshold):
self.threshold = threshold
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be solarized.
Returns:
img (PIL image), Solarized image.
"""
if not py_transforms_util.is_pil(img):
raise TypeError(
py_transforms_util.augment_error_message.format(type(img)))
return ImageOps.solarize(img, self.threshold)
class Posterize:
"""
Posterize reduces the number of bits for each color channel.
Args:
bits (int): The number of bits to keep for each channel.
Ranging within [1, 8].
"""
def __init__(self, bits):
self.bits = bits
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be posterized.
Returns:
img (PIL image), Posterized image.
"""
if not py_transforms_util.is_pil(img):
raise TypeError(
py_transforms_util.augment_error_message.format(type(img)))
return ImageOps.posterize(img, self.bits)

View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Operators for enhancing images.
"""
from PIL import ImageEnhance
from mindspore.dataset.vision import py_transforms_util
class Contrast:
"""
Contrast adjusts the contrast of images.
Args:
degree (float): contrast degree ranging within [0.1, 1.9], where 1.0
indicating an unchanged contrast.
"""
def __init__(self, degree):
self.degree = degree
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): PIL image to be adjusted.
Returns:
img (PIL image), Contrast adjusted image.
"""
return py_transforms_util.adjust_contrast(img, self.degree)
class Color:
"""
Color adjusts the saturation of images.
Args:
degree (float): saturation degree ranging within [0.1, 1.9], where 1.0
indicating an unchanged saturation.
"""
def __init__(self, degree):
self.degree = degree
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): PIL image to be adjusted.
Returns:
img (PIL image), Saturation adjusted image.
"""
return py_transforms_util.adjust_saturation(img, self.degree)
class Brightness:
"""
Brightness adjusts the brightness of images.
Args:
degree (float): brightness degree ranging within [0.1, 1.9], where 1.0
indicating an unchanged brightness.
"""
def __init__(self, degree):
self.degree = degree
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be adjusted.
Returns:
img (PIL image), Brightness adjusted image.
"""
return py_transforms_util.adjust_brightness(img, self.degree)
class Sharpness:
"""
Sharpness adjusts the sharpness of images.
Args:
degree (float): sharpness degree ranging within [0.1, 1.9], where 1.0
indicating an unchanged sharpness.
"""
def __init__(self, degree):
self.degree = degree
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be sharpness adjusted.
Returns:
img (PIL image), Sharpness adjusted image.
"""
if not py_transforms_util.is_pil(img):
raise TypeError(
py_transforms_util.augment_error_message.format(type(img)))
return ImageEnhance.Sharpness(img).enhance(self.degree)

View File

@ -0,0 +1,103 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Visualization for testing purposes.
"""
import matplotlib.pyplot as plt
import mindspore.dataset as ds
import mindspore.dataset.vision.py_transforms as py_trans
from mindspore import context
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
def compare(data_path, trans, output_path='./ops_test.png'):
"""Visualize images before and after applying the given transformations."""
# Load dataset
ds.config.set_seed(8)
dataset_orig = ds.Cifar10Dataset(
data_path,
num_samples=5,
shuffle=True,
)
# Apply transformations
dataset_augmented = dataset_orig.map(
operations=[py_trans.ToPIL()] + trans + [py_trans.ToTensor()],
input_columns=['image'],
)
# Collect images
image_orig_list, image_augmented_list, label_list = [], [], []
for data in dataset_orig.create_dict_iterator():
image_orig_list.append(data['image'])
label_list.append(data['label'])
print('Original image: shape {}, label {}'.format(
data['image'].shape, data['label'],
))
for data in dataset_augmented.create_dict_iterator():
image_augmented_list.append(data['image'])
print('Augmented image: shape {}, label {}'.format(
data['image'].shape, data['label'],
))
# Plot images
num_samples = len(image_orig_list)
fig, mesh = plt.subplots(ncols=num_samples, nrows=2, figsize=(5, 3))
axes = mesh[0]
for i in range(num_samples):
axes[i].axis('off')
axes[i].imshow(image_orig_list[i].asnumpy())
axes[i].set_title(label_list[i].asnumpy())
axes = mesh[1]
for i in range(num_samples):
axes[i].axis('off')
axes[i].imshow(image_augmented_list[i].asnumpy().transpose((1, 2, 0)))
fig.tight_layout()
fig.savefig(output_path)
if __name__ == '__main__':
import sys
sys.path.append('..')
from ops import OperatorClasses
oc = OperatorClasses()
levels = {
'Contrast': 7,
'Rotate': 7,
'TranslateX': 9,
'Sharpness': 9,
'ShearY': 8,
'TranslateY': 9,
'AutoContrast': 9,
'Equalize': 9,
'Solarize': 8,
'Color': 9,
'Posterize': 7,
'Brightness': 9,
'Cutout': 4,
'ShearX': 4,
'Invert': 7,
'RandomHorizontalFlip': None,
'RandomCrop': None,
}
cifar10_data_path = './cifar-10-batches-bin/'
for name, op in oc.vars().items():
compare(cifar10_data_path, [op(levels[name])], './ops_{}_{}.png'.format(
name, levels[name],
))

View File

@ -0,0 +1,252 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Operators for affine transformations.
"""
import numbers
import random
from PIL import Image, __version__
from mindspore.dataset.vision.py_transforms import DE_PY_INTER_MODE
from mindspore.dataset.vision.py_transforms_util import (
augment_error_message,
is_pil,
rotate,
)
from mindspore.dataset.vision.utils import Inter
class ShearX:
"""
ShearX shears images along the x-axis.
Args:
shear (int): the pixel size to shear.
resample (enum): the interpolation mode.
fill_value (int or tuple): the filling value to fill the area outside
the transform in the output image.
"""
def __init__(self, shear, resample=Inter.NEAREST, fill_value=0):
if not isinstance(shear, numbers.Number):
raise TypeError('shear must be a single number.')
self.shear = shear
self.resample = DE_PY_INTER_MODE[resample]
self.fill_value = fill_value
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to apply shear_x transformation.
Returns:
img (PIL image), X-axis sheared image.
"""
if not is_pil(img):
raise ValueError('Input image should be a Pillow image.')
output_size = img.size
shear = self.shear if random.random() > 0.5 else -self.shear
matrix = (1, shear, 0, 0, 1, 0)
if __version__ >= '5':
kwargs = {'fillcolor': self.fill_value}
else:
kwargs = {}
return img.transform(output_size, Image.AFFINE, matrix,
self.resample, **kwargs)
class ShearY:
"""
ShearY shears images along the y-axis.
Args:
shear (int): the pixel size to shear.
resample (enum): the interpolation mode.
fill_value (int or tuple): the filling value to fill the area outside
the transform in the output image.
"""
def __init__(self, shear, resample=Inter.NEAREST, fill_value=0):
if not isinstance(shear, numbers.Number):
raise TypeError('shear must be a single number.')
self.shear = shear
self.resample = DE_PY_INTER_MODE[resample]
self.fill_value = fill_value
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to apply shear_y transformation.
Returns:
img (PIL image), Y-axis sheared image.
"""
if not is_pil(img):
raise ValueError('Input image should be a Pillow image.')
output_size = img.size
shear = self.shear if random.random() > 0.5 else -self.shear
matrix = (1, 0, 0, shear, 1, 0)
if __version__ >= '5':
kwargs = {'fillcolor': self.fill_value}
else:
kwargs = {}
return img.transform(output_size, Image.AFFINE, matrix,
self.resample, **kwargs)
class TranslateX:
"""
TranslateX translates images along the x-axis.
Args:
translate (int): the pixel size to translate.
resample (enum): the interpolation mode.
fill_value (int or tuple): the filling value to fill the area outside
the transform in the output image.
"""
def __init__(self, translate, resample=Inter.NEAREST, fill_value=0):
if not isinstance(translate, numbers.Number):
raise TypeError('translate must be a single number.')
self.translate = translate
self.resample = DE_PY_INTER_MODE[resample]
self.fill_value = fill_value
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to apply translate_x transformation.
Returns:
img (PIL image), X-axis translated image.
"""
if not is_pil(img):
raise ValueError('Input image should be a Pillow image.')
output_size = img.size
trans = self.translate if random.random() > 0.5 else -self.translate
matrix = (1, 0, trans, 0, 1, 0)
if __version__ >= '5':
kwargs = {'fillcolor': self.fill_value}
else:
kwargs = {}
return img.transform(output_size, Image.AFFINE, matrix,
self.resample, **kwargs)
class TranslateY:
"""
TranslateY translates images along the y-axis.
Args:
translate (int): the pixel size to translate.
resample (enum): the interpolation mode.
fill_value (int or tuple): the filling value to fill the area outside
the transform in the output image.
"""
def __init__(self, translate, resample=Inter.NEAREST, fill_value=0):
if not isinstance(translate, numbers.Number):
raise TypeError('Translate must be a single number.')
self.translate = translate
self.resample = DE_PY_INTER_MODE[resample]
self.fill_value = fill_value
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to apply translate_y transformation.
Returns:
img (PIL image), Y-axis translated image.
"""
if not is_pil(img):
raise ValueError('Input image should be a Pillow image.')
output_size = img.size
trans = self.translate if random.random() > 0.5 else -self.translate
matrix = (1, 0, 0, 0, 1, trans)
if __version__ >= '5':
kwargs = {'fillcolor': self.fill_value}
else:
kwargs = {}
return img.transform(output_size, Image.AFFINE, matrix,
self.resample, **kwargs)
class Rotate:
"""
Rotate is similar to py_vision.RandomRotation but uses a fixed degree.
Args:
degree (int): the degree to rotate.
Please refer to py_transforms.RandomRotation for more argument
specifications.
"""
def __init__(
self, degree,
resample=Inter.NEAREST, expand=False, center=None, fill_value=0,
):
if not isinstance(degree, numbers.Number):
raise TypeError('degree must be a single number.')
self.degree = degree
self.resample = DE_PY_INTER_MODE[resample]
self.expand = expand
self.center = center
self.fill_value = fill_value
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be rotated.
Returns:
img (PIL image), Rotated image.
"""
if not is_pil(img):
raise TypeError(augment_error_message.format(type(img)))
degree = self.degree if random.random() > 0.5 else -self.degree
return rotate(img, degree, self.resample, self.expand,
self.center, self.fill_value)

View File

@ -0,0 +1,140 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Good policies found by the AutoAugment paper.
"""
def good_policies():
"""AutoAugment policies found on Cifar."""
exp0_0 = [
[('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
[('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
[('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
[('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
exp0_1 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
[('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
[('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
[('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
exp0_2 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
[('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
[('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
[('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
exp0_3 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
[('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
[('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
[('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
[('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
exp1_0 = [
[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
[('Color', 0.4, 3), ('Brightness', 0.6, 7)],
[('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
[('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
[('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
exp1_1 = [
[('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
[('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
[('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
[('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
[('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
exp1_2 = [
[('Solarize', 0.2, 6), ('Color', 0.8, 6)],
[('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
[('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
[('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
[('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
exp1_3 = [
[('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
[('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
[('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
[('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
[('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
exp1_4 = [
[('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
[('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
[('Equalize', 0.6, 8), ('Color', 0.6, 2)],
[('Color', 0.3, 7), ('Color', 0.2, 4)],
[('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
exp1_5 = [
[('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
[('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
[('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
[('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
[('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
exp1_6 = [
[('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
[('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
[('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
[('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
[('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
exp2_0 = [
[('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
[('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
[('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
[('Brightness', 0.9, 6), ('Color', 0.2, 8)],
[('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
exp2_1 = [
[('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
[('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
[('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
[('Color', 0.1, 8), ('ShearY', 0.2, 3)],
[('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
exp2_2 = [
[('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
[('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
[('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
[('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
[('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
exp2_3 = [
[('Equalize', 0.9, 5), ('Color', 0.7, 0)],
[('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
[('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
[('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
[('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
exp2_4 = [
[('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
[('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
[('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
[('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
[('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
exp2_5 = [
[('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
[('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
[('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
[('Solarize', 0.4, 3), ('Color', 0.2, 4)],
[('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
exp2_6 = [
[('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
[('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
[('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
[('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
exp2_7 = [
[('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
[('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
[('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
[('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
[('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + \
exp2_7
return exp0s + exp1s + exp2s

View File

@ -0,0 +1,111 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Helpers for creating Cifar-10 datasets (optionally with AutoAugment enabled).
"""
import os
import mindspore.common.dtype as mstype
from mindspore.communication.management import init, get_rank, get_group_size
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
from .autoaugment import Augment
def _get_rank_info():
"""Get rank size and rank id."""
rank_size = int(os.environ.get('RANK_SIZE', 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id
def create_dataset(dataset_path, do_train=True, repeat_num=1, batch_size=32,
target='Ascend', distribute=False, augment=True):
"""
Create a train or test cifar10 dataset.
Args:
dataset_path (string): Path to the dataset.
do_train (bool): Whether dataset is used for training or testing.
repeat_num (int): Repeat times of the dataset.
batch_size (int): Batch size of the dataset.
target (str): Device target.
distribute (bool): For distributed training or not.
augment (bool): Whether to enable auto-augment or not.
Returns:
dataset
"""
if target == 'Ascend':
rank_size, rank_id = _get_rank_info()
else:
if distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
else:
rank_size = 1
num_shards = None if rank_size == 1 else rank_size
shard_id = None if rank_size == 1 else rank_id
dataset = ds.Cifar10Dataset(
dataset_path, usage='train' if do_train else 'test',
num_parallel_workers=8, shuffle=True,
num_shards=num_shards, shard_id=shard_id,
)
# Define map operations
MEAN = [0.4914, 0.4822, 0.4465]
STD = [0.2023, 0.1994, 0.2010]
trans = []
if do_train and augment:
trans += [
Augment(mean=MEAN, std=STD),
]
else:
if do_train:
trans += [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(),
]
trans += [
C.Rescale(1. / 255., 0.),
C.Normalize(MEAN, STD),
C.HWC2CHW(),
]
dataset = dataset.map(operations=trans,
input_columns='image', num_parallel_workers=8)
type_cast_op = C2.TypeCast(mstype.int32)
dataset = dataset.map(operations=type_cast_op,
input_columns='label', num_parallel_workers=8)
# Apply batch operations
dataset = dataset.batch(batch_size, drop_remainder=True)
# Apply dataset repeat operation
dataset = dataset.repeat(repeat_num)
return dataset

View File

@ -0,0 +1,19 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Package initialization for network definitions.
"""
from .wrn import WRN

View File

@ -0,0 +1,233 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
WideResNet building blocks.
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
def _get_optional_avg_pool(stride):
"""
Create an average pool op if stride is larger than 1, return None otherwise.
Args:
stride (int): Stride size, a positive integer.
Returns:
nn.AvgPool2d or None.
"""
if stride == 1:
return None
return nn.AvgPool2d(kernel_size=stride, stride=stride)
def _get_optional_pad(in_channels, out_channels):
"""
Create a zero-pad op if out_channels is larger than in_channels, return None
otherwise.
Args:
in_channels (int): The input channel size.
out_channels (int): The output channel size (must not be smaller than
in_channels).
Returns:
nn.Pad or None.
"""
if in_channels == out_channels:
return None
pad_left = (out_channels - in_channels) // 2
pad_right = out_channels - in_channels - pad_left
return nn.Pad((
(0, 0),
(pad_left, pad_right),
(0, 0),
(0, 0),
))
class ResidualBlock(nn.Cell):
"""
ResidualBlock is the basic building block for wide-resnet.
Args:
in_channels (int): The input channel size.
out_channels (int): The output channel size.
stride (int): The stride size used in the first convolution layer.
activate_before_residual (bool): Whether to apply bn and relu before
residual.
"""
def __init__(
self,
in_channels,
out_channels,
stride,
activate_before_residual=False,
):
super(ResidualBlock, self).__init__()
self.activate_before_residual = activate_before_residual
self.in_channels = in_channels
self.out_channels = out_channels
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, stride=1)
self.avg_pool = _get_optional_avg_pool(stride)
self.pad = _get_optional_pad(in_channels, out_channels)
self.relu = nn.ReLU()
self.add = P.Add()
def construct(self, x):
"""Construct the forward network."""
if self.activate_before_residual:
out = self.bn1(x)
out = self.relu(out)
orig_x = out
else:
orig_x = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
if self.avg_pool is not None:
orig_x = self.avg_pool(orig_x)
if self.pad is not None:
orig_x = self.pad(orig_x)
return self.add(out, orig_x)
class ResidualGroup(nn.Cell):
"""
ResidualGroup gathers a group of ResidualBlocks (default: 4).
Args:
in_channels (int): The input channel size.
out_channels (int): The output channel size.
stride (int): The stride size used in the first ResidualBlock.
activate_before_residual (bool): Whether to apply bn and relu before
residual in the first ResidualBlock.
num_blocks (int): Number of ResidualBlocks in the group.
"""
def __init__(
self,
in_channels,
out_channels,
stride,
activate_before_residual=False,
num_blocks=4,
):
super(ResidualGroup, self).__init__()
self.rb = ResidualBlock(in_channels, out_channels, stride,
activate_before_residual)
self.rbs = nn.SequentialCell([
ResidualBlock(out_channels, out_channels, 1)
for _ in range(num_blocks - 1)
])
self.avg_pool = _get_optional_avg_pool(stride)
self.pad = _get_optional_pad(in_channels, out_channels)
self.add = P.Add()
def construct(self, x):
"""Construct the forward network."""
orig_x = x
out = self.rb(x)
out = self.rbs(out)
if self.avg_pool is not None:
orig_x = self.avg_pool(orig_x)
if self.pad is not None:
orig_x = self.pad(orig_x)
return self.add(out, orig_x)
class WRN(nn.Cell):
"""
WRN is short for Wide-ResNet.
Args:
wrn_size (int): Wide-ResNet size.
in_channels (int): The input channel size.
num_classes (int): Number of classes to predict.
"""
def __init__(self, wrn_size, in_channels, num_classes):
super(WRN, self).__init__()
sizes = [
min(wrn_size, 16),
wrn_size,
wrn_size * 2,
wrn_size * 4,
]
strides = [1, 2, 2]
self.conv1 = nn.Conv2d(in_channels, sizes[0], 3)
self.rg1 = ResidualGroup(sizes[0], sizes[1], strides[0], True)
self.rg2 = ResidualGroup(sizes[1], sizes[2], strides[1], False)
self.rg3 = ResidualGroup(sizes[2], sizes[3], strides[2], False)
final_stride = 1
for s in strides:
final_stride *= s
self.avg_pool = _get_optional_avg_pool(final_stride)
self.pad = _get_optional_pad(sizes[0], sizes[-1])
self.bn = nn.BatchNorm2d(sizes[-1])
self.fc = nn.Dense(sizes[-1], num_classes)
self.mean = P.ReduceMean()
self.relu = nn.ReLU()
self.add = P.Add()
def construct(self, x):
"""Construct the forward network."""
out = self.conv1(x)
orig_x = out
out = self.rg1(out)
out = self.rg2(out)
out = self.rg3(out)
if self.avg_pool is not None:
orig_x = self.avg_pool(orig_x)
if self.pad is not None:
orig_x = self.pad(orig_x)
out = self.add(out, orig_x)
out = self.bn(out)
out = self.relu(out)
out = self.mean(out, (2, 3))
out = self.fc(out)
return out

View File

@ -0,0 +1,19 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Package initialization for network optimization helpers.
"""
from .lr import get_lr

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utilities for generating learning rates."""
import math
import numpy as np
def _generate_cosine_lr(lr_init, lr_max, total_steps, warmup_steps):
"""
Create an array of learning rates conforming to the cosine decay mode.
Args:
lr_init (float): Starting learning rate.
lr_max (float): Maximum learning rate.
total_steps (int): Total number of train steps.
warmup_steps (int): Number of warmup steps.
Returns:
np.array, a learning rate array.
"""
decay_steps = total_steps - warmup_steps
lr_each_step = []
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * \
(1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
return lr_each_step
def get_lr(
lr_init, lr_max,
warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode='cosine',
):
"""
Args:
lr_init (float): Starting learning rate.
lr_max (float): Maximum learning rate.
warmup_epochs (int): Number of warmup epochs.
total_epochs (int): Total number of train epochs.
steps_per_epoch (int): Steps per epoch.
lr_decay_mode (string): Learning rate decay mode.
Returns:
np.array, a learning rate array.
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'cosine':
lr_each_step = _generate_cosine_lr(
lr_init, lr_max, total_steps, warmup_steps,
)
else:
assert False, 'lr_decay_mode {} is not supported'.format(lr_decay_mode)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Package initialization for common utils.
"""
import logging
logger = logging.getLogger('utils')
def init_utils(config):
"""
Initialize common utils.
Args:
config (Config): Contains configurable parameters throughout the
project.
"""
init_logging(config.log_level)
for k, v in vars(config).items():
logger.info('%s=%s', k, v)
def init_logging(level=logging.INFO):
"""
Initialize logging formats.
Args:
level (logging level constant): Set to mute logs with lower levels.
"""
FMT = r'[%(asctime)s][%(name)s][%(levelname)8s] %(message)s'
DATE_FMT = r'%Y-%m-%d %H:%M:%S'
logging.basicConfig(format=FMT, datefmt=DATE_FMT, level=level)

View File

@ -0,0 +1,77 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Model testing entrypoint.
"""
import os
from mindspore import context
from mindspore.common import set_seed
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import Config
from src.dataset import create_cifar10_dataset
from src.network import WRN
from src.utils import init_utils
if __name__ == '__main__':
conf = Config(training=False)
init_utils(conf)
set_seed(conf.seed)
# Initialize context
try:
device_id = int(os.getenv('DEVICE_ID'))
except TypeError:
device_id = 0
context.set_context(
mode=context.GRAPH_MODE,
device_target=conf.device_target,
save_graphs=False,
device_id=device_id,
)
# Create dataset
if conf.dataset == 'cifar10':
dataset = create_cifar10_dataset(
dataset_path=conf.dataset_path,
do_train=False,
repeat_num=1,
batch_size=conf.batch_size,
target=conf.device_target,
)
step_size = dataset.get_dataset_size()
# Define net
net = WRN(160, 3, conf.class_num)
# Load checkpoint
param_dict = load_checkpoint(conf.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# Define loss and model
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, metrics={
'top_1_accuracy', 'top_5_accuracy',
})
# Eval model
res = model.eval(dataset)
print('result:', res, 'checkpoint:', conf.checkpoint_path)

View File

@ -0,0 +1,140 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Model training entrypoint.
"""
import os
from mindspore import context, Model, load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore.common.tensor import Tensor
from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.context import ParallelMode
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, \
LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from src.config import Config
from src.dataset import create_cifar10_dataset
from src.network import WRN
from src.optim import get_lr
from src.utils import init_utils
if __name__ == '__main__':
conf = Config(training=True)
init_utils(conf)
set_seed(conf.seed)
# Initialize context
context.set_context(
mode=context.GRAPH_MODE,
device_target=conf.device_target,
save_graphs=False,
)
if conf.run_distribute:
if conf.device_target == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(
device_id=device_id,
enable_auto_mixed_precision=True,
)
context.set_auto_parallel_context(
device_num=conf.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
)
init()
elif conf.device_target == 'GPU':
init()
context.set_auto_parallel_context(
device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
)
else:
try:
device_id = int(os.getenv('DEVICE_ID'))
except TypeError:
device_id = 0
context.set_context(device_id=device_id)
# Create dataset
if conf.dataset == 'cifar10':
dataset = create_cifar10_dataset(
dataset_path=conf.dataset_path,
do_train=True,
repeat_num=1,
batch_size=conf.batch_size,
target=conf.device_target,
distribute=conf.run_distribute,
augment=conf.augment,
)
step_size = dataset.get_dataset_size()
# Define net
net = WRN(160, 3, conf.class_num)
# Load weight if pre_trained is configured
if conf.pre_trained:
param_dict = load_checkpoint(conf.pre_trained)
load_param_into_net(net, param_dict)
# Initialize learning rate
lr = Tensor(get_lr(
lr_init=conf.lr_init, lr_max=conf.lr_max,
warmup_epochs=conf.warmup_epochs, total_epochs=conf.epoch_size,
steps_per_epoch=step_size, lr_decay_mode=conf.lr_decay_mode,
))
# Define loss, opt, and model
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(
conf.loss_scale,
drop_overflow_update=False,
)
opt = Momentum(
filter(lambda x: x.requires_grad, net.get_parameters()),
lr, conf.momentum, conf.weight_decay, conf.loss_scale,
)
model = Model(net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale, metrics={'acc'})
# Define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
config_ck = CheckpointConfig(
save_checkpoint_steps=conf.save_checkpoint_epochs * step_size,
keep_checkpoint_max=conf.keep_checkpoint_max,
)
ck_cb = ModelCheckpoint(
prefix='train_%s_%s' % (conf.net, conf.dataset),
directory=conf.save_checkpoint_path,
config=config_ck,
)
# Train
if conf.run_distribute:
callbacks = [time_cb, loss_cb]
if conf.device_target == 'GPU' and str(get_rank()) == '0':
callbacks = [time_cb, loss_cb, ck_cb]
elif conf.device_target == 'Ascend' and device_id == 0:
callbacks = [time_cb, loss_cb, ck_cb]
else:
callbacks = [time_cb, loss_cb, ck_cb]
model.train(conf.epoch_size, dataset, callbacks=callbacks)