forked from mindspore-Ecosystem/mindspore
add GPU efficientnet to model_zoo
This commit is contained in:
parent
9c79b9d712
commit
bd4e441862
|
@ -0,0 +1,111 @@
|
|||
# EfficientNet-B0 Example
|
||||
|
||||
## Description
|
||||
|
||||
This is an example of training EfficientNet-B0 in MindSpore.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [Mindspore](http://www.mindspore.cn/install/en).
|
||||
- Download the dataset.
|
||||
|
||||
## Structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└─nasnet
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p)
|
||||
├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p)
|
||||
└─run_eval_for_gpu.sh # launch evaluating with gpu platform
|
||||
├─src
|
||||
├─config.py # parameter configuration
|
||||
├─dataset.py # data preprocessing
|
||||
├─efficientnet.py # network definition
|
||||
├─loss.py # Customized loss function
|
||||
├─transform_utils.py # random augment utils
|
||||
├─transform.py # random augment class
|
||||
├─eval.py # eval net
|
||||
└─train.py # train net
|
||||
|
||||
```
|
||||
|
||||
## Parameter Configuration
|
||||
|
||||
Parameters for both training and evaluating can be set in config.py
|
||||
|
||||
```
|
||||
'random_seed': 1, # fix random seed
|
||||
'model': 'efficientnet_b0', # model name
|
||||
'drop': 0.2, # dropout rate
|
||||
'drop_connect': 0.2, # drop connect rate
|
||||
'opt_eps': 0.001, # optimizer epsilon
|
||||
'lr': 0.064, # learning rate LR
|
||||
'batch_size': 128, # batch size
|
||||
'decay_epochs': 2.4, # epoch interval to decay LR
|
||||
'warmup_epochs': 5, # epochs to warmup LR
|
||||
'decay_rate': 0.97, # LR decay rate
|
||||
'weight_decay': 1e-5, # weight decay
|
||||
'epochs': 600, # number of epochs to train
|
||||
'workers': 8, # number of data processing processes
|
||||
'amp_level': 'O0', # amp level
|
||||
'opt': 'rmsprop', # optimizer
|
||||
'num_classes': 1000, # number of classes
|
||||
'gp': 'avg', # type of global pool, "avg", "max", "avgmax", "avgmaxc"
|
||||
'momentum': 0.9, # optimizer momentum
|
||||
'warmup_lr_init': 0.0001, # init warmup LR
|
||||
'smoothing': 0.1, # label smoothing factor
|
||||
'bn_tf': False, # use Tensorflow BatchNorm defaults
|
||||
'keep_checkpoint_max': 10, # max number ckpts to keep
|
||||
'loss_scale': 1024, # loss scale
|
||||
'resume_start_epoch': 0, # resume start epoch
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
### Train
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train_for_gpu.sh DATA_DIR
|
||||
# standalone training
|
||||
sh run_standalone_train_for_gpu.sh DATA_DIR DEVICE_ID
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# distributed training example(8p) for GPU
|
||||
sh scripts/run_distribute_train_for_gpu.sh /dataset
|
||||
# standalone training example for GPU
|
||||
sh scripts/run_standalone_train_for_gpu.sh /dataset 0
|
||||
```
|
||||
|
||||
#### Result
|
||||
|
||||
You can find checkpoint file together with result in log.
|
||||
|
||||
### Evaluation
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# Evaluation
|
||||
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# Evaluation with checkpoint
|
||||
sh scripts/run_eval_for_gpu.sh /dataset 0 ./checkpoint/efficientnet_b0-600_1251.ckpt
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
#### Result
|
||||
|
||||
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate imagenet"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import efficientnet_b0_config_gpu as cfg
|
||||
from src.dataset import create_dataset_val
|
||||
from src.efficientnet import efficientnet_b0
|
||||
from src.loss import LabelSmoothingCrossEntropy
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of efficientnet (Default: None)')
|
||||
parser.add_argument('--data_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||
|
||||
net = efficientnet_b0(num_classes=cfg.num_classes,
|
||||
drop_rate=cfg.drop,
|
||||
drop_connect_rate=cfg.drop_connect,
|
||||
global_pool=cfg.gp,
|
||||
bn_tf=cfg.bn_tf,
|
||||
)
|
||||
|
||||
ckpt = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
val_data_url = os.path.join(args_opt.data_path, 'val')
|
||||
dataset = create_dataset_val(cfg.batch_size, val_data_url, workers=cfg.workers, distributed=False)
|
||||
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||
|
||||
metrics = model.eval(dataset)
|
||||
print("metric: ", metrics)
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
curtime=`date '+%Y%m%d-%H%M%S'`
|
||||
RANK_SIZE=8
|
||||
|
||||
rm ${current_exec_path}/device_parallel/ -rf
|
||||
mkdir ${current_exec_path}/device_parallel
|
||||
echo ${curtime} > ${current_exec_path}/device_parallel/starttime
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE python ${current_exec_path}/train.py \
|
||||
--GPU \
|
||||
--distributed \
|
||||
--data_path ${DATA_DIR} \
|
||||
--cur_time ${curtime} > ${current_exec_path}/device_parallel/efficientnet_b0.log 2>&1 &
|
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
DEVICE_ID=$2
|
||||
PATH_CHECKPOINT=$3
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
curtime=`date '+%Y%m%d-%H%M%S'`
|
||||
|
||||
echo ${curtime} > ${current_exec_path}/eval_starttime
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ./eval.py --platform 'GPU' --data_path ${DATA_DIR} --checkpoint ${PATH_CHECKPOINT} > ${current_exec_path}/eval.log 2>&1 &
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
DEVICE_ID=$2
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
curtime=`date '+%Y%m%d-%H%M%S'`
|
||||
|
||||
rm ${current_exec_path}/device_${DEVICE_ID}/ -rf
|
||||
mkdir ${current_exec_path}/device_${DEVICE_ID}
|
||||
echo ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/starttime
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ${current_exec_path}/train.py \
|
||||
--GPU \
|
||||
--data_path ${DATA_DIR} \
|
||||
--cur_time ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/efficientnet_b0.log 2>&1 &
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
efficientnet_b0_config_gpu = edict({
|
||||
'random_seed': 1,
|
||||
'model': 'efficientnet_b0',
|
||||
'drop': 0.2,
|
||||
'drop_connect': 0.2,
|
||||
'opt_eps': 0.001,
|
||||
'lr': 0.064,
|
||||
'batch_size': 128,
|
||||
'decay_epochs': 2.4,
|
||||
'warmup_epochs': 5,
|
||||
'decay_rate': 0.97,
|
||||
'weight_decay': 1e-5,
|
||||
'epochs': 600,
|
||||
'workers': 8,
|
||||
'amp_level': 'O0',
|
||||
'opt': 'rmsprop',
|
||||
'num_classes': 1000,
|
||||
#'Type of global pool, "avg", "max", "avgmax", "avgmaxc"
|
||||
'gp': 'avg',
|
||||
'momentum': 0.9,
|
||||
'warmup_lr_init': 0.0001,
|
||||
'smoothing': 0.1,
|
||||
#Use Tensorflow BatchNorm defaults for models that support it
|
||||
'bn_tf': False,
|
||||
'keep_checkpoint_max': 10,
|
||||
'loss_scale': 1024,
|
||||
'resume_start_epoch': 0,
|
||||
})
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.dataset.vision import Inter
|
||||
|
||||
from src.config import efficientnet_b0_config_gpu as cfg
|
||||
from src.transform import RandAugment
|
||||
|
||||
ds.config.set_seed(cfg.random_seed)
|
||||
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
img_size = (224, 224)
|
||||
crop_pct = 0.875
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
inter_method = 'bilinear'
|
||||
resize_value = 224 # img_size
|
||||
scale = (0.08, 1.0)
|
||||
ratio = (3./4., 4./3.)
|
||||
inter_str = 'bicubic'
|
||||
|
||||
def str2MsInter(method):
|
||||
if method == 'bicubic':
|
||||
return Inter.BICUBIC
|
||||
if method == 'nearest':
|
||||
return Inter.NEAREST
|
||||
return Inter.BILINEAR
|
||||
|
||||
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False):
|
||||
if not os.path.exists(train_data_url):
|
||||
raise ValueError('Path not exists')
|
||||
interpolation = str2MsInter(inter_str)
|
||||
|
||||
c_decode_op = C.Decode()
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
random_resize_crop_op = C.RandomResizedCrop(size=(resize_value, resize_value), scale=scale, ratio=ratio,
|
||||
interpolation=interpolation)
|
||||
random_horizontal_flip_op = C.RandomHorizontalFlip(0.5)
|
||||
|
||||
efficient_rand_augment = RandAugment()
|
||||
|
||||
image_ops = [c_decode_op, random_resize_crop_op, random_horizontal_flip_op]
|
||||
|
||||
rank_id = get_rank() if distributed else 0
|
||||
rank_size = get_group_size() if distributed else 1
|
||||
|
||||
dataset_train = ds.ImageFolderDataset(train_data_url,
|
||||
num_parallel_workers=workers,
|
||||
shuffle=True,
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id)
|
||||
dataset_train = dataset_train.map(input_columns=["image"],
|
||||
operations=image_ops,
|
||||
num_parallel_workers=workers)
|
||||
dataset_train = dataset_train.map(input_columns=["label"],
|
||||
operations=type_cast_op,
|
||||
num_parallel_workers=workers)
|
||||
ds_train = dataset_train.batch(batch_size,
|
||||
per_batch_map=efficient_rand_augment,
|
||||
input_columns=["image", "label"],
|
||||
num_parallel_workers=2,
|
||||
drop_remainder=True)
|
||||
ds_train = ds_train.repeat(1)
|
||||
return ds_train
|
||||
|
||||
|
||||
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False):
|
||||
if not os.path.exists(val_data_url):
|
||||
raise ValueError('Path not exists')
|
||||
rank_id = get_rank() if distributed else 0
|
||||
rank_size = get_group_size() if distributed else 1
|
||||
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=False)
|
||||
scale_size = None
|
||||
interpolation = str2MsInter(inter_method)
|
||||
|
||||
if isinstance(img_size, tuple):
|
||||
assert len(img_size) == 2
|
||||
if img_size[-1] == img_size[-2]:
|
||||
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||
else:
|
||||
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
decode_op = C.Decode()
|
||||
resize_op = C.Resize(size=scale_size, interpolation=interpolation)
|
||||
center_crop = C.CenterCrop(size=224)
|
||||
rescale_op = C.Rescale(rescale, shift)
|
||||
normalize_op = C.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||
changeswap_op = C.HWC2CHW()
|
||||
|
||||
ctrans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, changeswap_op]
|
||||
|
||||
dataset = dataset.map(input_columns=["label"], operations=type_cast_op, num_parallel_workers=workers)
|
||||
dataset = dataset.map(input_columns=["image"], operations=ctrans, num_parallel_workers=workers)
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=workers)
|
||||
dataset = dataset.repeat(1)
|
||||
return dataset
|
|
@ -0,0 +1,746 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""EfficientNet model definition"""
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, ms_function
|
||||
from mindspore.common.initializer import (Normal, One, Uniform, Zero,
|
||||
initializer)
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.composite import clip_by_value
|
||||
|
||||
relu = P.ReLU()
|
||||
sigmoid = P.Sigmoid()
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||||
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
||||
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'efficientnet_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'),
|
||||
'efficientnet_b1': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'efficientnet_b2': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'efficientnet_b3': _cfg(
|
||||
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'efficientnet_b4': _cfg(
|
||||
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
}
|
||||
|
||||
_DEBUG = False
|
||||
|
||||
_BN_MOMENTUM_PT_DEFAULT = 0.1
|
||||
_BN_EPS_PT_DEFAULT = 1e-5
|
||||
_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
|
||||
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
_BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
|
||||
|
||||
|
||||
def _initialize_weight_goog(shape=None, layer_type='conv', bias=False):
|
||||
if layer_type not in ('conv', 'bn', 'fc'):
|
||||
raise ValueError('The layer type is not known, the supported are conv, bn and fc')
|
||||
if bias:
|
||||
return Zero()
|
||||
if layer_type == 'conv':
|
||||
assert isinstance(shape, (tuple, list)) and len(
|
||||
shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively'
|
||||
n = shape[1] * shape[1] * shape[2]
|
||||
return Normal(math.sqrt(2.0 / n))
|
||||
if layer_type == 'bn':
|
||||
return One()
|
||||
assert isinstance(shape, (tuple, list)) and len(
|
||||
shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively'
|
||||
n = shape[1]
|
||||
init_range = 1.0 / math.sqrt(n)
|
||||
return Uniform(init_range)
|
||||
|
||||
|
||||
def _initialize_weight_default(shape=None, layer_type='conv', bias=False):
|
||||
if layer_type not in ('conv', 'bn', 'fc'):
|
||||
raise ValueError('The layer type is not known, the supported are conv, bn and fc')
|
||||
if bias and layer_type == 'bn':
|
||||
return Zero()
|
||||
if layer_type == 'conv':
|
||||
return One()
|
||||
if layer_type == 'bn':
|
||||
return One()
|
||||
return One()
|
||||
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same', bias=False):
|
||||
weight_init_value = _initialize_weight_goog(shape=(in_channels, kernel_size, out_channels))
|
||||
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
|
||||
if bias:
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
has_bias=bias, bias_init=bias_init_value)
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
has_bias=bias)
|
||||
|
||||
|
||||
def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False):
|
||||
weight_init_value = _initialize_weight_goog(shape=(in_channels, 1, out_channels))
|
||||
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
|
||||
if bias:
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
has_bias=bias, bias_init=bias_init_value)
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
has_bias=bias)
|
||||
|
||||
|
||||
def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0, pad_mode='same', bias=False):
|
||||
weight_init_value = _initialize_weight_goog(shape=(in_channels, kernel_size, out_channels))
|
||||
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
|
||||
if bias:
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
group=group, has_bias=bias, bias_init=bias_init_value)
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
group=group, has_bias=bias)
|
||||
|
||||
|
||||
def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0):
|
||||
return nn.BatchNorm2d(channels, eps=eps, momentum=1 - momentum, gamma_init=gamma_init, beta_init=beta_init)
|
||||
|
||||
|
||||
def _dense(in_channels, output_channels, bias=True, activation=None):
|
||||
weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels), layer_type='fc')
|
||||
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
|
||||
if bias:
|
||||
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, bias_init=bias_init_value,
|
||||
has_bias=bias, activation=activation)
|
||||
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, has_bias=bias,
|
||||
activation=activation)
|
||||
|
||||
|
||||
def _resolve_bn_args(kwargs):
|
||||
bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy()
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
bn_eps = kwargs.pop('bn_eps', None)
|
||||
if bn_eps is not None:
|
||||
bn_args['eps'] = bn_eps
|
||||
return bn_args
|
||||
|
||||
|
||||
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if not multiplier:
|
||||
return channels
|
||||
|
||||
channels *= multiplier
|
||||
channel_min = channel_min or divisor
|
||||
new_channels = max(
|
||||
int(channels + divisor / 2) // divisor * divisor,
|
||||
channel_min)
|
||||
if new_channels < 0.9 * channels:
|
||||
new_channels += divisor
|
||||
return new_channels
|
||||
|
||||
|
||||
def _parse_ksize(ss):
|
||||
if ss.isdigit():
|
||||
return int(ss)
|
||||
return [int(k) for k in ss.split('.')]
|
||||
|
||||
|
||||
def _decode_block_str(block_str, depth_multiplier=1.0):
|
||||
""" Decode block definition string
|
||||
|
||||
Gets a list of block arg (dicts) through a string notation of arguments.
|
||||
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
||||
|
||||
All args can exist in any order with the exception of the leading string which
|
||||
is assumed to indicate the block type.
|
||||
|
||||
leading string - block type (
|
||||
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
|
||||
r - number of repeat blocks,
|
||||
k - kernel size,
|
||||
s - strides (1-9),
|
||||
e - expansion ratio,
|
||||
c - output channels,
|
||||
se - squeeze/excitation ratio
|
||||
n - activation fn ('re', 'r6', 'hs', or 'sw')
|
||||
Args:
|
||||
block_str: a string representation of block arguments.
|
||||
Returns:
|
||||
A list of block args (dicts)
|
||||
Raises:
|
||||
ValueError: if the string def not properly specified (TODO)
|
||||
"""
|
||||
assert isinstance(block_str, str)
|
||||
ops = block_str.split('_')
|
||||
block_type = ops[0]
|
||||
ops = ops[1:]
|
||||
options = {}
|
||||
noskip = False
|
||||
for op in ops:
|
||||
if op == 'noskip':
|
||||
noskip = True
|
||||
elif op.startswith('n'):
|
||||
# activation fn
|
||||
key = op[0]
|
||||
v = op[1:]
|
||||
if v == 're':
|
||||
print('not support')
|
||||
elif v == 'r6':
|
||||
print('not support')
|
||||
elif v == 'hs':
|
||||
print('not support')
|
||||
elif v == 'sw':
|
||||
print('not support')
|
||||
else:
|
||||
continue
|
||||
options[key] = value
|
||||
else:
|
||||
# all numeric options
|
||||
splits = re.split(r'(\d.*)', op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
act_fn = options['n'] if 'n' in options else None
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
fake_in_chs = int(options['fc']) if 'fc' in options else 0
|
||||
|
||||
num_repeat = int(options['r'])
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
if block_type == 'ir':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
noskip=noskip,
|
||||
)
|
||||
elif block_type in ('ds', 'dsa'):
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
pw_act=block_type == 'dsa',
|
||||
noskip=block_type == 'dsa' or noskip,
|
||||
)
|
||||
elif block_type == 'er':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
exp_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
fake_in_chs=fake_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
noskip=noskip,
|
||||
)
|
||||
elif block_type == 'cn':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
kernel_size=int(options['k']),
|
||||
out_chs=int(options['c']),
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
)
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
|
||||
return block_args, num_repeat
|
||||
|
||||
|
||||
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
""" Per-stage depth scaling
|
||||
Scales the block repeats in each stage. This depth scaling impl maintains
|
||||
compatibility with the EfficientNet scaling method, while allowing sensible
|
||||
scaling for other models that may have multiple block arg definitions in each stage.
|
||||
"""
|
||||
|
||||
# We scale the total repeat count for each stage, there may be multiple
|
||||
# block arg defs per stage so we need to sum.
|
||||
num_repeat = sum(repeats)
|
||||
if depth_trunc == 'round':
|
||||
# Truncating to int by rounding allows stages with few repeats to remain
|
||||
# proportionally smaller for longer. This is a good choice when stage definitions
|
||||
# include single repeat stages that we'd prefer to keep that way as long as possible
|
||||
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
||||
else:
|
||||
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
||||
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
||||
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
||||
|
||||
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
||||
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
||||
# The first block makes less sense to repeat in most of the arch definitions.
|
||||
repeats_scaled = []
|
||||
for r in repeats[::-1]:
|
||||
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
||||
repeats_scaled.append(rs)
|
||||
num_repeat -= r
|
||||
num_repeat_scaled -= rs
|
||||
repeats_scaled = repeats_scaled[::-1]
|
||||
|
||||
# Apply the calculated scaling to each block arg in the stage
|
||||
sa_scaled = []
|
||||
for ba, rep in zip(stack_args, repeats_scaled):
|
||||
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
||||
return sa_scaled
|
||||
|
||||
|
||||
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
arch_args = []
|
||||
for _, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
stack_args = []
|
||||
repeats = []
|
||||
for block_str in block_strings:
|
||||
assert isinstance(block_str, str)
|
||||
ba, rep = _decode_block_str(block_str)
|
||||
stack_args.append(ba)
|
||||
repeats.append(rep)
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
return arch_args
|
||||
|
||||
|
||||
@ms_function
|
||||
def hard_swish(x):
|
||||
x = P.Cast()(x, ms.float32)
|
||||
y = x + 3.0
|
||||
y = clip_by_value(y, 0.0, 6.0)
|
||||
y = y / 6.0
|
||||
return x * y
|
||||
|
||||
|
||||
class BlockBuilder(nn.Cell):
|
||||
def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0, channel_divisor=8,
|
||||
channel_min=None, pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False,
|
||||
bn_args=None, drop_connect_rate=0., verbose=False):
|
||||
super(BlockBuilder, self).__init__()
|
||||
|
||||
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
self.channel_min = channel_min
|
||||
self.pad_type = pad_type
|
||||
self.act_fn = act_fn
|
||||
self.se_gate_fn = se_gate_fn
|
||||
self.se_reduce_mid = se_reduce_mid
|
||||
self.bn_args = bn_args
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.verbose = verbose
|
||||
|
||||
self.in_chs = None
|
||||
self.block_idx = 0
|
||||
self.block_count = 0
|
||||
self.layer = self._make_layer(builder_in_channels, builder_block_args)
|
||||
|
||||
def _round_channels(self, chs):
|
||||
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
|
||||
def _make_block(self, ba):
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
||||
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
||||
ba['bn_args'] = self.bn_args
|
||||
ba['pad_type'] = self.pad_type
|
||||
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
|
||||
assert ba['act_fn'] is not None
|
||||
if bt == 'ir':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||
ba['se_gate_fn'] = self.se_gate_fn
|
||||
ba['se_reduce_mid'] = self.se_reduce_mid
|
||||
if self.verbose:
|
||||
logging.info(' InvertedResidual %d, Args: %s', self.block_idx, str(ba))
|
||||
block = InvertedResidual(**ba)
|
||||
elif bt in ('ds', 'dsa'):
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||
if self.verbose:
|
||||
logging.info(' DepthwiseSeparable %d, Args: %s', self.block_idx, str(ba))
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
self.in_chs = ba['out_chs']
|
||||
|
||||
return block
|
||||
|
||||
def _make_stack(self, stack_args):
|
||||
blocks = []
|
||||
# each stack (stage) contains a list of block arguments
|
||||
for i, ba in enumerate(stack_args):
|
||||
if self.verbose:
|
||||
logging.info(' Block: %d', i)
|
||||
if i >= 1:
|
||||
# only the first block in any stack can have a stride > 1
|
||||
ba['stride'] = 1
|
||||
block = self._make_block(ba)
|
||||
blocks.append(block)
|
||||
self.block_idx += 1
|
||||
return nn.SequentialCell(blocks)
|
||||
|
||||
def _make_layer(self, in_chs, block_args):
|
||||
""" Build the blocks
|
||||
Args:
|
||||
in_chs: Number of input-channels passed to first block
|
||||
block_args: A list of lists, outer list defines stages, inner
|
||||
list contains strings defining block configuration(s)
|
||||
Return:
|
||||
List of block stacks (each stack wrapped in nn.Sequential)
|
||||
"""
|
||||
if self.verbose:
|
||||
logging.info('Building model trunk with %d stages...', len(block_args))
|
||||
self.in_chs = in_chs
|
||||
self.block_count = sum([len(x) for x in block_args])
|
||||
self.block_idx = 0
|
||||
blocks = []
|
||||
|
||||
for stack_idx, stack in enumerate(block_args):
|
||||
if self.verbose:
|
||||
logging.info('Stack: %d', stack_idx)
|
||||
assert isinstance(stack, list)
|
||||
stack = self._make_stack(stack)
|
||||
blocks.append(stack)
|
||||
return nn.SequentialCell(blocks)
|
||||
|
||||
def construct(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class DepthWiseConv(nn.Cell):
|
||||
def __init__(self, in_planes, kernel_size, stride):
|
||||
super(DepthWiseConv, self).__init__()
|
||||
platform = context.get_context("device_target")
|
||||
weight_shape = [1, kernel_size, in_planes]
|
||||
weight_init = _initialize_weight_goog(shape=weight_shape)
|
||||
if platform == "GPU":
|
||||
self.depthwise_conv = P.Conv2D(out_channel=in_planes * 1, kernel_size=kernel_size,
|
||||
stride=stride, pad_mode="same", group=in_planes)
|
||||
self.weight = Parameter(initializer(
|
||||
weight_init, [in_planes * 1, 1, kernel_size, kernel_size]), name='depthwise_weight')
|
||||
else:
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(
|
||||
channel_multiplier=1, kernel_size=kernel_size, stride=stride, pad_mode='same',)
|
||||
self.weight = Parameter(initializer(
|
||||
weight_init, [1, in_planes, kernel_size, kernel_size]), name='depthwise_weight')
|
||||
|
||||
def construct(self, x):
|
||||
x = self.depthwise_conv(x, self.weight)
|
||||
return x
|
||||
|
||||
|
||||
class DropConnect(nn.Cell):
|
||||
def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
|
||||
super(DropConnect, self).__init__()
|
||||
self.shape = P.Shape()
|
||||
self.dtype = P.DType()
|
||||
self.keep_prob = 1 - drop_connect_rate
|
||||
self.dropout = P.Dropout(keep_prob=self.keep_prob)
|
||||
|
||||
def construct(self, x):
|
||||
shape = self.shape(x)
|
||||
dtype = self.dtype(x)
|
||||
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
|
||||
_, mask_ = self.dropout(ones_tensor)
|
||||
x = x * mask_
|
||||
return x
|
||||
|
||||
|
||||
def drop_connect(inputs, training=False, drop_connect_rate=0.):
|
||||
if not training:
|
||||
return inputs
|
||||
return DropConnect(drop_connect_rate)(inputs)
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Cell):
|
||||
def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
self.act_fn = act_fn
|
||||
self.gate_fn = gate_fn
|
||||
reduce_chs = reduce_chs or in_chs
|
||||
self.conv_reduce = _dense(in_chs, reduce_chs, bias=True)
|
||||
self.conv_expand = _dense(reduce_chs, in_chs, bias=True)
|
||||
self.avg_global_pool = P.ReduceMean(keep_dims=False)
|
||||
|
||||
def construct(self, x):
|
||||
x_se = self.avg_global_pool(x, (2, 3))
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act_fn(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x_se = self.gate_fn(x_se)
|
||||
x_se = P.ExpandDims()(x_se, 2)
|
||||
x_se = P.ExpandDims()(x_se, 3)
|
||||
x = x * x_se
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Cell):
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, pad_type='', act_fn=relu, noskip=False,
|
||||
pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid,
|
||||
bn_args=None, drop_connect_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
|
||||
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
|
||||
assert stride in [1, 2], 'stride must be 1 or 2'
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act
|
||||
self.act_fn = act_fn
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
|
||||
self.bn1 = _fused_bn(in_chs, **bn_args)
|
||||
|
||||
#
|
||||
if self.has_se:
|
||||
self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)),
|
||||
act_fn=act_fn, gate_fn=se_gate_fn)
|
||||
self.conv_pw = _conv1x1(in_chs, out_chs)
|
||||
self.bn2 = _fused_bn(out_chs, **bn_args)
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
if self.has_pw_act:
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x = x + identity
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1,
|
||||
pad_type='', act_fn=relu, pw_kernel_size=1,
|
||||
noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0.,
|
||||
se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None,
|
||||
bn_args=None, drop_connect_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
|
||||
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
|
||||
mid_chs = int(in_chs * exp_ratio)
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.act_fn = act_fn
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size)
|
||||
self.bn1 = _fused_bn(mid_chs, **bn_args)
|
||||
|
||||
self.shuffle_type = shuffle_type
|
||||
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
|
||||
self.shuffle = None
|
||||
|
||||
self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
|
||||
self.bn2 = _fused_bn(mid_chs, **bn_args)
|
||||
|
||||
if self.has_se:
|
||||
se_base_chs = mid_chs if se_reduce_mid else in_chs
|
||||
self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)),
|
||||
act_fn=act_fn, gate_fn=se_gate_fn)
|
||||
|
||||
self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size)
|
||||
self.bn3 = _fused_bn(out_chs, **bn_args)
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x = x + identity
|
||||
return x
|
||||
|
||||
|
||||
class GenEfficientNet(nn.Cell):
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0.,
|
||||
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
|
||||
global_pool='avg', head_conv='default', weight_init='goog'):
|
||||
super(GenEfficientNet, self).__init__()
|
||||
|
||||
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
self.act_fn = act_fn
|
||||
self.num_features = num_features
|
||||
|
||||
stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
self.conv_stem = _conv(in_chans, stem_size, 3, stride=2)
|
||||
self.bn1 = _fused_bn(stem_size, **bn_args)
|
||||
in_chans = stem_size
|
||||
self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier, channel_divisor, channel_min,
|
||||
pad_type, act_fn, se_gate_fn, se_reduce_mid,
|
||||
bn_args, drop_connect_rate, verbose=_DEBUG)
|
||||
in_chs = self.blocks.in_chs
|
||||
|
||||
if not head_conv or head_conv == 'none':
|
||||
self.efficient_head = False
|
||||
self.conv_head = None
|
||||
assert in_chs == self.num_features
|
||||
else:
|
||||
self.efficient_head = head_conv == 'efficient'
|
||||
self.conv_head = _conv1x1(in_chs, self.num_features)
|
||||
self.bn2 = None if self.efficient_head else _fused_bn(self.num_features, **bn_args)
|
||||
self.global_pool = P.ReduceMean(keep_dims=True)
|
||||
self.classifier = _dense(self.num_features, self.num_classes)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.drop_out = nn.Dropout(keep_prob=1 - self.drop_rate)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.blocks(x)
|
||||
|
||||
if self.efficient_head:
|
||||
x = self.global_pool(x, (2, 3))
|
||||
x = self.conv_head(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.reshape(self.shape(x)[0], -1)
|
||||
else:
|
||||
if self.conv_head is not None:
|
||||
x = self.conv_head(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.global_pool(x, (2, 3))
|
||||
x = self.reshape(x, (self.shape(x)[0], -1))
|
||||
|
||||
if self.training and self.drop_rate > 0.:
|
||||
x = self.drop_out(x)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
"""Creates an EfficientNet model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||
Paper: https://arxiv.org/abs/1905.11946
|
||||
|
||||
EfficientNet params
|
||||
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
||||
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
||||
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
||||
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
||||
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
||||
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
||||
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
||||
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
||||
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer
|
||||
depth_multiplier: multiplier to number of repeats per stage
|
||||
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||
['ir_r2_k3_s2_e6_c24_se0.25'],
|
||||
['ir_r2_k5_s2_e6_c40_se0.25'],
|
||||
['ir_r3_k3_s2_e6_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
]
|
||||
num_features = _round_channels(1280, channel_multiplier, 8, None)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def, depth_multiplier),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
num_features=num_features,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=hard_swish,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def efficientnet_b0(num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B0 """
|
||||
default_cfg = default_cfgs['efficientnet_b0']
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=1.0, depth_multiplier=1.0,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
return model
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network."""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
|
||||
class LabelSmoothingCrossEntropy(_Loss):
|
||||
|
||||
def __init__(self, smooth_factor=0.1, num_classes=1000):
|
||||
super(LabelSmoothingCrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logits, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logits, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
return loss_logit
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
random augment class
|
||||
"""
|
||||
import numpy as np
|
||||
import mindspore.dataset.vision.py_transforms as P
|
||||
from src import transform_utils
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
class RandAugment:
|
||||
# config_str belongs to str
|
||||
# hparams belongs to dict
|
||||
def __init__(self, config_str="rand-m9-mstd0.5", hparams=None):
|
||||
hparams = hparams if hparams is not None else {}
|
||||
self.config_str = config_str
|
||||
self.hparams = hparams
|
||||
|
||||
def __call__(self, imgs, labels, batchInfo):
|
||||
# assert the imgs objetc are pil_images
|
||||
ret_imgs = []
|
||||
ret_labels = []
|
||||
py_to_pil_op = P.ToPIL()
|
||||
to_tensor = P.ToTensor()
|
||||
normalize_op = P.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||
rand_augment_ops = transform_utils.rand_augment_transform(self.config_str, self.hparams)
|
||||
for i, image in enumerate(imgs):
|
||||
img_pil = py_to_pil_op(image)
|
||||
img_pil = rand_augment_ops(img_pil)
|
||||
img_array = to_tensor(img_pil)
|
||||
img_array = normalize_op(img_array)
|
||||
ret_imgs.append(img_array)
|
||||
ret_labels.append(labels[i])
|
||||
return np.array(ret_imgs), np.array(ret_labels)
|
|
@ -0,0 +1,571 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
random augment utils
|
||||
"""
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||
_FILL = (128, 128, 128)
|
||||
_MAX_LEVEL = 10.
|
||||
_HPARAMS_DEFAULT = dict(translate_const=250, img_mean=_FILL)
|
||||
_RAND_TRANSFORMS = [
|
||||
'Distort',
|
||||
'Zoom',
|
||||
'Blur',
|
||||
'Skew',
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'PosterizeTpu',
|
||||
'Solarize',
|
||||
'SolarizeAdd',
|
||||
'Color',
|
||||
'Contrast',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
]
|
||||
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
||||
_RAND_CHOICE_WEIGHTS_0 = {
|
||||
'Rotate': 0.3,
|
||||
'ShearX': 0.2,
|
||||
'ShearY': 0.2,
|
||||
'TranslateXRel': 0.1,
|
||||
'TranslateYRel': 0.1,
|
||||
'Color': .025,
|
||||
'Sharpness': 0.025,
|
||||
'AutoContrast': 0.025,
|
||||
'Solarize': .005,
|
||||
'SolarizeAdd': .005,
|
||||
'Contrast': .005,
|
||||
'Brightness': .005,
|
||||
'Equalize': .005,
|
||||
'PosterizeTpu': 0,
|
||||
'Invert': 0,
|
||||
'Distort': 0,
|
||||
'Zoom': 0,
|
||||
'Blur': 0,
|
||||
'Skew': 0,
|
||||
}
|
||||
|
||||
|
||||
def _interpolation(kwargs):
|
||||
interpolation = kwargs.pop('resample', Image.BILINEAR)
|
||||
if isinstance(interpolation, (list, tuple)):
|
||||
return random.choice(interpolation)
|
||||
return interpolation
|
||||
|
||||
|
||||
def _check_args_tf(kwargs):
|
||||
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
||||
kwargs.pop('fillcolor')
|
||||
kwargs['resample'] = _interpolation(kwargs)
|
||||
|
||||
# define all kinds of functions
|
||||
|
||||
|
||||
def _randomly_negate(v):
|
||||
return -v if random.random() > 0.5 else v
|
||||
|
||||
|
||||
def shear_x(img, factor, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def shear_y(img, factor, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def translate_x_rel(img, pct, **kwargs):
|
||||
pixels = pct * img.size[0]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def translate_y_rel(img, pct, **kwargs):
|
||||
pixels = pct * img.size[1]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
||||
|
||||
|
||||
def translate_x_abs(img, pixels, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def translate_y_abs(img, pixels, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
||||
|
||||
|
||||
def rotate(img, degrees, **kwargs):
|
||||
kwargs_new = kwargs
|
||||
kwargs_new.pop('resample')
|
||||
kwargs_new['resample'] = Image.BICUBIC
|
||||
if _PIL_VER >= (5, 2):
|
||||
return img.rotate(degrees, **kwargs_new)
|
||||
if _PIL_VER >= (5, 0):
|
||||
w, h = img.size
|
||||
post_trans = (0, 0)
|
||||
rotn_center = (w / 2.0, h / 2.0)
|
||||
angle = -math.radians(degrees)
|
||||
matrix = [
|
||||
round(math.cos(angle), 15),
|
||||
round(math.sin(angle), 15),
|
||||
0.0,
|
||||
round(-math.sin(angle), 15),
|
||||
round(math.cos(angle), 15),
|
||||
0.0,
|
||||
]
|
||||
|
||||
def transform(x, y, matrix):
|
||||
(a, b, c, d, e, f) = matrix
|
||||
return a * x + b * y + c, d * x + e * y + f
|
||||
|
||||
matrix[2], matrix[5] = transform(
|
||||
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
|
||||
)
|
||||
matrix[2] += rotn_center[0]
|
||||
matrix[5] += rotn_center[1]
|
||||
return img.transform(img.size, Image.AFFINE, matrix, **kwargs_new)
|
||||
return img.rotate(degrees, resample=kwargs['resample'])
|
||||
|
||||
|
||||
def auto_contrast(img, **__):
|
||||
return ImageOps.autocontrast(img)
|
||||
|
||||
|
||||
def invert(img, **__):
|
||||
return ImageOps.invert(img)
|
||||
|
||||
|
||||
def equalize(img, **__):
|
||||
return ImageOps.equalize(img)
|
||||
|
||||
|
||||
def solarize(img, thresh, **__):
|
||||
return ImageOps.solarize(img, thresh)
|
||||
|
||||
|
||||
def solarize_add(img, add, thresh=128, **__):
|
||||
lut = []
|
||||
for i in range(256):
|
||||
if i < thresh:
|
||||
lut.append(min(255, i + add))
|
||||
else:
|
||||
lut.append(i)
|
||||
if img.mode in ("L", "RGB"):
|
||||
if img.mode == "RGB" and len(lut) == 256:
|
||||
lut = lut + lut + lut
|
||||
return img.point(lut)
|
||||
return img
|
||||
|
||||
|
||||
def posterize(img, bits_to_keep, **__):
|
||||
if bits_to_keep >= 8:
|
||||
return img
|
||||
return ImageOps.posterize(img, bits_to_keep)
|
||||
|
||||
|
||||
def contrast(img, factor, **__):
|
||||
return ImageEnhance.Contrast(img).enhance(factor)
|
||||
|
||||
|
||||
def color(img, factor, **__):
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
|
||||
def brightness(img, factor, **__):
|
||||
return ImageEnhance.Brightness(img).enhance(factor)
|
||||
|
||||
|
||||
def sharpness(img, factor, **__):
|
||||
return ImageEnhance.Sharpness(img).enhance(factor)
|
||||
|
||||
|
||||
def _rotate_level_to_arg(level, _hparams):
|
||||
# range [-30, 30]
|
||||
level = (level / _MAX_LEVEL) * 30.
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
|
||||
|
||||
def _enhance_level_to_arg(level, _hparams):
|
||||
# range [0.1, 1.9]
|
||||
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
||||
|
||||
|
||||
def _shear_level_to_arg(level, _hparams):
|
||||
# range [-0.3, 0.3]
|
||||
level = (level / _MAX_LEVEL) * 0.3
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
|
||||
|
||||
def _translate_abs_level_to_arg(level, hparams):
|
||||
translate_const = hparams['translate_const']
|
||||
level = (level / _MAX_LEVEL) * float(translate_const)
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
|
||||
|
||||
def _translate_rel_level_to_arg(level, _hparams):
|
||||
# range [-0.45, 0.45]
|
||||
level = (level / _MAX_LEVEL) * 0.45
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
|
||||
|
||||
def _posterize_original_level_to_arg(level, _hparams):
|
||||
# As per original AutoAugment paper description
|
||||
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
||||
return (int((level / _MAX_LEVEL) * 4) + 4,)
|
||||
|
||||
|
||||
def _posterize_research_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow models research and UDA impl
|
||||
# range [4, 0], 'keep 4 down to 0 MSB of original image'
|
||||
return (4 - int((level / _MAX_LEVEL) * 4),)
|
||||
|
||||
|
||||
def _posterize_tpu_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow TPU EfficientNet impl
|
||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||
return (int((level / _MAX_LEVEL) * 4),)
|
||||
|
||||
|
||||
def _solarize_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
return (int((level / _MAX_LEVEL) * 256),)
|
||||
|
||||
|
||||
def _solarize_add_level_to_arg(level, _hparams):
|
||||
# range [0, 110]
|
||||
return (int((level / _MAX_LEVEL) * 110),)
|
||||
|
||||
|
||||
def _distort_level_to_arg(level, _hparams):
|
||||
return (int((level / _MAX_LEVEL) * 10 + 10),)
|
||||
|
||||
|
||||
def _zoom_level_to_arg(level, _hparams):
|
||||
return ((level / _MAX_LEVEL) * 0.4,)
|
||||
|
||||
|
||||
def _blur_level_to_arg(level, _hparams):
|
||||
level = (level / _MAX_LEVEL) * 0.5
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
|
||||
|
||||
def _skew_level_to_arg(level, _hparams):
|
||||
level = (level / _MAX_LEVEL) * 0.3
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
|
||||
|
||||
def distort(img, v, **__):
|
||||
w, h = img.size
|
||||
horizontal_tiles = int(0.1 * v)
|
||||
vertical_tiles = int(0.1 * v)
|
||||
|
||||
width_of_square = int(math.floor(w / float(horizontal_tiles)))
|
||||
height_of_square = int(math.floor(h / float(vertical_tiles)))
|
||||
width_of_last_square = w - (width_of_square * (horizontal_tiles - 1))
|
||||
height_of_last_square = h - (height_of_square * (vertical_tiles - 1))
|
||||
dimensions = []
|
||||
|
||||
for vertical_tile in range(vertical_tiles):
|
||||
for horizontal_tile in range(horizontal_tiles):
|
||||
if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1):
|
||||
dimensions.append([horizontal_tile * width_of_square,
|
||||
vertical_tile * height_of_square,
|
||||
width_of_last_square + (horizontal_tile * width_of_square),
|
||||
height_of_last_square + (height_of_square * vertical_tile)])
|
||||
elif vertical_tile == (vertical_tiles - 1):
|
||||
dimensions.append([horizontal_tile * width_of_square,
|
||||
vertical_tile * height_of_square,
|
||||
width_of_square + (horizontal_tile * width_of_square),
|
||||
height_of_last_square + (height_of_square * vertical_tile)])
|
||||
elif horizontal_tile == (horizontal_tiles - 1):
|
||||
dimensions.append([horizontal_tile * width_of_square,
|
||||
vertical_tile * height_of_square,
|
||||
width_of_last_square + (horizontal_tile * width_of_square),
|
||||
height_of_square + (height_of_square * vertical_tile)])
|
||||
else:
|
||||
dimensions.append([horizontal_tile * width_of_square,
|
||||
vertical_tile * height_of_square,
|
||||
width_of_square + (horizontal_tile * width_of_square),
|
||||
height_of_square + (height_of_square * vertical_tile)])
|
||||
last_column = []
|
||||
for i in range(vertical_tiles):
|
||||
last_column.append((horizontal_tiles - 1) + horizontal_tiles * i)
|
||||
|
||||
last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles)
|
||||
|
||||
polygons = []
|
||||
for x1, y1, x2, y2 in dimensions:
|
||||
polygons.append([x1, y1, x1, y2, x2, y2, x2, y1])
|
||||
|
||||
polygon_indices = []
|
||||
for i in range((vertical_tiles * horizontal_tiles) - 1):
|
||||
if i not in last_row and i not in last_column:
|
||||
polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles])
|
||||
|
||||
for a, b, c, d in polygon_indices:
|
||||
dx = v
|
||||
dy = v
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a]
|
||||
polygons[a] = [x1, y1,
|
||||
x2, y2,
|
||||
x3 + dx, y3 + dy,
|
||||
x4, y4]
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b]
|
||||
polygons[b] = [x1, y1,
|
||||
x2 + dx, y2 + dy,
|
||||
x3, y3,
|
||||
x4, y4]
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c]
|
||||
polygons[c] = [x1, y1,
|
||||
x2, y2,
|
||||
x3, y3,
|
||||
x4 + dx, y4 + dy]
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d]
|
||||
polygons[d] = [x1 + dx, y1 + dy,
|
||||
x2, y2,
|
||||
x3, y3,
|
||||
x4, y4]
|
||||
|
||||
generated_mesh = []
|
||||
for idx, i in enumerate(dimensions):
|
||||
generated_mesh.append([dimensions[idx], polygons[idx]])
|
||||
return img.transform(img.size, PIL.Image.MESH, generated_mesh, resample=PIL.Image.BICUBIC)
|
||||
|
||||
|
||||
def zoom(img, v, **__):
|
||||
#assert 0.1 <= v <= 2
|
||||
w, h = img.size
|
||||
image_zoomed = img.resize((int(round(img.size[0] * v)),
|
||||
int(round(img.size[1] * v))),
|
||||
resample=PIL.Image.BICUBIC)
|
||||
w_zoomed, h_zoomed = image_zoomed.size
|
||||
|
||||
return image_zoomed.crop((math.floor((float(w_zoomed) / 2) - (float(w) / 2)),
|
||||
math.floor((float(h_zoomed) / 2) - (float(h) / 2)),
|
||||
math.floor((float(w_zoomed) / 2) + (float(w) / 2)),
|
||||
math.floor((float(h_zoomed) / 2) + (float(h) / 2))))
|
||||
|
||||
|
||||
def erase(img, v, **__):
|
||||
#assert 0.1<= v <= 1
|
||||
w, h = img.size
|
||||
w_occlusion = int(w * v)
|
||||
h_occlusion = int(h * v)
|
||||
if len(img.getbands()) == 1:
|
||||
rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion) * 255))
|
||||
else:
|
||||
rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion, len(img.getbands())) * 255))
|
||||
|
||||
random_position_x = random.randint(0, w - w_occlusion)
|
||||
random_position_y = random.randint(0, h - h_occlusion)
|
||||
img.paste(rectangle, (random_position_x, random_position_y))
|
||||
return img
|
||||
|
||||
|
||||
def skew(img, v, **__):
|
||||
#assert -1 <= v <= 1
|
||||
w, h = img.size
|
||||
x1 = 0
|
||||
x2 = h
|
||||
y1 = 0
|
||||
y2 = w
|
||||
original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)]
|
||||
max_skew_amount = max(w, h)
|
||||
max_skew_amount = int(math.ceil(max_skew_amount * v))
|
||||
skew_amount = max_skew_amount
|
||||
new_plane = [(y1 - skew_amount, x1), # Top Left
|
||||
(y2, x1 - skew_amount), # Top Right
|
||||
(y2 + skew_amount, x2), # Bottom Right
|
||||
(y1, x2 + skew_amount)]
|
||||
matrix = []
|
||||
for p1, p2 in zip(new_plane, original_plane):
|
||||
matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
|
||||
matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
|
||||
|
||||
A = np.matrix(matrix, dtype=np.float)
|
||||
B = np.array(original_plane).reshape(8)
|
||||
perspective_skew_coefficients_matrix = np.dot(np.linalg.pinv(A), B)
|
||||
perspective_skew_coefficients_matrix = np.array(perspective_skew_coefficients_matrix).reshape(8)
|
||||
|
||||
return img.transform(img.size, PIL.Image.PERSPECTIVE, perspective_skew_coefficients_matrix,
|
||||
resample=PIL.Image.BICUBIC)
|
||||
|
||||
|
||||
def blur(img, v, **__):
|
||||
#assert -3 <= v <= 3
|
||||
return img.filter(PIL.ImageFilter.GaussianBlur(v))
|
||||
|
||||
|
||||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [AutoAugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
def _select_rand_weights(weight_idx=0, transforms=None):
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
assert weight_idx == 0 # only one set of weights currently
|
||||
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
||||
probs = [rand_weights[k] for k in transforms]
|
||||
probs /= np.sum(probs)
|
||||
return probs
|
||||
|
||||
|
||||
def rand_augment_transform(config_str, hparams):
|
||||
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
weight_idx = None # default to no probability weights for op choice
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'rand'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
num_layers = int(val)
|
||||
elif key == 'w':
|
||||
weight_idx = int(val)
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||
|
||||
final_result = RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
return final_result
|
||||
|
||||
|
||||
LEVEL_TO_ARG = {
|
||||
'Distort': _distort_level_to_arg,
|
||||
'Zoom': _zoom_level_to_arg,
|
||||
'Blur': _blur_level_to_arg,
|
||||
'Skew': _skew_level_to_arg,
|
||||
'AutoContrast': None,
|
||||
'Equalize': None,
|
||||
'Invert': None,
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
'PosterizeOriginal': _posterize_original_level_to_arg,
|
||||
'PosterizeResearch': _posterize_research_level_to_arg,
|
||||
'PosterizeTpu': _posterize_tpu_level_to_arg,
|
||||
'Solarize': _solarize_level_to_arg,
|
||||
'SolarizeAdd': _solarize_add_level_to_arg,
|
||||
'Color': _enhance_level_to_arg,
|
||||
'Contrast': _enhance_level_to_arg,
|
||||
'Brightness': _enhance_level_to_arg,
|
||||
'Sharpness': _enhance_level_to_arg,
|
||||
'ShearX': _shear_level_to_arg,
|
||||
'ShearY': _shear_level_to_arg,
|
||||
'TranslateX': _translate_abs_level_to_arg,
|
||||
'TranslateY': _translate_abs_level_to_arg,
|
||||
'TranslateXRel': _translate_rel_level_to_arg,
|
||||
'TranslateYRel': _translate_rel_level_to_arg,
|
||||
}
|
||||
|
||||
NAME_TO_OP = {
|
||||
'Distort': distort,
|
||||
'Zoom': zoom,
|
||||
'Blur': blur,
|
||||
'Skew': skew,
|
||||
'AutoContrast': auto_contrast,
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'PosterizeOriginal': posterize,
|
||||
'PosterizeResearch': posterize,
|
||||
'PosterizeTpu': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
'Contrast': contrast,
|
||||
'Brightness': brightness,
|
||||
'Sharpness': sharpness,
|
||||
'ShearX': shear_x,
|
||||
'ShearY': shear_y,
|
||||
'TranslateX': translate_x_abs,
|
||||
'TranslateY': translate_y_abs,
|
||||
'TranslateXRel': translate_x_rel,
|
||||
'TranslateYRel': translate_y_rel,
|
||||
}
|
||||
|
||||
|
||||
class AutoAugmentOp:
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
self.aug_fn = NAME_TO_OP[name]
|
||||
self.level_fn = LEVEL_TO_ARG[name]
|
||||
self.prob = prob
|
||||
self.magnitude = magnitude
|
||||
self.hparams = hparams.copy()
|
||||
self.kwargs = dict(
|
||||
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
||||
)
|
||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() > self.prob:
|
||||
return img
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std and self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||
|
||||
|
||||
class RandAugment:
|
||||
def __init__(self, ops, num_layers=2, choice_weights=None):
|
||||
self.ops = ops
|
||||
self.num_layers = num_layers
|
||||
self.choice_weights = choice_weights
|
||||
|
||||
def __call__(self, img):
|
||||
ops = np.random.choice(
|
||||
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
|
||||
for op in ops:
|
||||
img = op(img)
|
||||
return img
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train imagenet."""
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.communication.management import get_group_size, get_rank, init
|
||||
from mindspore.nn import SGD, RMSProp
|
||||
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
|
||||
ModelCheckpoint, TimeMonitor)
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.config import efficientnet_b0_config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.efficientnet import efficientnet_b0
|
||||
from src.loss import LabelSmoothingCrossEntropy
|
||||
|
||||
mindspore.common.set_seed(cfg.random_seed)
|
||||
random.seed(cfg.random_seed)
|
||||
np.random.seed(cfg.random_seed)
|
||||
|
||||
|
||||
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1,
|
||||
decay_rate=0.9, warmup_steps=0., warmup_lr_init=0., global_epoch=0):
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
global_steps = steps_per_epoch * global_epoch
|
||||
self_warmup_delta = ((base_lr - warmup_lr_init) /
|
||||
warmup_steps) if warmup_steps > 0 else 0
|
||||
self_decay_rate = decay_rate if decay_rate < 1 else 1 / decay_rate
|
||||
for i in range(total_steps):
|
||||
steps = math.floor(i / steps_per_epoch)
|
||||
cond = 1 if (steps < warmup_steps) else 0
|
||||
warmup_lr = warmup_lr_init + steps * self_warmup_delta
|
||||
decay_nums = math.floor(steps / decay_steps)
|
||||
decay_rate = math.pow(self_decay_rate, decay_nums)
|
||||
decay_lr = base_lr * decay_rate
|
||||
lr = cond * warmup_lr + (1 - cond) * decay_lr
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = lr_each_step[global_steps:]
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def get_outdir(path, *paths, inc=False):
|
||||
outdir = os.path.join(path, *paths)
|
||||
if not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
|
||||
elif inc:
|
||||
count = 1
|
||||
outdir_inc = outdir + '-' + str(count)
|
||||
while os.path.exists(outdir_inc):
|
||||
count = count + 1
|
||||
outdir_inc = outdir + '-' + str(count)
|
||||
assert count < 100
|
||||
outdir = outdir_inc
|
||||
os.makedirs(outdir)
|
||||
return outdir
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Training configuration', add_help=False)
|
||||
parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--distributed', action='store_true', default=False)
|
||||
parser.add_argument('--GPU', action='store_true', default=False,
|
||||
help='Use GPU for training (default: False)')
|
||||
parser.add_argument('--cur_time', type=str,
|
||||
default='19701010-000000', help='current time')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||
|
||||
|
||||
def main():
|
||||
args, _ = parser.parse_known_args()
|
||||
devid, rank_id, rank_size = 0, 0, 1
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
if args.distributed:
|
||||
if args.GPU:
|
||||
init("nccl")
|
||||
context.set_context(device_target='GPU')
|
||||
else:
|
||||
init()
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(
|
||||
device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False)
|
||||
context.reset_auto_parallel_context()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, device_num=rank_size)
|
||||
else:
|
||||
if args.GPU:
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
is_master = not args.distributed or (rank_id == 0)
|
||||
|
||||
net = efficientnet_b0(num_classes=cfg.num_classes,
|
||||
drop_rate=cfg.drop,
|
||||
drop_connect_rate=cfg.drop_connect,
|
||||
global_pool=cfg.gp,
|
||||
bn_tf=cfg.bn_tf,
|
||||
)
|
||||
|
||||
cur_time = args.cur_time
|
||||
output_base = './output'
|
||||
|
||||
exp_name = '-'.join([
|
||||
cur_time,
|
||||
cfg.model,
|
||||
str(224)
|
||||
])
|
||||
time.sleep(rank_id)
|
||||
output_dir = get_outdir(output_base, exp_name)
|
||||
|
||||
train_data_url = os.path.join(args.data_path, 'train')
|
||||
train_dataset = create_dataset(
|
||||
cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed)
|
||||
batches_per_epoch = train_dataset.get_dataset_size()
|
||||
|
||||
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
||||
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
|
||||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
loss_scale_manager = FixedLossScaleManager(
|
||||
cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(
|
||||
prefix=cfg.model, directory=output_dir, config=config_ck)
|
||||
|
||||
lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch,
|
||||
decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate,
|
||||
warmup_steps=cfg.warmup_epochs, warmup_lr_init=cfg.warmup_lr_init,
|
||||
global_epoch=cfg.resume_start_epoch))
|
||||
if cfg.opt == 'sgd':
|
||||
optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay,
|
||||
loss_scale=cfg.loss_scale
|
||||
)
|
||||
elif cfg.opt == 'rmsprop':
|
||||
optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale
|
||||
)
|
||||
|
||||
loss.add_flags_recursive(fp32=True, fp16=False)
|
||||
|
||||
if args.resume:
|
||||
ckpt = load_checkpoint(args.resume)
|
||||
load_param_into_net(net, ckpt)
|
||||
|
||||
model = Model(net, loss, optimizer,
|
||||
loss_scale_manager=loss_scale_manager,
|
||||
amp_level=cfg.amp_level
|
||||
)
|
||||
|
||||
callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else []
|
||||
|
||||
if args.resume:
|
||||
real_epoch = cfg.epochs - cfg.resume_start_epoch
|
||||
model.train(real_epoch, train_dataset,
|
||||
callbacks=callbacks, dataset_sink_mode=True)
|
||||
else:
|
||||
model.train(cfg.epochs, train_dataset,
|
||||
callbacks=callbacks, dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue