Adding TinyNet to Model Zoo

Adding TinyNet (https://arxiv.org/abs/2010.14819) MindSpore implementation to model Zoo
This commit is contained in:
yanglf1121 2020-10-31 19:52:39 +08:00
parent fc5a3b7d97
commit 882301f4b5
12 changed files with 1976 additions and 0 deletions

View File

@ -60,6 +60,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md) - [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md)
- [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md) - [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md)
- [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md) - [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md)
- [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md)
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp) - [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md) - [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)

View File

@ -0,0 +1,154 @@
# Contents
- [TinyNet Description](#tinynet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [TinyNet Description](#contents)
TinyNets are a series of lightweight models obtained by twisting resolution, depth and width with a data-driven tiny formula. TinyNet outperforms EfficientNet and MobileNetV3.
[Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020.
Note: We have only released TinyNet-C for now, and will release other TinyNets soon.
# [Model architecture](#contents)
The overall network architecture of TinyNet is show below:
[Link](https://arxiv.org/abs/2010.14819)
# [Dataset](#contents)
Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/)
- Dataset size:
- Train: 1.2 million images in 1,000 classes
- Test: 50,000 validation images in 1,000 classes
- Data format: RGB images.
- Note: Data will be processed in src/dataset/dataset.py
# [Environment Requirements](#contents)
- Hardware (GPU)
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```
.tinynet
├── Readme.md # descriptions about tinynet
├── script
│ ├── eval.sh # evaluation script
│ ├── train_1p_gpu.sh # training script on single GPU
│ └── train_distributed_gpu.sh # distributed training script on multiple GPUs
├── src
│ ├── callback.py # loss and checkpoint callbacks
│ ├── dataset.py # data processing
│ ├── loss.py # label-smoothing cross-entropy loss function
│ ├── tinynet.py # tinynet architecture
│ └── utils.py # utility functions
├── eval.py # evaluation interface
└── train.py # training interface
```
## [Training process](#contents)
### Launch
```
# training on single GPU
sh train_1p_gpu.sh
# training on multiple GPUs, the number after -n indicates how many GPUs will be used for training
sh train_distributed_gpu.sh -n 8
```
Inside train.sh, there are hyperparameters that can be adjusted during training, for example:
```
--model tinynet_c model to be used for training
--drop 0.2 dropout rate
--drop-connect 0 drop connect rate
--num-classes 1000 number of classes for training
--opt-eps 0.001 optimizer's epsilon
--lr 0.048 learning rate
--batch-size 128 batch size
--decay-epochs 2.4 learning rate decays every 2.4 epoch
--warmup-lr 1e-6 warm up learning rate
--warmup-epochs 3 learning rate warm up epoch
--decay-rate 0.97 learning rate decay rate
--ema-decay 0.9999 decay factor for model weights moving average
--weight-decay 1e-5 optimizer's weight decay
--epochs 450 number of epochs to be trained
--ckpt_save_epoch 1 checkpoint saving interval
--workers 8 number of processes for loading data
--amp_level O0 training auto-mixed precision
--opt rmsprop optimizers, currently we support SGD and RMSProp
--data_path /path_to_ImageNet/
--GPU using GPU for training
--dataset_sink using sink mode
```
The config above was used to train tinynets on ImageNet (change drop-connect to 0.2 for training tinynet-b)
> checkpoints will be saved in the ./device_{rank_id} folder (single GPU)
or ./device_parallel folder (multiple GPUs)
## [Eval process](#contents)
### Launch
```
# infer example
sh eval.sh
```
Inside the eval.sh, there are configs that can be adjusted during inference, for example:
```
--num-classes 1000
--batch-size 128
--workers 8
--data_path /path_to_ImageNet/
--GPU
--ckpt /path_to_EMA_checkpoint/
--dataset_sink > tinynet_c_eval.log 2>&1 &
```
> checkpoint can be produced in training process.
# [Model Description](#contents)
## [Performance](#contents)
#### Evaluation Performance
| Model | FLOPs | Latency* | ImageNet Top-1 |
| ------------------- | ----- | -------- | -------------- |
| EfficientNet-B0 | 387M | 99.85 ms | 76.7% |
| TinyNet-A | 339M | 81.30 ms | 76.8% |
| EfficientNet-B^{-4} | 24M | 11.54 ms | 56.7% |
| TinyNet-E | 24M | 9.18 ms | 59.9% |
*Latency is measured using MS Lite on Huawei P40 smartphone.
*More details in [Paper](https://arxiv.org/abs/2010.14819).
# [Description of Random Situation](#contents)
We set the seed inside dataset.py. We also use random seed in train.py.
# [Model Zoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inference Interface"""
import sys
import os
import argparse
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from mindspore import context
from src.dataset import create_dataset_val
from src.utils import count_params
from src.loss import LabelSmoothingCrossEntropy
from src.tinynet import tinynet
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/',
metavar='DIR', help='path to dataset')
parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL',
help='Name of model to train (default: "tinynet_c"')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--ckpt', type=str, default=None,
help='model checkpoint to load')
parser.add_argument('--GPU', action='store_true', default=True,
help='Use GPU for training (default: True)')
parser.add_argument('--dataset_sink', action='store_true', default=True)
def main():
"""Main entrance for training"""
args = parser.parse_args()
print(sys.argv)
context.set_context(mode=context.GRAPH_MODE)
if args.GPU:
context.set_context(device_target='GPU')
# parse model argument
assert args.model.startswith(
"tinynet"), "Only Tinynet models are supported."
_, sub_name = args.model.split("_")
net = tinynet(sub_model=sub_name,
num_classes=args.num_classes,
drop_rate=0.0,
drop_connect_rate=0.0,
global_pool="avg",
bn_tf=False,
bn_momentum=None,
bn_eps=None)
print("Total number of parameters:", count_params(net))
input_size = net.default_cfg['input_size'][1]
val_data_url = os.path.join(args.data_path, 'val')
val_dataset = create_dataset_val(args.batch_size,
val_data_url,
workers=args.workers,
distributed=False,
input_size=input_size)
loss = LabelSmoothingCrossEntropy(smooth_factor=args.smoothing,
num_classes=args.num_classes)
loss.add_flags_recursive(fp32=True, fp16=False)
eval_metrics = {'Validation-Loss': Loss(),
'Top1-Acc': Top1CategoricalAccuracy(),
'Top5-Acc': Top5CategoricalAccuracy()}
ckpt = load_checkpoint(args.ckpt)
load_param_into_net(net, ckpt)
net.set_train(False)
model = Model(net, loss, metrics=eval_metrics)
metrics = model.eval(val_dataset, dataset_sink_mode=False)
print(metrics)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cd ../ || exit
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_SIZE=1
export start=0
export value=$((start + RANK_SIZE))
export curtime
curtime=$(date '+%Y%m%d-%H%M%S')
echo "$curtime"
rm ${current_exec_path}/device${start}_$curtime/ -rf
mkdir ${current_exec_path}/device${start}_$curtime
cd ${current_exec_path}/device${start}_$curtime || exit
export RANK_ID=start
export DEVICE_ID=start
time python3 ${current_exec_path}/eval.py \
--model tinynet_c \
--num-classes 1000 \
--batch-size 128 \
--workers 8 \
--data_path /path_to_ImageNet/\
--GPU \
--ckpt /path_to_ckpt/ \
--dataset_sink > tinynet_c_eval.log 2>&1 &

View File

@ -0,0 +1,59 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cd ../ || exit
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_SIZE=1
export start=0
export value=$(($start+$RANK_SIZE))
export curtime
curtime=$(date '+%Y%m%d-%H%M%S')
echo $curtime
echo "rank_id = ${start}"
rm ${current_exec_path}/device_$start/ -rf
mkdir ${current_exec_path}/device_$start
cd ${current_exec_path}/device_$start || exit
export RANK_ID=$start
export DEVICE_ID=$start
time python3 ${current_exec_path}/train.py \
--model tinynet_c \
--drop 0.2 \
--drop-connect 0 \
--num-classes 1000 \
--opt-eps 0.001 \
--lr 0.048 \
--batch-size 128 \
--decay-epochs 2.4 \
--warmup-lr 1e-6 \
--warmup-epochs 3 \
--decay-rate 0.97 \
--ema-decay 0.9999 \
--weight-decay 1e-5 \
--epochs 100\
--ckpt_save_epoch 1 \
--workers 8 \
--amp_level O0 \
--opt rmsprop \
--data_path /path_to_ImageNet/ \
--GPU \
--dataset_sink > tinynet_c.log 2>&1 &
cd ${current_exec_path} || exit

View File

@ -0,0 +1,82 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# below help function was adapted from
# https://unix.stackexchange.com/questions/31414/how-can-i-pass-a-command-line-argument-into-a-shell-script
helpFunction()
{
echo ""
echo "Usage: $0 -n num_device"
echo -e "\t-n how many gpus to use for training"
exit 1 # Exit script after printing help
}
while getopts "n:" opt
do
case "$opt" in
n ) num_device="$OPTARG" ;;
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
esac
done
# Print helpFunction in case parameters are empty
if [ -z "$num_device" ]
then
echo "Some or all of the parameters are empty";
helpFunction
fi
# Begin script in case all parameters are correct
echo "$num_device"
cd ../ || exit
current_exec_path=$(pwd)
echo ${current_exec_path}
export SLOG_PRINT_TO_STDOUT=0
export RANK_SIZE=$num_device
export curtime
curtime=$(date '+%Y%m%d-%H%M%S')
echo $curtime
echo $curtime >> starttime
rm ${current_exec_path}/device_parallel/ -rf
mkdir ${current_exec_path}/device_parallel
cd ${current_exec_path}/device_parallel || exit
echo $curtime >> starttime
time mpirun -n $RANK_SIZE --allow-run-as-root python3 ${current_exec_path}/train.py \
--model tinynet_c \
--drop 0.2 \
--drop-connect 0 \
--num-classes 1000 \
--opt-eps 0.001 \
--lr 0.048 \
--batch-size 128 \
--decay-epochs 2.4 \
--warmup-lr 1e-6 \
--warmup-epochs 3 \
--decay-rate 0.97 \
--ema-decay 0.9999 \
--weight-decay 1e-5 \
--per_print_times 100 \
--epochs 450 \
--ckpt_save_epoch 1 \
--workers 8 \
--amp_level O0 \
--opt rmsprop \
--distributed \
--data_path /path_to_ImageNet/ \
--GPU \
--dataset_sink > tinynet_c.log 2>&1 &

View File

@ -0,0 +1,203 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""custom callbacks for ema and loss"""
from copy import deepcopy
import numpy as np
from mindspore.train.callback import Callback
from mindspore.common.parameter import Parameter
from mindspore.train.serialization import save_checkpoint
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from mindspore.train.model import Model
from mindspore import Tensor
def load_nparray_into_net(net, array_dict):
"""
Loads dictionary of numpy arrays into network.
Args:
net (Cell): Cell network.
array_dict (dict): dictionary of numpy array format model weights.
"""
param_not_load = []
for _, param in net.parameters_and_names():
if param.name in array_dict:
new_param = array_dict[param.name]
param.set_data(Parameter(new_param.copy(), name=param.name))
else:
param_not_load.append(param.name)
return param_not_load
class EmaEvalCallBack(Callback):
"""
Call back that will evaluate the model and save model checkpoint at
the end of training epoch.
Args:
model: Mindspore model instance.
ema_network: step-wise exponential moving average for ema_network.
eval_dataset: the evaluation daatset.
decay (float): ema decay.
save_epoch (int): defines how often to save checkpoint.
dataset_sink_mode (bool): whether to use data sink mode.
start_epoch (int): which epoch to start/resume training.
"""
def __init__(self, model, ema_network, eval_dataset, loss_fn, decay=0.999,
save_epoch=1, dataset_sink_mode=True, start_epoch=0):
self.model = model
self.ema_network = ema_network
self.eval_dataset = eval_dataset
self.loss_fn = loss_fn
self.decay = decay
self.save_epoch = save_epoch
self.shadow = {}
self.ema_accuracy = {}
self.best_ema_accuracy = 0
self.best_accuracy = 0
self.best_ema_epoch = 0
self.best_epoch = 0
self._start_epoch = start_epoch
self.eval_metrics = {'Validation-Loss': Loss(),
'Top1-Acc': Top1CategoricalAccuracy(),
'Top5-Acc': Top5CategoricalAccuracy()}
self.dataset_sink_mode = dataset_sink_mode
def begin(self, run_context):
"""Initialize the EMA parameters """
cb_params = run_context.original_args()
for _, param in cb_params.network.parameters_and_names():
self.shadow[param.name] = deepcopy(param.data.asnumpy())
def step_end(self, run_context):
"""Update the EMA parameters"""
cb_params = run_context.original_args()
for _, param in cb_params.network.parameters_and_names():
new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \
self.decay * self.shadow[param.name]
self.shadow[param.name] = new_average
def epoch_end(self, run_context):
"""evaluate the model and ema-model at the end of each epoch"""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1
save_ckpt = (cur_epoch % self.save_epoch == 0)
acc = self.model.eval(
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
print("Model Accuracy:", acc)
load_nparray_into_net(self.ema_network, self.shadow)
self.ema_network.set_train(False)
model_ema = Model(self.ema_network, loss_fn=self.loss_fn,
metrics=self.eval_metrics)
ema_acc = model_ema.eval(
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
print("EMA-Model Accuracy:", ema_acc)
self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"]
output = [{"name": k, "data": Tensor(v)}
for k, v in self.shadow.items()]
if self.best_ema_accuracy < ema_acc["Top1-Acc"]:
self.best_ema_accuracy = ema_acc["Top1-Acc"]
self.best_ema_epoch = cur_epoch
save_checkpoint(output, "ema_best.ckpt")
if self.best_accuracy < acc["Top1-Acc"]:
self.best_accuracy = acc["Top1-Acc"]
self.best_epoch = cur_epoch
print("Best Model Accuracy: %s, at epoch %s" %
(self.best_accuracy, self.best_epoch))
print("Best EMA-Model Accuracy: %s, at epoch %s" %
(self.best_ema_accuracy, self.best_ema_epoch))
if save_ckpt:
# Save the ema_model checkpoints
ckpt = "{}-{}.ckpt".format("ema", cur_epoch)
save_checkpoint(output, ckpt)
save_checkpoint(output, "ema_last.ckpt")
# Save the model checkpoints
save_checkpoint(cb_params.train_network, "last.ckpt")
print("Top 10 EMA-Model Accuracies: ")
count = 0
for epoch in sorted(self.ema_accuracy, key=self.ema_accuracy.get,
reverse=True):
if count == 10:
break
print("epoch: %s, Top-1: %s)" % (epoch, self.ema_accuracy[epoch]))
count += 1
class LossMonitor(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF, it will terminate training.
Note:
If per_print_times is 0, do not print loss.
Args:
lr_array (numpy.array): scheduled learning rate.
total_epochs (int): Total number of epochs for training.
per_print_times (int): Print the loss every time. Default: 1.
start_epoch (int): which epoch to start, used when resume from a
certain epoch.
Raises:
ValueError: If print_step is not an integer or less than zero.
"""
def __init__(self, lr_array, total_epochs, per_print_times=1, start_epoch=0):
super(LossMonitor, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self._lr_array = lr_array
self._total_epochs = total_epochs
self._start_epoch = start_epoch
def step_end(self, run_context):
"""log epoch, step, loss and learning rate"""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
cur_epoch_num = cb_params.cur_epoch_num + self._start_epoch - 1
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
global_step = cb_params.cur_step_num - 1
cur_step_in_epoch = global_step % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cur_step_in_epoch % self._per_print_times == 0:
print("epoch: %s/%s, step: %s/%s, loss is %s, learning rate: %s"
% (cur_epoch_num, self._total_epochs, cur_step_in_epoch,
cb_params.batch_num, loss, self._lr_array[global_step]),
flush=True)

View File

@ -0,0 +1,143 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py and eval.py"""
import math
import os
import numpy as np
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.dataset.transforms.py_transforms as py_transforms
import mindspore.dataset.transforms.c_transforms as c_transforms
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore.communication.management import get_rank, get_group_size
from mindspore.dataset.vision import Inter
# values that should remain constant
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# data preprocess configs
SCALE = (0.08, 1.0)
RATIO = (3./4., 4./3.)
ds.config.set_seed(1)
def split_imgs_and_labels(imgs, labels, batchInfo):
"""split data into labels and images"""
ret_imgs = []
ret_labels = []
for i, image in enumerate(imgs):
ret_imgs.append(image)
ret_labels.append(labels[i])
return np.array(ret_imgs), np.array(ret_labels)
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
input_size=224, color_jitter=0.4):
"""Creat ImageNet training dataset"""
if not os.path.exists(train_data_url):
raise ValueError('Path not exists')
decode_op = py_vision.Decode()
type_cast_op = c_transforms.TypeCast(mstype.int32)
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size),
scale=SCALE, ratio=RATIO,
interpolation=Inter.BICUBIC)
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
contrast=adjust_range,
saturation=adjust_range)
to_tensor = py_vision.ToTensor()
nromlize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
# assemble all the transforms
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
random_horizontal_flip_op, random_color_jitter_op, to_tensor, nromlize_op])
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset_train = ds.ImageFolderDataset(train_data_url,
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
dataset_train = dataset_train.map(input_columns=["image"],
operations=image_ops,
num_parallel_workers=workers)
dataset_train = dataset_train.map(input_columns=["label"],
operations=type_cast_op,
num_parallel_workers=workers)
# batch dealing
ds_train = dataset_train.batch(batch_size,
per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
ds_train = ds_train.repeat(1)
return ds_train
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False,
input_size=224):
"""Creat ImageNet validation dataset"""
if not os.path.exists(val_data_url):
raise ValueError('Path not exists')
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id)
scale_size = None
if isinstance(input_size, tuple):
assert len(input_size) == 2
if input_size[-1] == input_size[-2]:
scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
else:
scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
else:
scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))
type_cast_op = c_transforms.TypeCast(mstype.int32)
decode_op = py_vision.Decode()
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
center_crop = py_vision.CenterCrop(size=input_size)
to_tensor = py_vision.ToTensor()
nromlize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
to_tensor, nromlize_op])
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
num_parallel_workers=workers)
dataset = dataset.map(input_columns=["image"], operations=image_ops,
num_parallel_workers=workers)
dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
dataset = dataset.repeat(1)
return dataset

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
"""cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000):
super(LabelSmoothingCrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()
def construct(self, logits, label):
label = self.cast(label, mstype.int32)
one_hot_label = self.onehot(label, F.shape(
logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit

View File

@ -0,0 +1,808 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tinynet model definition"""
import math
import re
from copy import deepcopy
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform
from mindspore import context, ms_function
from mindspore.common.parameter import Parameter
# Imagenet constant values
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# model structure configurations for TinyNets, values are
# (resolution multiplier, channel multiplier, depth multiplier)
# only tinynet-c is availiable for now, we will release other tinynet
# models soon
# codes are inspired and partially adapted from
# https://github.com/rwightman/gen-efficientnet-pytorch
TINYNET_CFG = {"c": (0.825, 0.54, 0.85)}
relu = P.ReLU()
sigmoid = P.Sigmoid()
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'),
'efficientnet_b1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'efficientnet_b3': _cfg(
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'efficientnet_b4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
}
_DEBUG = False
# Default args for PyTorch BN impl
_BN_MOMENTUM_PT_DEFAULT = 0.1
_BN_EPS_PT_DEFAULT = 1e-5
_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
# Defaults used for Google/Tensorflow training of mobile networks /w
# RMSprop as per papers and TF reference implementations. PT momentum
# equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
_BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
def _initialize_weight_goog(shape=None, layer_type='conv', bias=False):
"""Google style weight initialization"""
if layer_type not in ('conv', 'bn', 'fc'):
raise ValueError(
'The layer type is not known, the supported are conv, bn and fc')
if bias:
return Zero()
if layer_type == 'conv':
assert isinstance(shape, (tuple, list)) and len(
shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively'
n = shape[1] * shape[1] * shape[2]
return Normal(math.sqrt(2.0 / n))
if layer_type == 'bn':
return One()
assert isinstance(shape, (tuple, list)) and len(
shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively'
n = shape[1]
init_range = 1.0 / math.sqrt(n)
return Uniform(init_range)
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0,
pad_mode='same', bias=False):
"""convolution wrapper"""
weight_init_value = _initialize_weight_goog(
shape=(in_channels, kernel_size, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias)
def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False):
"""1x1 convolution wrapper"""
weight_init_value = _initialize_weight_goog(
shape=(in_channels, 1, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
has_bias=bias)
def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0,
pad_mode='same', bias=False):
"""group convolution wrapper"""
weight_init_value = _initialize_weight_goog(
shape=(in_channels, kernel_size, out_channels))
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
group=group, has_bias=bias, bias_init=bias_init_value)
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
group=group, has_bias=bias)
def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0):
return nn.BatchNorm2d(channels, eps=eps, momentum=1-momentum, gamma_init=gamma_init,
beta_init=beta_init)
def _dense(in_channels, output_channels, bias=True, activation=None):
weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels),
layer_type='fc')
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
if bias:
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
bias_init=bias_init_value, has_bias=bias, activation=activation)
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
has_bias=bias, activation=activation)
def _resolve_bn_args(kwargs):
bn_args = _BN_ARGS_TF.copy() if kwargs.pop(
'bn_tf', False) else _BN_ARGS_PT.copy()
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
channel_min = channel_min or divisor
new_channels = max(
int(channels + divisor / 2) // divisor * divisor,
channel_min)
# Make sure that round down does not go down by more than 10%.
if new_channels < 0.9 * channels:
new_channels += divisor
return new_channels
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str, depth_multiplier=1.0):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
print('not support')
elif v == 'r6':
print('not support')
elif v == 'hs':
print('not support')
elif v == 'sw':
print('not support')
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
act_fn = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type in ('ds', 'dsa'):
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'er':
block_args = dict(
block_type=block_type,
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_fn=act_fn,
)
else:
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
"""further decode the architecture definition into model-ready format"""
arch_args = []
for _, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = _decode_block_str(block_str)
stack_args.append(ba)
repeats.append(rep)
arch_args.append(_scale_stage_depth(
stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args
class Swish(nn.Cell):
"""swish activation function"""
def __init__(self):
super(Swish, self).__init__()
self.sigmoid = P.Sigmoid()
def construct(self, x):
return x * self.sigmoid(x)
@ms_function
def swish(x):
return x * nn.Sigmoid()(x)
class BlockBuilder(nn.Cell):
"""build efficient-net convolution blocks"""
def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0,
channel_divisor=8, channel_min=None, pad_type='', act_fn=None,
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
drop_connect_rate=0., verbose=False):
super(BlockBuilder, self).__init__()
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.pad_type = pad_type
self.act_fn = Swish()
self.se_gate_fn = se_gate_fn
self.se_reduce_mid = se_reduce_mid
self.bn_args = bn_args
self.drop_connect_rate = drop_connect_rate
self.verbose = verbose
# updated during build
self.in_chs = None
self.block_idx = 0
self.block_count = 0
self.layer = self._make_layer(builder_in_channels, builder_block_args)
def _round_channels(self, chs):
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba):
"""make the current block based on the block argument"""
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# this is a hack to work around mismatch in origin impl input filters
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['bn_args'] = self.bn_args
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None
if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate * \
self.block_idx / self.block_count
ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid
block = InvertedResidual(**ba)
elif bt in ('ds', 'dsa'):
ba['drop_connect_rate'] = self.drop_connect_rate * \
self.block_idx / self.block_count
block = DepthwiseSeparableConv(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs']
return block
def _make_stack(self, stack_args):
"""make a stack of blocks"""
blocks = []
# each stack (stage) contains a list of block arguments
for i, ba in enumerate(stack_args):
if i >= 1:
# only the first block in any stack can have a stride > 1
ba['stride'] = 1
block = self._make_block(ba)
blocks.append(block)
self.block_idx += 1 # incr global idx (across all stacks)
return nn.SequentialCell(blocks)
def _make_layer(self, in_chs, block_args):
""" Build the entire layer
Args:
in_chs: Number of input-channels passed to first block
block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
self.in_chs = in_chs
self.block_count = sum([len(x) for x in block_args])
self.block_idx = 0
blocks = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for _, stack in enumerate(block_args):
assert isinstance(stack, list)
stack = self._make_stack(stack)
blocks.append(stack)
return nn.SequentialCell(blocks)
def construct(self, x):
return self.layer(x)
class DepthWiseConv(nn.Cell):
"""depth-wise convolution"""
def __init__(self, in_planes, kernel_size, stride):
super(DepthWiseConv, self).__init__()
platform = context.get_context("device_target")
weight_shape = [1, kernel_size, in_planes]
weight_init = _initialize_weight_goog(shape=weight_shape)
if platform == "GPU":
self.depthwise_conv = P.Conv2D(out_channel=in_planes*1,
kernel_size=kernel_size,
stride=stride,
pad=int(kernel_size/2),
pad_mode="pad",
group=in_planes)
self.weight = Parameter(initializer(weight_init,
[in_planes*1, 1, kernel_size, kernel_size]), name='depthwise_weight')
else:
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=kernel_size,
stride=stride, pad_mode='pad',
pad=int(kernel_size/2))
self.weight = Parameter(initializer(weight_init,
[1, in_planes, kernel_size, kernel_size]), name='depthwise_weight')
def construct(self, x):
x = self.depthwise_conv(x, self.weight)
return x
class DropConnect(nn.Cell):
"""drop connect implementation"""
def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
super(DropConnect, self).__init__()
self.shape = P.Shape()
self.dtype = P.DType()
self.keep_prob = 1 - drop_connect_rate
self.dropout = P.Dropout(keep_prob=self.keep_prob)
def construct(self, x):
shape = self.shape(x)
dtype = self.dtype(x)
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
_, mask_ = self.dropout(ones_tensor)
x = x * mask_
return x
def drop_connect(inputs, training=False, drop_connect_rate=0.):
if not training:
return inputs
return DropConnect(drop_connect_rate)(inputs)
class SqueezeExcite(nn.Cell):
"""squeeze-excite implementation"""
def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid):
super(SqueezeExcite, self).__init__()
self.act_fn = Swish()
self.gate_fn = gate_fn
reduce_chs = reduce_chs or in_chs
self.conv_reduce = nn.Conv2d(in_channels=in_chs, out_channels=reduce_chs,
kernel_size=1, has_bias=True, pad_mode='pad')
self.conv_expand = nn.Conv2d(in_channels=reduce_chs, out_channels=in_chs,
kernel_size=1, has_bias=True, pad_mode='pad')
self.avg_global_pool = P.ReduceMean(keep_dims=True)
def construct(self, x):
x_se = self.avg_global_pool(x, (2, 3))
x_se = self.conv_reduce(x_se)
x_se = self.act_fn(x_se)
x_se = self.conv_expand(x_se)
x_se = self.gate_fn(x_se)
x = x * x_se
return x
class DepthwiseSeparableConv(nn.Cell):
"""depth-wise convolution -> (squeeze-excite) -> point-wise convolution"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_fn=relu, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid,
bn_args=None, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2], 'stride must be 1 or 2'
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act
self.act_fn = Swish()
self.drop_connect_rate = drop_connect_rate
self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
self.bn1 = _fused_bn(in_chs, **bn_args)
if self.has_se:
self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)),
act_fn=act_fn, gate_fn=se_gate_fn)
self.conv_pw = _conv1x1(in_chs, out_chs)
self.bn2 = _fused_bn(out_chs, **bn_args)
def construct(self, x):
"""forward the depthwise separable conv"""
identity = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act_fn(x)
if self.has_se:
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
if self.has_pw_act:
x = self.act_fn(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x = x + identity
return x
class InvertedResidual(nn.Cell):
"""inverted-residual block implementation"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1,
pad_type='', act_fn=relu, pw_kernel_size=1,
noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0.,
se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None,
bn_args=None, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.act_fn = Swish()
self.drop_connect_rate = drop_connect_rate
self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size)
self.bn1 = _fused_bn(mid_chs, **bn_args)
self.shuffle_type = shuffle_type
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
self.shuffle = None
self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
self.bn2 = _fused_bn(mid_chs, **bn_args)
if self.has_se:
se_base_chs = mid_chs if se_reduce_mid else in_chs
self.se = SqueezeExcite(
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)),
act_fn=act_fn, gate_fn=se_gate_fn
)
self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size)
self.bn3 = _fused_bn(out_chs, **bn_args)
def construct(self, x):
"""forward the inverted-residual block"""
identity = x
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act_fn(x)
x = self.conv_dw(x)
x = self.bn2(x)
x = self.act_fn(x)
if self.has_se:
x = self.se(x)
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0:
x = drop_connect(x, self.training, self.drop_connect_rate)
x = x + identity
return x
class GenEfficientNet(nn.Cell):
"""Generate EfficientNet architecture"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0.,
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
global_pool='avg', head_conv='default', weight_init='goog'):
super(GenEfficientNet, self).__init__()
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
self.num_classes = num_classes
self.drop_rate = drop_rate
self.num_features = num_features
self.conv_stem = _conv(in_chans, stem_size, 3,
stride=2, padding=1, pad_mode='pad')
self.bn1 = _fused_bn(stem_size, **bn_args)
self.act_fn = Swish()
in_chans = stem_size
self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier,
channel_divisor, channel_min,
pad_type, act_fn, se_gate_fn, se_reduce_mid,
bn_args, drop_connect_rate, verbose=_DEBUG)
in_chs = self.blocks.in_chs
if not head_conv or head_conv == 'none':
self.efficient_head = False
self.conv_head = None
assert in_chs == self.num_features
else:
self.efficient_head = head_conv == 'efficient'
self.conv_head = _conv1x1(in_chs, self.num_features)
self.bn2 = None if self.efficient_head else _fused_bn(
self.num_features, **bn_args)
self.global_pool = P.ReduceMean(keep_dims=True)
self.classifier = _dense(self.num_features, self.num_classes)
self.reshape = P.Reshape()
self.shape = P.Shape()
self.drop_out = nn.Dropout(keep_prob=1-self.drop_rate)
def construct(self, x):
"""efficient net entry point"""
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act_fn(x)
x = self.blocks(x)
if self.efficient_head:
x = self.global_pool(x, (2, 3))
x = self.conv_head(x)
x = self.act_fn(x)
x = self.reshape(self.shape(x)[0], -1)
else:
if self.conv_head is not None:
x = self.conv_head(x)
x = self.bn2(x)
x = self.act_fn(x)
x = self.global_pool(x, (2, 3))
x = self.reshape(x, (self.shape(x)[0], -1))
if self.training and self.drop_rate > 0.:
x = self.drop_out(x)
return self.classifier(x)
def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
Args:
channel_multiplier (int): multiplier to number of channels per layer
depth_multiplier (int): multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25'],
['ir_r2_k3_s2_e6_c24_se0.25'],
['ir_r2_k5_s2_e6_c40_se0.25'],
['ir_r3_k3_s2_e6_c80_se0.25'],
['ir_r3_k5_s1_e6_c112_se0.25'],
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
num_features = max(1280, _round_channels(
1280, channel_multiplier, 8, None))
model = GenEfficientNet(
_decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
num_classes=num_classes,
stem_size=32,
channel_multiplier=channel_multiplier,
num_features=num_features,
bn_args=_resolve_bn_args(kwargs),
act_fn=Swish,
**kwargs)
return model
def tinynet(sub_model="c", num_classes=1000, in_chans=3, **kwargs):
""" TinyNet Models """
# choose a sub model
r, w, d = TINYNET_CFG[sub_model]
default_cfg = default_cfgs['efficientnet_b0']
assert default_cfg['input_size'] == (3, 224, 224), "All tinynet models are \
evolved from Efficient-B0, which has input dimension of 3*224*224"
channel, height, width = default_cfg['input_size']
height = int(r * height)
width = int(r * width)
default_cfg['input_size'] = (channel, height, width)
print("Data processing configuration for current model + dataset:")
print("input_size:", default_cfg['input_size'])
print("channel mutiplier:%s, depth multiplier:%s, resolution multiplier:%s" % (w, d, r))
model = _gen_efficientnet(
channel_multiplier=w, depth_multiplier=d,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
return model

View File

@ -0,0 +1,89 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""model utils"""
import math
import argparse
import numpy as np
def str2bool(value):
"""Convert string arguments to bool type"""
if value.lower() in ('yes', 'true', 't', 'y', '1'):
return True
if value.lower() in ('no', 'false', 'f', 'n', '0'):
return False
raise argparse.ArgumentTypeError('Boolean value expected.')
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_epochs=1, decay_rate=0.9,
warmup_epochs=0., warmup_lr_init=0., global_epoch=0):
"""Get scheduled learning rate"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
global_steps = steps_per_epoch * global_epoch
self_warmup_delta = ((base_lr - warmup_lr_init) / \
warmup_epochs) if warmup_epochs > 0 else 0
self_decay_rate = decay_rate if decay_rate < 1 else 1/decay_rate
for i in range(total_steps):
epochs = math.floor(i/steps_per_epoch)
cond = 1 if (epochs < warmup_epochs) else 0
warmup_lr = warmup_lr_init + epochs * self_warmup_delta
decay_nums = math.floor(epochs / decay_epochs)
decay_rate = math.pow(self_decay_rate, decay_nums)
decay_lr = base_lr * decay_rate
lr = cond * warmup_lr + (1 - cond) * decay_lr
lr_each_step.append(lr)
lr_each_step = lr_each_step[global_steps:]
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def add_weight_decay(net, weight_decay=1e-5, skip_list=None):
"""Apply weight decay to only conv and dense layers (len(shape) > =2)
Args:
net (mindspore.nn.Cell): Mindspore network instance
weight_decay (float): weight decay tobe used.
skip_list (tuple): list of parameter names without weight decay
Returns:
A list of group of parameters, separated by different weight decay.
"""
decay = []
no_decay = []
if not skip_list:
skip_list = ()
for param in net.trainable_params():
if len(param.shape) == 1 or \
param.name.endswith(".bias") or \
param.name in skip_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
def count_params(net):
"""Count number of parameters in the network
Args:
net (mindspore.nn.Cell): Mindspore network instance
Returns:
total_params (int): Total number of trainable params
"""
total_params = 0
for param in net.trainable_params():
total_params += np.prod(param.shape)
return total_params

View File

@ -0,0 +1,250 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Training Interface"""
import sys
import os
import argparse
import copy
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.model import ParallelMode, Model
from mindspore.train.callback import TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.nn import SGD, RMSProp, Loss, Top1CategoricalAccuracy, \
Top5CategoricalAccuracy
from mindspore import context, Tensor
from src.dataset import create_dataset, create_dataset_val
from src.utils import add_weight_decay, count_params, str2bool, get_lr
from src.callback import EmaEvalCallBack, LossMonitor
from src.loss import LabelSmoothingCrossEntropy
from src.tinynet import tinynet
parser = argparse.ArgumentParser(description='Training')
# training parameters
parser.add_argument('--data_path', type=str, default="", metavar="DIR",
help='path to dataset')
parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL',
help='Name of model to train (default: "tinynet_c"')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
help='Drop connect rate (default: 0.)')
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
parser.add_argument('--ema-decay', type=float, default=0,
help='decay factor for model weights moving average \
(default: 0.999)')
parser.add_argument('--amp_level', type=str, default='O0')
parser.add_argument('--per_print_times', type=int, default=100)
# batch norm parameters
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that \
support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
# parallel parameters
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--distributed', action='store_true', default=False)
parser.add_argument('--dataset_sink', action='store_true', default=True)
# checkpoint config
parser.add_argument('--ckpt', type=str, default=None)
parser.add_argument('--ckpt_save_epoch', type=int, default=1)
parser.add_argument('--loss_scale', type=int,
default=1024, help='static loss scale')
parser.add_argument('--train', type=str2bool, default=1, help='train or eval')
parser.add_argument('--GPU', action='store_true', default=False,
help='Use GPU for training (default: False)')
def main():
"""Main entrance for training"""
args = parser.parse_args()
print(sys.argv)
devid, args.rank_id, args.rank_size = 0, 0, 1
context.set_context(mode=context.GRAPH_MODE)
if args.distributed:
if args.GPU:
init("nccl")
context.set_context(device_target='GPU')
else:
init()
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_target='Ascend',
device_id=devid,
reserve_class_name_in_scope=False)
context.reset_auto_parallel_context()
args.rank_id = get_rank()
args.rank_size = get_group_size()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
device_num=args.rank_size)
else:
if args.GPU:
context.set_context(device_target='GPU')
is_master = not args.distributed or (args.rank_id == 0)
# parse model argument
assert args.model.startswith(
"tinynet"), "Only Tinynet models are supported."
_, sub_name = args.model.split("_")
net = tinynet(sub_model=sub_name,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect,
global_pool="avg",
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps)
if is_master:
print("Total number of parameters:", count_params(net))
# input image size of the network
input_size = net.default_cfg['input_size'][1]
train_dataset = val_dataset = None
train_data_url = os.path.join(args.data_path, 'train')
val_data_url = os.path.join(args.data_path, 'val')
val_dataset = create_dataset_val(args.batch_size,
val_data_url,
workers=args.workers,
distributed=False,
input_size=input_size)
if args.train:
train_dataset = create_dataset(args.batch_size,
train_data_url,
workers=args.workers,
distributed=args.distributed,
input_size=input_size)
batches_per_epoch = train_dataset.get_dataset_size()
loss = LabelSmoothingCrossEntropy(
smooth_factor=args.smoothing, num_classes=args.num_classes)
time_cb = TimeMonitor(data_size=batches_per_epoch)
loss_scale_manager = FixedLossScaleManager(
args.loss_scale, drop_overflow_update=False)
lr_array = get_lr(base_lr=args.lr,
total_epochs=args.epochs,
steps_per_epoch=batches_per_epoch,
decay_epochs=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_epochs=args.warmup_epochs,
warmup_lr_init=args.warmup_lr,
global_epoch=0)
lr = Tensor(lr_array)
loss_cb = LossMonitor(lr_array,
args.epochs,
per_print_times=args.per_print_times,
start_epoch=0)
param_group = add_weight_decay(net, weight_decay=args.weight_decay)
if args.opt == 'sgd':
if is_master:
print('Using SGD optimizer')
optimizer = SGD(param_group,
learning_rate=lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
elif args.opt == 'rmsprop':
if is_master:
print('Using rmsprop optimizer')
optimizer = RMSProp(param_group,
learning_rate=lr,
decay=0.9,
weight_decay=args.weight_decay,
momentum=args.momentum,
epsilon=args.opt_eps,
loss_scale=args.loss_scale)
loss.add_flags_recursive(fp32=True, fp16=False)
eval_metrics = {'Validation-Loss': Loss(),
'Top1-Acc': Top1CategoricalAccuracy(),
'Top5-Acc': Top5CategoricalAccuracy()}
if args.ckpt:
ckpt = load_checkpoint(args.ckpt)
load_param_into_net(net, ckpt)
net.set_train(False)
model = Model(net, loss, optimizer, metrics=eval_metrics,
loss_scale_manager=loss_scale_manager,
amp_level=args.amp_level)
net_ema = copy.deepcopy(net)
net_ema.set_train(False)
assert args.ema_decay > 0, "EMA should be used in tinynet training."
ema_cb = EmaEvalCallBack(model=model,
ema_network=net_ema,
loss_fn=loss,
eval_dataset=val_dataset,
decay=args.ema_decay,
save_epoch=args.ckpt_save_epoch,
dataset_sink_mode=args.dataset_sink,
start_epoch=0)
callbacks = [loss_cb, ema_cb, time_cb] if is_master else []
if is_master:
print("Training on " + args.model
+ " with " + str(args.num_classes) + " classes")
model.train(args.epochs, train_dataset, callbacks=callbacks,
dataset_sink_mode=args.dataset_sink)
if __name__ == '__main__':
main()