Adding TinyNet to Model Zoo
Adding TinyNet (https://arxiv.org/abs/2010.14819) MindSpore implementation to model Zoo
This commit is contained in:
parent
fc5a3b7d97
commit
882301f4b5
|
@ -60,6 +60,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md)
|
||||
- [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md)
|
||||
- [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md)
|
||||
- [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md)
|
||||
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
|
||||
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
|
||||
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
# Contents
|
||||
|
||||
- [TinyNet Description](#tinynet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [TinyNet Description](#contents)
|
||||
|
||||
TinyNets are a series of lightweight models obtained by twisting resolution, depth and width with a data-driven tiny formula. TinyNet outperforms EfficientNet and MobileNetV3.
|
||||
|
||||
[Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020.
|
||||
|
||||
Note: We have only released TinyNet-C for now, and will release other TinyNets soon.
|
||||
# [Model architecture](#contents)
|
||||
|
||||
The overall network architecture of TinyNet is show below:
|
||||
|
||||
[Link](https://arxiv.org/abs/2010.14819)
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/)
|
||||
|
||||
- Dataset size:
|
||||
- Train: 1.2 million images in 1,000 classes
|
||||
- Test: 50,000 validation images in 1,000 classes
|
||||
- Data format: RGB images.
|
||||
- Note: Data will be processed in src/dataset/dataset.py
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware (GPU)
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
||||
```
|
||||
.tinynet
|
||||
├── Readme.md # descriptions about tinynet
|
||||
├── script
|
||||
│ ├── eval.sh # evaluation script
|
||||
│ ├── train_1p_gpu.sh # training script on single GPU
|
||||
│ └── train_distributed_gpu.sh # distributed training script on multiple GPUs
|
||||
├── src
|
||||
│ ├── callback.py # loss and checkpoint callbacks
|
||||
│ ├── dataset.py # data processing
|
||||
│ ├── loss.py # label-smoothing cross-entropy loss function
|
||||
│ ├── tinynet.py # tinynet architecture
|
||||
│ └── utils.py # utility functions
|
||||
├── eval.py # evaluation interface
|
||||
└── train.py # training interface
|
||||
```
|
||||
## [Training process](#contents)
|
||||
|
||||
### Launch
|
||||
|
||||
```
|
||||
# training on single GPU
|
||||
sh train_1p_gpu.sh
|
||||
# training on multiple GPUs, the number after -n indicates how many GPUs will be used for training
|
||||
sh train_distributed_gpu.sh -n 8
|
||||
```
|
||||
Inside train.sh, there are hyperparameters that can be adjusted during training, for example:
|
||||
```
|
||||
--model tinynet_c model to be used for training
|
||||
--drop 0.2 dropout rate
|
||||
--drop-connect 0 drop connect rate
|
||||
--num-classes 1000 number of classes for training
|
||||
--opt-eps 0.001 optimizer's epsilon
|
||||
--lr 0.048 learning rate
|
||||
--batch-size 128 batch size
|
||||
--decay-epochs 2.4 learning rate decays every 2.4 epoch
|
||||
--warmup-lr 1e-6 warm up learning rate
|
||||
--warmup-epochs 3 learning rate warm up epoch
|
||||
--decay-rate 0.97 learning rate decay rate
|
||||
--ema-decay 0.9999 decay factor for model weights moving average
|
||||
--weight-decay 1e-5 optimizer's weight decay
|
||||
--epochs 450 number of epochs to be trained
|
||||
--ckpt_save_epoch 1 checkpoint saving interval
|
||||
--workers 8 number of processes for loading data
|
||||
--amp_level O0 training auto-mixed precision
|
||||
--opt rmsprop optimizers, currently we support SGD and RMSProp
|
||||
--data_path /path_to_ImageNet/
|
||||
--GPU using GPU for training
|
||||
--dataset_sink using sink mode
|
||||
```
|
||||
The config above was used to train tinynets on ImageNet (change drop-connect to 0.2 for training tinynet-b)
|
||||
|
||||
> checkpoints will be saved in the ./device_{rank_id} folder (single GPU)
|
||||
or ./device_parallel folder (multiple GPUs)
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Launch
|
||||
|
||||
```
|
||||
# infer example
|
||||
|
||||
sh eval.sh
|
||||
```
|
||||
Inside the eval.sh, there are configs that can be adjusted during inference, for example:
|
||||
```
|
||||
--num-classes 1000
|
||||
--batch-size 128
|
||||
--workers 8
|
||||
--data_path /path_to_ImageNet/
|
||||
--GPU
|
||||
--ckpt /path_to_EMA_checkpoint/
|
||||
--dataset_sink > tinynet_c_eval.log 2>&1 &
|
||||
```
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
#### Evaluation Performance
|
||||
|
||||
| Model | FLOPs | Latency* | ImageNet Top-1 |
|
||||
| ------------------- | ----- | -------- | -------------- |
|
||||
| EfficientNet-B0 | 387M | 99.85 ms | 76.7% |
|
||||
| TinyNet-A | 339M | 81.30 ms | 76.8% |
|
||||
| EfficientNet-B^{-4} | 24M | 11.54 ms | 56.7% |
|
||||
| TinyNet-E | 24M | 9.18 ms | 59.9% |
|
||||
|
||||
*Latency is measured using MS Lite on Huawei P40 smartphone.
|
||||
|
||||
*More details in [Paper](https://arxiv.org/abs/2010.14819).
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We set the seed inside dataset.py. We also use random seed in train.py.
|
||||
|
||||
# [Model Zoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Inference Interface"""
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
|
||||
from mindspore import context
|
||||
|
||||
from src.dataset import create_dataset_val
|
||||
from src.utils import count_params
|
||||
from src.loss import LabelSmoothingCrossEntropy
|
||||
from src.tinynet import tinynet
|
||||
|
||||
parser = argparse.ArgumentParser(description='Evaluation')
|
||||
parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/',
|
||||
metavar='DIR', help='path to dataset')
|
||||
parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "tinynet_c"')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='label smoothing (default: 0.1)')
|
||||
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||
help='how many training processes to use (default: 1)')
|
||||
parser.add_argument('--ckpt', type=str, default=None,
|
||||
help='model checkpoint to load')
|
||||
parser.add_argument('--GPU', action='store_true', default=True,
|
||||
help='Use GPU for training (default: True)')
|
||||
parser.add_argument('--dataset_sink', action='store_true', default=True)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entrance for training"""
|
||||
args = parser.parse_args()
|
||||
print(sys.argv)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
if args.GPU:
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
# parse model argument
|
||||
assert args.model.startswith(
|
||||
"tinynet"), "Only Tinynet models are supported."
|
||||
_, sub_name = args.model.split("_")
|
||||
net = tinynet(sub_model=sub_name,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=0.0,
|
||||
drop_connect_rate=0.0,
|
||||
global_pool="avg",
|
||||
bn_tf=False,
|
||||
bn_momentum=None,
|
||||
bn_eps=None)
|
||||
print("Total number of parameters:", count_params(net))
|
||||
|
||||
input_size = net.default_cfg['input_size'][1]
|
||||
val_data_url = os.path.join(args.data_path, 'val')
|
||||
val_dataset = create_dataset_val(args.batch_size,
|
||||
val_data_url,
|
||||
workers=args.workers,
|
||||
distributed=False,
|
||||
input_size=input_size)
|
||||
|
||||
loss = LabelSmoothingCrossEntropy(smooth_factor=args.smoothing,
|
||||
num_classes=args.num_classes)
|
||||
|
||||
loss.add_flags_recursive(fp32=True, fp16=False)
|
||||
eval_metrics = {'Validation-Loss': Loss(),
|
||||
'Top1-Acc': Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': Top5CategoricalAccuracy()}
|
||||
|
||||
ckpt = load_checkpoint(args.ckpt)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
|
||||
model = Model(net, loss, metrics=eval_metrics)
|
||||
|
||||
metrics = model.eval(val_dataset, dataset_sink_mode=False)
|
||||
print(metrics)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
cd ../ || exit
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
export RANK_SIZE=1
|
||||
export start=0
|
||||
export value=$((start + RANK_SIZE))
|
||||
export curtime
|
||||
curtime=$(date '+%Y%m%d-%H%M%S')
|
||||
echo "$curtime"
|
||||
|
||||
rm ${current_exec_path}/device${start}_$curtime/ -rf
|
||||
mkdir ${current_exec_path}/device${start}_$curtime
|
||||
cd ${current_exec_path}/device${start}_$curtime || exit
|
||||
|
||||
export RANK_ID=start
|
||||
export DEVICE_ID=start
|
||||
time python3 ${current_exec_path}/eval.py \
|
||||
--model tinynet_c \
|
||||
--num-classes 1000 \
|
||||
--batch-size 128 \
|
||||
--workers 8 \
|
||||
--data_path /path_to_ImageNet/\
|
||||
--GPU \
|
||||
--ckpt /path_to_ckpt/ \
|
||||
--dataset_sink > tinynet_c_eval.log 2>&1 &
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
cd ../ || exit
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
export RANK_SIZE=1
|
||||
export start=0
|
||||
export value=$(($start+$RANK_SIZE))
|
||||
export curtime
|
||||
curtime=$(date '+%Y%m%d-%H%M%S')
|
||||
|
||||
echo $curtime
|
||||
echo "rank_id = ${start}"
|
||||
rm ${current_exec_path}/device_$start/ -rf
|
||||
mkdir ${current_exec_path}/device_$start
|
||||
cd ${current_exec_path}/device_$start || exit
|
||||
export RANK_ID=$start
|
||||
export DEVICE_ID=$start
|
||||
|
||||
time python3 ${current_exec_path}/train.py \
|
||||
--model tinynet_c \
|
||||
--drop 0.2 \
|
||||
--drop-connect 0 \
|
||||
--num-classes 1000 \
|
||||
--opt-eps 0.001 \
|
||||
--lr 0.048 \
|
||||
--batch-size 128 \
|
||||
--decay-epochs 2.4 \
|
||||
--warmup-lr 1e-6 \
|
||||
--warmup-epochs 3 \
|
||||
--decay-rate 0.97 \
|
||||
--ema-decay 0.9999 \
|
||||
--weight-decay 1e-5 \
|
||||
--epochs 100\
|
||||
--ckpt_save_epoch 1 \
|
||||
--workers 8 \
|
||||
--amp_level O0 \
|
||||
--opt rmsprop \
|
||||
--data_path /path_to_ImageNet/ \
|
||||
--GPU \
|
||||
--dataset_sink > tinynet_c.log 2>&1 &
|
||||
|
||||
|
||||
cd ${current_exec_path} || exit
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# below help function was adapted from
|
||||
# https://unix.stackexchange.com/questions/31414/how-can-i-pass-a-command-line-argument-into-a-shell-script
|
||||
helpFunction()
|
||||
{
|
||||
echo ""
|
||||
echo "Usage: $0 -n num_device"
|
||||
echo -e "\t-n how many gpus to use for training"
|
||||
exit 1 # Exit script after printing help
|
||||
}
|
||||
|
||||
while getopts "n:" opt
|
||||
do
|
||||
case "$opt" in
|
||||
n ) num_device="$OPTARG" ;;
|
||||
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
|
||||
esac
|
||||
done
|
||||
|
||||
# Print helpFunction in case parameters are empty
|
||||
if [ -z "$num_device" ]
|
||||
then
|
||||
echo "Some or all of the parameters are empty";
|
||||
helpFunction
|
||||
fi
|
||||
|
||||
# Begin script in case all parameters are correct
|
||||
echo "$num_device"
|
||||
cd ../ || exit
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
export SLOG_PRINT_TO_STDOUT=0
|
||||
export RANK_SIZE=$num_device
|
||||
export curtime
|
||||
curtime=$(date '+%Y%m%d-%H%M%S')
|
||||
echo $curtime
|
||||
echo $curtime >> starttime
|
||||
rm ${current_exec_path}/device_parallel/ -rf
|
||||
mkdir ${current_exec_path}/device_parallel
|
||||
cd ${current_exec_path}/device_parallel || exit
|
||||
echo $curtime >> starttime
|
||||
|
||||
time mpirun -n $RANK_SIZE --allow-run-as-root python3 ${current_exec_path}/train.py \
|
||||
--model tinynet_c \
|
||||
--drop 0.2 \
|
||||
--drop-connect 0 \
|
||||
--num-classes 1000 \
|
||||
--opt-eps 0.001 \
|
||||
--lr 0.048 \
|
||||
--batch-size 128 \
|
||||
--decay-epochs 2.4 \
|
||||
--warmup-lr 1e-6 \
|
||||
--warmup-epochs 3 \
|
||||
--decay-rate 0.97 \
|
||||
--ema-decay 0.9999 \
|
||||
--weight-decay 1e-5 \
|
||||
--per_print_times 100 \
|
||||
--epochs 450 \
|
||||
--ckpt_save_epoch 1 \
|
||||
--workers 8 \
|
||||
--amp_level O0 \
|
||||
--opt rmsprop \
|
||||
--distributed \
|
||||
--data_path /path_to_ImageNet/ \
|
||||
--GPU \
|
||||
--dataset_sink > tinynet_c.log 2>&1 &
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""custom callbacks for ema and loss"""
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
|
||||
from mindspore.train.model import Model
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def load_nparray_into_net(net, array_dict):
|
||||
"""
|
||||
Loads dictionary of numpy arrays into network.
|
||||
|
||||
Args:
|
||||
net (Cell): Cell network.
|
||||
array_dict (dict): dictionary of numpy array format model weights.
|
||||
"""
|
||||
param_not_load = []
|
||||
for _, param in net.parameters_and_names():
|
||||
if param.name in array_dict:
|
||||
new_param = array_dict[param.name]
|
||||
param.set_data(Parameter(new_param.copy(), name=param.name))
|
||||
else:
|
||||
param_not_load.append(param.name)
|
||||
return param_not_load
|
||||
|
||||
|
||||
class EmaEvalCallBack(Callback):
|
||||
"""
|
||||
Call back that will evaluate the model and save model checkpoint at
|
||||
the end of training epoch.
|
||||
|
||||
Args:
|
||||
model: Mindspore model instance.
|
||||
ema_network: step-wise exponential moving average for ema_network.
|
||||
eval_dataset: the evaluation daatset.
|
||||
decay (float): ema decay.
|
||||
save_epoch (int): defines how often to save checkpoint.
|
||||
dataset_sink_mode (bool): whether to use data sink mode.
|
||||
start_epoch (int): which epoch to start/resume training.
|
||||
"""
|
||||
|
||||
def __init__(self, model, ema_network, eval_dataset, loss_fn, decay=0.999,
|
||||
save_epoch=1, dataset_sink_mode=True, start_epoch=0):
|
||||
self.model = model
|
||||
self.ema_network = ema_network
|
||||
self.eval_dataset = eval_dataset
|
||||
self.loss_fn = loss_fn
|
||||
self.decay = decay
|
||||
self.save_epoch = save_epoch
|
||||
self.shadow = {}
|
||||
self.ema_accuracy = {}
|
||||
|
||||
self.best_ema_accuracy = 0
|
||||
self.best_accuracy = 0
|
||||
self.best_ema_epoch = 0
|
||||
self.best_epoch = 0
|
||||
self._start_epoch = start_epoch
|
||||
self.eval_metrics = {'Validation-Loss': Loss(),
|
||||
'Top1-Acc': Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': Top5CategoricalAccuracy()}
|
||||
self.dataset_sink_mode = dataset_sink_mode
|
||||
|
||||
def begin(self, run_context):
|
||||
"""Initialize the EMA parameters """
|
||||
cb_params = run_context.original_args()
|
||||
for _, param in cb_params.network.parameters_and_names():
|
||||
self.shadow[param.name] = deepcopy(param.data.asnumpy())
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Update the EMA parameters"""
|
||||
cb_params = run_context.original_args()
|
||||
for _, param in cb_params.network.parameters_and_names():
|
||||
new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \
|
||||
self.decay * self.shadow[param.name]
|
||||
self.shadow[param.name] = new_average
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""evaluate the model and ema-model at the end of each epoch"""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1
|
||||
|
||||
save_ckpt = (cur_epoch % self.save_epoch == 0)
|
||||
|
||||
acc = self.model.eval(
|
||||
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
|
||||
print("Model Accuracy:", acc)
|
||||
|
||||
load_nparray_into_net(self.ema_network, self.shadow)
|
||||
self.ema_network.set_train(False)
|
||||
|
||||
model_ema = Model(self.ema_network, loss_fn=self.loss_fn,
|
||||
metrics=self.eval_metrics)
|
||||
ema_acc = model_ema.eval(
|
||||
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
|
||||
|
||||
print("EMA-Model Accuracy:", ema_acc)
|
||||
self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"]
|
||||
output = [{"name": k, "data": Tensor(v)}
|
||||
for k, v in self.shadow.items()]
|
||||
|
||||
if self.best_ema_accuracy < ema_acc["Top1-Acc"]:
|
||||
self.best_ema_accuracy = ema_acc["Top1-Acc"]
|
||||
self.best_ema_epoch = cur_epoch
|
||||
save_checkpoint(output, "ema_best.ckpt")
|
||||
|
||||
if self.best_accuracy < acc["Top1-Acc"]:
|
||||
self.best_accuracy = acc["Top1-Acc"]
|
||||
self.best_epoch = cur_epoch
|
||||
|
||||
print("Best Model Accuracy: %s, at epoch %s" %
|
||||
(self.best_accuracy, self.best_epoch))
|
||||
print("Best EMA-Model Accuracy: %s, at epoch %s" %
|
||||
(self.best_ema_accuracy, self.best_ema_epoch))
|
||||
|
||||
if save_ckpt:
|
||||
# Save the ema_model checkpoints
|
||||
ckpt = "{}-{}.ckpt".format("ema", cur_epoch)
|
||||
save_checkpoint(output, ckpt)
|
||||
save_checkpoint(output, "ema_last.ckpt")
|
||||
|
||||
# Save the model checkpoints
|
||||
save_checkpoint(cb_params.train_network, "last.ckpt")
|
||||
|
||||
print("Top 10 EMA-Model Accuracies: ")
|
||||
count = 0
|
||||
for epoch in sorted(self.ema_accuracy, key=self.ema_accuracy.get,
|
||||
reverse=True):
|
||||
if count == 10:
|
||||
break
|
||||
print("epoch: %s, Top-1: %s)" % (epoch, self.ema_accuracy[epoch]))
|
||||
count += 1
|
||||
|
||||
|
||||
class LossMonitor(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF, it will terminate training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0, do not print loss.
|
||||
|
||||
Args:
|
||||
lr_array (numpy.array): scheduled learning rate.
|
||||
total_epochs (int): Total number of epochs for training.
|
||||
per_print_times (int): Print the loss every time. Default: 1.
|
||||
start_epoch (int): which epoch to start, used when resume from a
|
||||
certain epoch.
|
||||
|
||||
Raises:
|
||||
ValueError: If print_step is not an integer or less than zero.
|
||||
"""
|
||||
|
||||
def __init__(self, lr_array, total_epochs, per_print_times=1, start_epoch=0):
|
||||
super(LossMonitor, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self._lr_array = lr_array
|
||||
self._total_epochs = total_epochs
|
||||
self._start_epoch = start_epoch
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""log epoch, step, loss and learning rate"""
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
cur_epoch_num = cb_params.cur_epoch_num + self._start_epoch - 1
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
loss = loss[0]
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
global_step = cb_params.cur_step_num - 1
|
||||
cur_step_in_epoch = global_step % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
|
||||
cur_epoch_num, cur_step_in_epoch))
|
||||
|
||||
if self._per_print_times != 0 and cur_step_in_epoch % self._per_print_times == 0:
|
||||
print("epoch: %s/%s, step: %s/%s, loss is %s, learning rate: %s"
|
||||
% (cur_epoch_num, self._total_epochs, cur_step_in_epoch,
|
||||
cb_params.batch_num, loss, self._lr_array[global_step]),
|
||||
flush=True)
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data operations, will be used in train.py and eval.py"""
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||
import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from mindspore.dataset.vision import Inter
|
||||
|
||||
# values that should remain constant
|
||||
DEFAULT_CROP_PCT = 0.875
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
# data preprocess configs
|
||||
SCALE = (0.08, 1.0)
|
||||
RATIO = (3./4., 4./3.)
|
||||
|
||||
ds.config.set_seed(1)
|
||||
|
||||
|
||||
def split_imgs_and_labels(imgs, labels, batchInfo):
|
||||
"""split data into labels and images"""
|
||||
ret_imgs = []
|
||||
ret_labels = []
|
||||
|
||||
for i, image in enumerate(imgs):
|
||||
ret_imgs.append(image)
|
||||
ret_labels.append(labels[i])
|
||||
return np.array(ret_imgs), np.array(ret_labels)
|
||||
|
||||
|
||||
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
|
||||
input_size=224, color_jitter=0.4):
|
||||
"""Creat ImageNet training dataset"""
|
||||
if not os.path.exists(train_data_url):
|
||||
raise ValueError('Path not exists')
|
||||
decode_op = py_vision.Decode()
|
||||
type_cast_op = c_transforms.TypeCast(mstype.int32)
|
||||
|
||||
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size),
|
||||
scale=SCALE, ratio=RATIO,
|
||||
interpolation=Inter.BICUBIC)
|
||||
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
|
||||
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
|
||||
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
|
||||
contrast=adjust_range,
|
||||
saturation=adjust_range)
|
||||
to_tensor = py_vision.ToTensor()
|
||||
nromlize_op = py_vision.Normalize(
|
||||
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||
|
||||
# assemble all the transforms
|
||||
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
|
||||
random_horizontal_flip_op, random_color_jitter_op, to_tensor, nromlize_op])
|
||||
|
||||
rank_id = get_rank() if distributed else 0
|
||||
rank_size = get_group_size() if distributed else 1
|
||||
|
||||
dataset_train = ds.ImageFolderDataset(train_data_url,
|
||||
num_parallel_workers=workers,
|
||||
shuffle=True,
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id)
|
||||
|
||||
dataset_train = dataset_train.map(input_columns=["image"],
|
||||
operations=image_ops,
|
||||
num_parallel_workers=workers)
|
||||
|
||||
dataset_train = dataset_train.map(input_columns=["label"],
|
||||
operations=type_cast_op,
|
||||
num_parallel_workers=workers)
|
||||
|
||||
# batch dealing
|
||||
ds_train = dataset_train.batch(batch_size,
|
||||
per_batch_map=split_imgs_and_labels,
|
||||
input_columns=["image", "label"],
|
||||
num_parallel_workers=2,
|
||||
drop_remainder=True)
|
||||
|
||||
ds_train = ds_train.repeat(1)
|
||||
return ds_train
|
||||
|
||||
|
||||
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False,
|
||||
input_size=224):
|
||||
"""Creat ImageNet validation dataset"""
|
||||
if not os.path.exists(val_data_url):
|
||||
raise ValueError('Path not exists')
|
||||
rank_id = get_rank() if distributed else 0
|
||||
rank_size = get_group_size() if distributed else 1
|
||||
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
|
||||
num_shards=rank_size, shard_id=rank_id)
|
||||
scale_size = None
|
||||
|
||||
if isinstance(input_size, tuple):
|
||||
assert len(input_size) == 2
|
||||
if input_size[-1] == input_size[-2]:
|
||||
scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
|
||||
else:
|
||||
scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
|
||||
else:
|
||||
scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))
|
||||
|
||||
type_cast_op = c_transforms.TypeCast(mstype.int32)
|
||||
decode_op = py_vision.Decode()
|
||||
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
|
||||
center_crop = py_vision.CenterCrop(size=input_size)
|
||||
to_tensor = py_vision.ToTensor()
|
||||
nromlize_op = py_vision.Normalize(
|
||||
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||
|
||||
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
|
||||
to_tensor, nromlize_op])
|
||||
|
||||
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
|
||||
num_parallel_workers=workers)
|
||||
dataset = dataset.map(input_columns=["image"], operations=image_ops,
|
||||
num_parallel_workers=workers)
|
||||
dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels,
|
||||
input_columns=["image", "label"],
|
||||
num_parallel_workers=2,
|
||||
drop_remainder=True)
|
||||
dataset = dataset.repeat(1)
|
||||
return dataset
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network."""
|
||||
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class LabelSmoothingCrossEntropy(_Loss):
|
||||
"""cross-entropy with label smoothing"""
|
||||
|
||||
def __init__(self, smooth_factor=0.1, num_classes=1000):
|
||||
super(LabelSmoothingCrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor /
|
||||
(num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, logits, label):
|
||||
label = self.cast(label, mstype.int32)
|
||||
one_hot_label = self.onehot(label, F.shape(
|
||||
logits)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logits, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
return loss_logit
|
|
@ -0,0 +1,808 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tinynet model definition"""
|
||||
import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform
|
||||
from mindspore import context, ms_function
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
# Imagenet constant values
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
# model structure configurations for TinyNets, values are
|
||||
# (resolution multiplier, channel multiplier, depth multiplier)
|
||||
# only tinynet-c is availiable for now, we will release other tinynet
|
||||
# models soon
|
||||
# codes are inspired and partially adapted from
|
||||
# https://github.com/rwightman/gen-efficientnet-pytorch
|
||||
|
||||
TINYNET_CFG = {"c": (0.825, 0.54, 0.85)}
|
||||
|
||||
relu = P.ReLU()
|
||||
sigmoid = P.Sigmoid()
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'efficientnet_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'),
|
||||
'efficientnet_b1': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'efficientnet_b2': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'efficientnet_b3': _cfg(
|
||||
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'efficientnet_b4': _cfg(
|
||||
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
}
|
||||
|
||||
_DEBUG = False
|
||||
|
||||
# Default args for PyTorch BN impl
|
||||
_BN_MOMENTUM_PT_DEFAULT = 0.1
|
||||
_BN_EPS_PT_DEFAULT = 1e-5
|
||||
_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w
|
||||
# RMSprop as per papers and TF reference implementations. PT momentum
|
||||
# equiv for TF decay is (1 - TF decay)
|
||||
# NOTE: momentum varies btw .99 and .9997 depending on source
|
||||
# .99 in official TF TPU impl
|
||||
# .9997 (/w .999 in search space) for paper
|
||||
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
_BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
|
||||
|
||||
|
||||
def _initialize_weight_goog(shape=None, layer_type='conv', bias=False):
|
||||
"""Google style weight initialization"""
|
||||
if layer_type not in ('conv', 'bn', 'fc'):
|
||||
raise ValueError(
|
||||
'The layer type is not known, the supported are conv, bn and fc')
|
||||
if bias:
|
||||
return Zero()
|
||||
if layer_type == 'conv':
|
||||
assert isinstance(shape, (tuple, list)) and len(
|
||||
shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively'
|
||||
n = shape[1] * shape[1] * shape[2]
|
||||
return Normal(math.sqrt(2.0 / n))
|
||||
if layer_type == 'bn':
|
||||
return One()
|
||||
|
||||
assert isinstance(shape, (tuple, list)) and len(
|
||||
shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively'
|
||||
n = shape[1]
|
||||
init_range = 1.0 / math.sqrt(n)
|
||||
return Uniform(init_range)
|
||||
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0,
|
||||
pad_mode='same', bias=False):
|
||||
"""convolution wrapper"""
|
||||
weight_init_value = _initialize_weight_goog(
|
||||
shape=(in_channels, kernel_size, out_channels))
|
||||
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
|
||||
if bias:
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
has_bias=bias, bias_init=bias_init_value)
|
||||
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
has_bias=bias)
|
||||
|
||||
|
||||
def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False):
|
||||
"""1x1 convolution wrapper"""
|
||||
weight_init_value = _initialize_weight_goog(
|
||||
shape=(in_channels, 1, out_channels))
|
||||
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
|
||||
if bias:
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
has_bias=bias, bias_init=bias_init_value)
|
||||
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
has_bias=bias)
|
||||
|
||||
|
||||
def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0,
|
||||
pad_mode='same', bias=False):
|
||||
"""group convolution wrapper"""
|
||||
weight_init_value = _initialize_weight_goog(
|
||||
shape=(in_channels, kernel_size, out_channels))
|
||||
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
|
||||
if bias:
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
group=group, has_bias=bias, bias_init=bias_init_value)
|
||||
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
|
||||
group=group, has_bias=bias)
|
||||
|
||||
|
||||
def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0):
|
||||
return nn.BatchNorm2d(channels, eps=eps, momentum=1-momentum, gamma_init=gamma_init,
|
||||
beta_init=beta_init)
|
||||
|
||||
|
||||
def _dense(in_channels, output_channels, bias=True, activation=None):
|
||||
weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels),
|
||||
layer_type='fc')
|
||||
bias_init_value = _initialize_weight_goog(bias=True) if bias else None
|
||||
if bias:
|
||||
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
|
||||
bias_init=bias_init_value, has_bias=bias, activation=activation)
|
||||
|
||||
return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
|
||||
has_bias=bias, activation=activation)
|
||||
|
||||
|
||||
def _resolve_bn_args(kwargs):
|
||||
bn_args = _BN_ARGS_TF.copy() if kwargs.pop(
|
||||
'bn_tf', False) else _BN_ARGS_PT.copy()
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
bn_eps = kwargs.pop('bn_eps', None)
|
||||
if bn_eps is not None:
|
||||
bn_args['eps'] = bn_eps
|
||||
return bn_args
|
||||
|
||||
|
||||
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if not multiplier:
|
||||
return channels
|
||||
channels *= multiplier
|
||||
channel_min = channel_min or divisor
|
||||
new_channels = max(
|
||||
int(channels + divisor / 2) // divisor * divisor,
|
||||
channel_min)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_channels < 0.9 * channels:
|
||||
new_channels += divisor
|
||||
return new_channels
|
||||
|
||||
|
||||
def _parse_ksize(ss):
|
||||
if ss.isdigit():
|
||||
return int(ss)
|
||||
return [int(k) for k in ss.split('.')]
|
||||
|
||||
|
||||
def _decode_block_str(block_str, depth_multiplier=1.0):
|
||||
""" Decode block definition string
|
||||
|
||||
Gets a list of block arg (dicts) through a string notation of arguments.
|
||||
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
||||
|
||||
All args can exist in any order with the exception of the leading string which
|
||||
is assumed to indicate the block type.
|
||||
|
||||
leading string - block type (
|
||||
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
|
||||
r - number of repeat blocks,
|
||||
k - kernel size,
|
||||
s - strides (1-9),
|
||||
e - expansion ratio,
|
||||
c - output channels,
|
||||
se - squeeze/excitation ratio
|
||||
n - activation fn ('re', 'r6', 'hs', or 'sw')
|
||||
Args:
|
||||
block_str: a string representation of block arguments.
|
||||
Returns:
|
||||
A list of block args (dicts)
|
||||
Raises:
|
||||
ValueError: if the string def not properly specified (TODO)
|
||||
"""
|
||||
assert isinstance(block_str, str)
|
||||
ops = block_str.split('_')
|
||||
block_type = ops[0] # take the block type off the front
|
||||
ops = ops[1:]
|
||||
options = {}
|
||||
noskip = False
|
||||
for op in ops:
|
||||
if op == 'noskip':
|
||||
noskip = True
|
||||
elif op.startswith('n'):
|
||||
# activation fn
|
||||
key = op[0]
|
||||
v = op[1:]
|
||||
if v == 're':
|
||||
print('not support')
|
||||
elif v == 'r6':
|
||||
print('not support')
|
||||
elif v == 'hs':
|
||||
print('not support')
|
||||
elif v == 'sw':
|
||||
print('not support')
|
||||
else:
|
||||
continue
|
||||
options[key] = value
|
||||
else:
|
||||
# all numeric options
|
||||
splits = re.split(r'(\d.*)', op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
act_fn = options['n'] if 'n' in options else None
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
fake_in_chs = int(options['fc']) if 'fc' in options else 0
|
||||
|
||||
num_repeat = int(options['r'])
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
if block_type == 'ir':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
noskip=noskip,
|
||||
)
|
||||
elif block_type in ('ds', 'dsa'):
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
pw_act=block_type == 'dsa',
|
||||
noskip=block_type == 'dsa' or noskip,
|
||||
)
|
||||
elif block_type == 'er':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
exp_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
fake_in_chs=fake_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
noskip=noskip,
|
||||
)
|
||||
elif block_type == 'cn':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
kernel_size=int(options['k']),
|
||||
out_chs=int(options['c']),
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
)
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
|
||||
return block_args, num_repeat
|
||||
|
||||
|
||||
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
""" Per-stage depth scaling
|
||||
Scales the block repeats in each stage. This depth scaling impl maintains
|
||||
compatibility with the EfficientNet scaling method, while allowing sensible
|
||||
scaling for other models that may have multiple block arg definitions in each stage.
|
||||
"""
|
||||
|
||||
# We scale the total repeat count for each stage, there may be multiple
|
||||
# block arg defs per stage so we need to sum.
|
||||
num_repeat = sum(repeats)
|
||||
if depth_trunc == 'round':
|
||||
# Truncating to int by rounding allows stages with few repeats to remain
|
||||
# proportionally smaller for longer. This is a good choice when stage definitions
|
||||
# include single repeat stages that we'd prefer to keep that way as long as possible
|
||||
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
||||
else:
|
||||
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
||||
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
||||
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
||||
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
||||
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
||||
# The first block makes less sense to repeat in most of the arch definitions.
|
||||
repeats_scaled = []
|
||||
for r in repeats[::-1]:
|
||||
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
||||
repeats_scaled.append(rs)
|
||||
num_repeat -= r
|
||||
num_repeat_scaled -= rs
|
||||
repeats_scaled = repeats_scaled[::-1]
|
||||
# Apply the calculated scaling to each block arg in the stage
|
||||
sa_scaled = []
|
||||
for ba, rep in zip(stack_args, repeats_scaled):
|
||||
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
||||
return sa_scaled
|
||||
|
||||
|
||||
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
"""further decode the architecture definition into model-ready format"""
|
||||
arch_args = []
|
||||
for _, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
stack_args = []
|
||||
repeats = []
|
||||
for block_str in block_strings:
|
||||
assert isinstance(block_str, str)
|
||||
ba, rep = _decode_block_str(block_str)
|
||||
stack_args.append(ba)
|
||||
repeats.append(rep)
|
||||
arch_args.append(_scale_stage_depth(
|
||||
stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
return arch_args
|
||||
|
||||
|
||||
class Swish(nn.Cell):
|
||||
"""swish activation function"""
|
||||
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
|
||||
def construct(self, x):
|
||||
return x * self.sigmoid(x)
|
||||
|
||||
|
||||
@ms_function
|
||||
def swish(x):
|
||||
return x * nn.Sigmoid()(x)
|
||||
|
||||
|
||||
class BlockBuilder(nn.Cell):
|
||||
"""build efficient-net convolution blocks"""
|
||||
|
||||
def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0,
|
||||
channel_divisor=8, channel_min=None, pad_type='', act_fn=None,
|
||||
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
|
||||
drop_connect_rate=0., verbose=False):
|
||||
super(BlockBuilder, self).__init__()
|
||||
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
self.channel_min = channel_min
|
||||
self.pad_type = pad_type
|
||||
self.act_fn = Swish()
|
||||
self.se_gate_fn = se_gate_fn
|
||||
self.se_reduce_mid = se_reduce_mid
|
||||
self.bn_args = bn_args
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.verbose = verbose
|
||||
|
||||
# updated during build
|
||||
self.in_chs = None
|
||||
self.block_idx = 0
|
||||
self.block_count = 0
|
||||
self.layer = self._make_layer(builder_in_channels, builder_block_args)
|
||||
|
||||
def _round_channels(self, chs):
|
||||
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
|
||||
def _make_block(self, ba):
|
||||
"""make the current block based on the block argument"""
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
||||
# this is a hack to work around mismatch in origin impl input filters
|
||||
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
||||
ba['bn_args'] = self.bn_args
|
||||
ba['pad_type'] = self.pad_type
|
||||
# block act fn overrides the model default
|
||||
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
|
||||
assert ba['act_fn'] is not None
|
||||
if bt == 'ir':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * \
|
||||
self.block_idx / self.block_count
|
||||
ba['se_gate_fn'] = self.se_gate_fn
|
||||
ba['se_reduce_mid'] = self.se_reduce_mid
|
||||
block = InvertedResidual(**ba)
|
||||
elif bt in ('ds', 'dsa'):
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate * \
|
||||
self.block_idx / self.block_count
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
self.in_chs = ba['out_chs']
|
||||
|
||||
return block
|
||||
|
||||
def _make_stack(self, stack_args):
|
||||
"""make a stack of blocks"""
|
||||
blocks = []
|
||||
# each stack (stage) contains a list of block arguments
|
||||
for i, ba in enumerate(stack_args):
|
||||
if i >= 1:
|
||||
# only the first block in any stack can have a stride > 1
|
||||
ba['stride'] = 1
|
||||
block = self._make_block(ba)
|
||||
blocks.append(block)
|
||||
self.block_idx += 1 # incr global idx (across all stacks)
|
||||
return nn.SequentialCell(blocks)
|
||||
|
||||
def _make_layer(self, in_chs, block_args):
|
||||
""" Build the entire layer
|
||||
Args:
|
||||
in_chs: Number of input-channels passed to first block
|
||||
block_args: A list of lists, outer list defines stages, inner
|
||||
list contains strings defining block configuration(s)
|
||||
Return:
|
||||
List of block stacks (each stack wrapped in nn.Sequential)
|
||||
"""
|
||||
self.in_chs = in_chs
|
||||
self.block_count = sum([len(x) for x in block_args])
|
||||
self.block_idx = 0
|
||||
blocks = []
|
||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||
for _, stack in enumerate(block_args):
|
||||
assert isinstance(stack, list)
|
||||
stack = self._make_stack(stack)
|
||||
blocks.append(stack)
|
||||
return nn.SequentialCell(blocks)
|
||||
|
||||
def construct(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class DepthWiseConv(nn.Cell):
|
||||
"""depth-wise convolution"""
|
||||
|
||||
def __init__(self, in_planes, kernel_size, stride):
|
||||
super(DepthWiseConv, self).__init__()
|
||||
platform = context.get_context("device_target")
|
||||
weight_shape = [1, kernel_size, in_planes]
|
||||
weight_init = _initialize_weight_goog(shape=weight_shape)
|
||||
|
||||
if platform == "GPU":
|
||||
self.depthwise_conv = P.Conv2D(out_channel=in_planes*1,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
pad=int(kernel_size/2),
|
||||
pad_mode="pad",
|
||||
group=in_planes)
|
||||
|
||||
self.weight = Parameter(initializer(weight_init,
|
||||
[in_planes*1, 1, kernel_size, kernel_size]), name='depthwise_weight')
|
||||
|
||||
else:
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride, pad_mode='pad',
|
||||
pad=int(kernel_size/2))
|
||||
|
||||
self.weight = Parameter(initializer(weight_init,
|
||||
[1, in_planes, kernel_size, kernel_size]), name='depthwise_weight')
|
||||
|
||||
def construct(self, x):
|
||||
x = self.depthwise_conv(x, self.weight)
|
||||
return x
|
||||
|
||||
|
||||
class DropConnect(nn.Cell):
|
||||
"""drop connect implementation"""
|
||||
|
||||
def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
|
||||
super(DropConnect, self).__init__()
|
||||
self.shape = P.Shape()
|
||||
self.dtype = P.DType()
|
||||
self.keep_prob = 1 - drop_connect_rate
|
||||
self.dropout = P.Dropout(keep_prob=self.keep_prob)
|
||||
|
||||
def construct(self, x):
|
||||
shape = self.shape(x)
|
||||
dtype = self.dtype(x)
|
||||
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
|
||||
_, mask_ = self.dropout(ones_tensor)
|
||||
x = x * mask_
|
||||
return x
|
||||
|
||||
|
||||
def drop_connect(inputs, training=False, drop_connect_rate=0.):
|
||||
if not training:
|
||||
return inputs
|
||||
return DropConnect(drop_connect_rate)(inputs)
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Cell):
|
||||
"""squeeze-excite implementation"""
|
||||
|
||||
def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
self.act_fn = Swish()
|
||||
self.gate_fn = gate_fn
|
||||
reduce_chs = reduce_chs or in_chs
|
||||
self.conv_reduce = nn.Conv2d(in_channels=in_chs, out_channels=reduce_chs,
|
||||
kernel_size=1, has_bias=True, pad_mode='pad')
|
||||
self.conv_expand = nn.Conv2d(in_channels=reduce_chs, out_channels=in_chs,
|
||||
kernel_size=1, has_bias=True, pad_mode='pad')
|
||||
self.avg_global_pool = P.ReduceMean(keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
x_se = self.avg_global_pool(x, (2, 3))
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act_fn(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x_se = self.gate_fn(x_se)
|
||||
x = x * x_se
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Cell):
|
||||
"""depth-wise convolution -> (squeeze-excite) -> point-wise convolution"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, pad_type='', act_fn=relu, noskip=False,
|
||||
pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid,
|
||||
bn_args=None, drop_connect_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
assert stride in [1, 2], 'stride must be 1 or 2'
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act
|
||||
self.act_fn = Swish()
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
|
||||
self.bn1 = _fused_bn(in_chs, **bn_args)
|
||||
|
||||
if self.has_se:
|
||||
self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)),
|
||||
act_fn=act_fn, gate_fn=se_gate_fn)
|
||||
self.conv_pw = _conv1x1(in_chs, out_chs)
|
||||
self.bn2 = _fused_bn(out_chs, **bn_args)
|
||||
|
||||
def construct(self, x):
|
||||
"""forward the depthwise separable conv"""
|
||||
identity = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
|
||||
if self.has_pw_act:
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x = x + identity
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
"""inverted-residual block implementation"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1,
|
||||
pad_type='', act_fn=relu, pw_kernel_size=1,
|
||||
noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0.,
|
||||
se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None,
|
||||
bn_args=None, drop_connect_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
mid_chs = int(in_chs * exp_ratio)
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.act_fn = Swish()
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size)
|
||||
self.bn1 = _fused_bn(mid_chs, **bn_args)
|
||||
|
||||
self.shuffle_type = shuffle_type
|
||||
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
|
||||
self.shuffle = None
|
||||
|
||||
self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
|
||||
self.bn2 = _fused_bn(mid_chs, **bn_args)
|
||||
|
||||
if self.has_se:
|
||||
se_base_chs = mid_chs if se_reduce_mid else in_chs
|
||||
self.se = SqueezeExcite(
|
||||
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)),
|
||||
act_fn=act_fn, gate_fn=se_gate_fn
|
||||
)
|
||||
|
||||
self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size)
|
||||
self.bn3 = _fused_bn(out_chs, **bn_args)
|
||||
|
||||
def construct(self, x):
|
||||
"""forward the inverted-residual block"""
|
||||
identity = x
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x = x + identity
|
||||
return x
|
||||
|
||||
|
||||
class GenEfficientNet(nn.Cell):
|
||||
"""Generate EfficientNet architecture"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0.,
|
||||
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
|
||||
global_pool='avg', head_conv='default', weight_init='goog'):
|
||||
|
||||
super(GenEfficientNet, self).__init__()
|
||||
bn_args = _BN_ARGS_PT if bn_args is None else bn_args
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
self.num_features = num_features
|
||||
|
||||
self.conv_stem = _conv(in_chans, stem_size, 3,
|
||||
stride=2, padding=1, pad_mode='pad')
|
||||
self.bn1 = _fused_bn(stem_size, **bn_args)
|
||||
self.act_fn = Swish()
|
||||
in_chans = stem_size
|
||||
self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier,
|
||||
channel_divisor, channel_min,
|
||||
pad_type, act_fn, se_gate_fn, se_reduce_mid,
|
||||
bn_args, drop_connect_rate, verbose=_DEBUG)
|
||||
in_chs = self.blocks.in_chs
|
||||
|
||||
if not head_conv or head_conv == 'none':
|
||||
self.efficient_head = False
|
||||
self.conv_head = None
|
||||
assert in_chs == self.num_features
|
||||
else:
|
||||
self.efficient_head = head_conv == 'efficient'
|
||||
self.conv_head = _conv1x1(in_chs, self.num_features)
|
||||
self.bn2 = None if self.efficient_head else _fused_bn(
|
||||
self.num_features, **bn_args)
|
||||
|
||||
self.global_pool = P.ReduceMean(keep_dims=True)
|
||||
self.classifier = _dense(self.num_features, self.num_classes)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.drop_out = nn.Dropout(keep_prob=1-self.drop_rate)
|
||||
|
||||
def construct(self, x):
|
||||
"""efficient net entry point"""
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.blocks(x)
|
||||
if self.efficient_head:
|
||||
x = self.global_pool(x, (2, 3))
|
||||
x = self.conv_head(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.reshape(self.shape(x)[0], -1)
|
||||
else:
|
||||
if self.conv_head is not None:
|
||||
x = self.conv_head(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.global_pool(x, (2, 3))
|
||||
x = self.reshape(x, (self.shape(x)[0], -1))
|
||||
|
||||
if self.training and self.drop_rate > 0.:
|
||||
x = self.drop_out(x)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
"""Creates an EfficientNet model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||
Paper: https://arxiv.org/abs/1905.11946
|
||||
|
||||
EfficientNet params
|
||||
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
||||
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
||||
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
||||
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
||||
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
||||
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
||||
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
||||
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
||||
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
||||
|
||||
Args:
|
||||
channel_multiplier (int): multiplier to number of channels per layer
|
||||
depth_multiplier (int): multiplier to number of repeats per stage
|
||||
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||
['ir_r2_k3_s2_e6_c24_se0.25'],
|
||||
['ir_r2_k5_s2_e6_c40_se0.25'],
|
||||
['ir_r3_k3_s2_e6_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
]
|
||||
num_features = max(1280, _round_channels(
|
||||
1280, channel_multiplier, 8, None))
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
num_features=num_features,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=Swish,
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def tinynet(sub_model="c", num_classes=1000, in_chans=3, **kwargs):
|
||||
""" TinyNet Models """
|
||||
# choose a sub model
|
||||
r, w, d = TINYNET_CFG[sub_model]
|
||||
default_cfg = default_cfgs['efficientnet_b0']
|
||||
assert default_cfg['input_size'] == (3, 224, 224), "All tinynet models are \
|
||||
evolved from Efficient-B0, which has input dimension of 3*224*224"
|
||||
|
||||
channel, height, width = default_cfg['input_size']
|
||||
height = int(r * height)
|
||||
width = int(r * width)
|
||||
default_cfg['input_size'] = (channel, height, width)
|
||||
|
||||
print("Data processing configuration for current model + dataset:")
|
||||
print("input_size:", default_cfg['input_size'])
|
||||
print("channel mutiplier:%s, depth multiplier:%s, resolution multiplier:%s" % (w, d, r))
|
||||
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=w, depth_multiplier=d,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
|
||||
return model
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""model utils"""
|
||||
import math
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def str2bool(value):
|
||||
"""Convert string arguments to bool type"""
|
||||
if value.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
if value.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
|
||||
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_epochs=1, decay_rate=0.9,
|
||||
warmup_epochs=0., warmup_lr_init=0., global_epoch=0):
|
||||
"""Get scheduled learning rate"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
global_steps = steps_per_epoch * global_epoch
|
||||
self_warmup_delta = ((base_lr - warmup_lr_init) / \
|
||||
warmup_epochs) if warmup_epochs > 0 else 0
|
||||
self_decay_rate = decay_rate if decay_rate < 1 else 1/decay_rate
|
||||
for i in range(total_steps):
|
||||
epochs = math.floor(i/steps_per_epoch)
|
||||
cond = 1 if (epochs < warmup_epochs) else 0
|
||||
warmup_lr = warmup_lr_init + epochs * self_warmup_delta
|
||||
decay_nums = math.floor(epochs / decay_epochs)
|
||||
decay_rate = math.pow(self_decay_rate, decay_nums)
|
||||
decay_lr = base_lr * decay_rate
|
||||
lr = cond * warmup_lr + (1 - cond) * decay_lr
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = lr_each_step[global_steps:]
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def add_weight_decay(net, weight_decay=1e-5, skip_list=None):
|
||||
"""Apply weight decay to only conv and dense layers (len(shape) > =2)
|
||||
Args:
|
||||
net (mindspore.nn.Cell): Mindspore network instance
|
||||
weight_decay (float): weight decay tobe used.
|
||||
skip_list (tuple): list of parameter names without weight decay
|
||||
Returns:
|
||||
A list of group of parameters, separated by different weight decay.
|
||||
"""
|
||||
decay = []
|
||||
no_decay = []
|
||||
if not skip_list:
|
||||
skip_list = ()
|
||||
for param in net.trainable_params():
|
||||
if len(param.shape) == 1 or \
|
||||
param.name.endswith(".bias") or \
|
||||
param.name in skip_list:
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [
|
||||
{'params': no_decay, 'weight_decay': 0.},
|
||||
{'params': decay, 'weight_decay': weight_decay}]
|
||||
|
||||
|
||||
def count_params(net):
|
||||
"""Count number of parameters in the network
|
||||
Args:
|
||||
net (mindspore.nn.Cell): Mindspore network instance
|
||||
Returns:
|
||||
total_params (int): Total number of trainable params
|
||||
"""
|
||||
total_params = 0
|
||||
for param in net.trainable_params():
|
||||
total_params += np.prod(param.shape)
|
||||
return total_params
|
|
@ -0,0 +1,250 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Training Interface"""
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.model import ParallelMode, Model
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.nn import SGD, RMSProp, Loss, Top1CategoricalAccuracy, \
|
||||
Top5CategoricalAccuracy
|
||||
from mindspore import context, Tensor
|
||||
|
||||
from src.dataset import create_dataset, create_dataset_val
|
||||
from src.utils import add_weight_decay, count_params, str2bool, get_lr
|
||||
from src.callback import EmaEvalCallBack, LossMonitor
|
||||
from src.loss import LabelSmoothingCrossEntropy
|
||||
from src.tinynet import tinynet
|
||||
|
||||
parser = argparse.ArgumentParser(description='Training')
|
||||
|
||||
# training parameters
|
||||
parser.add_argument('--data_path', type=str, default="", metavar="DIR",
|
||||
help='path to dataset')
|
||||
parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "tinynet_c"')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
||||
help='Drop connect rate (default: 0.)')
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: 1e-8)')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
||||
help='weight decay (default: 0.0001)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||
help='warmup learning rate (default: 0.0001)')
|
||||
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
||||
help='number of epochs to train (default: 2)')
|
||||
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||
help='epoch interval to decay LR')
|
||||
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
|
||||
help='epochs to warmup LR, if scheduler supports')
|
||||
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||
help='LR decay rate (default: 0.1)')
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='label smoothing (default: 0.1)')
|
||||
parser.add_argument('--ema-decay', type=float, default=0,
|
||||
help='decay factor for model weights moving average \
|
||||
(default: 0.999)')
|
||||
parser.add_argument('--amp_level', type=str, default='O0')
|
||||
parser.add_argument('--per_print_times', type=int, default=100)
|
||||
|
||||
# batch norm parameters
|
||||
parser.add_argument('--bn-tf', action='store_true', default=False,
|
||||
help='Use Tensorflow BatchNorm defaults for models that \
|
||||
support it (default: False)')
|
||||
parser.add_argument('--bn-momentum', type=float, default=None,
|
||||
help='BatchNorm momentum override (if not None)')
|
||||
parser.add_argument('--bn-eps', type=float, default=None,
|
||||
help='BatchNorm epsilon override (if not None)')
|
||||
|
||||
# parallel parameters
|
||||
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||
help='how many training processes to use (default: 1)')
|
||||
parser.add_argument('--distributed', action='store_true', default=False)
|
||||
parser.add_argument('--dataset_sink', action='store_true', default=True)
|
||||
|
||||
# checkpoint config
|
||||
parser.add_argument('--ckpt', type=str, default=None)
|
||||
parser.add_argument('--ckpt_save_epoch', type=int, default=1)
|
||||
parser.add_argument('--loss_scale', type=int,
|
||||
default=1024, help='static loss scale')
|
||||
parser.add_argument('--train', type=str2bool, default=1, help='train or eval')
|
||||
parser.add_argument('--GPU', action='store_true', default=False,
|
||||
help='Use GPU for training (default: False)')
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entrance for training"""
|
||||
args = parser.parse_args()
|
||||
print(sys.argv)
|
||||
devid, args.rank_id, args.rank_size = 0, 0, 1
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
if args.distributed:
|
||||
if args.GPU:
|
||||
init("nccl")
|
||||
context.set_context(device_target='GPU')
|
||||
else:
|
||||
init()
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_target='Ascend',
|
||||
device_id=devid,
|
||||
reserve_class_name_in_scope=False)
|
||||
context.reset_auto_parallel_context()
|
||||
args.rank_id = get_rank()
|
||||
args.rank_size = get_group_size()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
device_num=args.rank_size)
|
||||
else:
|
||||
if args.GPU:
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
is_master = not args.distributed or (args.rank_id == 0)
|
||||
|
||||
# parse model argument
|
||||
assert args.model.startswith(
|
||||
"tinynet"), "Only Tinynet models are supported."
|
||||
_, sub_name = args.model.split("_")
|
||||
net = tinynet(sub_model=sub_name,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_connect_rate=args.drop_connect,
|
||||
global_pool="avg",
|
||||
bn_tf=args.bn_tf,
|
||||
bn_momentum=args.bn_momentum,
|
||||
bn_eps=args.bn_eps)
|
||||
|
||||
if is_master:
|
||||
print("Total number of parameters:", count_params(net))
|
||||
# input image size of the network
|
||||
input_size = net.default_cfg['input_size'][1]
|
||||
|
||||
train_dataset = val_dataset = None
|
||||
train_data_url = os.path.join(args.data_path, 'train')
|
||||
val_data_url = os.path.join(args.data_path, 'val')
|
||||
val_dataset = create_dataset_val(args.batch_size,
|
||||
val_data_url,
|
||||
workers=args.workers,
|
||||
distributed=False,
|
||||
input_size=input_size)
|
||||
|
||||
if args.train:
|
||||
train_dataset = create_dataset(args.batch_size,
|
||||
train_data_url,
|
||||
workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
input_size=input_size)
|
||||
batches_per_epoch = train_dataset.get_dataset_size()
|
||||
|
||||
loss = LabelSmoothingCrossEntropy(
|
||||
smooth_factor=args.smoothing, num_classes=args.num_classes)
|
||||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
loss_scale_manager = FixedLossScaleManager(
|
||||
args.loss_scale, drop_overflow_update=False)
|
||||
|
||||
lr_array = get_lr(base_lr=args.lr,
|
||||
total_epochs=args.epochs,
|
||||
steps_per_epoch=batches_per_epoch,
|
||||
decay_epochs=args.decay_epochs,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_epochs=args.warmup_epochs,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
global_epoch=0)
|
||||
lr = Tensor(lr_array)
|
||||
|
||||
loss_cb = LossMonitor(lr_array,
|
||||
args.epochs,
|
||||
per_print_times=args.per_print_times,
|
||||
start_epoch=0)
|
||||
|
||||
param_group = add_weight_decay(net, weight_decay=args.weight_decay)
|
||||
|
||||
if args.opt == 'sgd':
|
||||
if is_master:
|
||||
print('Using SGD optimizer')
|
||||
optimizer = SGD(param_group,
|
||||
learning_rate=lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
|
||||
elif args.opt == 'rmsprop':
|
||||
if is_master:
|
||||
print('Using rmsprop optimizer')
|
||||
optimizer = RMSProp(param_group,
|
||||
learning_rate=lr,
|
||||
decay=0.9,
|
||||
weight_decay=args.weight_decay,
|
||||
momentum=args.momentum,
|
||||
epsilon=args.opt_eps,
|
||||
loss_scale=args.loss_scale)
|
||||
|
||||
loss.add_flags_recursive(fp32=True, fp16=False)
|
||||
eval_metrics = {'Validation-Loss': Loss(),
|
||||
'Top1-Acc': Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': Top5CategoricalAccuracy()}
|
||||
|
||||
if args.ckpt:
|
||||
ckpt = load_checkpoint(args.ckpt)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
|
||||
model = Model(net, loss, optimizer, metrics=eval_metrics,
|
||||
loss_scale_manager=loss_scale_manager,
|
||||
amp_level=args.amp_level)
|
||||
|
||||
net_ema = copy.deepcopy(net)
|
||||
net_ema.set_train(False)
|
||||
assert args.ema_decay > 0, "EMA should be used in tinynet training."
|
||||
|
||||
ema_cb = EmaEvalCallBack(model=model,
|
||||
ema_network=net_ema,
|
||||
loss_fn=loss,
|
||||
eval_dataset=val_dataset,
|
||||
decay=args.ema_decay,
|
||||
save_epoch=args.ckpt_save_epoch,
|
||||
dataset_sink_mode=args.dataset_sink,
|
||||
start_epoch=0)
|
||||
|
||||
callbacks = [loss_cb, ema_cb, time_cb] if is_master else []
|
||||
|
||||
if is_master:
|
||||
print("Training on " + args.model
|
||||
+ " with " + str(args.num_classes) + " classes")
|
||||
|
||||
model.train(args.epochs, train_dataset, callbacks=callbacks,
|
||||
dataset_sink_mode=args.dataset_sink)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue