add GPU efficientnet to model_zoo

This commit is contained in:
TFbunny 2020-10-15 15:24:57 -04:00
parent 9c79b9d712
commit bd4e441862
12 changed files with 2027 additions and 0 deletions

View File

@ -0,0 +1,111 @@
# EfficientNet-B0 Example
## Description
This is an example of training EfficientNet-B0 in MindSpore.
## Requirements
- Install [Mindspore](http://www.mindspore.cn/install/en).
- Download the dataset.
## Structure
```shell
.
└─nasnet
├─README.md
├─scripts
├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p)
├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p)
└─run_eval_for_gpu.sh # launch evaluating with gpu platform
├─src
├─config.py # parameter configuration
├─dataset.py # data preprocessing
├─efficientnet.py # network definition
├─loss.py # Customized loss function
├─transform_utils.py # random augment utils
├─transform.py # random augment class
├─eval.py # eval net
└─train.py # train net
```
## Parameter Configuration
Parameters for both training and evaluating can be set in config.py
```
'random_seed': 1, # fix random seed
'model': 'efficientnet_b0', # model name
'drop': 0.2, # dropout rate
'drop_connect': 0.2, # drop connect rate
'opt_eps': 0.001, # optimizer epsilon
'lr': 0.064, # learning rate LR
'batch_size': 128, # batch size
'decay_epochs': 2.4, # epoch interval to decay LR
'warmup_epochs': 5, # epochs to warmup LR
'decay_rate': 0.97, # LR decay rate
'weight_decay': 1e-5, # weight decay
'epochs': 600, # number of epochs to train
'workers': 8, # number of data processing processes
'amp_level': 'O0', # amp level
'opt': 'rmsprop', # optimizer
'num_classes': 1000, # number of classes
'gp': 'avg', # type of global pool, "avg", "max", "avgmax", "avgmaxc"
'momentum': 0.9, # optimizer momentum
'warmup_lr_init': 0.0001, # init warmup LR
'smoothing': 0.1, # label smoothing factor
'bn_tf': False, # use Tensorflow BatchNorm defaults
'keep_checkpoint_max': 10, # max number ckpts to keep
'loss_scale': 1024, # loss scale
'resume_start_epoch': 0, # resume start epoch
```
## Running the example
### Train
#### Usage
```
# distribute training example(8p)
sh run_distribute_train_for_gpu.sh DATA_DIR
# standalone training
sh run_standalone_train_for_gpu.sh DATA_DIR DEVICE_ID
```
#### Launch
```bash
# distributed training example(8p) for GPU
sh scripts/run_distribute_train_for_gpu.sh /dataset
# standalone training example for GPU
sh scripts/run_standalone_train_for_gpu.sh /dataset 0
```
#### Result
You can find checkpoint file together with result in log.
### Evaluation
#### Usage
```
# Evaluation
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
```
#### Launch
```bash
# Evaluation with checkpoint
sh scripts/run_eval_for_gpu.sh /dataset 0 ./checkpoint/efficientnet_b0-600_1251.ckpt
```
> checkpoint can be produced in training process.
#### Result
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.

View File

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluate imagenet"""
import argparse
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import efficientnet_b0_config_gpu as cfg
from src.dataset import create_dataset_val
from src.efficientnet import efficientnet_b0
from src.loss import LabelSmoothingCrossEntropy
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of efficientnet (Default: None)')
parser.add_argument('--data_path', type=str, default='', help='Dataset path')
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
args_opt = parser.parse_args()
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
net = efficientnet_b0(num_classes=cfg.num_classes,
drop_rate=cfg.drop,
drop_connect_rate=cfg.drop_connect,
global_pool=cfg.gp,
bn_tf=cfg.bn_tf,
)
ckpt = load_checkpoint(args_opt.checkpoint)
load_param_into_net(net, ckpt)
net.set_train(False)
val_data_url = os.path.join(args_opt.data_path, 'val')
dataset = create_dataset_val(cfg.batch_size, val_data_url, workers=cfg.workers, distributed=False)
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
metrics = model.eval(dataset)
print("metric: ", metrics)

View File

@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
current_exec_path=$(pwd)
echo ${current_exec_path}
curtime=`date '+%Y%m%d-%H%M%S'`
RANK_SIZE=8
rm ${current_exec_path}/device_parallel/ -rf
mkdir ${current_exec_path}/device_parallel
echo ${curtime} > ${current_exec_path}/device_parallel/starttime
mpirun --allow-run-as-root -n $RANK_SIZE python ${current_exec_path}/train.py \
--GPU \
--distributed \
--data_path ${DATA_DIR} \
--cur_time ${curtime} > ${current_exec_path}/device_parallel/efficientnet_b0.log 2>&1 &

View File

@ -0,0 +1,27 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
DEVICE_ID=$2
PATH_CHECKPOINT=$3
current_exec_path=$(pwd)
echo ${current_exec_path}
curtime=`date '+%Y%m%d-%H%M%S'`
echo ${curtime} > ${current_exec_path}/eval_starttime
CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ./eval.py --platform 'GPU' --data_path ${DATA_DIR} --checkpoint ${PATH_CHECKPOINT} > ${current_exec_path}/eval.log 2>&1 &

View File

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
DEVICE_ID=$2
current_exec_path=$(pwd)
echo ${current_exec_path}
curtime=`date '+%Y%m%d-%H%M%S'`
rm ${current_exec_path}/device_${DEVICE_ID}/ -rf
mkdir ${current_exec_path}/device_${DEVICE_ID}
echo ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/starttime
CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ${current_exec_path}/train.py \
--GPU \
--data_path ${DATA_DIR} \
--cur_time ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/efficientnet_b0.log 2>&1 &

View File

@ -0,0 +1,47 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting
"""
from easydict import EasyDict as edict
efficientnet_b0_config_gpu = edict({
'random_seed': 1,
'model': 'efficientnet_b0',
'drop': 0.2,
'drop_connect': 0.2,
'opt_eps': 0.001,
'lr': 0.064,
'batch_size': 128,
'decay_epochs': 2.4,
'warmup_epochs': 5,
'decay_rate': 0.97,
'weight_decay': 1e-5,
'epochs': 600,
'workers': 8,
'amp_level': 'O0',
'opt': 'rmsprop',
'num_classes': 1000,
#'Type of global pool, "avg", "max", "avgmax", "avgmaxc"
'gp': 'avg',
'momentum': 0.9,
'warmup_lr_init': 0.0001,
'smoothing': 0.1,
#Use Tensorflow BatchNorm defaults for models that support it
'bn_tf': False,
'keep_checkpoint_max': 10,
'loss_scale': 1024,
'resume_start_epoch': 0,
})

View File

@ -0,0 +1,125 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import math
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
from mindspore.communication.management import get_group_size, get_rank
from mindspore.dataset.vision import Inter
from src.config import efficientnet_b0_config_gpu as cfg
from src.transform import RandAugment
ds.config.set_seed(cfg.random_seed)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
img_size = (224, 224)
crop_pct = 0.875
rescale = 1.0 / 255.0
shift = 0.0
inter_method = 'bilinear'
resize_value = 224 # img_size
scale = (0.08, 1.0)
ratio = (3./4., 4./3.)
inter_str = 'bicubic'
def str2MsInter(method):
if method == 'bicubic':
return Inter.BICUBIC
if method == 'nearest':
return Inter.NEAREST
return Inter.BILINEAR
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False):
if not os.path.exists(train_data_url):
raise ValueError('Path not exists')
interpolation = str2MsInter(inter_str)
c_decode_op = C.Decode()
type_cast_op = C2.TypeCast(mstype.int32)
random_resize_crop_op = C.RandomResizedCrop(size=(resize_value, resize_value), scale=scale, ratio=ratio,
interpolation=interpolation)
random_horizontal_flip_op = C.RandomHorizontalFlip(0.5)
efficient_rand_augment = RandAugment()
image_ops = [c_decode_op, random_resize_crop_op, random_horizontal_flip_op]
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset_train = ds.ImageFolderDataset(train_data_url,
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
dataset_train = dataset_train.map(input_columns=["image"],
operations=image_ops,
num_parallel_workers=workers)
dataset_train = dataset_train.map(input_columns=["label"],
operations=type_cast_op,
num_parallel_workers=workers)
ds_train = dataset_train.batch(batch_size,
per_batch_map=efficient_rand_augment,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
ds_train = ds_train.repeat(1)
return ds_train
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False):
if not os.path.exists(val_data_url):
raise ValueError('Path not exists')
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id, shuffle=False)
scale_size = None
interpolation = str2MsInter(inter_method)
if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
type_cast_op = C2.TypeCast(mstype.int32)
decode_op = C.Decode()
resize_op = C.Resize(size=scale_size, interpolation=interpolation)
center_crop = C.CenterCrop(size=224)
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
changeswap_op = C.HWC2CHW()
ctrans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, changeswap_op]
dataset = dataset.map(input_columns=["label"], operations=type_cast_op, num_parallel_workers=workers)
dataset = dataset.map(input_columns=["image"], operations=ctrans, num_parallel_workers=workers)
dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=workers)
dataset = dataset.repeat(1)
return dataset

View File

@ -0,0 +1,746 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""EfficientNet model definition"""
import logging
import math
import re
from copy import deepcopy
import mindspore as ms
import mindspore.nn as nn
from mindspore import context, ms_function
from mindspore.common.initializer import (Normal, One, Uniform, Zero,
initializer)
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops.composite import clip_by_value
relu = P.ReLU()
sigmoid = P.Sigmoid()
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'),
'efficientnet_b1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'efficientnet_b3': _cfg(
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'efficientnet_b4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
}
_DEBUG = False
_BN_MOMENTUM_PT_DEFAULT = 0.1
_BN_EPS_PT_DEFAULT = 1e-5
_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
_BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
def _initialize_weight_goog(shape=None, layer_type='conv', bias=False):
if layer_type not in ('conv', 'bn', 'fc'):
raise ValueError('The layer type is not known, the supported are conv, bn and fc')
if bias:
return Zero()
if layer_type == 'conv':
assert isinstance(shape, (tuple, list)) and len(
shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively'
n = shape[1] * shape[1] * shape[2]
return Normal(math.sqrt(2.0 / n))
if layer_type == 'bn':
return One()
assert isinstance(shape, (tuple, list)) and len(
shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively'
n = shape[1]
init_range = 1.0 / math.sqrt(n)
return Uniform(init_range)
def _initialize_weight_default(shape=None, layer_type='conv', bias=False):
if layer_type not in ('conv', 'bn', 'fc'):
raise ValueError('The layer type is not known, the supported are conv, bn and fc')
if bias and layer_type == 'bn':
return Zero()
if layer_type == 'conv':
return One()
if layer_type == 'bn':
return One()
return One()
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same', bias=False):
weight_init_value = _initialize_weight_goog(shape=(in_channels, kernel_size, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias)
def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False):
weight_init_value = _initialize_weight_goog(shape=(in_channels, 1, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias)
def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0, pad_mode='same', bias=False):
weight_init_value = _initialize_weight_goog(shape=(in_channels, kernel_size, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
group=group, has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
group=group, has_bias=bias)
def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0):
return nn.BatchNorm2d(channels, eps=eps, momentum=1 - momentum, gamma_init=gamma_init, beta_init=beta_init)
def _dense(in_channels, output_channels, bias=True, activation=None):
weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels), layer_type='fc')
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, bias_init=bias_init_value,
has_bias=bias, activation=activation)
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, has_bias=bias,
activation=activation)
def _resolve_bn_args(kwargs):
bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy()
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
channel_min = channel_min or divisor
new_channels = max(
int(channels + divisor / 2) // divisor * divisor,
channel_min)
if new_channels < 0.9 * channels:
new_channels += divisor
return new_channels
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str, depth_multiplier=1.0):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0]
ops = ops[1:]
options = {}
noskip = False
for op in ops:
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
print('not support')
elif v == 'r6':
print('not support')
elif v == 'hs':
print('not support')
elif v == 'sw':
print('not support')
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
act_fn = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type in ('ds', 'dsa'):
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'er':
block_args = dict(
block_type=block_type,
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_fn=act_fn,
)
else:
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
arch_args = []
for _, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = _decode_block_str(block_str)
stack_args.append(ba)
repeats.append(rep)
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args
@ms_function
def hard_swish(x):
x = P.Cast()(x, ms.float32)
y = x + 3.0
y = clip_by_value(y, 0.0, 6.0)
y = y / 6.0
return x * y
class BlockBuilder(nn.Cell):
def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0, channel_divisor=8,
channel_min=None, pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False,
bn_args=None, drop_connect_rate=0., verbose=False):
super(BlockBuilder, self).__init__()
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.pad_type = pad_type
self.act_fn = act_fn
self.se_gate_fn = se_gate_fn
self.se_reduce_mid = se_reduce_mid
self.bn_args = bn_args
self.drop_connect_rate = drop_connect_rate
self.verbose = verbose
self.in_chs = None
self.block_idx = 0
self.block_count = 0
self.layer = self._make_layer(builder_in_channels, builder_block_args)
def _round_channels(self, chs):
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba):
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['bn_args'] = self.bn_args
ba['pad_type'] = self.pad_type
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None
if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid
if self.verbose:
logging.info(' InvertedResidual %d, Args: %s', self.block_idx, str(ba))
block = InvertedResidual(**ba)
elif bt in ('ds', 'dsa'):
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
if self.verbose:
logging.info(' DepthwiseSeparable %d, Args: %s', self.block_idx, str(ba))
block = DepthwiseSeparableConv(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs']
return block
def _make_stack(self, stack_args):
blocks = []
# each stack (stage) contains a list of block arguments
for i, ba in enumerate(stack_args):
if self.verbose:
logging.info(' Block: %d', i)
if i >= 1:
# only the first block in any stack can have a stride > 1
ba['stride'] = 1
block = self._make_block(ba)
blocks.append(block)
self.block_idx += 1
return nn.SequentialCell(blocks)
def _make_layer(self, in_chs, block_args):
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
if self.verbose:
logging.info('Building model trunk with %d stages...', len(block_args))
self.in_chs = in_chs
self.block_count = sum([len(x) for x in block_args])
self.block_idx = 0
blocks = []
for stack_idx, stack in enumerate(block_args):
if self.verbose:
logging.info('Stack: %d', stack_idx)
assert isinstance(stack, list)
stack = self._make_stack(stack)
blocks.append(stack)
return nn.SequentialCell(blocks)
def construct(self, x):
return self.layer(x)
class DepthWiseConv(nn.Cell):
def __init__(self, in_planes, kernel_size, stride):
super(DepthWiseConv, self).__init__()
platform = context.get_context("device_target")
weight_shape = [1, kernel_size, in_planes]
weight_init = _initialize_weight_goog(shape=weight_shape)
if platform == "GPU":
self.depthwise_conv = P.Conv2D(out_channel=in_planes * 1, kernel_size=kernel_size,
stride=stride, pad_mode="same", group=in_planes)
self.weight = Parameter(initializer(
weight_init, [in_planes * 1, 1, kernel_size, kernel_size]), name='depthwise_weight')
else:
self.depthwise_conv = P.DepthwiseConv2dNative(
channel_multiplier=1, kernel_size=kernel_size, stride=stride, pad_mode='same',)
self.weight = Parameter(initializer(
weight_init, [1, in_planes, kernel_size, kernel_size]), name='depthwise_weight')
def construct(self, x):
x = self.depthwise_conv(x, self.weight)
return x
class DropConnect(nn.Cell):
def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
super(DropConnect, self).__init__()
self.shape = P.Shape()
self.dtype = P.DType()
self.keep_prob = 1 - drop_connect_rate
self.dropout = P.Dropout(keep_prob=self.keep_prob)
def construct(self, x):
shape = self.shape(x)
dtype = self.dtype(x)
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
_, mask_ = self.dropout(ones_tensor)
x = x * mask_
return x
def drop_connect(inputs, training=False, drop_connect_rate=0.):
if not training:
return inputs
return DropConnect(drop_connect_rate)(inputs)
class SqueezeExcite(nn.Cell):
def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid):
super(SqueezeExcite, self).__init__()
self.act_fn = act_fn
self.gate_fn = gate_fn
reduce_chs = reduce_chs or in_chs
self.conv_reduce = _dense(in_chs, reduce_chs, bias=True)
self.conv_expand = _dense(reduce_chs, in_chs, bias=True)
self.avg_global_pool = P.ReduceMean(keep_dims=False)
def construct(self, x):
x_se = self.avg_global_pool(x, (2, 3))
x_se = self.conv_reduce(x_se)
x_se = self.act_fn(x_se)
x_se = self.conv_expand(x_se)
x_se = self.gate_fn(x_se)
x_se = P.ExpandDims()(x_se, 2)
x_se = P.ExpandDims()(x_se, 3)
x = x * x_se
return x
class DepthwiseSeparableConv(nn.Cell):
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_fn=relu, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid,
bn_args=None, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
assert stride in [1, 2], 'stride must be 1 or 2'
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
self.bn1 = _fused_bn(in_chs, **bn_args)
#
if self.has_se:
self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)),
act_fn=act_fn, gate_fn=se_gate_fn)
self.conv_pw = _conv1x1(in_chs, out_chs)
self.bn2 = _fused_bn(out_chs, **bn_args)
def construct(self, x):
identity = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act_fn(x)
if self.has_se:
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
if self.has_pw_act:
x = self.act_fn(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x = x + identity
return x
class InvertedResidual(nn.Cell):
def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1,
pad_type='', act_fn=relu, pw_kernel_size=1,
noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0.,
se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None,
bn_args=None, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size)
self.bn1 = _fused_bn(mid_chs, **bn_args)
self.shuffle_type = shuffle_type
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
self.shuffle = None
self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
self.bn2 = _fused_bn(mid_chs, **bn_args)
if self.has_se:
se_base_chs = mid_chs if se_reduce_mid else in_chs
self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)),
act_fn=act_fn, gate_fn=se_gate_fn)
self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size)
self.bn3 = _fused_bn(out_chs, **bn_args)
def construct(self, x):
identity = x
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act_fn(x)
x = self.conv_dw(x)
x = self.bn2(x)
x = self.act_fn(x)
if self.has_se:
x = self.se(x)
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0:
x = drop_connect(x, self.training, self.drop_connect_rate)
x = x + identity
return x
class GenEfficientNet(nn.Cell):
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0.,
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
global_pool='avg', head_conv='default', weight_init='goog'):
super(GenEfficientNet, self).__init__()
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
self.num_classes = num_classes
self.drop_rate = drop_rate
self.act_fn = act_fn
self.num_features = num_features
stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = _conv(in_chans, stem_size, 3, stride=2)
self.bn1 = _fused_bn(stem_size, **bn_args)
in_chans = stem_size
self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier, channel_divisor, channel_min,
pad_type, act_fn, se_gate_fn, se_reduce_mid,
bn_args, drop_connect_rate, verbose=_DEBUG)
in_chs = self.blocks.in_chs
if not head_conv or head_conv == 'none':
self.efficient_head = False
self.conv_head = None
assert in_chs == self.num_features
else:
self.efficient_head = head_conv == 'efficient'
self.conv_head = _conv1x1(in_chs, self.num_features)
self.bn2 = None if self.efficient_head else _fused_bn(self.num_features, **bn_args)
self.global_pool = P.ReduceMean(keep_dims=True)
self.classifier = _dense(self.num_features, self.num_classes)
self.reshape = P.Reshape()
self.shape = P.Shape()
self.drop_out = nn.Dropout(keep_prob=1 - self.drop_rate)
def construct(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act_fn(x)
x = self.blocks(x)
if self.efficient_head:
x = self.global_pool(x, (2, 3))
x = self.conv_head(x)
x = self.act_fn(x)
x = self.reshape(self.shape(x)[0], -1)
else:
if self.conv_head is not None:
x = self.conv_head(x)
x = self.bn2(x)
x = self.act_fn(x)
x = self.global_pool(x, (2, 3))
x = self.reshape(x, (self.shape(x)[0], -1))
if self.training and self.drop_rate > 0.:
x = self.drop_out(x)
return self.classifier(x)
def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25'],
['ir_r2_k3_s2_e6_c24_se0.25'],
['ir_r2_k5_s2_e6_c40_se0.25'],
['ir_r3_k3_s2_e6_c80_se0.25'],
['ir_r3_k5_s1_e6_c112_se0.25'],
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
num_features = _round_channels(1280, channel_multiplier, 8, None)
model = GenEfficientNet(
_decode_arch_def(arch_def, depth_multiplier),
num_classes=num_classes,
stem_size=32,
channel_multiplier=channel_multiplier,
num_features=num_features,
bn_args=_resolve_bn_args(kwargs),
act_fn=hard_swish,
**kwargs
)
return model
def efficientnet_b0(num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B0 """
default_cfg = default_cfgs['efficientnet_b0']
model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.0,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
return model

