forked from mindspore-Ecosystem/mindspore
Add autoaugment
This commit is contained in:
parent
a8553078b9
commit
1440b1d0a0
|
@ -0,0 +1,365 @@
|
|||
## 目录
|
||||
|
||||
- [目录](#目录)
|
||||
- [AutoAugment描述](#AutoAugment描述)
|
||||
- [概述](#概述)
|
||||
- [AutoAugment论文](#AutoAugment论文)
|
||||
- [模型架构](#模型架构)
|
||||
- [WideResNet论文](#WideResNet论文)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [脚本使用](#脚本使用)
|
||||
- [AutoAugment算子用法](#AutoAugment算子用法)
|
||||
- [训练脚本用法](#训练脚本用法)
|
||||
- [评估脚本用法](#评估脚本用法)
|
||||
- [导出脚本用法](#导出脚本用法)
|
||||
- [模型描述](#模型描述)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#ModelZoo主页)
|
||||
|
||||
## AutoAugment描述
|
||||
|
||||
### 概述
|
||||
|
||||
数据增广是提升图像分类器准确度和泛化能力的一种重要手段,传统数据增广方法主要依赖人工设计并使用固定的增广流程(例如组合应用`RandomCrop`与`RandomHorizontalFlip`图像变换算子)。
|
||||
|
||||
不同于传统方法,AutoAugment为数据增广提出了一种有效的策略空间设计,使得研究者能够使用不同的搜索算法(例如强化学习、进化算法、甚至是随机搜索等)来为特定的模型、数据集自动定制增广流程。具体而言,AutoAugment提出的策略空间主要涵盖以下概念:
|
||||
|
||||
| 概念名称 | 英文对照 | 概念简述 |
|
||||
|:---|:---|:---|
|
||||
| 算子 | Operation | 图像变换算子(例如平移、旋转等),AutoAugment选用的算子均不改变输入图片的大小和类型;每种算子具有两个可搜索的参数,为概率及量级。 |
|
||||
| 概率 | Probability | 随机应用某一图像变换算子的概率,如不应用,则直接返回输入图片。 |
|
||||
| 量级 | Magnitude | 应用某一图像变换算子的强度,例如平移的像素数、旋转的角度等。 |
|
||||
| 子策略 | Subpolicy | 每个子策略包含两个算子;应用子策略时,两个算子依据概率和量级按序变换输入图像。 |
|
||||
| 策略 | Policy | 每个策略包含若干个子策略,对数据进行增广时,策略为每张图片随机选择一个子策略。 |
|
||||
|
||||
由于算子数目是有限的、每个算子的概率和量级参数均可离散化,因此AutoAugment提出的策略空间能够引出一个有限状态的离散搜索问题。特别地,实验表明,AutoAugment提出的策略空间还具有一定的可迁移能力,即使用某一模型、数据集组合搜索得到的策略能被迁移到针对同一数据集的其它模型、或使用某一数据集搜索得到的策略能被迁移到其它相似的数据集。
|
||||
|
||||
本示例主要针对AutoAugment提出的策略空间进行了实现,开发者可以基于本示例使用AutoAugment论文列出的“好策略”对数据集进行增广、或基于本示例设计搜索算法以自动定制增广流程。
|
||||
|
||||
### AutoAugment论文
|
||||
|
||||
Cubuk, Ekin D., et al. "Autoaugment: Learning augmentation strategies from data." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
|
||||
|
||||
## 模型架构
|
||||
|
||||
除实现AutoAugment提出的策略空间外,本示例还提供了Wide-ResNet模型的简单实现,以供开发者参考。
|
||||
|
||||
### WideResNet论文
|
||||
|
||||
Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).
|
||||
|
||||
## 数据集
|
||||
|
||||
本示例以Cifar10为例,介绍AutoAugment的使用方法并验证本示例的有效性。
|
||||
|
||||
本示例使用[CIFAR-10 binary version](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz),其目录结构如下:
|
||||
|
||||
```bash
|
||||
cifar-10-batches-bin
|
||||
├── batches.meta.txt
|
||||
├── data_batch_1.bin
|
||||
├── data_batch_2.bin
|
||||
├── data_batch_3.bin
|
||||
├── data_batch_4.bin
|
||||
├── data_batch_5.bin
|
||||
├── readme.html
|
||||
└── test_batch.bin
|
||||
```
|
||||
|
||||
## 环境要求
|
||||
|
||||
- 硬件
|
||||
- 准备Ascend处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo体验资源申请表.docx)至ascend@huawei.com,审核通过即可获得资源。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
## 快速入门
|
||||
|
||||
完成计算设备和框架环境的准备后,开发者可以运行如下指令对本示例进行训练和评估。
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# 8卡分布式训练
|
||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# 单卡训练
|
||||
用法:bash run_standalone_train.sh [DATASET_PATH]
|
||||
|
||||
# 单卡评估
|
||||
用法:bash run_eval.sh [CHECKPOINT_PATH] [DATASET_PATH]
|
||||
```
|
||||
|
||||
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||
|
||||
具体操作,请参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
|
||||
|
||||
## 脚本说明
|
||||
|
||||
```bash
|
||||
.
|
||||
├── export.py # 导出网络
|
||||
├── mindspore_hub_conf.py # MindSpore Hub配置
|
||||
├── README.md # 说明文档
|
||||
├── run_distribute_train.sh # Ascend处理器环境多卡训练脚本
|
||||
├── run_eval.sh # Ascend处理器环境评估脚本
|
||||
├── run_standalone_train.sh # Ascend处理器环境单卡训练脚本
|
||||
├── src
|
||||
│ ├── config.py # 模型训练/测试配置文件
|
||||
│ ├── dataset
|
||||
│ │ ├── autoaugment
|
||||
│ │ │ ├── aug.py # AutoAugment策略
|
||||
│ │ │ ├── aug_test.py # AutoAugment策略测试及可视化
|
||||
│ │ │ ├── ops
|
||||
│ │ │ │ ├── crop.py # RandomCrop算子
|
||||
│ │ │ │ ├── cutout.py # RandomCutout算子
|
||||
│ │ │ │ ├── effect.py # 图像特效算子
|
||||
│ │ │ │ ├── enhance.py # 图像增强算子
|
||||
│ │ │ │ ├── ops_test.py # 算子测试及可视化
|
||||
│ │ │ │ └── transform.py # 图像变换算子
|
||||
│ │ │ └── third_party
|
||||
│ │ │ └── policies.py # AutoAugment搜索得到的“好策略”
|
||||
│ │ └── cifar10.py # Cifar10数据集处理
|
||||
│ ├── network
|
||||
│ │ └── wrn.py # Wide-ResNet模型定义
|
||||
│ ├── optim
|
||||
│ │ └── lr.py # Cosine学习率定义
|
||||
│ └── utils # 初始化日志格式等
|
||||
├── test.py # 测试网络
|
||||
└── train.py # 训练网络
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在[src/config.py](./src/config.py)中可以配置训练参数、数据集路径等参数。
|
||||
|
||||
```python
|
||||
# Set to mute logs with lower levels.
|
||||
self.log_level = logging.INFO
|
||||
|
||||
# Random seed.
|
||||
self.seed = 1
|
||||
|
||||
# Type of device(s) where the model would be deployed to.
|
||||
# Choices: ['Ascend', 'GPU', 'CPU']
|
||||
self.device_target = 'Ascend'
|
||||
|
||||
# The model to use. Choices: ['wrn']
|
||||
self.net = 'wrn'
|
||||
|
||||
# The dataset to train or test against. Choices: ['cifar10']
|
||||
self.dataset = 'cifar10'
|
||||
# The number of classes.
|
||||
self.class_num = 10
|
||||
# Path to the folder where the intended dataset is stored.
|
||||
self.dataset_path = './cifar-10-batches-bin'
|
||||
|
||||
# Batch size for both training mode and testing mode.
|
||||
self.batch_size = 128
|
||||
|
||||
# Indicates training or testing mode.
|
||||
self.training = training
|
||||
|
||||
# Testing parameters.
|
||||
if not self.training:
|
||||
# The checkpoint to load and test against.
|
||||
# Example: './checkpoint/train_wrn_cifar10-200_390.ckpt'
|
||||
self.checkpoint_path = None
|
||||
|
||||
# Training parameters.
|
||||
if self.training:
|
||||
# Whether to apply auto-augment or not.
|
||||
self.augment = True
|
||||
|
||||
# The number of device(s) to be used for training.
|
||||
self.device_num = 1
|
||||
# Whether to train the model in a distributed mode or not.
|
||||
self.run_distribute = False
|
||||
# The pre-trained checkpoint to load and train from.
|
||||
# Example: './checkpoint/train_wrn_cifar10-200_390.ckpt'
|
||||
self.pre_trained = None
|
||||
|
||||
# Number of epochs to train.
|
||||
self.epoch_size = 200
|
||||
# Momentum factor.
|
||||
self.momentum = 0.9
|
||||
# L2 penalty.
|
||||
self.weight_decay = 5e-4
|
||||
# Learning rate decaying mode. Choices: ['cosine']
|
||||
self.lr_decay_mode = 'cosine'
|
||||
# The starting learning rate.
|
||||
self.lr_init = 0.1
|
||||
# The maximum learning rate.
|
||||
self.lr_max = 0.1
|
||||
# The number of warmup epochs. Note that during the warmup period,
|
||||
# the learning rate grows from `lr_init` to `lr_max` linearly.
|
||||
self.warmup_epochs = 5
|
||||
# Loss scaling for mixed-precision training.
|
||||
self.loss_scale = 1024
|
||||
|
||||
# Create a checkpoint per `save_checkpoint_epochs` epochs.
|
||||
self.save_checkpoint_epochs = 5
|
||||
# The maximum number of checkpoints to keep.
|
||||
self.keep_checkpoint_max = 10
|
||||
# The folder path to save checkpoints.
|
||||
self.save_checkpoint_path = './checkpoint'
|
||||
```
|
||||
|
||||
## 脚本使用
|
||||
|
||||
### AutoAugment算子用法
|
||||
|
||||
类似于[src/dataset/cifar10.py](./src/dataset/cifar10.py),为使用AutoAugment算子,首先需要引入`Augment`类:
|
||||
|
||||
```python
|
||||
# 开发者需将"src/dataset/autoaugment/"文件夹完整复制到当前目录,或使用软链接。
|
||||
from autoaugment import Augment
|
||||
```
|
||||
|
||||
AutoAugment算子与MindSpore数据集兼容,直接将其用作数据集的变换算子即可:
|
||||
|
||||
```python
|
||||
dataset = dataset.map(operations=[Augment(mean=MEAN, std=STD)],
|
||||
input_columns='image', num_parallel_workers=8)
|
||||
```
|
||||
|
||||
AutoAugment支持的参数如下:
|
||||
|
||||
```python
|
||||
Args:
|
||||
index (int or None): If index is not None, the indexed policy would
|
||||
always be used. Otherwise, a policy would be randomly chosen from
|
||||
the policies set for each image.
|
||||
policies (policies found by AutoAugment or None): A set of policies
|
||||
to sample from. When the given policies is None, good policies found
|
||||
on cifar10 would be used.
|
||||
enable_basic (bool): Whether to apply basic augmentations after
|
||||
auto-augment or not. Note that basic augmentations
|
||||
include RandomFlip, RandomCrop, and RandomCutout.
|
||||
from_pil (bool): Whether the image passed to the operator is already a
|
||||
PIL image.
|
||||
as_pil (bool): Whether the returned image should be kept as a PIL image.
|
||||
mean, std (list): Per-channel mean and std used to normalize the output
|
||||
image. Only applicable when as_pil is False.
|
||||
```
|
||||
|
||||
### 训练脚本用法
|
||||
|
||||
使用AutoAugment算子对数据集进行增广,并进行模型训练:
|
||||
|
||||
```bash
|
||||
# python train.py -h
|
||||
usage: train.py [-h] [--device_target {Ascend,GPU,CPU}] [--dataset {cifar10}]
|
||||
[--dataset_path DATASET_PATH] [--augment AUGMENT]
|
||||
[--device_num DEVICE_NUM] [--run_distribute RUN_DISTRIBUTE]
|
||||
[--lr_max LR_MAX] [--pre_trained PRE_TRAINED]
|
||||
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
|
||||
|
||||
AutoAugment for image classification.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--device_target {Ascend,GPU,CPU}
|
||||
Type of device(s) where the model would be deployed
|
||||
to.
|
||||
--dataset {cifar10} The dataset to train or test against.
|
||||
--dataset_path DATASET_PATH
|
||||
Path to the folder where the intended dataset is
|
||||
stored.
|
||||
--augment AUGMENT Whether to apply auto-augment or not.
|
||||
--device_num DEVICE_NUM
|
||||
The number of device(s) to be used for training.
|
||||
--run_distribute RUN_DISTRIBUTE
|
||||
Whether to train the model in distributed mode or not.
|
||||
--lr_max LR_MAX The maximum learning rate.
|
||||
--pre_trained PRE_TRAINED
|
||||
The pre-trained checkpoint to load and train from.
|
||||
Example: ./checkpoint/train_wrn_cifar10-200_390.ckpt
|
||||
--save_checkpoint_path SAVE_CHECKPOINT_PATH
|
||||
The folder path to save checkpoints.
|
||||
```
|
||||
|
||||
### 评估脚本用法
|
||||
|
||||
对训练好的模型进行精度评估:
|
||||
|
||||
```bash
|
||||
# python test.py -h
|
||||
usage: test.py [-h] [--device_target {Ascend,GPU,CPU}] [--dataset {cifar10}]
|
||||
[--dataset_path DATASET_PATH]
|
||||
[--checkpoint_path CHECKPOINT_PATH]
|
||||
|
||||
AutoAugment for image classification.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--device_target {Ascend,GPU,CPU}
|
||||
Type of device(s) where the model would be deployed
|
||||
to.
|
||||
--dataset {cifar10} The dataset to train or test against.
|
||||
--dataset_path DATASET_PATH
|
||||
Path to the folder where the intended dataset is
|
||||
stored.
|
||||
--checkpoint_path CHECKPOINT_PATH
|
||||
The checkpoint to load and test against.
|
||||
Example: ./checkpoint/train_wrn_cifar10-200_390.ckpt
|
||||
```
|
||||
|
||||
### 导出脚本用法
|
||||
|
||||
将训练好的模型导出为AIR、ONNX或MINDIR格式:
|
||||
|
||||
```bash
|
||||
# python export.py -h
|
||||
usage: export.py [-h] [--device_id DEVICE_ID] --checkpoint_path
|
||||
CHECKPOINT_PATH [--file_name FILE_NAME]
|
||||
[--file_format {AIR,ONNX,MINDIR}]
|
||||
[--device_target {Ascend,GPU,CPU}]
|
||||
|
||||
WRN with AutoAugment export.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--device_id DEVICE_ID
|
||||
Device id.
|
||||
--checkpoint_path CHECKPOINT_PATH
|
||||
Checkpoint file path.
|
||||
--file_name FILE_NAME
|
||||
Output file name.
|
||||
--file_format {AIR,ONNX,MINDIR}
|
||||
Export format.
|
||||
--device_target {Ascend,GPU,CPU}
|
||||
Device target.
|
||||
```
|
||||
|
||||
## 模型描述
|
||||
|
||||
| 参数 | 单卡Ascend 910 | 八卡Ascend 910 |
|
||||
|:---|:---|:---|
|
||||
| 资源 | Ascend 910 | Ascend 910 |
|
||||
| 上传日期 | 2021.06.21 | 2021.06.24 |
|
||||
| MindSpore版本 | 1.2.0 | 1.2.0 |
|
||||
| 训练数据集 | Cifar10 | Cifar10 |
|
||||
| 训练参数 | epoch=200, batch_size=128 | epoch=200, batch_size=128, lr_max=0.8 |
|
||||
| 优化器 | Momentum | Momentum |
|
||||
| 输出 | 损失 | 损失 |
|
||||
| 准确率 | 97.42% | 97.39% |
|
||||
| 速度 | 97.73 ms/step | 106.29 ms/step |
|
||||
| 总时长 | 127 min | 17 min |
|
||||
| 微调检查点 | 277M(.ckpt文件) | 277M(.ckpt文件) |
|
||||
| 脚本 | [autoaugment](./) | [autoaugment](./) |
|
||||
|
||||
## 随机情况说明
|
||||
|
||||
[train.py](./train.py)中设置了随机种子,以确保训练的可复现性。
|
||||
|
||||
## ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Export checkpoint file into air, onnx, mindir models."""
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import (
|
||||
context,
|
||||
export,
|
||||
load_checkpoint,
|
||||
load_param_into_net,
|
||||
Tensor,
|
||||
)
|
||||
|
||||
from src.config import Config
|
||||
from src.network import WRN
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='WRN with AutoAugment export.')
|
||||
parser.add_argument(
|
||||
'--device_id', type=int, default=0,
|
||||
help='Device id.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--checkpoint_path', type=str, required=True,
|
||||
help='Checkpoint file path.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--file_name', type=str, default='wrn-autoaugment',
|
||||
help='Output file name.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--file_format', type=str, choices=['AIR', 'ONNX', 'MINDIR'],
|
||||
default='AIR', help='Export format.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'],
|
||||
default='Ascend', help='Device target.',
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target=args.device_target,
|
||||
)
|
||||
if args.device_target == 'Ascend':
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
conf = Config(training=False, load_args=False)
|
||||
net = WRN(160, 3, conf.class_num)
|
||||
|
||||
param_dict = load_checkpoint(args.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
image = Tensor(np.ones((1, 3, 32, 32), np.float32))
|
||||
export(
|
||||
net, image,
|
||||
file_name=args.file_name,
|
||||
file_format=args.file_format,
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MindSpore Hub config"""
|
||||
|
||||
from src.network import WRN
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""Creates a WideResNet."""
|
||||
if name == 'WRN':
|
||||
return WRN(*args, **kwargs)
|
||||
raise NotImplementedError('%s is not implemented in the repo' % name)
|
|
@ -0,0 +1,49 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage:
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export RANK_TABLE_FILE=$1
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
||||
PID_LIST=()
|
||||
for ((i=0; i<${RANK_SIZE}; i++)); do
|
||||
export DEVICE_ID=${i}
|
||||
export RANK_ID=${i}
|
||||
echo "Start distributed training for rank ${RANK_ID}, device ${DEVICE_ID}"
|
||||
python ../train.py \
|
||||
--dataset=cifar10 \
|
||||
--dataset_path=$2 \
|
||||
--run_distribute=True \
|
||||
--lr_max=0.8 \
|
||||
> train-${i}.log 2>&1 &
|
||||
pid=$!
|
||||
PID_LIST+=("${pid}")
|
||||
done
|
||||
|
||||
RUN_BACKGROUND=1
|
||||
if (( RUN_BACKGROUND == 0 )); then
|
||||
echo "Waiting for all processes to exit..."
|
||||
for pid in ${PID_LIST[*]}; do
|
||||
wait ${pid}
|
||||
echo "Process ${pid} exited"
|
||||
done
|
||||
fi
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=0
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ $# == 2 ]; then
|
||||
python ../test.py \
|
||||
--checkpoint_path $1 \
|
||||
--dataset cifar10 \
|
||||
--dataset_path $2 \
|
||||
> eval.log 2>&1 &
|
||||
else
|
||||
echo "Usage: \
|
||||
bash run_eval.sh [CHECKPOINT_PATH] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=0
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ $# == 1 ]; then
|
||||
python ../train.py \
|
||||
--dataset cifar10 \
|
||||
--dataset_path $1 \
|
||||
> train.log 2>&1 &
|
||||
else
|
||||
echo "Usage: \
|
||||
bash run_standalone_train.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Package initialization.
|
||||
"""
|
|
@ -0,0 +1,224 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Configurable parameters.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
logger = logging.getLogger('config')
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Define configurable parameters.
|
||||
|
||||
Args:
|
||||
training (bool): Whether it is used in training mode or testing mode.
|
||||
load_args (bool): Whether to load cli arguments automatically or not.
|
||||
"""
|
||||
|
||||
def __init__(self, training, load_args=True):
|
||||
# Set to mute logs with lower levels.
|
||||
self.log_level = logging.INFO
|
||||
|
||||
# Random seed.
|
||||
self.seed = 1
|
||||
|
||||
# Type of device(s) where the model would be deployed to.
|
||||
# Choices: ['Ascend', 'GPU', 'CPU']
|
||||
self.device_target = 'Ascend'
|
||||
|
||||
# The model to use. Choices: ['wrn']
|
||||
self.net = 'wrn'
|
||||
|
||||
# The dataset to train or test against. Choices: ['cifar10']
|
||||
self.dataset = 'cifar10'
|
||||
# The number of classes.
|
||||
self.class_num = 10
|
||||
# Path to the folder where the intended dataset is stored.
|
||||
self.dataset_path = './cifar-10-batches-bin'
|
||||
|
||||
# Batch size for both training mode and testing mode.
|
||||
self.batch_size = 128
|
||||
|
||||
# Indicates training or testing mode.
|
||||
self.training = training
|
||||
|
||||
# Testing parameters.
|
||||
if not self.training:
|
||||
# The checkpoint to load and test against.
|
||||
# Example: './checkpoint/train_wrn_cifar10-200_390.ckpt'
|
||||
self.checkpoint_path = None
|
||||
|
||||
# Training parameters.
|
||||
if self.training:
|
||||
# Whether to apply auto-augment or not.
|
||||
self.augment = True
|
||||
|
||||
# The number of device(s) to be used for training.
|
||||
self.device_num = 1
|
||||
# Whether to train the model in a distributed mode or not.
|
||||
self.run_distribute = False
|
||||
# The pre-trained checkpoint to load and train from.
|
||||
# Example: './checkpoint/train_wrn_cifar10-200_390.ckpt'
|
||||
self.pre_trained = None
|
||||
|
||||
# Number of epochs to train.
|
||||
self.epoch_size = 200
|
||||
# Momentum factor.
|
||||
self.momentum = 0.9
|
||||
# L2 penalty.
|
||||
self.weight_decay = 5e-4
|
||||
# Learning rate decaying mode. Choices: ['cosine']
|
||||
self.lr_decay_mode = 'cosine'
|
||||
# The starting learning rate.
|
||||
self.lr_init = 0.1
|
||||
# The maximum learning rate.
|
||||
self.lr_max = 0.1
|
||||
# The number of warmup epochs. Note that during the warmup period,
|
||||
# the learning rate grows from `lr_init` to `lr_max` linearly.
|
||||
self.warmup_epochs = 5
|
||||
# Loss scaling for mixed-precision training.
|
||||
self.loss_scale = 1024
|
||||
|
||||
# Create a checkpoint per `save_checkpoint_epochs` epochs.
|
||||
self.save_checkpoint_epochs = 5
|
||||
# The maximum number of checkpoints to keep.
|
||||
self.keep_checkpoint_max = 10
|
||||
# The folder path to save checkpoints.
|
||||
self.save_checkpoint_path = './checkpoint'
|
||||
|
||||
# _init is an initialization guard, which helps warn setting attributes
|
||||
# outside __init__.
|
||||
self._init = True
|
||||
if load_args:
|
||||
self.load_args()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""___setattr__ is customized to warn adding attributes outside
|
||||
__init__ and encourage declaring configurable parameters explicitly in
|
||||
__init__."""
|
||||
if getattr(self, '_init', False) and not hasattr(self, name):
|
||||
logger.warning('attempting to add an attribute '
|
||||
'outside __init__: %s=%s', name, value)
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def load_args(self):
|
||||
"""load_args overwrites configurations by cli arguments."""
|
||||
hooks = {} # hooks are used to assign values.
|
||||
parser = argparse.ArgumentParser(
|
||||
description='AutoAugment for image classification.')
|
||||
|
||||
parser.add_argument(
|
||||
'--device_target', type=str, default='Ascend',
|
||||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='Type of device(s) where the model would be deployed to.',
|
||||
)
|
||||
def hook_device_target(x):
|
||||
"""Sets the device_target value."""
|
||||
self.device_target = x
|
||||
hooks['device_target'] = hook_device_target
|
||||
|
||||
parser.add_argument(
|
||||
'--dataset', type=str, default='cifar10',
|
||||
choices=['cifar10'],
|
||||
help='The dataset to train or test against.',
|
||||
)
|
||||
def hook_dataset(x):
|
||||
"""Sets the dataset value."""
|
||||
self.dataset = x
|
||||
hooks['dataset'] = hook_dataset
|
||||
|
||||
parser.add_argument(
|
||||
'--dataset_path', type=str, default='./cifar-10-batches-bin',
|
||||
help='Path to the folder where the intended dataset is stored.',
|
||||
)
|
||||
def hook_dataset_path(x):
|
||||
"""Sets the dataset_path value."""
|
||||
self.dataset_path = x
|
||||
hooks['dataset_path'] = hook_dataset_path
|
||||
|
||||
if not self.training:
|
||||
parser.add_argument(
|
||||
'--checkpoint_path', type=str, default=None,
|
||||
help='The checkpoint to load and test against. '
|
||||
'Example: ./checkpoint/train_wrn_cifar10-200_390.ckpt',
|
||||
)
|
||||
def hook_checkpoint_path(x):
|
||||
"""Sets the checkpoint_path value."""
|
||||
self.checkpoint_path = x
|
||||
hooks['checkpoint_path'] = hook_checkpoint_path
|
||||
|
||||
if self.training:
|
||||
parser.add_argument(
|
||||
'--augment', type=bool, default=True,
|
||||
help='Whether to apply auto-augment or not.',
|
||||
)
|
||||
def hook_augment(x):
|
||||
"""Sets the augment value."""
|
||||
self.augment = x
|
||||
hooks['augment'] = hook_augment
|
||||
|
||||
parser.add_argument(
|
||||
'--device_num', type=int, default=1,
|
||||
help='The number of device(s) to be used for training.',
|
||||
)
|
||||
def hook_device_num(x):
|
||||
"""Sets the device_num value."""
|
||||
self.device_num = x
|
||||
hooks['device_num'] = hook_device_num
|
||||
|
||||
parser.add_argument(
|
||||
'--run_distribute', type=bool, default=False,
|
||||
help='Whether to train the model in distributed mode or not.',
|
||||
)
|
||||
def hook_distribute(x):
|
||||
"""Sets the run_distribute value."""
|
||||
self.run_distribute = x
|
||||
hooks['run_distribute'] = hook_distribute
|
||||
|
||||
parser.add_argument(
|
||||
'--lr_max', type=float, default=0.1,
|
||||
help='The maximum learning rate.',
|
||||
)
|
||||
def hook_lr_max(x):
|
||||
"""Sets the lr_max value."""
|
||||
self.lr_max = x
|
||||
hooks['lr_max'] = hook_lr_max
|
||||
|
||||
parser.add_argument(
|
||||
'--pre_trained', type=str, default=None,
|
||||
help='The pre-trained checkpoint to load and train from. '
|
||||
'Example: ./checkpoint/train_wrn_cifar10-200_390.ckpt',
|
||||
)
|
||||
def hook_pre_trained(x):
|
||||
"""Sets the pre_trained value."""
|
||||
self.pre_trained = x
|
||||
hooks['pre_trained'] = hook_pre_trained
|
||||
|
||||
parser.add_argument(
|
||||
'--save_checkpoint_path', type=str, default='./checkpoint',
|
||||
help='The folder path to save checkpoints.',
|
||||
)
|
||||
def hook_save_checkpoint_path(x):
|
||||
"""Sets the save_checkpoint_path value."""
|
||||
self.save_checkpoint_path = x
|
||||
hooks['save_checkpoint_path'] = hook_save_checkpoint_path
|
||||
|
||||
# Overwrite default configurations by cli arguments
|
||||
args_opt = parser.parse_args()
|
||||
for name, val in args_opt.__dict__.items():
|
||||
hooks[name](val)
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Package initialization for dataset helpers.
|
||||
"""
|
||||
|
||||
from .cifar10 import create_dataset as create_cifar10_dataset
|
|
@ -0,0 +1,2 @@
|
|||
cifar-10-batches-bin/
|
||||
*.png
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Package initialization for the AutoAugment operator.
|
||||
"""
|
||||
|
||||
from .aug import Augment
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
The Augment operator.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import mindspore.dataset.vision.py_transforms as py_trans
|
||||
|
||||
from .third_party.policies import good_policies
|
||||
from .ops import OperatorClasses, RandomCutout
|
||||
|
||||
|
||||
class Augment:
|
||||
"""
|
||||
Augment acts as a single transformation operator and applies policies found
|
||||
by AutoAugment.
|
||||
|
||||
Args:
|
||||
index (int or None): If index is not None, the indexed policy would
|
||||
always be used. Otherwise, a policy would be randomly chosen from
|
||||
the policies set for each image.
|
||||
policies (policies found by AutoAugment or None): A set of policies
|
||||
to sample from. When the given policies is None, good policies found
|
||||
on cifar10 would be used.
|
||||
enable_basic (bool): Whether to apply basic augmentations after
|
||||
auto-augment or not. Note that basic augmentations
|
||||
include RandomFlip, RandomCrop, and RandomCutout.
|
||||
from_pil (bool): Whether the image passed to the operator is already a
|
||||
PIL image.
|
||||
as_pil (bool): Whether the returned image should be kept as a PIL image.
|
||||
mean, std (list): Per-channel mean and std used to normalize the output
|
||||
image. Only applicable when as_pil is False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, index=None, policies=None, enable_basic=True,
|
||||
from_pil=False, as_pil=False,
|
||||
mean=None, std=None,
|
||||
):
|
||||
self.index = index
|
||||
if policies is None:
|
||||
self.policies = good_policies()
|
||||
else:
|
||||
self.policies = policies
|
||||
|
||||
self.oc = OperatorClasses()
|
||||
self.to_pil = py_trans.ToPIL()
|
||||
self.to_tensor = py_trans.ToTensor()
|
||||
|
||||
self.enable_basic = enable_basic
|
||||
self.random_crop = self.oc.RandomCrop(None)
|
||||
self.random_flip = self.oc.RandomHorizontalFlip(None)
|
||||
self.cutout = RandomCutout(size=16, value=(0, 0, 0))
|
||||
|
||||
self.from_pil = from_pil
|
||||
self.as_pil = as_pil
|
||||
self.normalize = None
|
||||
if mean is not None and std is not None:
|
||||
self.normalize = py_trans.Normalize(mean, std)
|
||||
|
||||
def _apply(self, name, prob, level, img):
|
||||
if random.random() > prob:
|
||||
# Untouched
|
||||
return img
|
||||
# Apply the operator
|
||||
return getattr(self.oc, name)(level)(img)
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (raw image or PIL image): Image to be auto-augmented.
|
||||
|
||||
Returns:
|
||||
img (PIL image or Tensor), Auto-augmented image.
|
||||
"""
|
||||
if self.index is None:
|
||||
policy = random.choice(self.policies)
|
||||
else:
|
||||
policy = self.policies[self.index]
|
||||
|
||||
if not self.from_pil:
|
||||
img = self.to_pil(img)
|
||||
|
||||
for name, prob, level in policy:
|
||||
img = self._apply(name, prob, level, img)
|
||||
|
||||
if self.enable_basic:
|
||||
img = self.random_crop(img)
|
||||
img = self.random_flip(img)
|
||||
img = self.cutout(img)
|
||||
|
||||
if not self.as_pil:
|
||||
img = self.to_tensor(img)
|
||||
if self.normalize is not None:
|
||||
img = self.normalize(img)
|
||||
return img
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Visualization for testing purposes.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
def compare(data_path, index=None, ops=None, rescale=False):
|
||||
"""Visualize images before and after applying auto-augment."""
|
||||
# Load dataset
|
||||
ds.config.set_seed(8)
|
||||
dataset_orig = ds.Cifar10Dataset(
|
||||
data_path,
|
||||
num_samples=5,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# Apply transformations
|
||||
dataset_augmented = dataset_orig.map(
|
||||
operations=[Augment(index)] if ops is None else ops,
|
||||
input_columns=['image'],
|
||||
)
|
||||
|
||||
# Collect images
|
||||
image_orig_list, image_augmented_list, label_list = [], [], []
|
||||
for data in dataset_orig.create_dict_iterator():
|
||||
image_orig_list.append(data['image'])
|
||||
label_list.append(data['label'])
|
||||
print('Original image: shape {}, label {}'.format(
|
||||
data['image'].shape, data['label'],
|
||||
))
|
||||
for data in dataset_augmented.create_dict_iterator():
|
||||
image_augmented_list.append(data['image'])
|
||||
print('Augmented image: shape {}, label {}'.format(
|
||||
data['image'].shape, data['label'],
|
||||
))
|
||||
|
||||
num_samples = len(image_orig_list)
|
||||
fig, mesh = plt.subplots(ncols=num_samples, nrows=2, figsize=(5, 3))
|
||||
axes = mesh[0]
|
||||
for i in range(num_samples):
|
||||
axes[i].axis('off')
|
||||
axes[i].imshow(image_orig_list[i].asnumpy())
|
||||
axes[i].set_title(label_list[i].asnumpy())
|
||||
axes = mesh[1]
|
||||
for i in range(num_samples):
|
||||
axes[i].axis('off')
|
||||
img = image_augmented_list[i].asnumpy().transpose((1, 2, 0))
|
||||
if rescale:
|
||||
max_val = max(np.abs(img.min()), img.max())
|
||||
img = (img / max_val + 1) / 2
|
||||
print('min and max of the transformed image:', img.min(), img.max())
|
||||
axes[i].imshow(img)
|
||||
fig.tight_layout()
|
||||
fig.savefig(
|
||||
'aug_test.png' if index is None else 'aug_test_{}.png'.format(index),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.path.append('..')
|
||||
from autoaugment.third_party.policies import good_policies
|
||||
from autoaugment import Augment
|
||||
|
||||
cifar10_data_path = './cifar-10-batches-bin/'
|
||||
|
||||
# Test the feasibility of each policy
|
||||
for ind, policy in enumerate(good_policies()):
|
||||
if ind >= 3:
|
||||
pass
|
||||
# break
|
||||
print(policy)
|
||||
compare(cifar10_data_path, ind)
|
||||
|
||||
# Test the random policy selection and the normalize operation
|
||||
MEAN = [0.49139968, 0.48215841, 0.44653091]
|
||||
STD = [0.24703223, 0.24348513, 0.26158784]
|
||||
compare(
|
||||
cifar10_data_path,
|
||||
ops=[Augment(mean=MEAN, std=STD, enable_basic=False)],
|
||||
)
|
||||
compare(
|
||||
cifar10_data_path,
|
||||
ops=[Augment(mean=MEAN, std=STD, enable_basic=False)],
|
||||
rescale=True,
|
||||
)
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Package initialization for custom PIL operators.
|
||||
"""
|
||||
|
||||
from mindspore.dataset.vision import py_transforms
|
||||
|
||||
from .crop import RandomCrop
|
||||
from .cutout import RandomCutout
|
||||
from .effect import (
|
||||
Posterize,
|
||||
Solarize,
|
||||
)
|
||||
from .enhance import (
|
||||
Brightness,
|
||||
Color,
|
||||
Contrast,
|
||||
Sharpness,
|
||||
)
|
||||
from .transform import (
|
||||
Rotate,
|
||||
ShearX,
|
||||
ShearY,
|
||||
TranslateX,
|
||||
TranslateY,
|
||||
)
|
||||
|
||||
|
||||
class OperatorClasses:
|
||||
"""OperatorClasses gathers all unary-image transformations listed in the
|
||||
Table 6 of https://arxiv.org/abs/1805.09501 and uses discrte levels for
|
||||
these transformations (The Sample Pairing transformation is an
|
||||
exception, which involes multiple images from a single mini-batch and
|
||||
is not exploited in this implementation.)
|
||||
|
||||
Additionally, there are RandomHorizontalFlip and RandomCrop.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.Rotate = self.decorate(Rotate, max_val=30, rounding=True)
|
||||
self.ShearX = self.decorate(ShearX, max_val=0.3)
|
||||
self.ShearY = self.decorate(ShearY, max_val=0.3)
|
||||
self.TranslateX = self.decorate(TranslateX, max_val=10, rounding=True)
|
||||
self.TranslateY = self.decorate(TranslateY, max_val=10, rounding=True)
|
||||
|
||||
self.AutoContrast = self.decorate(py_transforms.AutoContrast)
|
||||
self.Invert = self.decorate(py_transforms.Invert)
|
||||
self.Equalize = self.decorate(py_transforms.Equalize)
|
||||
|
||||
self.Solarize = self.decorate(
|
||||
Solarize, max_val=256, rounding=True, post=lambda x: 256 - x)
|
||||
self.Posterize = self.decorate(
|
||||
Posterize, max_val=4, rounding=True, post=lambda x: 4 - x)
|
||||
|
||||
def post(x):
|
||||
"""Post operation to avoid 0 value."""
|
||||
return x + 0.1
|
||||
self.Brightness = self.decorate(Brightness, max_val=1.8, post=post)
|
||||
self.Color = self.decorate(Color, max_val=1.8, post=post)
|
||||
self.Contrast = self.decorate(Contrast, max_val=1.8, post=post)
|
||||
self.Sharpness = self.decorate(Sharpness, max_val=1.8, post=post)
|
||||
|
||||
self.Cutout = self.decorate(RandomCutout, max_val=20, rounding=True)
|
||||
|
||||
self.RandomHorizontalFlip = self.decorate(
|
||||
py_transforms.RandomHorizontalFlip)
|
||||
self.RandomCrop = self.decorate(RandomCrop)
|
||||
|
||||
def vars(self):
|
||||
"""vars returns all available operators as a dictionary."""
|
||||
return vars(self)
|
||||
|
||||
def decorate(self, op, max_val=None, rounding=False, post=None):
|
||||
"""
|
||||
decorate interprets discrete levels for the given operator when
|
||||
applicable.
|
||||
|
||||
Args:
|
||||
op (Augmentation Operator): Operator to be decorated.
|
||||
max_val (int or float): Maximum value level-10 corresponds to.
|
||||
rounding (bool): Whether the corresponding value should be rounded
|
||||
to an integer.
|
||||
post (function): User-defined post-processing value function.
|
||||
|
||||
Returns:
|
||||
Decorated operator.
|
||||
"""
|
||||
if max_val is None:
|
||||
def no_arg_fn(_):
|
||||
"""Decorates an operator without level parameter."""
|
||||
return op()
|
||||
return no_arg_fn
|
||||
|
||||
def fn(level):
|
||||
"""Decorates an operator with level parameter."""
|
||||
val = max_val * level / 10
|
||||
if rounding:
|
||||
val = int(val)
|
||||
if post is not None:
|
||||
val = post(val)
|
||||
return op(val)
|
||||
return fn
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
RandomCrop operator.
|
||||
"""
|
||||
|
||||
from mindspore.dataset.vision import py_transforms
|
||||
from mindspore.dataset.vision import py_transforms_util
|
||||
from mindspore.dataset.vision import utils
|
||||
|
||||
|
||||
class RandomCrop(py_transforms.RandomCrop):
|
||||
"""
|
||||
RandomCrop inherits from py_transforms.RandomCrop but derives/uses the
|
||||
original image size as the output size.
|
||||
|
||||
Please refer to py_transforms.RandomCrop for argument specifications.
|
||||
"""
|
||||
|
||||
def __init__(self, padding=4, pad_if_needed=False,
|
||||
fill_value=0, padding_mode=utils.Border.CONSTANT):
|
||||
# Note the `1` for the size argument is only set for passing the check.
|
||||
super(RandomCrop, self).__init__(1, padding=padding, pad_if_needed=pad_if_needed,
|
||||
fill_value=fill_value, padding_mode=padding_mode)
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to be padded and then randomly cropped back
|
||||
to the same size.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Randomly cropped image.
|
||||
"""
|
||||
if not py_transforms_util.is_pil(img):
|
||||
raise TypeError(
|
||||
py_transforms_util.augment_error_message.format(type(img)))
|
||||
|
||||
return py_transforms_util.random_crop(
|
||||
img, img.size, self.padding, self.pad_if_needed,
|
||||
self.fill_value, self.padding_mode,
|
||||
)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
RandomCutout operator.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
|
||||
class RandomCutout:
|
||||
"""
|
||||
RandomCutout is similar to py_transforms.Cutout but is simplified and
|
||||
crafted for PIL images.
|
||||
|
||||
Args:
|
||||
size (int): the side-size of each square cutout patch.
|
||||
num_patches (int): the number of square cutout patches to add.
|
||||
value (RGB value): the pixel value to fill in each cutout patches.
|
||||
"""
|
||||
|
||||
def __init__(self, size=30, num_patches=1, value=(125, 122, 113)):
|
||||
self.size = size
|
||||
self.num_patches = num_patches
|
||||
self.value = value
|
||||
|
||||
@staticmethod
|
||||
def _clip(x, lower, upper):
|
||||
"""Clip value to the [lower, upper] range."""
|
||||
return max(lower, min(x, upper))
|
||||
|
||||
@staticmethod
|
||||
def _get_cutout_area(img_w, img_h, size):
|
||||
"""Randomly create a cutout area."""
|
||||
x = random.randint(0, img_w)
|
||||
y = random.randint(0, img_h)
|
||||
x1 = x - size // 2
|
||||
x2 = x1 + size
|
||||
y1 = y - size // 2
|
||||
y2 = y1 + size
|
||||
x1 = RandomCutout._clip(x1, 0, img_w)
|
||||
x2 = RandomCutout._clip(x2, 0, img_w)
|
||||
y1 = RandomCutout._clip(y1, 0, img_h)
|
||||
y2 = RandomCutout._clip(y2, 0, img_h)
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to be cutout.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Randomly cutout image.
|
||||
"""
|
||||
img_w, img_h = img.size
|
||||
pixels = img.load()
|
||||
|
||||
for _ in range(self.num_patches):
|
||||
x1, y1, x2, y2 = self._get_cutout_area(img_w, img_h, self.size)
|
||||
for i in range(x1, x2): # columns
|
||||
for j in range(y1, y2): # rows
|
||||
pixels[i, j] = self.value
|
||||
|
||||
return img
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Operators for applying image effects.
|
||||
"""
|
||||
|
||||
from PIL import ImageOps
|
||||
|
||||
from mindspore.dataset.vision import py_transforms_util
|
||||
|
||||
|
||||
class Solarize:
|
||||
"""
|
||||
Solarize inverts image pixels with values above the configured threshold.
|
||||
|
||||
Args:
|
||||
threshold (int): All pixels above the threshold would be inverted.
|
||||
Ranging within [0, 255].
|
||||
"""
|
||||
|
||||
def __init__(self, threshold):
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to be solarized.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Solarized image.
|
||||
"""
|
||||
if not py_transforms_util.is_pil(img):
|
||||
raise TypeError(
|
||||
py_transforms_util.augment_error_message.format(type(img)))
|
||||
|
||||
return ImageOps.solarize(img, self.threshold)
|
||||
|
||||
|
||||
class Posterize:
|
||||
"""
|
||||
Posterize reduces the number of bits for each color channel.
|
||||
|
||||
Args:
|
||||
bits (int): The number of bits to keep for each channel.
|
||||
Ranging within [1, 8].
|
||||
"""
|
||||
|
||||
def __init__(self, bits):
|
||||
self.bits = bits
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to be posterized.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Posterized image.
|
||||
"""
|
||||
if not py_transforms_util.is_pil(img):
|
||||
raise TypeError(
|
||||
py_transforms_util.augment_error_message.format(type(img)))
|
||||
|
||||
return ImageOps.posterize(img, self.bits)
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Operators for enhancing images.
|
||||
"""
|
||||
|
||||
from PIL import ImageEnhance
|
||||
|
||||
from mindspore.dataset.vision import py_transforms_util
|
||||
|
||||
|
||||
class Contrast:
|
||||
"""
|
||||
Contrast adjusts the contrast of images.
|
||||
|
||||
Args:
|
||||
degree (float): contrast degree ranging within [0.1, 1.9], where 1.0
|
||||
indicating an unchanged contrast.
|
||||
"""
|
||||
|
||||
def __init__(self, degree):
|
||||
self.degree = degree
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): PIL image to be adjusted.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Contrast adjusted image.
|
||||
"""
|
||||
return py_transforms_util.adjust_contrast(img, self.degree)
|
||||
|
||||
|
||||
class Color:
|
||||
"""
|
||||
Color adjusts the saturation of images.
|
||||
|
||||
Args:
|
||||
degree (float): saturation degree ranging within [0.1, 1.9], where 1.0
|
||||
indicating an unchanged saturation.
|
||||
"""
|
||||
|
||||
def __init__(self, degree):
|
||||
self.degree = degree
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): PIL image to be adjusted.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Saturation adjusted image.
|
||||
"""
|
||||
return py_transforms_util.adjust_saturation(img, self.degree)
|
||||
|
||||
|
||||
class Brightness:
|
||||
"""
|
||||
Brightness adjusts the brightness of images.
|
||||
|
||||
Args:
|
||||
degree (float): brightness degree ranging within [0.1, 1.9], where 1.0
|
||||
indicating an unchanged brightness.
|
||||
"""
|
||||
|
||||
def __init__(self, degree):
|
||||
self.degree = degree
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to be adjusted.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Brightness adjusted image.
|
||||
"""
|
||||
return py_transforms_util.adjust_brightness(img, self.degree)
|
||||
|
||||
|
||||
class Sharpness:
|
||||
"""
|
||||
Sharpness adjusts the sharpness of images.
|
||||
|
||||
Args:
|
||||
degree (float): sharpness degree ranging within [0.1, 1.9], where 1.0
|
||||
indicating an unchanged sharpness.
|
||||
"""
|
||||
|
||||
def __init__(self, degree):
|
||||
self.degree = degree
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to be sharpness adjusted.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Sharpness adjusted image.
|
||||
"""
|
||||
if not py_transforms_util.is_pil(img):
|
||||
raise TypeError(
|
||||
py_transforms_util.augment_error_message.format(type(img)))
|
||||
|
||||
return ImageEnhance.Sharpness(img).enhance(self.degree)
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Visualization for testing purposes.
|
||||
"""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.py_transforms as py_trans
|
||||
from mindspore import context
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
|
||||
|
||||
|
||||
def compare(data_path, trans, output_path='./ops_test.png'):
|
||||
"""Visualize images before and after applying the given transformations."""
|
||||
# Load dataset
|
||||
ds.config.set_seed(8)
|
||||
dataset_orig = ds.Cifar10Dataset(
|
||||
data_path,
|
||||
num_samples=5,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# Apply transformations
|
||||
dataset_augmented = dataset_orig.map(
|
||||
operations=[py_trans.ToPIL()] + trans + [py_trans.ToTensor()],
|
||||
input_columns=['image'],
|
||||
)
|
||||
|
||||
# Collect images
|
||||
image_orig_list, image_augmented_list, label_list = [], [], []
|
||||
for data in dataset_orig.create_dict_iterator():
|
||||
image_orig_list.append(data['image'])
|
||||
label_list.append(data['label'])
|
||||
print('Original image: shape {}, label {}'.format(
|
||||
data['image'].shape, data['label'],
|
||||
))
|
||||
for data in dataset_augmented.create_dict_iterator():
|
||||
image_augmented_list.append(data['image'])
|
||||
print('Augmented image: shape {}, label {}'.format(
|
||||
data['image'].shape, data['label'],
|
||||
))
|
||||
|
||||
# Plot images
|
||||
num_samples = len(image_orig_list)
|
||||
fig, mesh = plt.subplots(ncols=num_samples, nrows=2, figsize=(5, 3))
|
||||
axes = mesh[0]
|
||||
for i in range(num_samples):
|
||||
axes[i].axis('off')
|
||||
axes[i].imshow(image_orig_list[i].asnumpy())
|
||||
axes[i].set_title(label_list[i].asnumpy())
|
||||
axes = mesh[1]
|
||||
for i in range(num_samples):
|
||||
axes[i].axis('off')
|
||||
axes[i].imshow(image_augmented_list[i].asnumpy().transpose((1, 2, 0)))
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
from ops import OperatorClasses
|
||||
oc = OperatorClasses()
|
||||
|
||||
levels = {
|
||||
'Contrast': 7,
|
||||
'Rotate': 7,
|
||||
'TranslateX': 9,
|
||||
'Sharpness': 9,
|
||||
'ShearY': 8,
|
||||
'TranslateY': 9,
|
||||
'AutoContrast': 9,
|
||||
'Equalize': 9,
|
||||
'Solarize': 8,
|
||||
'Color': 9,
|
||||
'Posterize': 7,
|
||||
'Brightness': 9,
|
||||
'Cutout': 4,
|
||||
'ShearX': 4,
|
||||
'Invert': 7,
|
||||
'RandomHorizontalFlip': None,
|
||||
'RandomCrop': None,
|
||||
}
|
||||
|
||||
cifar10_data_path = './cifar-10-batches-bin/'
|
||||
for name, op in oc.vars().items():
|
||||
compare(cifar10_data_path, [op(levels[name])], './ops_{}_{}.png'.format(
|
||||
name, levels[name],
|
||||
))
|
|
@ -0,0 +1,252 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Operators for affine transformations.
|
||||
"""
|
||||
|
||||
import numbers
|
||||
import random
|
||||
|
||||
from PIL import Image, __version__
|
||||
|
||||
from mindspore.dataset.vision.py_transforms import DE_PY_INTER_MODE
|
||||
from mindspore.dataset.vision.py_transforms_util import (
|
||||
augment_error_message,
|
||||
is_pil,
|
||||
rotate,
|
||||
)
|
||||
from mindspore.dataset.vision.utils import Inter
|
||||
|
||||
|
||||
class ShearX:
|
||||
"""
|
||||
ShearX shears images along the x-axis.
|
||||
|
||||
Args:
|
||||
shear (int): the pixel size to shear.
|
||||
resample (enum): the interpolation mode.
|
||||
fill_value (int or tuple): the filling value to fill the area outside
|
||||
the transform in the output image.
|
||||
"""
|
||||
|
||||
def __init__(self, shear, resample=Inter.NEAREST, fill_value=0):
|
||||
if not isinstance(shear, numbers.Number):
|
||||
raise TypeError('shear must be a single number.')
|
||||
|
||||
self.shear = shear
|
||||
self.resample = DE_PY_INTER_MODE[resample]
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to apply shear_x transformation.
|
||||
|
||||
Returns:
|
||||
img (PIL image), X-axis sheared image.
|
||||
"""
|
||||
if not is_pil(img):
|
||||
raise ValueError('Input image should be a Pillow image.')
|
||||
|
||||
output_size = img.size
|
||||
shear = self.shear if random.random() > 0.5 else -self.shear
|
||||
matrix = (1, shear, 0, 0, 1, 0)
|
||||
|
||||
if __version__ >= '5':
|
||||
kwargs = {'fillcolor': self.fill_value}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
return img.transform(output_size, Image.AFFINE, matrix,
|
||||
self.resample, **kwargs)
|
||||
|
||||
|
||||
class ShearY:
|
||||
"""
|
||||
ShearY shears images along the y-axis.
|
||||
|
||||
Args:
|
||||
shear (int): the pixel size to shear.
|
||||
resample (enum): the interpolation mode.
|
||||
fill_value (int or tuple): the filling value to fill the area outside
|
||||
the transform in the output image.
|
||||
"""
|
||||
|
||||
def __init__(self, shear, resample=Inter.NEAREST, fill_value=0):
|
||||
if not isinstance(shear, numbers.Number):
|
||||
raise TypeError('shear must be a single number.')
|
||||
|
||||
self.shear = shear
|
||||
self.resample = DE_PY_INTER_MODE[resample]
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to apply shear_y transformation.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Y-axis sheared image.
|
||||
"""
|
||||
if not is_pil(img):
|
||||
raise ValueError('Input image should be a Pillow image.')
|
||||
|
||||
output_size = img.size
|
||||
shear = self.shear if random.random() > 0.5 else -self.shear
|
||||
matrix = (1, 0, 0, shear, 1, 0)
|
||||
|
||||
if __version__ >= '5':
|
||||
kwargs = {'fillcolor': self.fill_value}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
return img.transform(output_size, Image.AFFINE, matrix,
|
||||
self.resample, **kwargs)
|
||||
|
||||
|
||||
class TranslateX:
|
||||
"""
|
||||
TranslateX translates images along the x-axis.
|
||||
|
||||
Args:
|
||||
translate (int): the pixel size to translate.
|
||||
resample (enum): the interpolation mode.
|
||||
fill_value (int or tuple): the filling value to fill the area outside
|
||||
the transform in the output image.
|
||||
"""
|
||||
|
||||
def __init__(self, translate, resample=Inter.NEAREST, fill_value=0):
|
||||
if not isinstance(translate, numbers.Number):
|
||||
raise TypeError('translate must be a single number.')
|
||||
|
||||
self.translate = translate
|
||||
self.resample = DE_PY_INTER_MODE[resample]
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to apply translate_x transformation.
|
||||
|
||||
Returns:
|
||||
img (PIL image), X-axis translated image.
|
||||
"""
|
||||
if not is_pil(img):
|
||||
raise ValueError('Input image should be a Pillow image.')
|
||||
|
||||
output_size = img.size
|
||||
trans = self.translate if random.random() > 0.5 else -self.translate
|
||||
matrix = (1, 0, trans, 0, 1, 0)
|
||||
|
||||
if __version__ >= '5':
|
||||
kwargs = {'fillcolor': self.fill_value}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
return img.transform(output_size, Image.AFFINE, matrix,
|
||||
self.resample, **kwargs)
|
||||
|
||||
|
||||
class TranslateY:
|
||||
"""
|
||||
TranslateY translates images along the y-axis.
|
||||
|
||||
Args:
|
||||
translate (int): the pixel size to translate.
|
||||
resample (enum): the interpolation mode.
|
||||
fill_value (int or tuple): the filling value to fill the area outside
|
||||
the transform in the output image.
|
||||
"""
|
||||
|
||||
def __init__(self, translate, resample=Inter.NEAREST, fill_value=0):
|
||||
if not isinstance(translate, numbers.Number):
|
||||
raise TypeError('Translate must be a single number.')
|
||||
|
||||
self.translate = translate
|
||||
self.resample = DE_PY_INTER_MODE[resample]
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to apply translate_y transformation.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Y-axis translated image.
|
||||
"""
|
||||
if not is_pil(img):
|
||||
raise ValueError('Input image should be a Pillow image.')
|
||||
|
||||
output_size = img.size
|
||||
trans = self.translate if random.random() > 0.5 else -self.translate
|
||||
matrix = (1, 0, 0, 0, 1, trans)
|
||||
|
||||
if __version__ >= '5':
|
||||
kwargs = {'fillcolor': self.fill_value}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
return img.transform(output_size, Image.AFFINE, matrix,
|
||||
self.resample, **kwargs)
|
||||
|
||||
|
||||
class Rotate:
|
||||
"""
|
||||
Rotate is similar to py_vision.RandomRotation but uses a fixed degree.
|
||||
|
||||
Args:
|
||||
degree (int): the degree to rotate.
|
||||
|
||||
Please refer to py_transforms.RandomRotation for more argument
|
||||
specifications.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, degree,
|
||||
resample=Inter.NEAREST, expand=False, center=None, fill_value=0,
|
||||
):
|
||||
if not isinstance(degree, numbers.Number):
|
||||
raise TypeError('degree must be a single number.')
|
||||
|
||||
self.degree = degree
|
||||
self.resample = DE_PY_INTER_MODE[resample]
|
||||
self.expand = expand
|
||||
self.center = center
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL image): Image to be rotated.
|
||||
|
||||
Returns:
|
||||
img (PIL image), Rotated image.
|
||||
"""
|
||||
if not is_pil(img):
|
||||
raise TypeError(augment_error_message.format(type(img)))
|
||||
|
||||
degree = self.degree if random.random() > 0.5 else -self.degree
|
||||
return rotate(img, degree, self.resample, self.expand,
|
||||
self.center, self.fill_value)
|
140
model_zoo/research/cv/autoaugment/src/dataset/autoaugment/third_party/policies.py
vendored
Normal file
140
model_zoo/research/cv/autoaugment/src/dataset/autoaugment/third_party/policies.py
vendored
Normal file
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Good policies found by the AutoAugment paper.
|
||||
"""
|
||||
|
||||
|
||||
def good_policies():
|
||||
"""AutoAugment policies found on Cifar."""
|
||||
exp0_0 = [
|
||||
[('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
|
||||
[('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
|
||||
[('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
|
||||
[('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
|
||||
[('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
|
||||
exp0_1 = [
|
||||
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
|
||||
[('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
|
||||
[('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
|
||||
[('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
|
||||
[('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
|
||||
exp0_2 = [
|
||||
[('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
|
||||
[('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
|
||||
[('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
|
||||
[('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
|
||||
[('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
|
||||
exp0_3 = [
|
||||
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
|
||||
[('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
|
||||
[('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
|
||||
[('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
|
||||
[('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
|
||||
exp1_0 = [
|
||||
[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
|
||||
[('Color', 0.4, 3), ('Brightness', 0.6, 7)],
|
||||
[('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
|
||||
[('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
|
||||
[('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
|
||||
exp1_1 = [
|
||||
[('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
|
||||
[('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
|
||||
[('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
|
||||
[('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
|
||||
[('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
|
||||
exp1_2 = [
|
||||
[('Solarize', 0.2, 6), ('Color', 0.8, 6)],
|
||||
[('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
|
||||
[('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
|
||||
[('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
|
||||
[('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
|
||||
exp1_3 = [
|
||||
[('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
|
||||
[('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
|
||||
[('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
|
||||
[('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
|
||||
[('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
|
||||
exp1_4 = [
|
||||
[('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
|
||||
[('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
|
||||
[('Equalize', 0.6, 8), ('Color', 0.6, 2)],
|
||||
[('Color', 0.3, 7), ('Color', 0.2, 4)],
|
||||
[('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
|
||||
exp1_5 = [
|
||||
[('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
|
||||
[('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
|
||||
[('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
|
||||
[('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
|
||||
exp1_6 = [
|
||||
[('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
|
||||
[('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
|
||||
[('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
|
||||
[('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
|
||||
[('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
|
||||
exp2_0 = [
|
||||
[('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
|
||||
[('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
|
||||
[('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
|
||||
[('Brightness', 0.9, 6), ('Color', 0.2, 8)],
|
||||
[('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
|
||||
exp2_1 = [
|
||||
[('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
|
||||
[('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
|
||||
[('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
|
||||
[('Color', 0.1, 8), ('ShearY', 0.2, 3)],
|
||||
[('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
|
||||
exp2_2 = [
|
||||
[('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
|
||||
[('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
|
||||
[('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
|
||||
[('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
|
||||
[('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
|
||||
exp2_3 = [
|
||||
[('Equalize', 0.9, 5), ('Color', 0.7, 0)],
|
||||
[('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
|
||||
[('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
|
||||
[('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
|
||||
[('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
|
||||
exp2_4 = [
|
||||
[('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
|
||||
[('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
|
||||
[('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
|
||||
[('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
|
||||
[('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
|
||||
exp2_5 = [
|
||||
[('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
|
||||
[('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
|
||||
[('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
|
||||
[('Solarize', 0.4, 3), ('Color', 0.2, 4)],
|
||||
[('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
|
||||
exp2_6 = [
|
||||
[('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
|
||||
[('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
|
||||
[('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
|
||||
[('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
|
||||
exp2_7 = [
|
||||
[('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
|
||||
[('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
|
||||
[('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
|
||||
[('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
|
||||
[('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
|
||||
exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
|
||||
exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
|
||||
exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + \
|
||||
exp2_7
|
||||
return exp0s + exp1s + exp2s
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Helpers for creating Cifar-10 datasets (optionally with AutoAugment enabled).
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
|
||||
from .autoaugment import Augment
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""Get rank size and rank id."""
|
||||
rank_size = int(os.environ.get('RANK_SIZE', 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train=True, repeat_num=1, batch_size=32,
|
||||
target='Ascend', distribute=False, augment=True):
|
||||
"""
|
||||
Create a train or test cifar10 dataset.
|
||||
|
||||
Args:
|
||||
dataset_path (string): Path to the dataset.
|
||||
do_train (bool): Whether dataset is used for training or testing.
|
||||
repeat_num (int): Repeat times of the dataset.
|
||||
batch_size (int): Batch size of the dataset.
|
||||
target (str): Device target.
|
||||
distribute (bool): For distributed training or not.
|
||||
augment (bool): Whether to enable auto-augment or not.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == 'Ascend':
|
||||
rank_size, rank_id = _get_rank_info()
|
||||
else:
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
else:
|
||||
rank_size = 1
|
||||
|
||||
num_shards = None if rank_size == 1 else rank_size
|
||||
shard_id = None if rank_size == 1 else rank_id
|
||||
dataset = ds.Cifar10Dataset(
|
||||
dataset_path, usage='train' if do_train else 'test',
|
||||
num_parallel_workers=8, shuffle=True,
|
||||
num_shards=num_shards, shard_id=shard_id,
|
||||
)
|
||||
|
||||
# Define map operations
|
||||
MEAN = [0.4914, 0.4822, 0.4465]
|
||||
STD = [0.2023, 0.1994, 0.2010]
|
||||
trans = []
|
||||
if do_train and augment:
|
||||
trans += [
|
||||
Augment(mean=MEAN, std=STD),
|
||||
]
|
||||
else:
|
||||
if do_train:
|
||||
trans += [
|
||||
C.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||
C.RandomHorizontalFlip(),
|
||||
]
|
||||
trans += [
|
||||
C.Rescale(1. / 255., 0.),
|
||||
C.Normalize(MEAN, STD),
|
||||
C.HWC2CHW(),
|
||||
]
|
||||
dataset = dataset.map(operations=trans,
|
||||
input_columns='image', num_parallel_workers=8)
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
dataset = dataset.map(operations=type_cast_op,
|
||||
input_columns='label', num_parallel_workers=8)
|
||||
|
||||
# Apply batch operations
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# Apply dataset repeat operation
|
||||
dataset = dataset.repeat(repeat_num)
|
||||
|
||||
return dataset
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Package initialization for network definitions.
|
||||
"""
|
||||
|
||||
from .wrn import WRN
|
|
@ -0,0 +1,233 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
WideResNet building blocks.
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def _get_optional_avg_pool(stride):
|
||||
"""
|
||||
Create an average pool op if stride is larger than 1, return None otherwise.
|
||||
|
||||
Args:
|
||||
stride (int): Stride size, a positive integer.
|
||||
|
||||
Returns:
|
||||
nn.AvgPool2d or None.
|
||||
"""
|
||||
if stride == 1:
|
||||
return None
|
||||
return nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
|
||||
def _get_optional_pad(in_channels, out_channels):
|
||||
"""
|
||||
Create a zero-pad op if out_channels is larger than in_channels, return None
|
||||
otherwise.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channel size.
|
||||
out_channels (int): The output channel size (must not be smaller than
|
||||
in_channels).
|
||||
|
||||
Returns:
|
||||
nn.Pad or None.
|
||||
"""
|
||||
if in_channels == out_channels:
|
||||
return None
|
||||
pad_left = (out_channels - in_channels) // 2
|
||||
pad_right = out_channels - in_channels - pad_left
|
||||
return nn.Pad((
|
||||
(0, 0),
|
||||
(pad_left, pad_right),
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
))
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResidualBlock is the basic building block for wide-resnet.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channel size.
|
||||
out_channels (int): The output channel size.
|
||||
stride (int): The stride size used in the first convolution layer.
|
||||
activate_before_residual (bool): Whether to apply bn and relu before
|
||||
residual.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
activate_before_residual=False,
|
||||
):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.activate_before_residual = activate_before_residual
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=3, stride=stride)
|
||||
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, stride=1)
|
||||
|
||||
self.avg_pool = _get_optional_avg_pool(stride)
|
||||
self.pad = _get_optional_pad(in_channels, out_channels)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""Construct the forward network."""
|
||||
if self.activate_before_residual:
|
||||
out = self.bn1(x)
|
||||
out = self.relu(out)
|
||||
orig_x = out
|
||||
else:
|
||||
orig_x = x
|
||||
out = self.bn1(x)
|
||||
out = self.relu(out)
|
||||
out = self.conv1(out)
|
||||
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.avg_pool is not None:
|
||||
orig_x = self.avg_pool(orig_x)
|
||||
if self.pad is not None:
|
||||
orig_x = self.pad(orig_x)
|
||||
return self.add(out, orig_x)
|
||||
|
||||
|
||||
class ResidualGroup(nn.Cell):
|
||||
"""
|
||||
ResidualGroup gathers a group of ResidualBlocks (default: 4).
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channel size.
|
||||
out_channels (int): The output channel size.
|
||||
stride (int): The stride size used in the first ResidualBlock.
|
||||
activate_before_residual (bool): Whether to apply bn and relu before
|
||||
residual in the first ResidualBlock.
|
||||
num_blocks (int): Number of ResidualBlocks in the group.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
activate_before_residual=False,
|
||||
num_blocks=4,
|
||||
):
|
||||
super(ResidualGroup, self).__init__()
|
||||
|
||||
self.rb = ResidualBlock(in_channels, out_channels, stride,
|
||||
activate_before_residual)
|
||||
self.rbs = nn.SequentialCell([
|
||||
ResidualBlock(out_channels, out_channels, 1)
|
||||
for _ in range(num_blocks - 1)
|
||||
])
|
||||
|
||||
self.avg_pool = _get_optional_avg_pool(stride)
|
||||
self.pad = _get_optional_pad(in_channels, out_channels)
|
||||
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""Construct the forward network."""
|
||||
orig_x = x
|
||||
|
||||
out = self.rb(x)
|
||||
out = self.rbs(out)
|
||||
|
||||
if self.avg_pool is not None:
|
||||
orig_x = self.avg_pool(orig_x)
|
||||
if self.pad is not None:
|
||||
orig_x = self.pad(orig_x)
|
||||
return self.add(out, orig_x)
|
||||
|
||||
|
||||
class WRN(nn.Cell):
|
||||
"""
|
||||
WRN is short for Wide-ResNet.
|
||||
|
||||
Args:
|
||||
wrn_size (int): Wide-ResNet size.
|
||||
in_channels (int): The input channel size.
|
||||
num_classes (int): Number of classes to predict.
|
||||
"""
|
||||
|
||||
def __init__(self, wrn_size, in_channels, num_classes):
|
||||
super(WRN, self).__init__()
|
||||
|
||||
sizes = [
|
||||
min(wrn_size, 16),
|
||||
wrn_size,
|
||||
wrn_size * 2,
|
||||
wrn_size * 4,
|
||||
]
|
||||
strides = [1, 2, 2]
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels, sizes[0], 3)
|
||||
|
||||
self.rg1 = ResidualGroup(sizes[0], sizes[1], strides[0], True)
|
||||
self.rg2 = ResidualGroup(sizes[1], sizes[2], strides[1], False)
|
||||
self.rg3 = ResidualGroup(sizes[2], sizes[3], strides[2], False)
|
||||
|
||||
final_stride = 1
|
||||
for s in strides:
|
||||
final_stride *= s
|
||||
self.avg_pool = _get_optional_avg_pool(final_stride)
|
||||
self.pad = _get_optional_pad(sizes[0], sizes[-1])
|
||||
|
||||
self.bn = nn.BatchNorm2d(sizes[-1])
|
||||
self.fc = nn.Dense(sizes[-1], num_classes)
|
||||
|
||||
self.mean = P.ReduceMean()
|
||||
self.relu = nn.ReLU()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""Construct the forward network."""
|
||||
out = self.conv1(x)
|
||||
orig_x = out
|
||||
|
||||
out = self.rg1(out)
|
||||
out = self.rg2(out)
|
||||
out = self.rg3(out)
|
||||
|
||||
if self.avg_pool is not None:
|
||||
orig_x = self.avg_pool(orig_x)
|
||||
if self.pad is not None:
|
||||
orig_x = self.pad(orig_x)
|
||||
out = self.add(out, orig_x)
|
||||
|
||||
out = self.bn(out)
|
||||
out = self.relu(out)
|
||||
out = self.mean(out, (2, 3))
|
||||
out = self.fc(out)
|
||||
|
||||
return out
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Package initialization for network optimization helpers.
|
||||
"""
|
||||
|
||||
from .lr import get_lr
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utilities for generating learning rates."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _generate_cosine_lr(lr_init, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Create an array of learning rates conforming to the cosine decay mode.
|
||||
|
||||
Args:
|
||||
lr_init (float): Starting learning rate.
|
||||
lr_max (float): Maximum learning rate.
|
||||
total_steps (int): Total number of train steps.
|
||||
warmup_steps (int): Number of warmup steps.
|
||||
|
||||
Returns:
|
||||
np.array, a learning rate array.
|
||||
"""
|
||||
decay_steps = total_steps - warmup_steps
|
||||
lr_each_step = []
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * \
|
||||
(1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def get_lr(
|
||||
lr_init, lr_max,
|
||||
warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode='cosine',
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
lr_init (float): Starting learning rate.
|
||||
lr_max (float): Maximum learning rate.
|
||||
warmup_epochs (int): Number of warmup epochs.
|
||||
total_epochs (int): Total number of train epochs.
|
||||
steps_per_epoch (int): Steps per epoch.
|
||||
lr_decay_mode (string): Learning rate decay mode.
|
||||
|
||||
Returns:
|
||||
np.array, a learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
|
||||
if lr_decay_mode == 'cosine':
|
||||
lr_each_step = _generate_cosine_lr(
|
||||
lr_init, lr_max, total_steps, warmup_steps,
|
||||
)
|
||||
else:
|
||||
assert False, 'lr_decay_mode {} is not supported'.format(lr_decay_mode)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Package initialization for common utils.
|
||||
"""
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger('utils')
|
||||
|
||||
|
||||
def init_utils(config):
|
||||
"""
|
||||
Initialize common utils.
|
||||
|
||||
Args:
|
||||
config (Config): Contains configurable parameters throughout the
|
||||
project.
|
||||
"""
|
||||
init_logging(config.log_level)
|
||||
for k, v in vars(config).items():
|
||||
logger.info('%s=%s', k, v)
|
||||
|
||||
|
||||
def init_logging(level=logging.INFO):
|
||||
"""
|
||||
Initialize logging formats.
|
||||
|
||||
Args:
|
||||
level (logging level constant): Set to mute logs with lower levels.
|
||||
"""
|
||||
FMT = r'[%(asctime)s][%(name)s][%(levelname)8s] %(message)s'
|
||||
DATE_FMT = r'%Y-%m-%d %H:%M:%S'
|
||||
logging.basicConfig(format=FMT, datefmt=DATE_FMT, level=level)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Model testing entrypoint.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import Config
|
||||
from src.dataset import create_cifar10_dataset
|
||||
from src.network import WRN
|
||||
from src.utils import init_utils
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = Config(training=False)
|
||||
init_utils(conf)
|
||||
set_seed(conf.seed)
|
||||
|
||||
# Initialize context
|
||||
try:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
except TypeError:
|
||||
device_id = 0
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target=conf.device_target,
|
||||
save_graphs=False,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# Create dataset
|
||||
if conf.dataset == 'cifar10':
|
||||
dataset = create_cifar10_dataset(
|
||||
dataset_path=conf.dataset_path,
|
||||
do_train=False,
|
||||
repeat_num=1,
|
||||
batch_size=conf.batch_size,
|
||||
target=conf.device_target,
|
||||
)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# Define net
|
||||
net = WRN(160, 3, conf.class_num)
|
||||
|
||||
# Load checkpoint
|
||||
param_dict = load_checkpoint(conf.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# Define loss and model
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
model = Model(net, loss_fn=loss, metrics={
|
||||
'top_1_accuracy', 'top_5_accuracy',
|
||||
})
|
||||
|
||||
# Eval model
|
||||
res = model.eval(dataset)
|
||||
print('result:', res, 'checkpoint:', conf.checkpoint_path)
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Model training entrypoint.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from mindspore import context, Model, load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, \
|
||||
LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
|
||||
from src.config import Config
|
||||
from src.dataset import create_cifar10_dataset
|
||||
from src.network import WRN
|
||||
from src.optim import get_lr
|
||||
from src.utils import init_utils
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = Config(training=True)
|
||||
init_utils(conf)
|
||||
set_seed(conf.seed)
|
||||
|
||||
# Initialize context
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target=conf.device_target,
|
||||
save_graphs=False,
|
||||
)
|
||||
if conf.run_distribute:
|
||||
if conf.device_target == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(
|
||||
device_id=device_id,
|
||||
enable_auto_mixed_precision=True,
|
||||
)
|
||||
context.set_auto_parallel_context(
|
||||
device_num=conf.device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
)
|
||||
init()
|
||||
elif conf.device_target == 'GPU':
|
||||
init()
|
||||
context.set_auto_parallel_context(
|
||||
device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
except TypeError:
|
||||
device_id = 0
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
# Create dataset
|
||||
if conf.dataset == 'cifar10':
|
||||
dataset = create_cifar10_dataset(
|
||||
dataset_path=conf.dataset_path,
|
||||
do_train=True,
|
||||
repeat_num=1,
|
||||
batch_size=conf.batch_size,
|
||||
target=conf.device_target,
|
||||
distribute=conf.run_distribute,
|
||||
augment=conf.augment,
|
||||
)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# Define net
|
||||
net = WRN(160, 3, conf.class_num)
|
||||
|
||||
# Load weight if pre_trained is configured
|
||||
if conf.pre_trained:
|
||||
param_dict = load_checkpoint(conf.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# Initialize learning rate
|
||||
lr = Tensor(get_lr(
|
||||
lr_init=conf.lr_init, lr_max=conf.lr_max,
|
||||
warmup_epochs=conf.warmup_epochs, total_epochs=conf.epoch_size,
|
||||
steps_per_epoch=step_size, lr_decay_mode=conf.lr_decay_mode,
|
||||
))
|
||||
|
||||
# Define loss, opt, and model
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
loss_scale = FixedLossScaleManager(
|
||||
conf.loss_scale,
|
||||
drop_overflow_update=False,
|
||||
)
|
||||
opt = Momentum(
|
||||
filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
lr, conf.momentum, conf.weight_decay, conf.loss_scale,
|
||||
)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt,
|
||||
loss_scale_manager=loss_scale, metrics={'acc'})
|
||||
|
||||
# Define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=conf.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=conf.keep_checkpoint_max,
|
||||
)
|
||||
ck_cb = ModelCheckpoint(
|
||||
prefix='train_%s_%s' % (conf.net, conf.dataset),
|
||||
directory=conf.save_checkpoint_path,
|
||||
config=config_ck,
|
||||
)
|
||||
|
||||
# Train
|
||||
if conf.run_distribute:
|
||||
callbacks = [time_cb, loss_cb]
|
||||
if conf.device_target == 'GPU' and str(get_rank()) == '0':
|
||||
callbacks = [time_cb, loss_cb, ck_cb]
|
||||
elif conf.device_target == 'Ascend' and device_id == 0:
|
||||
callbacks = [time_cb, loss_cb, ck_cb]
|
||||
else:
|
||||
callbacks = [time_cb, loss_cb, ck_cb]
|
||||
|
||||
model.train(conf.epoch_size, dataset, callbacks=callbacks)
|
Loading…
Reference in New Issue