!14129 add HourNAS model zoo

From: @xinghaochen
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-29 09:07:26 +08:00 committed by Gitee
commit 8b796b9146
7 changed files with 1336 additions and 0 deletions

View File

@ -0,0 +1,104 @@
# Contents
- [HourNAS Description](#tinynet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [HourNAS Description](#contents)
HourNAS is an efficient neural architecture search method. Only using 3 hours (0.1 days) with one GPU, HourNAS can search an architecture that achieves a 77.0% Top-1 accuracy, which outperforms the state-of-the-art methods.
[Paper](https://arxiv.org/abs/2005.14446): Zhaohui Yang, Yunhe Wang, Xinghao Chen, Jianyuan Guo, Wei Zhang, Chao Xu, Chunjing Xu, Dacheng Tao, Chang Xu. HourNAS: Extremely Fast Neural Architecture Search Through an Hourglass Lens. In CVPR 2021.
# [Model architecture](#contents)
The overall network architecture of HourNAS is show below:
[Link](https://arxiv.org/abs/2005.14446)
# [Dataset](#contents)
Dataset used: [CIFAR-10](http://www.cs.toronto.edu/~kriz/cifar.html)
- Dataset size175M60,000 32*32 colorful images in 10 classes
- Train146M50,000 images
- Test29M10,000 images
- Data formatbinary files
- NoteData will be processed in src/dataset.py
# [Environment Requirements](#contents)
- Hardware (GPU)
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```markdown
.HourNAS
├── README.md # descriptions about HourNAS
├── src
│ ├── architectures.py # definition of HourNAS-F model
│ ├── dataset.py # data preprocessing
│ ├── hournasnet.py # HourNAS general architecture
│ └── utils.py # utility functions
├── eval.py # evaluation interface
```
### [Training process](#contents)
To Be Done
### [Evaluation Process](#contents)
#### Launch
```bash
# infer example
python eval.py --model hournas_f_c10 --dataset_path [DATA_PATH] --GPU --ckpt [CHECKPOINT_PATH]
```
### Result
```bash
result: {'Top1-Acc': 0.9618389423076923} ckpt= ./hournas_f_cifar10.ckpt
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Model | FLOPs (M) | Params (M) | ImageNet Top-1 |
| --------------- | --------- | ---------- | -------------- |
| MnasNet-A1 | 312 | 3.9 | 75.2% |
| HourNAS-E | 313 | 3.8 | 75.7% |
| EfficientNet-B0 | 390 | 5.3 | 76.8% |
| HourNAS-F | 383 | 5.3 | 77.0% |
More details in [Paper](https://arxiv.org/abs/2005.14446).
# [Description of Random Situation](#contents)
We set the seed inside dataset.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,100 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inference Interface"""
import sys
import argparse
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from mindspore import context
from mindspore import nn
from src.dataset import create_dataset_cifar10
from src.utils import count_params
from src.hournasnet import hournasnet
from easydict import EasyDict as edict
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--data_path', type=str, default='/home/workspace/mindspore_dataset/',
metavar='DIR', help='path to dataset')
parser.add_argument('--model', default='hournas_f_c10', type=str, metavar='MODEL',
help='Name of model to train (default: "tinynet_c"')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 10)')
parser.add_argument('-b', '--batch-size', type=int, default=256, metavar='N',
help='input batch size for training (default: 256)')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
parser.add_argument('--ckpt', type=str, default='./ms_hournas_f_c10.ckpt',
help='model checkpoint to load')
parser.add_argument('--GPU', action='store_true', default=True,
help='Use GPU for training (default: True)')
parser.add_argument('--dataset_sink', action='store_true', default=True)
parser.add_argument('--image-size', type=int, default=32, metavar='N',
help='input image size (default: 32)')
def main():
"""Main entrance for training"""
args = parser.parse_args()
print(sys.argv)
#context.set_context(mode=context.GRAPH_MODE)
context.set_context(mode=context.PYNATIVE_MODE)
if args.GPU:
context.set_context(device_target='GPU')
# parse model argument
assert args.model.startswith(
"hournas"), "Only Tinynet models are supported."
#_, sub_name = args.model.split("_")
net = hournasnet(args.model,
num_classes=args.num_classes,
drop_rate=0.0,
drop_connect_rate=0.0,
global_pool="avg",
bn_tf=False,
bn_momentum=None,
bn_eps=None)
print(net)
print("Total number of parameters:", count_params(net))
cfg = edict({'image_height': args.image_size, 'image_width': args.image_size,})
cfg.batch_size = args.batch_size
print(cfg)
#input_size = net.default_cfg['input_size'][1]
val_data_url = args.data_path #os.path.join(args.data_path, 'val')
val_dataset = create_dataset_cifar10(val_data_url, repeat_num=1, training=False, cifar_cfg=cfg)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
eval_metrics = {'Validation-Loss': Loss(),
'Top1-Acc': Top1CategoricalAccuracy(),
'Top5-Acc': Top5CategoricalAccuracy()}
ckpt = load_checkpoint(args.ckpt)
load_param_into_net(net, ckpt)
net.set_train(False)
model = Model(net, loss, metrics=eval_metrics)
metrics = model.eval(val_dataset, dataset_sink_mode=False)
print(metrics)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,22 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.hournasnet import hournasnet
def create_network(name, *args, **kwargs):
if name == 'HourNAS':
return hournasnet(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Architecture of HourNAS"""
predefine_archs = {
'hournas_f_c10': {
'genotypes': [
#'conv3bnrelu',
'ir_k3_e1_se',
'ir_k5_e6_se', 'ir_k5_e1_se', 'ir_k5_e1_se', 'ir_k3_e1_se',
'ir_k5_e6_se', 'ir_k5_e1_se', 'ir_k3_e1_se', 'ir_k5_e1_se',
'ir_k5_e6_se', 'ir_k3_e6_se', 'ir_k3_e6_se', 'ir_k3_e6_se',
'ir_k5_e6_se', 'ir_k5_e3_se', 'ir_k5_e3_se', 'ir_k5_e3_se',
'ir_k5_e6_se', 'ir_k5_e6_se', 'ir_k3_e6_se', 'ir_k5_e6_se',
'ir_k5_e6_se',
#'conv1', 'adaavgpool'
],
'strides': [
#1,
1,
1, 1, 1, 1,
1, 1, 1, 1,
2, 1, 1, 1,
1, 1, 1, 1,
2, 1, 1, 1,
1,
#1, 1
],
'out_channels': [
#32,
16,
24, 24, 24, 24,
40, 40, 40, 40,
80, 80, 80, 80,
112, 112, 112, 112,
192, 192, 192, 192,
320,
#1280, 1280,
],
'dropout_ratio': 0.2,
'default_init': 'True',
'se_ratio': '0.05'
},
}

View File

@ -0,0 +1,200 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py and eval.py"""
import math
import os
import numpy as np
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.dataset.transforms.py_transforms as py_transforms
import mindspore.dataset.transforms.c_transforms as c_transforms
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore.communication.management import get_rank, get_group_size
from mindspore.dataset.vision import Inter
import mindspore.dataset.vision.c_transforms as vision
# values that should remain constant
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# data preprocess configs
SCALE = (0.08, 1.0)
RATIO = (3./4., 4./3.)
ds.config.set_seed(1)
def split_imgs_and_labels(imgs, labels, batchInfo):
"""split data into labels and images"""
ret_imgs = []
ret_labels = []
for i, image in enumerate(imgs):
ret_imgs.append(image)
ret_labels.append(labels[i])
return np.array(ret_imgs), np.array(ret_labels)
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
input_size=224, color_jitter=0.4):
"""Create ImageNet training dataset"""
if not os.path.exists(train_data_url):
raise ValueError('Path not exists')
decode_op = py_vision.Decode()
type_cast_op = c_transforms.TypeCast(mstype.int32)
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size),
scale=SCALE, ratio=RATIO,
interpolation=Inter.BICUBIC)
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
contrast=adjust_range,
saturation=adjust_range)
to_tensor = py_vision.ToTensor()
normalize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
# assemble all the transforms
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op])
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset_train = ds.ImageFolderDataset(train_data_url,
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
dataset_train = dataset_train.map(input_columns=["image"],
operations=image_ops,
num_parallel_workers=workers)
dataset_train = dataset_train.map(input_columns=["label"],
operations=type_cast_op,
num_parallel_workers=workers)
# batch dealing
ds_train = dataset_train.batch(batch_size,
per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
ds_train = ds_train.repeat(1)
return ds_train
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False,
input_size=224):
"""Create ImageNet validation dataset"""
if not os.path.exists(val_data_url):
raise ValueError('Path not exists')
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id)
scale_size = None
if isinstance(input_size, tuple):
assert len(input_size) == 2
if input_size[-1] == input_size[-2]:
scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
else:
scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
else:
scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))
type_cast_op = c_transforms.TypeCast(mstype.int32)
decode_op = py_vision.Decode()
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
center_crop = py_vision.CenterCrop(size=input_size)
to_tensor = py_vision.ToTensor()
normalize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
to_tensor, normalize_op])
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
num_parallel_workers=workers)
dataset = dataset.map(input_columns=["image"], operations=image_ops,
num_parallel_workers=workers)
dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
dataset = dataset.repeat(1)
return dataset
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id
def create_dataset_cifar10(data_home, repeat_num=1, training=True, cifar_cfg=None):
"""Data operations."""
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
rank_size, rank_id = _get_rank_info()
if training:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True)
else:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False)
resize_height = cifar_cfg.image_height
resize_width = cifar_cfg.image_width
# define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
#normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.24703233, 0.24348505, 0.26158768))
changeswap_op = vision.HWC2CHW()
type_cast_op = c_transforms.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
# apply map operations on images
data_set = data_set.map(operations=type_cast_op, input_columns="label")
data_set = data_set.map(operations=c_trans, input_columns="image")
# apply batch operations
data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set

View File

@ -0,0 +1,766 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tinynet model definition"""
import math
import re
from copy import deepcopy
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.initializer import Normal, Zero, One, Uniform
from mindspore import ms_function
from mindspore import Tensor
from src.architectures import predefine_archs
# Imagenet constant values
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
relu = P.ReLU()
sigmoid = P.Sigmoid()
_DEBUG = False
# Default args for PyTorch BN impl
_BN_MOMENTUM_PT_DEFAULT = 0.1
_BN_EPS_PT_DEFAULT = 1e-5
_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
# Defaults used for Google/Tensorflow training of mobile networks /w
# RMSprop as per papers and TF reference implementations. PT momentum
# equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
_BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
def _initialize_weight_goog(shape=None, layer_type='conv', bias=False):
"""Google style weight initialization"""
if layer_type not in ('conv', 'bn', 'fc'):
raise ValueError(
'The layer type is not known, the supported are conv, bn and fc')
if bias:
return Zero()
if layer_type == 'conv':
assert isinstance(shape, (tuple, list)) and len(
shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively'
n = shape[1] * shape[1] * shape[2]
return Normal(math.sqrt(2.0 / n))
if layer_type == 'bn':
return One()
assert isinstance(shape, (tuple, list)) and len(
shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively'
n = shape[1]
init_range = 1.0 / math.sqrt(n)
return Uniform(init_range)
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0,
pad_mode='same', bias=False):
"""convolution wrapper"""
weight_init_value = _initialize_weight_goog(
shape=(in_channels, kernel_size, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias)
def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False):
"""1x1 convolution wrapper"""
weight_init_value = _initialize_weight_goog(
shape=(in_channels, 1, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias)
def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0,
pad_mode='same', bias=False):
"""group convolution wrapper"""
weight_init_value = _initialize_weight_goog(
shape=(in_channels, kernel_size, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
group=group, has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
group=group, has_bias=bias)
def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0):
return nn.BatchNorm2d(channels, eps=eps, momentum=1-momentum, gamma_init=gamma_init,
beta_init=beta_init)
def _dense(in_channels, output_channels, bias=True, activation=None):
weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels),
layer_type='fc')
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
bias_init=bias_init_value, has_bias=bias, activation=activation)
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
has_bias=bias, activation=activation)
def _resolve_bn_args(kwargs):
bn_args = _BN_ARGS_TF.copy() if kwargs.pop(
'bn_tf', False) else _BN_ARGS_PT.copy()
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
channel_min = channel_min or divisor
new_channels = max(
int(channels + divisor / 2) // divisor * divisor,
channel_min)
# Make sure that round down does not go down by more than 10%.
if new_channels < 0.9 * channels:
new_channels += divisor
return new_channels
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str, depth_multiplier=1.0):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v in ('re', 'r6', 'hs', 'sw'):
print('not support')
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
act_fn = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type in ('ds', 'dsa'):
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'er':
block_args = dict(
block_type=block_type,
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_fn=act_fn,
)
else:
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
"""further decode the architecture definition into model-ready format"""
arch_args = []
for _, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = _decode_block_str(block_str)
stack_args.append(ba)
repeats.append(rep)
arch_args.append(_scale_stage_depth(
stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args
class Swish(nn.Cell):
"""swish activation function"""
def __init__(self):
super(Swish, self).__init__()
self.sigmoid = P.Sigmoid()
def construct(self, x):
return x * self.sigmoid(x)
@ms_function
def swish(x):
return x * nn.Sigmoid()(x)
class BlockBuilder(nn.Cell):
"""build efficient-net convolution blocks"""
def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0,
channel_divisor=8, channel_min=None, pad_type='', act_fn=relu,
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
drop_connect_rate=0., verbose=False):
super(BlockBuilder, self).__init__()
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.pad_type = pad_type
self.act_fn = act_fn #Swish()
self.se_gate_fn = se_gate_fn
self.se_reduce_mid = se_reduce_mid
self.bn_args = bn_args
self.drop_connect_rate = drop_connect_rate
self.verbose = verbose
# updated during build
self.in_chs = None
self.block_idx = 0
self.block_count = 0
self.layer = self._make_layer(builder_in_channels, builder_block_args)
def _round_channels(self, chs):
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba):
"""make the current block based on the block argument"""
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# this is a hack to work around mismatch in origin impl input filters
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['bn_args'] = self.bn_args
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None
if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate * \
self.block_idx / self.block_count
ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid
block = InvertedResidual(**ba)
elif bt in ('ds', 'dsa'):
ba['drop_connect_rate'] = self.drop_connect_rate * \
self.block_idx / self.block_count
block = DepthwiseSeparableConv(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs']
return block
def _make_stack(self, stack_args):
"""make a stack of blocks"""
blocks = []
# each stack (stage) contains a list of block arguments
for i, ba in enumerate(stack_args):
if i >= 1:
# only the first block in any stack can have a stride > 1
ba['stride'] = 1
block = self._make_block(ba)
blocks.append(block)
self.block_idx += 1 # incr global idx (across all stacks)
return nn.SequentialCell(blocks)
def _make_layer(self, in_chs, block_args):
""" Build the entire layer
Args:
in_chs: Number of input-channels passed to first block
block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
self.in_chs = in_chs
self.block_count = sum([len(x) for x in block_args])
self.block_idx = 0
blocks = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for _, stack in enumerate(block_args):
assert isinstance(stack, list)
stack = self._make_stack(stack)
blocks.append(stack)
return nn.SequentialCell(blocks)
def construct(self, x):
return self.layer(x)
class DropConnect(nn.Cell):
"""drop connect implementation"""
def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
super(DropConnect, self).__init__()
self.shape = P.Shape()
self.dtype = P.DType()
self.keep_prob = 1 - drop_connect_rate
self.dropout = P.Dropout(keep_prob=self.keep_prob)
self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32)
def construct(self, x):
shape = self.shape(x)
dtype = self.dtype(x)
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
_, mask = self.dropout(ones_tensor)
x = x * mask
x = x / self.keep_prob_tensor
return x
def drop_connect(inputs, training=False, drop_connect_rate=0.):
if not training:
return inputs
return DropConnect(drop_connect_rate)(inputs)
class SqueezeExcite(nn.Cell):
"""squeeze-excite implementation"""
def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid):
super(SqueezeExcite, self).__init__()
self.act_fn = act_fn #Swish()
self.gate_fn = gate_fn
reduce_chs = reduce_chs or in_chs
self.conv_reduce = nn.Conv2d(in_channels=in_chs, out_channels=reduce_chs,
kernel_size=1, has_bias=False, pad_mode='pad')
self.conv_expand = nn.Conv2d(in_channels=reduce_chs, out_channels=in_chs,
kernel_size=1, has_bias=False, pad_mode='pad')
self.avg_global_pool = P.ReduceMean(keep_dims=True)
def construct(self, x):
x_se = self.avg_global_pool(x, (2, 3))
x_se = self.conv_reduce(x_se)
x_se = self.act_fn(x_se)
x_se = self.conv_expand(x_se)
x_se = self.gate_fn(x_se)
x = x * x_se
return x
class DepthwiseSeparableConv(nn.Cell):
"""depth-wise convolution -> (squeeze-excite) -> point-wise convolution"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_fn=relu, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid,
bn_args=None, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2], 'stride must be 1 or 2'
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act
self.act_fn = act_fn #Swish()
self.drop_connect_rate = drop_connect_rate
self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="pad",
padding=int(dw_kernel_size/2), has_bias=False, group=in_chs,
weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs]))
self.bn1 = _fused_bn(in_chs, **bn_args)
if self.has_se:
self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)),
act_fn=act_fn, gate_fn=se_gate_fn)
self.conv_pw = _conv1x1(in_chs, out_chs)
self.bn2 = _fused_bn(out_chs, **bn_args)
def construct(self, x):
"""forward the depthwise separable conv"""
identity = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act_fn(x)
if self.has_se:
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
if self.has_pw_act:
x = self.act_fn(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x = x + identity
return x
class InvertedResidual(nn.Cell):
"""inverted-residual block implementation"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1,
pad_type='', act_fn=relu, pw_kernel_size=1,
noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0.,
se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None,
bn_args=None, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.act_fn = act_fn #Swish()
self.drop_connect_rate = drop_connect_rate
self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size)
self.bn1 = _fused_bn(mid_chs, **bn_args)
self.shuffle_type = shuffle_type
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
self.shuffle = None
self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="pad",
padding=int(dw_kernel_size/2), has_bias=False, group=mid_chs,
weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs]))
self.bn2 = _fused_bn(mid_chs, **bn_args)
if self.has_se:
se_base_chs = mid_chs if se_reduce_mid else in_chs
#print(se_base_chs)
self.se = SqueezeExcite(
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)),
act_fn=act_fn, gate_fn=se_gate_fn
)
self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size)
self.bn3 = _fused_bn(out_chs, **bn_args)
def construct(self, x):
"""forward the inverted-residual block"""
identity = x
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act_fn(x)
x = self.conv_dw(x)
x = self.bn2(x)
x = self.act_fn(x)
if self.has_se:
x = self.se(x)
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0:
x = drop_connect(x, self.training, self.drop_connect_rate)
x = x + identity
return x
class GenEfficientNet(nn.Cell):
"""Generate EfficientNet architecture"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0.,
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
global_pool='avg', head_conv='default', weight_init='goog'):
super(GenEfficientNet, self).__init__()
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
self.num_classes = num_classes
self.drop_rate = drop_rate
self.num_features = num_features
self.conv_stem = _conv(in_chans, stem_size, 3,
stride=1, padding=1, pad_mode='pad')
self.bn1 = _fused_bn(stem_size, **bn_args)
self.act_fn = relu #Swish()
in_chans = stem_size
self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier,
channel_divisor, channel_min,
pad_type, act_fn, se_gate_fn, se_reduce_mid,
bn_args, drop_connect_rate, verbose=_DEBUG)
in_chs = self.blocks.in_chs
if not head_conv or head_conv == 'none':
self.efficient_head = False
self.conv_head = None
assert in_chs == self.num_features
else:
self.efficient_head = head_conv == 'efficient'
self.conv_head = _conv1x1(in_chs, self.num_features)
self.bn2 = None if self.efficient_head else _fused_bn(
self.num_features, **bn_args)
self.global_pool = P.ReduceMean(keep_dims=True)
self.classifier = _dense(self.num_features, self.num_classes)
self.reshape = P.Reshape()
self.shape = P.Shape()
self.drop_out = nn.Dropout(keep_prob=1-self.drop_rate)
def construct(self, x):
"""efficient net entry point"""
x = self.conv_stem(x)
#aux = x
x = self.bn1(x)
x = self.act_fn(x)
x = self.blocks(x)
if self.efficient_head:
x = self.global_pool(x, (2, 3))
x = self.conv_head(x)
x = self.act_fn(x)
x = self.reshape(self.shape(x)[0], -1)
else:
if self.conv_head is not None:
x = self.conv_head(x)
x = self.bn2(x)
x = self.act_fn(x)
x = self.global_pool(x, (2, 3))
x = self.reshape(x, (self.shape(x)[0], -1))
if self.training and self.drop_rate > 0.:
x = self.drop_out(x)
#print('forward')
#aux = x
return self.classifier(x) #, aux
def _convert_arch_def(genotypes, strides, out_channels, se_ratio):
"""Convert HourNAS architecture to EfficientNet arch_def.
HourNAS style:
'genotypes' : [
#'conv3bnrelu',
'ir_k3_e1_se',
'ir_k5_e6_se', 'ir_k5_e1_se', 'ir_k5_e1_se', 'ir_k3_e1_se',
'ir_k5_e6_se', 'ir_k5_e1_se', 'ir_k3_e1_se', 'ir_k5_e1_se',
'ir_k5_e6_se', 'ir_k3_e6_se', 'ir_k3_e6_se', 'ir_k3_e6_se',
'ir_k5_e6_se', 'ir_k5_e3_se', 'ir_k5_e3_se', 'ir_k5_e3_se',
'ir_k5_e6_se', 'ir_k5_e6_se', 'ir_k3_e6_se', 'ir_k5_e6_se',
'ir_k5_e6_se',
#'conv1', 'adaavgpool'
],
EfficientNet style:
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25'],
['ir_r2_k3_s2_e6_c24_se0.25'],
['ir_r2_k5_s2_e6_c40_se0.25'],
['ir_r3_k3_s2_e6_c80_se0.25'],
['ir_r3_k5_s1_e6_c112_se0.25'],
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
"""
arch_def = []
for genotype, stride, out_channel in zip(genotypes, strides, out_channels):
arch_str = genotype.replace('se', 'se'+se_ratio)
arch_str = arch_str + '_r1'
arch_str = arch_str + '_c{}'.format(out_channel)
arch_str = arch_str + '_s{}'.format(stride)
arch_def.append([arch_str])
return arch_def
def _gen_efficientnet(genotypes, strides, out_channels, se_ratio, num_classes=1000, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
Args:
channel_multiplier (int): multiplier to number of channels per layer
depth_multiplier (int): multiplier to number of repeats per stage
"""
arch_def = _convert_arch_def(genotypes, strides, out_channels, se_ratio)
print(arch_def)
channel_multiplier = 1.0
depth_multiplier = 1.0
num_features = max(1280, _round_channels(
1280, channel_multiplier, 8, None))
model = GenEfficientNet(
_decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
num_classes=num_classes,
stem_size=32,
channel_multiplier=channel_multiplier,
num_features=num_features,
bn_args=_resolve_bn_args(kwargs),
act_fn=relu,
se_reduce_mid=True,
**kwargs)
return model
def hournasnet(arch="hournas_f_c10", num_classes=10, in_chans=3, **kwargs):
""" HourNAS Models """
# choose a sub model
genotypes = predefine_archs[arch]['genotypes']
strides = predefine_archs[arch]['strides']
out_channels = predefine_archs[arch]['out_channels']
se_ratio = predefine_archs[arch]['se_ratio']
model = _gen_efficientnet(
genotypes=genotypes, strides=strides, out_channels=out_channels, se_ratio=se_ratio,
num_classes=num_classes, in_chans=in_chans, **kwargs)
return model

View File

@ -0,0 +1,89 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""model utils"""
import math
import argparse
import numpy as np
def str2bool(value):
"""Convert string arguments to bool type"""
if value.lower() in ('yes', 'true', 't', 'y', '1'):
return True
if value.lower() in ('no', 'false', 'f', 'n', '0'):
return False
raise argparse.ArgumentTypeError('Boolean value expected.')
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_epochs=1, decay_rate=0.9,
warmup_epochs=0., warmup_lr_init=0., global_epoch=0):
"""Get scheduled learning rate"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
global_steps = steps_per_epoch * global_epoch
self_warmup_delta = ((base_lr - warmup_lr_init) / \
warmup_epochs) if warmup_epochs > 0 else 0
self_decay_rate = decay_rate if decay_rate < 1 else 1/decay_rate
for i in range(total_steps):
epochs = math.floor(i/steps_per_epoch)
cond = 1 if (epochs < warmup_epochs) else 0
warmup_lr = warmup_lr_init + epochs * self_warmup_delta
decay_nums = math.floor(epochs / decay_epochs)
decay_rate = math.pow(self_decay_rate, decay_nums)
decay_lr = base_lr * decay_rate
lr = cond * warmup_lr + (1 - cond) * decay_lr
lr_each_step.append(lr)
lr_each_step = lr_each_step[global_steps:]
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def add_weight_decay(net, weight_decay=1e-5, skip_list=None):
"""Apply weight decay to only conv and dense layers (len(shape) > =2)
Args:
net (mindspore.nn.Cell): Mindspore network instance
weight_decay (float): weight decay tobe used.
skip_list (tuple): list of parameter names without weight decay
Returns:
A list of group of parameters, separated by different weight decay.
"""
decay = []
no_decay = []
if not skip_list:
skip_list = ()
for param in net.trainable_params():
if len(param.shape) == 1 or \
param.name.endswith(".bias") or \
param.name in skip_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
def count_params(net):
"""Count number of parameters in the network
Args:
net (mindspore.nn.Cell): Mindspore network instance
Returns:
total_params (int): Total number of trainable params
"""
total_params = 0
for param in net.trainable_params():
total_params += np.prod(param.shape)
return total_params