View File

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore import Tensor
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
def __init__(self, smooth_factor=0.1, num_classes=1000):
super(LabelSmoothingCrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logits, label):
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
random augment class
"""
import numpy as np
import mindspore.dataset.vision.py_transforms as P
from src import transform_utils
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
class RandAugment:
# config_str belongs to str
# hparams belongs to dict
def __init__(self, config_str="rand-m9-mstd0.5", hparams=None):
hparams = hparams if hparams is not None else {}
self.config_str = config_str
self.hparams = hparams
def __call__(self, imgs, labels, batchInfo):
# assert the imgs objetc are pil_images
ret_imgs = []
ret_labels = []
py_to_pil_op = P.ToPIL()
to_tensor = P.ToTensor()
normalize_op = P.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
rand_augment_ops = transform_utils.rand_augment_transform(self.config_str, self.hparams)
for i, image in enumerate(imgs):
img_pil = py_to_pil_op(image)
img_pil = rand_augment_ops(img_pil)
img_array = to_tensor(img_pil)
img_array = normalize_op(img_array)
ret_imgs.append(img_array)
ret_labels.append(labels[i])
return np.array(ret_imgs), np.array(ret_labels)

View File

@ -0,0 +1,571 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
random augment utils
"""
import math
import random
import re
import numpy as np
import PIL
from PIL import Image, ImageEnhance, ImageOps
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
_FILL = (128, 128, 128)
_MAX_LEVEL = 10.
_HPARAMS_DEFAULT = dict(translate_const=250, img_mean=_FILL)
_RAND_TRANSFORMS = [
'Distort',
'Zoom',
'Blur',
'Skew',
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'PosterizeTpu',
'Solarize',
'SolarizeAdd',
'Color',
'Contrast',
'Brightness',
'Sharpness',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
]
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_RAND_CHOICE_WEIGHTS_0 = {
'Rotate': 0.3,
'ShearX': 0.2,
'ShearY': 0.2,
'TranslateXRel': 0.1,
'TranslateYRel': 0.1,
'Color': .025,
'Sharpness': 0.025,
'AutoContrast': 0.025,
'Solarize': .005,
'SolarizeAdd': .005,
'Contrast': .005,
'Brightness': .005,
'Equalize': .005,
'PosterizeTpu': 0,
'Invert': 0,
'Distort': 0,
'Zoom': 0,
'Blur': 0,
'Skew': 0,
}
def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.BILINEAR)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
return interpolation
def _check_args_tf(kwargs):
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
kwargs.pop('fillcolor')
kwargs['resample'] = _interpolation(kwargs)
# define all kinds of functions
def _randomly_negate(v):
return -v if random.random() > 0.5 else v
def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def rotate(img, degrees, **kwargs):
kwargs_new = kwargs
kwargs_new.pop('resample')
kwargs_new['resample'] = Image.BICUBIC
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs_new)
if _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -math.radians(degrees)
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]
def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f
matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs_new)
return img.rotate(degrees, resample=kwargs['resample'])
def auto_contrast(img, **__):
return ImageOps.autocontrast(img)
def invert(img, **__):
return ImageOps.invert(img)
def equalize(img, **__):
return ImageOps.equalize(img)
def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh)
def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
return img
def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8:
return img
return ImageOps.posterize(img, bits_to_keep)
def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor)
def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor)
def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor)
def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)
def _rotate_level_to_arg(level, _hparams):
# range [-30, 30]
level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate(level)
return (level,)
def _enhance_level_to_arg(level, _hparams):
# range [0.1, 1.9]
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _shear_level_to_arg(level, _hparams):
# range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level)
return (level,)
def _translate_abs_level_to_arg(level, hparams):
translate_const = hparams['translate_const']
level = (level / _MAX_LEVEL) * float(translate_const)
level = _randomly_negate(level)
return (level,)
def _translate_rel_level_to_arg(level, _hparams):
# range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45
level = _randomly_negate(level)
return (level,)
def _posterize_original_level_to_arg(level, _hparams):
# As per original AutoAugment paper description
# range [4, 8], 'keep 4 up to 8 MSB of image'
return (int((level / _MAX_LEVEL) * 4) + 4,)
def _posterize_research_level_to_arg(level, _hparams):
# As per Tensorflow models research and UDA impl
# range [4, 0], 'keep 4 down to 0 MSB of original image'
return (4 - int((level / _MAX_LEVEL) * 4),)
def _posterize_tpu_level_to_arg(level, _hparams):
# As per Tensorflow TPU EfficientNet impl
# range [0, 4], 'keep 0 up to 4 MSB of original image'
return (int((level / _MAX_LEVEL) * 4),)
def _solarize_level_to_arg(level, _hparams):
# range [0, 256]
return (int((level / _MAX_LEVEL) * 256),)
def _solarize_add_level_to_arg(level, _hparams):
# range [0, 110]
return (int((level / _MAX_LEVEL) * 110),)
def _distort_level_to_arg(level, _hparams):
return (int((level / _MAX_LEVEL) * 10 + 10),)
def _zoom_level_to_arg(level, _hparams):
return ((level / _MAX_LEVEL) * 0.4,)
def _blur_level_to_arg(level, _hparams):
level = (level / _MAX_LEVEL) * 0.5
level = _randomly_negate(level)
return (level,)
def _skew_level_to_arg(level, _hparams):
level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level)
return (level,)
def distort(img, v, **__):
w, h = img.size
horizontal_tiles = int(0.1 * v)
vertical_tiles = int(0.1 * v)
width_of_square = int(math.floor(w / float(horizontal_tiles)))
height_of_square = int(math.floor(h / float(vertical_tiles)))
width_of_last_square = w - (width_of_square * (horizontal_tiles - 1))
height_of_last_square = h - (height_of_square * (vertical_tiles - 1))
dimensions = []
for vertical_tile in range(vertical_tiles):
for horizontal_tile in range(horizontal_tiles):
if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1):
dimensions.append([horizontal_tile * width_of_square,
vertical_tile * height_of_square,
width_of_last_square + (horizontal_tile * width_of_square),
height_of_last_square + (height_of_square * vertical_tile)])
elif vertical_tile == (vertical_tiles - 1):
dimensions.append([horizontal_tile * width_of_square,
vertical_tile * height_of_square,
width_of_square + (horizontal_tile * width_of_square),
height_of_last_square + (height_of_square * vertical_tile)])
elif horizontal_tile == (horizontal_tiles - 1):
dimensions.append([horizontal_tile * width_of_square,
vertical_tile * height_of_square,
width_of_last_square + (horizontal_tile * width_of_square),
height_of_square + (height_of_square * vertical_tile)])
else:
dimensions.append([horizontal_tile * width_of_square,
vertical_tile * height_of_square,
width_of_square + (horizontal_tile * width_of_square),
height_of_square + (height_of_square * vertical_tile)])
last_column = []
for i in range(vertical_tiles):
last_column.append((horizontal_tiles - 1) + horizontal_tiles * i)
last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles)
polygons = []
for x1, y1, x2, y2 in dimensions:
polygons.append([x1, y1, x1, y2, x2, y2, x2, y1])
polygon_indices = []
for i in range((vertical_tiles * horizontal_tiles) - 1):
if i not in last_row and i not in last_column:
polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles])
for a, b, c, d in polygon_indices:
dx = v
dy = v
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a]
polygons[a] = [x1, y1,
x2, y2,
x3 + dx, y3 + dy,
x4, y4]
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b]
polygons[b] = [x1, y1,
x2 + dx, y2 + dy,
x3, y3,
x4, y4]
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c]
polygons[c] = [x1, y1,
x2, y2,
x3, y3,
x4 + dx, y4 + dy]
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d]
polygons[d] = [x1 + dx, y1 + dy,
x2, y2,
x3, y3,
x4, y4]
generated_mesh = []
for idx, i in enumerate(dimensions):
generated_mesh.append([dimensions[idx], polygons[idx]])
return img.transform(img.size, PIL.Image.MESH, generated_mesh, resample=PIL.Image.BICUBIC)
def zoom(img, v, **__):
#assert 0.1 <= v <= 2
w, h = img.size
image_zoomed = img.resize((int(round(img.size[0] * v)),
int(round(img.size[1] * v))),
resample=PIL.Image.BICUBIC)
w_zoomed, h_zoomed = image_zoomed.size
return image_zoomed.crop((math.floor((float(w_zoomed) / 2) - (float(w) / 2)),
math.floor((float(h_zoomed) / 2) - (float(h) / 2)),
math.floor((float(w_zoomed) / 2) + (float(w) / 2)),
math.floor((float(h_zoomed) / 2) + (float(h) / 2))))
def erase(img, v, **__):
#assert 0.1<= v <= 1
w, h = img.size
w_occlusion = int(w * v)
h_occlusion = int(h * v)
if len(img.getbands()) == 1:
rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion) * 255))
else:
rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion, len(img.getbands())) * 255))
random_position_x = random.randint(0, w - w_occlusion)
random_position_y = random.randint(0, h - h_occlusion)
img.paste(rectangle, (random_position_x, random_position_y))
return img
def skew(img, v, **__):
#assert -1 <= v <= 1
w, h = img.size
x1 = 0
x2 = h
y1 = 0
y2 = w
original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)]
max_skew_amount = max(w, h)
max_skew_amount = int(math.ceil(max_skew_amount * v))
skew_amount = max_skew_amount
new_plane = [(y1 - skew_amount, x1), # Top Left
(y2, x1 - skew_amount), # Top Right
(y2 + skew_amount, x2), # Bottom Right
(y1, x2 + skew_amount)]
matrix = []
for p1, p2 in zip(new_plane, original_plane):
matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
A = np.matrix(matrix, dtype=np.float)
B = np.array(original_plane).reshape(8)
perspective_skew_coefficients_matrix = np.dot(np.linalg.pinv(A), B)
perspective_skew_coefficients_matrix = np.array(perspective_skew_coefficients_matrix).reshape(8)
return img.transform(img.size, PIL.Image.PERSPECTIVE, perspective_skew_coefficients_matrix,
resample=PIL.Image.BICUBIC)
def blur(img, v, **__):
#assert -3 <= v <= 3
return img.filter(PIL.ImageFilter.GaussianBlur(v))
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS
return [AutoAugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
def _select_rand_weights(weight_idx=0, transforms=None):
transforms = transforms or _RAND_TRANSFORMS
assert weight_idx == 0 # only one set of weights currently
rand_weights = _RAND_CHOICE_WEIGHTS_0
probs = [rand_weights[k] for k in transforms]
probs /= np.sum(probs)
return probs
def rand_augment_transform(config_str, hparams):
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image
weight_idx = None # default to no probability weights for op choice
config = config_str.split('-')
assert config[0] == 'rand'
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
elif key == 'm':
magnitude = int(val)
elif key == 'n':
num_layers = int(val)
elif key == 'w':
weight_idx = int(val)
else:
assert False, 'Unknown RandAugment config section'
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
final_result = RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
return final_result
LEVEL_TO_ARG = {
'Distort': _distort_level_to_arg,
'Zoom': _zoom_level_to_arg,
'Blur': _blur_level_to_arg,
'Skew': _skew_level_to_arg,
'AutoContrast': None,
'Equalize': None,
'Invert': None,
'Rotate': _rotate_level_to_arg,
'PosterizeOriginal': _posterize_original_level_to_arg,
'PosterizeResearch': _posterize_research_level_to_arg,
'PosterizeTpu': _posterize_tpu_level_to_arg,
'Solarize': _solarize_level_to_arg,
'SolarizeAdd': _solarize_add_level_to_arg,
'Color': _enhance_level_to_arg,
'Contrast': _enhance_level_to_arg,
'Brightness': _enhance_level_to_arg,
'Sharpness': _enhance_level_to_arg,
'ShearX': _shear_level_to_arg,
'ShearY': _shear_level_to_arg,
'TranslateX': _translate_abs_level_to_arg,
'TranslateY': _translate_abs_level_to_arg,
'TranslateXRel': _translate_rel_level_to_arg,
'TranslateYRel': _translate_rel_level_to_arg,
}
NAME_TO_OP = {
'Distort': distort,
'Zoom': zoom,
'Blur': blur,
'Skew': skew,
'AutoContrast': auto_contrast,
'Equalize': equalize,
'Invert': invert,
'Rotate': rotate,
'PosterizeOriginal': posterize,
'PosterizeResearch': posterize,
'PosterizeTpu': posterize,
'Solarize': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'Contrast': contrast,
'Brightness': brightness,
'Sharpness': sharpness,
'ShearX': shear_x,
'ShearY': shear_y,
'TranslateX': translate_x_abs,
'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel,
}
class AutoAugmentOp:
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
self.aug_fn = NAME_TO_OP[name]
self.level_fn = LEVEL_TO_ARG[name]
self.prob = prob
self.magnitude = magnitude
self.hparams = hparams.copy()
self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
)
self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img):
if random.random() > self.prob:
return img
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs)
class RandAugment:
def __init__(self, ops, num_layers=2, choice_weights=None):
self.ops = ops
self.num_layers = num_layers
self.choice_weights = choice_weights
def __call__(self, img):
ops = np.random.choice(
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
for op in ops:
img = op(img)
return img

View File

@ -0,0 +1,191 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train imagenet."""
import argparse
import math
import os
import random
import time
import numpy as np
import mindspore
from mindspore import Tensor, context
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.nn import SGD, RMSProp
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
ModelCheckpoint, TimeMonitor)
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import efficientnet_b0_config_gpu as cfg
from src.dataset import create_dataset
from src.efficientnet import efficientnet_b0
from src.loss import LabelSmoothingCrossEntropy
mindspore.common.set_seed(cfg.random_seed)
random.seed(cfg.random_seed)
np.random.seed(cfg.random_seed)
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1,
decay_rate=0.9, warmup_steps=0., warmup_lr_init=0., global_epoch=0):
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
global_steps = steps_per_epoch * global_epoch
self_warmup_delta = ((base_lr - warmup_lr_init) /
warmup_steps) if warmup_steps > 0 else 0
self_decay_rate = decay_rate if decay_rate < 1 else 1 / decay_rate
for i in range(total_steps):
steps = math.floor(i / steps_per_epoch)
cond = 1 if (steps < warmup_steps) else 0
warmup_lr = warmup_lr_init + steps * self_warmup_delta
decay_nums = math.floor(steps / decay_steps)
decay_rate = math.pow(self_decay_rate, decay_nums)
decay_lr = base_lr * decay_rate
lr = cond * warmup_lr + (1 - cond) * decay_lr
lr_each_step.append(lr)
lr_each_step = lr_each_step[global_steps:]
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
if not os.path.exists(outdir):
os.makedirs(outdir)
elif inc:
count = 1
outdir_inc = outdir + '-' + str(count)
while os.path.exists(outdir_inc):
count = count + 1
outdir_inc = outdir + '-' + str(count)
assert count < 100
outdir = outdir_inc
os.makedirs(outdir)
return outdir
parser = argparse.ArgumentParser(
description='Training configuration', add_help=False)
parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR',
help='path to dataset')
parser.add_argument('--distributed', action='store_true', default=False)
parser.add_argument('--GPU', action='store_true', default=False,
help='Use GPU for training (default: False)')
parser.add_argument('--cur_time', type=str,
default='19701010-000000', help='current time')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
def main():
args, _ = parser.parse_known_args()
devid, rank_id, rank_size = 0, 0, 1
context.set_context(mode=context.GRAPH_MODE)
if args.distributed:
if args.GPU:
init("nccl")
context.set_context(device_target='GPU')
else:
init()
devid = int(os.getenv('DEVICE_ID'))
context.set_context(
device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False)
context.reset_auto_parallel_context()
rank_id = get_rank()
rank_size = get_group_size()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, device_num=rank_size)
else:
if args.GPU:
context.set_context(device_target='GPU')
is_master = not args.distributed or (rank_id == 0)
net = efficientnet_b0(num_classes=cfg.num_classes,
drop_rate=cfg.drop,
drop_connect_rate=cfg.drop_connect,
global_pool=cfg.gp,
bn_tf=cfg.bn_tf,
)
cur_time = args.cur_time
output_base = './output'
exp_name = '-'.join([
cur_time,
cfg.model,
str(224)
])
time.sleep(rank_id)
output_dir = get_outdir(output_base, exp_name)
train_data_url = os.path.join(args.data_path, 'train')
train_dataset = create_dataset(
cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed)
batches_per_epoch = train_dataset.get_dataset_size()
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
time_cb = TimeMonitor(data_size=batches_per_epoch)
loss_scale_manager = FixedLossScaleManager(
cfg.loss_scale, drop_overflow_update=False)
config_ck = CheckpointConfig(
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(
prefix=cfg.model, directory=output_dir, config=config_ck)
lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch,
decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate,
warmup_steps=cfg.warmup_epochs, warmup_lr_init=cfg.warmup_lr_init,
global_epoch=cfg.resume_start_epoch))
if cfg.opt == 'sgd':
optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale
)
elif cfg.opt == 'rmsprop':
optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=cfg.weight_decay,
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale
)
loss.add_flags_recursive(fp32=True, fp16=False)
if args.resume:
ckpt = load_checkpoint(args.resume)
load_param_into_net(net, ckpt)
model = Model(net, loss, optimizer,
loss_scale_manager=loss_scale_manager,
amp_level=cfg.amp_level
)
callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else []
if args.resume:
real_epoch = cfg.epochs - cfg.resume_start_epoch
model.train(real_epoch, train_dataset,
callbacks=callbacks, dataset_sink_mode=True)
else:
model.train(cfg.epochs, train_dataset,
callbacks=callbacks, dataset_sink_mode=True)
if __name__ == '__main__':
main()