From 882301f4b5b1716b22535d80654fdf7fbd46d8b2 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Sat, 31 Oct 2020 19:52:39 +0800 Subject: [PATCH] Adding TinyNet to Model Zoo Adding TinyNet (https://arxiv.org/abs/2010.14819) MindSpore implementation to model Zoo --- model_zoo/README.md | 1 + model_zoo/research/cv/tinynet/README.md | 154 ++++ model_zoo/research/cv/tinynet/eval.py | 101 +++ model_zoo/research/cv/tinynet/script/eval.sh | 42 + .../cv/tinynet/script/train_1p_gpu.sh | 59 ++ .../tinynet/script/train_distributed_gpu.sh | 82 ++ model_zoo/research/cv/tinynet/src/callback.py | 203 +++++ model_zoo/research/cv/tinynet/src/dataset.py | 143 ++++ model_zoo/research/cv/tinynet/src/loss.py | 44 + model_zoo/research/cv/tinynet/src/tinynet.py | 808 ++++++++++++++++++ model_zoo/research/cv/tinynet/src/utils.py | 89 ++ model_zoo/research/cv/tinynet/train.py | 250 ++++++ 12 files changed, 1976 insertions(+) create mode 100644 model_zoo/research/cv/tinynet/README.md create mode 100644 model_zoo/research/cv/tinynet/eval.py create mode 100755 model_zoo/research/cv/tinynet/script/eval.sh create mode 100755 model_zoo/research/cv/tinynet/script/train_1p_gpu.sh create mode 100755 model_zoo/research/cv/tinynet/script/train_distributed_gpu.sh create mode 100755 model_zoo/research/cv/tinynet/src/callback.py create mode 100755 model_zoo/research/cv/tinynet/src/dataset.py create mode 100755 model_zoo/research/cv/tinynet/src/loss.py create mode 100755 model_zoo/research/cv/tinynet/src/tinynet.py create mode 100755 model_zoo/research/cv/tinynet/src/utils.py create mode 100755 model_zoo/research/cv/tinynet/train.py diff --git a/model_zoo/README.md b/model_zoo/README.md index 48125fd05e4..e28d8b3a160 100644 --- a/model_zoo/README.md +++ b/model_zoo/README.md @@ -60,6 +60,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework, - [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md) - [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md) - [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md) + - [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md) - [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp) - [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md) - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) diff --git a/model_zoo/research/cv/tinynet/README.md b/model_zoo/research/cv/tinynet/README.md new file mode 100644 index 00000000000..d8d49c6200a --- /dev/null +++ b/model_zoo/research/cv/tinynet/README.md @@ -0,0 +1,154 @@ +# Contents + +- [TinyNet Description](#tinynet-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [TinyNet Description](#contents) + +TinyNets are a series of lightweight models obtained by twisting resolution, depth and width with a data-driven tiny formula. TinyNet outperforms EfficientNet and MobileNetV3. + +[Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020. + +Note: We have only released TinyNet-C for now, and will release other TinyNets soon. +# [Model architecture](#contents) + +The overall network architecture of TinyNet is show below: + +[Link](https://arxiv.org/abs/2010.14819) + +# [Dataset](#contents) + +Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/) + +- Dataset size: + - Train: 1.2 million images in 1,000 classes + - Test: 50,000 validation images in 1,000 classes +- Data format: RGB images. + - Note: Data will be processed in src/dataset/dataset.py + +# [Environment Requirements](#contents) + +- Hardware (GPU) +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [Script description](#contents) + +## [Script and sample code](#contents) + +``` +.tinynet +├── Readme.md # descriptions about tinynet +├── script +│ ├── eval.sh # evaluation script +│ ├── train_1p_gpu.sh # training script on single GPU +│ └── train_distributed_gpu.sh # distributed training script on multiple GPUs +├── src +│ ├── callback.py # loss and checkpoint callbacks +│ ├── dataset.py # data processing +│ ├── loss.py # label-smoothing cross-entropy loss function +│ ├── tinynet.py # tinynet architecture +│ └── utils.py # utility functions +├── eval.py # evaluation interface +└── train.py # training interface +``` +## [Training process](#contents) + +### Launch + +``` +# training on single GPU + sh train_1p_gpu.sh +# training on multiple GPUs, the number after -n indicates how many GPUs will be used for training + sh train_distributed_gpu.sh -n 8 +``` +Inside train.sh, there are hyperparameters that can be adjusted during training, for example: +``` +--model tinynet_c model to be used for training +--drop 0.2 dropout rate +--drop-connect 0 drop connect rate +--num-classes 1000 number of classes for training +--opt-eps 0.001 optimizer's epsilon +--lr 0.048 learning rate +--batch-size 128 batch size +--decay-epochs 2.4 learning rate decays every 2.4 epoch +--warmup-lr 1e-6 warm up learning rate +--warmup-epochs 3 learning rate warm up epoch +--decay-rate 0.97 learning rate decay rate +--ema-decay 0.9999 decay factor for model weights moving average +--weight-decay 1e-5 optimizer's weight decay +--epochs 450 number of epochs to be trained +--ckpt_save_epoch 1 checkpoint saving interval +--workers 8 number of processes for loading data +--amp_level O0 training auto-mixed precision +--opt rmsprop optimizers, currently we support SGD and RMSProp +--data_path /path_to_ImageNet/ +--GPU using GPU for training +--dataset_sink using sink mode +``` +The config above was used to train tinynets on ImageNet (change drop-connect to 0.2 for training tinynet-b) + +> checkpoints will be saved in the ./device_{rank_id} folder (single GPU) +or ./device_parallel folder (multiple GPUs) + +## [Eval process](#contents) + +### Launch + +``` +# infer example + +sh eval.sh +``` +Inside the eval.sh, there are configs that can be adjusted during inference, for example: +``` +--num-classes 1000 +--batch-size 128 +--workers 8 +--data_path /path_to_ImageNet/ +--GPU +--ckpt /path_to_EMA_checkpoint/ +--dataset_sink > tinynet_c_eval.log 2>&1 & +``` +> checkpoint can be produced in training process. + +# [Model Description](#contents) + +## [Performance](#contents) + +#### Evaluation Performance + +| Model | FLOPs | Latency* | ImageNet Top-1 | +| ------------------- | ----- | -------- | -------------- | +| EfficientNet-B0 | 387M | 99.85 ms | 76.7% | +| TinyNet-A | 339M | 81.30 ms | 76.8% | +| EfficientNet-B^{-4} | 24M | 11.54 ms | 56.7% | +| TinyNet-E | 24M | 9.18 ms | 59.9% | + +*Latency is measured using MS Lite on Huawei P40 smartphone. + +*More details in [Paper](https://arxiv.org/abs/2010.14819). + +# [Description of Random Situation](#contents) + +We set the seed inside dataset.py. We also use random seed in train.py. + +# [Model Zoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/research/cv/tinynet/eval.py b/model_zoo/research/cv/tinynet/eval.py new file mode 100644 index 00000000000..0aa364ae5dd --- /dev/null +++ b/model_zoo/research/cv/tinynet/eval.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Inference Interface""" +import sys +import os +import argparse + +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy +from mindspore import context + +from src.dataset import create_dataset_val +from src.utils import count_params +from src.loss import LabelSmoothingCrossEntropy +from src.tinynet import tinynet + +parser = argparse.ArgumentParser(description='Evaluation') +parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', + metavar='DIR', help='path to dataset') +parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL', + help='Name of model to train (default: "tinynet_c"') +parser.add_argument('--num-classes', type=int, default=1000, metavar='N', + help='number of label classes (default: 1000)') +parser.add_argument('--smoothing', type=float, default=0.1, + help='label smoothing (default: 0.1)') +parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') +parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', + help='how many training processes to use (default: 1)') +parser.add_argument('--ckpt', type=str, default=None, + help='model checkpoint to load') +parser.add_argument('--GPU', action='store_true', default=True, + help='Use GPU for training (default: True)') +parser.add_argument('--dataset_sink', action='store_true', default=True) + + +def main(): + """Main entrance for training""" + args = parser.parse_args() + print(sys.argv) + + context.set_context(mode=context.GRAPH_MODE) + + if args.GPU: + context.set_context(device_target='GPU') + + # parse model argument + assert args.model.startswith( + "tinynet"), "Only Tinynet models are supported." + _, sub_name = args.model.split("_") + net = tinynet(sub_model=sub_name, + num_classes=args.num_classes, + drop_rate=0.0, + drop_connect_rate=0.0, + global_pool="avg", + bn_tf=False, + bn_momentum=None, + bn_eps=None) + print("Total number of parameters:", count_params(net)) + + input_size = net.default_cfg['input_size'][1] + val_data_url = os.path.join(args.data_path, 'val') + val_dataset = create_dataset_val(args.batch_size, + val_data_url, + workers=args.workers, + distributed=False, + input_size=input_size) + + loss = LabelSmoothingCrossEntropy(smooth_factor=args.smoothing, + num_classes=args.num_classes) + + loss.add_flags_recursive(fp32=True, fp16=False) + eval_metrics = {'Validation-Loss': Loss(), + 'Top1-Acc': Top1CategoricalAccuracy(), + 'Top5-Acc': Top5CategoricalAccuracy()} + + ckpt = load_checkpoint(args.ckpt) + load_param_into_net(net, ckpt) + net.set_train(False) + + model = Model(net, loss, metrics=eval_metrics) + + metrics = model.eval(val_dataset, dataset_sink_mode=False) + print(metrics) + + +if __name__ == '__main__': + main() diff --git a/model_zoo/research/cv/tinynet/script/eval.sh b/model_zoo/research/cv/tinynet/script/eval.sh new file mode 100755 index 00000000000..0282bacf5cd --- /dev/null +++ b/model_zoo/research/cv/tinynet/script/eval.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +cd ../ || exit +current_exec_path=$(pwd) +echo ${current_exec_path} + +export RANK_SIZE=1 +export start=0 +export value=$((start + RANK_SIZE)) +export curtime +curtime=$(date '+%Y%m%d-%H%M%S') +echo "$curtime" + +rm ${current_exec_path}/device${start}_$curtime/ -rf +mkdir ${current_exec_path}/device${start}_$curtime +cd ${current_exec_path}/device${start}_$curtime || exit + +export RANK_ID=start +export DEVICE_ID=start +time python3 ${current_exec_path}/eval.py \ + --model tinynet_c \ + --num-classes 1000 \ + --batch-size 128 \ + --workers 8 \ + --data_path /path_to_ImageNet/\ + --GPU \ + --ckpt /path_to_ckpt/ \ + --dataset_sink > tinynet_c_eval.log 2>&1 & + diff --git a/model_zoo/research/cv/tinynet/script/train_1p_gpu.sh b/model_zoo/research/cv/tinynet/script/train_1p_gpu.sh new file mode 100755 index 00000000000..fc982e2bdf5 --- /dev/null +++ b/model_zoo/research/cv/tinynet/script/train_1p_gpu.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +cd ../ || exit +current_exec_path=$(pwd) +echo ${current_exec_path} + +export RANK_SIZE=1 +export start=0 +export value=$(($start+$RANK_SIZE)) +export curtime +curtime=$(date '+%Y%m%d-%H%M%S') + +echo $curtime +echo "rank_id = ${start}" +rm ${current_exec_path}/device_$start/ -rf +mkdir ${current_exec_path}/device_$start +cd ${current_exec_path}/device_$start || exit +export RANK_ID=$start +export DEVICE_ID=$start + +time python3 ${current_exec_path}/train.py \ + --model tinynet_c \ + --drop 0.2 \ + --drop-connect 0 \ + --num-classes 1000 \ + --opt-eps 0.001 \ + --lr 0.048 \ + --batch-size 128 \ + --decay-epochs 2.4 \ + --warmup-lr 1e-6 \ + --warmup-epochs 3 \ + --decay-rate 0.97 \ + --ema-decay 0.9999 \ + --weight-decay 1e-5 \ + --epochs 100\ + --ckpt_save_epoch 1 \ + --workers 8 \ + --amp_level O0 \ + --opt rmsprop \ + --data_path /path_to_ImageNet/ \ + --GPU \ + --dataset_sink > tinynet_c.log 2>&1 & + + +cd ${current_exec_path} || exit + diff --git a/model_zoo/research/cv/tinynet/script/train_distributed_gpu.sh b/model_zoo/research/cv/tinynet/script/train_distributed_gpu.sh new file mode 100755 index 00000000000..aaed02d25a0 --- /dev/null +++ b/model_zoo/research/cv/tinynet/script/train_distributed_gpu.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# below help function was adapted from +# https://unix.stackexchange.com/questions/31414/how-can-i-pass-a-command-line-argument-into-a-shell-script +helpFunction() +{ + echo "" + echo "Usage: $0 -n num_device" + echo -e "\t-n how many gpus to use for training" + exit 1 # Exit script after printing help +} + +while getopts "n:" opt +do + case "$opt" in + n ) num_device="$OPTARG" ;; + ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent + esac +done + +# Print helpFunction in case parameters are empty +if [ -z "$num_device" ] +then + echo "Some or all of the parameters are empty"; + helpFunction +fi + +# Begin script in case all parameters are correct +echo "$num_device" +cd ../ || exit +current_exec_path=$(pwd) +echo ${current_exec_path} + +export SLOG_PRINT_TO_STDOUT=0 +export RANK_SIZE=$num_device +export curtime +curtime=$(date '+%Y%m%d-%H%M%S') +echo $curtime +echo $curtime >> starttime +rm ${current_exec_path}/device_parallel/ -rf +mkdir ${current_exec_path}/device_parallel +cd ${current_exec_path}/device_parallel || exit +echo $curtime >> starttime + +time mpirun -n $RANK_SIZE --allow-run-as-root python3 ${current_exec_path}/train.py \ + --model tinynet_c \ + --drop 0.2 \ + --drop-connect 0 \ + --num-classes 1000 \ + --opt-eps 0.001 \ + --lr 0.048 \ + --batch-size 128 \ + --decay-epochs 2.4 \ + --warmup-lr 1e-6 \ + --warmup-epochs 3 \ + --decay-rate 0.97 \ + --ema-decay 0.9999 \ + --weight-decay 1e-5 \ + --per_print_times 100 \ + --epochs 450 \ + --ckpt_save_epoch 1 \ + --workers 8 \ + --amp_level O0 \ + --opt rmsprop \ + --distributed \ + --data_path /path_to_ImageNet/ \ + --GPU \ + --dataset_sink > tinynet_c.log 2>&1 & + diff --git a/model_zoo/research/cv/tinynet/src/callback.py b/model_zoo/research/cv/tinynet/src/callback.py new file mode 100755 index 00000000000..61f45c3ae6b --- /dev/null +++ b/model_zoo/research/cv/tinynet/src/callback.py @@ -0,0 +1,203 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""custom callbacks for ema and loss""" +from copy import deepcopy + +import numpy as np +from mindspore.train.callback import Callback +from mindspore.common.parameter import Parameter +from mindspore.train.serialization import save_checkpoint +from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy +from mindspore.train.model import Model +from mindspore import Tensor + + +def load_nparray_into_net(net, array_dict): + """ + Loads dictionary of numpy arrays into network. + + Args: + net (Cell): Cell network. + array_dict (dict): dictionary of numpy array format model weights. + """ + param_not_load = [] + for _, param in net.parameters_and_names(): + if param.name in array_dict: + new_param = array_dict[param.name] + param.set_data(Parameter(new_param.copy(), name=param.name)) + else: + param_not_load.append(param.name) + return param_not_load + + +class EmaEvalCallBack(Callback): + """ + Call back that will evaluate the model and save model checkpoint at + the end of training epoch. + + Args: + model: Mindspore model instance. + ema_network: step-wise exponential moving average for ema_network. + eval_dataset: the evaluation daatset. + decay (float): ema decay. + save_epoch (int): defines how often to save checkpoint. + dataset_sink_mode (bool): whether to use data sink mode. + start_epoch (int): which epoch to start/resume training. + """ + + def __init__(self, model, ema_network, eval_dataset, loss_fn, decay=0.999, + save_epoch=1, dataset_sink_mode=True, start_epoch=0): + self.model = model + self.ema_network = ema_network + self.eval_dataset = eval_dataset + self.loss_fn = loss_fn + self.decay = decay + self.save_epoch = save_epoch + self.shadow = {} + self.ema_accuracy = {} + + self.best_ema_accuracy = 0 + self.best_accuracy = 0 + self.best_ema_epoch = 0 + self.best_epoch = 0 + self._start_epoch = start_epoch + self.eval_metrics = {'Validation-Loss': Loss(), + 'Top1-Acc': Top1CategoricalAccuracy(), + 'Top5-Acc': Top5CategoricalAccuracy()} + self.dataset_sink_mode = dataset_sink_mode + + def begin(self, run_context): + """Initialize the EMA parameters """ + cb_params = run_context.original_args() + for _, param in cb_params.network.parameters_and_names(): + self.shadow[param.name] = deepcopy(param.data.asnumpy()) + + def step_end(self, run_context): + """Update the EMA parameters""" + cb_params = run_context.original_args() + for _, param in cb_params.network.parameters_and_names(): + new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \ + self.decay * self.shadow[param.name] + self.shadow[param.name] = new_average + + def epoch_end(self, run_context): + """evaluate the model and ema-model at the end of each epoch""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1 + + save_ckpt = (cur_epoch % self.save_epoch == 0) + + acc = self.model.eval( + self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) + print("Model Accuracy:", acc) + + load_nparray_into_net(self.ema_network, self.shadow) + self.ema_network.set_train(False) + + model_ema = Model(self.ema_network, loss_fn=self.loss_fn, + metrics=self.eval_metrics) + ema_acc = model_ema.eval( + self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) + + print("EMA-Model Accuracy:", ema_acc) + self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"] + output = [{"name": k, "data": Tensor(v)} + for k, v in self.shadow.items()] + + if self.best_ema_accuracy < ema_acc["Top1-Acc"]: + self.best_ema_accuracy = ema_acc["Top1-Acc"] + self.best_ema_epoch = cur_epoch + save_checkpoint(output, "ema_best.ckpt") + + if self.best_accuracy < acc["Top1-Acc"]: + self.best_accuracy = acc["Top1-Acc"] + self.best_epoch = cur_epoch + + print("Best Model Accuracy: %s, at epoch %s" % + (self.best_accuracy, self.best_epoch)) + print("Best EMA-Model Accuracy: %s, at epoch %s" % + (self.best_ema_accuracy, self.best_ema_epoch)) + + if save_ckpt: + # Save the ema_model checkpoints + ckpt = "{}-{}.ckpt".format("ema", cur_epoch) + save_checkpoint(output, ckpt) + save_checkpoint(output, "ema_last.ckpt") + + # Save the model checkpoints + save_checkpoint(cb_params.train_network, "last.ckpt") + + print("Top 10 EMA-Model Accuracies: ") + count = 0 + for epoch in sorted(self.ema_accuracy, key=self.ema_accuracy.get, + reverse=True): + if count == 10: + break + print("epoch: %s, Top-1: %s)" % (epoch, self.ema_accuracy[epoch])) + count += 1 + + +class LossMonitor(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF, it will terminate training. + + Note: + If per_print_times is 0, do not print loss. + + Args: + lr_array (numpy.array): scheduled learning rate. + total_epochs (int): Total number of epochs for training. + per_print_times (int): Print the loss every time. Default: 1. + start_epoch (int): which epoch to start, used when resume from a + certain epoch. + + Raises: + ValueError: If print_step is not an integer or less than zero. + """ + + def __init__(self, lr_array, total_epochs, per_print_times=1, start_epoch=0): + super(LossMonitor, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self._lr_array = lr_array + self._total_epochs = total_epochs + self._start_epoch = start_epoch + + def step_end(self, run_context): + """log epoch, step, loss and learning rate""" + cb_params = run_context.original_args() + loss = cb_params.net_outputs + cur_epoch_num = cb_params.cur_epoch_num + self._start_epoch - 1 + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + global_step = cb_params.cur_step_num - 1 + cur_step_in_epoch = global_step % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( + cur_epoch_num, cur_step_in_epoch)) + + if self._per_print_times != 0 and cur_step_in_epoch % self._per_print_times == 0: + print("epoch: %s/%s, step: %s/%s, loss is %s, learning rate: %s" + % (cur_epoch_num, self._total_epochs, cur_step_in_epoch, + cb_params.batch_num, loss, self._lr_array[global_step]), + flush=True) diff --git a/model_zoo/research/cv/tinynet/src/dataset.py b/model_zoo/research/cv/tinynet/src/dataset.py new file mode 100755 index 00000000000..8e6486ffa47 --- /dev/null +++ b/model_zoo/research/cv/tinynet/src/dataset.py @@ -0,0 +1,143 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Data operations, will be used in train.py and eval.py""" +import math +import os + +import numpy as np +import mindspore.dataset.vision.py_transforms as py_vision +import mindspore.dataset.transforms.py_transforms as py_transforms +import mindspore.dataset.transforms.c_transforms as c_transforms +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +from mindspore.communication.management import get_rank, get_group_size +from mindspore.dataset.vision import Inter + +# values that should remain constant +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +# data preprocess configs +SCALE = (0.08, 1.0) +RATIO = (3./4., 4./3.) + +ds.config.set_seed(1) + + +def split_imgs_and_labels(imgs, labels, batchInfo): + """split data into labels and images""" + ret_imgs = [] + ret_labels = [] + + for i, image in enumerate(imgs): + ret_imgs.append(image) + ret_labels.append(labels[i]) + return np.array(ret_imgs), np.array(ret_labels) + + +def create_dataset(batch_size, train_data_url='', workers=8, distributed=False, + input_size=224, color_jitter=0.4): + """Creat ImageNet training dataset""" + if not os.path.exists(train_data_url): + raise ValueError('Path not exists') + decode_op = py_vision.Decode() + type_cast_op = c_transforms.TypeCast(mstype.int32) + + random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size), + scale=SCALE, ratio=RATIO, + interpolation=Inter.BICUBIC) + random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5) + adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter) + random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range, + contrast=adjust_range, + saturation=adjust_range) + to_tensor = py_vision.ToTensor() + nromlize_op = py_vision.Normalize( + IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + # assemble all the transforms + image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic, + random_horizontal_flip_op, random_color_jitter_op, to_tensor, nromlize_op]) + + rank_id = get_rank() if distributed else 0 + rank_size = get_group_size() if distributed else 1 + + dataset_train = ds.ImageFolderDataset(train_data_url, + num_parallel_workers=workers, + shuffle=True, + num_shards=rank_size, + shard_id=rank_id) + + dataset_train = dataset_train.map(input_columns=["image"], + operations=image_ops, + num_parallel_workers=workers) + + dataset_train = dataset_train.map(input_columns=["label"], + operations=type_cast_op, + num_parallel_workers=workers) + + # batch dealing + ds_train = dataset_train.batch(batch_size, + per_batch_map=split_imgs_and_labels, + input_columns=["image", "label"], + num_parallel_workers=2, + drop_remainder=True) + + ds_train = ds_train.repeat(1) + return ds_train + + +def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False, + input_size=224): + """Creat ImageNet validation dataset""" + if not os.path.exists(val_data_url): + raise ValueError('Path not exists') + rank_id = get_rank() if distributed else 0 + rank_size = get_group_size() if distributed else 1 + dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers, + num_shards=rank_size, shard_id=rank_id) + scale_size = None + + if isinstance(input_size, tuple): + assert len(input_size) == 2 + if input_size[-1] == input_size[-2]: + scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT)) + else: + scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size]) + else: + scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT)) + + type_cast_op = c_transforms.TypeCast(mstype.int32) + decode_op = py_vision.Decode() + resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC) + center_crop = py_vision.CenterCrop(size=input_size) + to_tensor = py_vision.ToTensor() + nromlize_op = py_vision.Normalize( + IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + image_ops = py_transforms.Compose([decode_op, resize_op, center_crop, + to_tensor, nromlize_op]) + + dataset = dataset.map(input_columns=["label"], operations=type_cast_op, + num_parallel_workers=workers) + dataset = dataset.map(input_columns=["image"], operations=image_ops, + num_parallel_workers=workers) + dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels, + input_columns=["image", "label"], + num_parallel_workers=2, + drop_remainder=True) + dataset = dataset.repeat(1) + return dataset diff --git a/model_zoo/research/cv/tinynet/src/loss.py b/model_zoo/research/cv/tinynet/src/loss.py new file mode 100755 index 00000000000..1db8b966a5f --- /dev/null +++ b/model_zoo/research/cv/tinynet/src/loss.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network.""" + +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class LabelSmoothingCrossEntropy(_Loss): + """cross-entropy with label smoothing""" + + def __init__(self, smooth_factor=0.1, num_classes=1000): + super(LabelSmoothingCrossEntropy, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logits, label): + label = self.cast(label, mstype.int32) + one_hot_label = self.onehot(label, F.shape( + logits)[1], self.on_value, self.off_value) + loss_logit = self.ce(logits, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + return loss_logit diff --git a/model_zoo/research/cv/tinynet/src/tinynet.py b/model_zoo/research/cv/tinynet/src/tinynet.py new file mode 100755 index 00000000000..2634802e30a --- /dev/null +++ b/model_zoo/research/cv/tinynet/src/tinynet.py @@ -0,0 +1,808 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tinynet model definition""" +import math +import re +from copy import deepcopy + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform +from mindspore import context, ms_function +from mindspore.common.parameter import Parameter + +# Imagenet constant values +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +# model structure configurations for TinyNets, values are +# (resolution multiplier, channel multiplier, depth multiplier) +# only tinynet-c is availiable for now, we will release other tinynet +# models soon +# codes are inspired and partially adapted from +# https://github.com/rwightman/gen-efficientnet-pytorch + +TINYNET_CFG = {"c": (0.825, 0.54, 0.85)} + +relu = P.ReLU() +sigmoid = P.Sigmoid() + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'), + 'efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_b3': _cfg( + url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_b4': _cfg( + url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), +} + +_DEBUG = False + +# Default args for PyTorch BN impl +_BN_MOMENTUM_PT_DEFAULT = 0.1 +_BN_EPS_PT_DEFAULT = 1e-5 +_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT) + +# Defaults used for Google/Tensorflow training of mobile networks /w +# RMSprop as per papers and TF reference implementations. PT momentum +# equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +_BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT) + + +def _initialize_weight_goog(shape=None, layer_type='conv', bias=False): + """Google style weight initialization""" + if layer_type not in ('conv', 'bn', 'fc'): + raise ValueError( + 'The layer type is not known, the supported are conv, bn and fc') + if bias: + return Zero() + if layer_type == 'conv': + assert isinstance(shape, (tuple, list)) and len( + shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively' + n = shape[1] * shape[1] * shape[2] + return Normal(math.sqrt(2.0 / n)) + if layer_type == 'bn': + return One() + + assert isinstance(shape, (tuple, list)) and len( + shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively' + n = shape[1] + init_range = 1.0 / math.sqrt(n) + return Uniform(init_range) + + +def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, + pad_mode='same', bias=False): + """convolution wrapper""" + weight_init_value = _initialize_weight_goog( + shape=(in_channels, kernel_size, out_channels)) + bias_init_value = _initialize_weight_goog(bias=True) if bias else None + if bias: + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + has_bias=bias, bias_init=bias_init_value) + + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + has_bias=bias) + + +def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False): + """1x1 convolution wrapper""" + weight_init_value = _initialize_weight_goog( + shape=(in_channels, 1, out_channels)) + bias_init_value = _initialize_weight_goog(bias=True) if bias else None + if bias: + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + has_bias=bias, bias_init=bias_init_value) + + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + has_bias=bias) + + +def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0, + pad_mode='same', bias=False): + """group convolution wrapper""" + weight_init_value = _initialize_weight_goog( + shape=(in_channels, kernel_size, out_channels)) + bias_init_value = _initialize_weight_goog(bias=True) if bias else None + if bias: + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + group=group, has_bias=bias, bias_init=bias_init_value) + + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + group=group, has_bias=bias) + + +def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0): + return nn.BatchNorm2d(channels, eps=eps, momentum=1-momentum, gamma_init=gamma_init, + beta_init=beta_init) + + +def _dense(in_channels, output_channels, bias=True, activation=None): + weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels), + layer_type='fc') + bias_init_value = _initialize_weight_goog(bias=True) if bias else None + if bias: + return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, + bias_init=bias_init_value, has_bias=bias, activation=activation) + + return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, + has_bias=bias, activation=activation) + + +def _resolve_bn_args(kwargs): + bn_args = _BN_ARGS_TF.copy() if kwargs.pop( + 'bn_tf', False) else _BN_ARGS_PT.copy() + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + channel_min = channel_min or divisor + new_channels = max( + int(channels + divisor / 2) // divisor * divisor, + channel_min) + # Make sure that round down does not go down by more than 10%. + if new_channels < 0.9 * channels: + new_channels += divisor + return new_channels + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str, depth_multiplier=1.0): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + print('not support') + elif v == 'r6': + print('not support') + elif v == 'hs': + print('not support') + elif v == 'sw': + print('not support') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + act_fn = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_fn=act_fn, + noskip=noskip, + ) + elif block_type in ('ds', 'dsa'): + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_fn=act_fn, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_fn=act_fn, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_fn=act_fn, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): + """further decode the architecture definition into model-ready format""" + arch_args = [] + for _, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + stack_args.append(ba) + repeats.append(rep) + arch_args.append(_scale_stage_depth( + stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +class Swish(nn.Cell): + """swish activation function""" + + def __init__(self): + super(Swish, self).__init__() + self.sigmoid = P.Sigmoid() + + def construct(self, x): + return x * self.sigmoid(x) + + +@ms_function +def swish(x): + return x * nn.Sigmoid()(x) + + +class BlockBuilder(nn.Cell): + """build efficient-net convolution blocks""" + + def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0, + channel_divisor=8, channel_min=None, pad_type='', act_fn=None, + se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None, + drop_connect_rate=0., verbose=False): + super(BlockBuilder, self).__init__() + + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.pad_type = pad_type + self.act_fn = Swish() + self.se_gate_fn = se_gate_fn + self.se_reduce_mid = se_reduce_mid + self.bn_args = bn_args + self.drop_connect_rate = drop_connect_rate + self.verbose = verbose + + # updated during build + self.in_chs = None + self.block_idx = 0 + self.block_count = 0 + self.layer = self._make_layer(builder_in_channels, builder_block_args) + + def _round_channels(self, chs): + return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba): + """make the current block based on the block argument""" + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # this is a hack to work around mismatch in origin impl input filters + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['bn_args'] = self.bn_args + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn + assert ba['act_fn'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = self.drop_connect_rate * \ + self.block_idx / self.block_count + ba['se_gate_fn'] = self.se_gate_fn + ba['se_reduce_mid'] = self.se_reduce_mid + block = InvertedResidual(**ba) + elif bt in ('ds', 'dsa'): + ba['drop_connect_rate'] = self.drop_connect_rate * \ + self.block_idx / self.block_count + block = DepthwiseSeparableConv(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] + + return block + + def _make_stack(self, stack_args): + """make a stack of blocks""" + blocks = [] + # each stack (stage) contains a list of block arguments + for i, ba in enumerate(stack_args): + if i >= 1: + # only the first block in any stack can have a stride > 1 + ba['stride'] = 1 + block = self._make_block(ba) + blocks.append(block) + self.block_idx += 1 # incr global idx (across all stacks) + return nn.SequentialCell(blocks) + + def _make_layer(self, in_chs, block_args): + """ Build the entire layer + Args: + in_chs: Number of input-channels passed to first block + block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + self.in_chs = in_chs + self.block_count = sum([len(x) for x in block_args]) + self.block_idx = 0 + blocks = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for _, stack in enumerate(block_args): + assert isinstance(stack, list) + stack = self._make_stack(stack) + blocks.append(stack) + return nn.SequentialCell(blocks) + + def construct(self, x): + return self.layer(x) + + +class DepthWiseConv(nn.Cell): + """depth-wise convolution""" + + def __init__(self, in_planes, kernel_size, stride): + super(DepthWiseConv, self).__init__() + platform = context.get_context("device_target") + weight_shape = [1, kernel_size, in_planes] + weight_init = _initialize_weight_goog(shape=weight_shape) + + if platform == "GPU": + self.depthwise_conv = P.Conv2D(out_channel=in_planes*1, + kernel_size=kernel_size, + stride=stride, + pad=int(kernel_size/2), + pad_mode="pad", + group=in_planes) + + self.weight = Parameter(initializer(weight_init, + [in_planes*1, 1, kernel_size, kernel_size]), name='depthwise_weight') + + else: + self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, + kernel_size=kernel_size, + stride=stride, pad_mode='pad', + pad=int(kernel_size/2)) + + self.weight = Parameter(initializer(weight_init, + [1, in_planes, kernel_size, kernel_size]), name='depthwise_weight') + + def construct(self, x): + x = self.depthwise_conv(x, self.weight) + return x + + +class DropConnect(nn.Cell): + """drop connect implementation""" + + def __init__(self, drop_connect_rate=0., seed0=0, seed1=0): + super(DropConnect, self).__init__() + self.shape = P.Shape() + self.dtype = P.DType() + self.keep_prob = 1 - drop_connect_rate + self.dropout = P.Dropout(keep_prob=self.keep_prob) + + def construct(self, x): + shape = self.shape(x) + dtype = self.dtype(x) + ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1) + _, mask_ = self.dropout(ones_tensor) + x = x * mask_ + return x + + +def drop_connect(inputs, training=False, drop_connect_rate=0.): + if not training: + return inputs + return DropConnect(drop_connect_rate)(inputs) + + +class SqueezeExcite(nn.Cell): + """squeeze-excite implementation""" + + def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid): + super(SqueezeExcite, self).__init__() + self.act_fn = Swish() + self.gate_fn = gate_fn + reduce_chs = reduce_chs or in_chs + self.conv_reduce = nn.Conv2d(in_channels=in_chs, out_channels=reduce_chs, + kernel_size=1, has_bias=True, pad_mode='pad') + self.conv_expand = nn.Conv2d(in_channels=reduce_chs, out_channels=in_chs, + kernel_size=1, has_bias=True, pad_mode='pad') + self.avg_global_pool = P.ReduceMean(keep_dims=True) + + def construct(self, x): + x_se = self.avg_global_pool(x, (2, 3)) + x_se = self.conv_reduce(x_se) + x_se = self.act_fn(x_se) + x_se = self.conv_expand(x_se) + x_se = self.gate_fn(x_se) + x = x * x_se + return x + + +class DepthwiseSeparableConv(nn.Cell): + """depth-wise convolution -> (squeeze-excite) -> point-wise convolution""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_fn=relu, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid, + bn_args=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + assert stride in [1, 2], 'stride must be 1 or 2' + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act + self.act_fn = Swish() + self.drop_connect_rate = drop_connect_rate + self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride) + self.bn1 = _fused_bn(in_chs, **bn_args) + + if self.has_se: + self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), + act_fn=act_fn, gate_fn=se_gate_fn) + self.conv_pw = _conv1x1(in_chs, out_chs) + self.bn2 = _fused_bn(out_chs, **bn_args) + + def construct(self, x): + """forward the depthwise separable conv""" + identity = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act_fn(x) + + if self.has_se: + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + + if self.has_pw_act: + x = self.act_fn(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x = x + identity + + return x + + +class InvertedResidual(nn.Cell): + """inverted-residual block implementation""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, + pad_type='', act_fn=relu, pw_kernel_size=1, + noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0., + se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None, + bn_args=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + mid_chs = int(in_chs * exp_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.act_fn = Swish() + self.drop_connect_rate = drop_connect_rate + + self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size) + self.bn1 = _fused_bn(mid_chs, **bn_args) + + self.shuffle_type = shuffle_type + if self.shuffle_type is not None and isinstance(exp_kernel_size, list): + self.shuffle = None + + self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride) + self.bn2 = _fused_bn(mid_chs, **bn_args) + + if self.has_se: + se_base_chs = mid_chs if se_reduce_mid else in_chs + self.se = SqueezeExcite( + mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), + act_fn=act_fn, gate_fn=se_gate_fn + ) + + self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size) + self.bn3 = _fused_bn(out_chs, **bn_args) + + def construct(self, x): + """forward the inverted-residual block""" + identity = x + + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act_fn(x) + + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act_fn(x) + + if self.has_se: + x = self.se(x) + + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0: + x = drop_connect(x, self.training, self.drop_connect_rate) + x = x + identity + return x + + +class GenEfficientNet(nn.Cell): + """Generate EfficientNet architecture""" + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0., + se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None, + global_pool='avg', head_conv='default', weight_init='goog'): + + super(GenEfficientNet, self).__init__() + bn_args = _BN_ARGS_PT if bn_args is None else bn_args + self.num_classes = num_classes + self.drop_rate = drop_rate + self.num_features = num_features + + self.conv_stem = _conv(in_chans, stem_size, 3, + stride=2, padding=1, pad_mode='pad') + self.bn1 = _fused_bn(stem_size, **bn_args) + self.act_fn = Swish() + in_chans = stem_size + self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier, + channel_divisor, channel_min, + pad_type, act_fn, se_gate_fn, se_reduce_mid, + bn_args, drop_connect_rate, verbose=_DEBUG) + in_chs = self.blocks.in_chs + + if not head_conv or head_conv == 'none': + self.efficient_head = False + self.conv_head = None + assert in_chs == self.num_features + else: + self.efficient_head = head_conv == 'efficient' + self.conv_head = _conv1x1(in_chs, self.num_features) + self.bn2 = None if self.efficient_head else _fused_bn( + self.num_features, **bn_args) + + self.global_pool = P.ReduceMean(keep_dims=True) + self.classifier = _dense(self.num_features, self.num_classes) + self.reshape = P.Reshape() + self.shape = P.Shape() + self.drop_out = nn.Dropout(keep_prob=1-self.drop_rate) + + def construct(self, x): + """efficient net entry point""" + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act_fn(x) + x = self.blocks(x) + if self.efficient_head: + x = self.global_pool(x, (2, 3)) + x = self.conv_head(x) + x = self.act_fn(x) + x = self.reshape(self.shape(x)[0], -1) + else: + if self.conv_head is not None: + x = self.conv_head(x) + x = self.bn2(x) + x = self.act_fn(x) + x = self.global_pool(x, (2, 3)) + x = self.reshape(x, (self.shape(x)[0], -1)) + + if self.training and self.drop_rate > 0.: + x = self.drop_out(x) + return self.classifier(x) + + +def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + + Args: + channel_multiplier (int): multiplier to number of channels per layer + depth_multiplier (int): multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + num_features = max(1280, _round_channels( + 1280, channel_multiplier, 8, None)) + model = GenEfficientNet( + _decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_classes=num_classes, + stem_size=32, + channel_multiplier=channel_multiplier, + num_features=num_features, + bn_args=_resolve_bn_args(kwargs), + act_fn=Swish, + **kwargs) + return model + + +def tinynet(sub_model="c", num_classes=1000, in_chans=3, **kwargs): + """ TinyNet Models """ + # choose a sub model + r, w, d = TINYNET_CFG[sub_model] + default_cfg = default_cfgs['efficientnet_b0'] + assert default_cfg['input_size'] == (3, 224, 224), "All tinynet models are \ + evolved from Efficient-B0, which has input dimension of 3*224*224" + + channel, height, width = default_cfg['input_size'] + height = int(r * height) + width = int(r * width) + default_cfg['input_size'] = (channel, height, width) + + print("Data processing configuration for current model + dataset:") + print("input_size:", default_cfg['input_size']) + print("channel mutiplier:%s, depth multiplier:%s, resolution multiplier:%s" % (w, d, r)) + + model = _gen_efficientnet( + channel_multiplier=w, depth_multiplier=d, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + + return model diff --git a/model_zoo/research/cv/tinynet/src/utils.py b/model_zoo/research/cv/tinynet/src/utils.py new file mode 100755 index 00000000000..daf4e8990a7 --- /dev/null +++ b/model_zoo/research/cv/tinynet/src/utils.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""model utils""" +import math +import argparse + +import numpy as np + + +def str2bool(value): + """Convert string arguments to bool type""" + if value.lower() in ('yes', 'true', 't', 'y', '1'): + return True + if value.lower() in ('no', 'false', 'f', 'n', '0'): + return False + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def get_lr(base_lr, total_epochs, steps_per_epoch, decay_epochs=1, decay_rate=0.9, + warmup_epochs=0., warmup_lr_init=0., global_epoch=0): + """Get scheduled learning rate""" + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + global_steps = steps_per_epoch * global_epoch + self_warmup_delta = ((base_lr - warmup_lr_init) / \ + warmup_epochs) if warmup_epochs > 0 else 0 + self_decay_rate = decay_rate if decay_rate < 1 else 1/decay_rate + for i in range(total_steps): + epochs = math.floor(i/steps_per_epoch) + cond = 1 if (epochs < warmup_epochs) else 0 + warmup_lr = warmup_lr_init + epochs * self_warmup_delta + decay_nums = math.floor(epochs / decay_epochs) + decay_rate = math.pow(self_decay_rate, decay_nums) + decay_lr = base_lr * decay_rate + lr = cond * warmup_lr + (1 - cond) * decay_lr + lr_each_step.append(lr) + lr_each_step = lr_each_step[global_steps:] + lr_each_step = np.array(lr_each_step).astype(np.float32) + return lr_each_step + + +def add_weight_decay(net, weight_decay=1e-5, skip_list=None): + """Apply weight decay to only conv and dense layers (len(shape) > =2) + Args: + net (mindspore.nn.Cell): Mindspore network instance + weight_decay (float): weight decay tobe used. + skip_list (tuple): list of parameter names without weight decay + Returns: + A list of group of parameters, separated by different weight decay. + """ + decay = [] + no_decay = [] + if not skip_list: + skip_list = () + for param in net.trainable_params(): + if len(param.shape) == 1 or \ + param.name.endswith(".bias") or \ + param.name in skip_list: + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def count_params(net): + """Count number of parameters in the network + Args: + net (mindspore.nn.Cell): Mindspore network instance + Returns: + total_params (int): Total number of trainable params + """ + total_params = 0 + for param in net.trainable_params(): + total_params += np.prod(param.shape) + return total_params diff --git a/model_zoo/research/cv/tinynet/train.py b/model_zoo/research/cv/tinynet/train.py new file mode 100755 index 00000000000..e27afe00219 --- /dev/null +++ b/model_zoo/research/cv/tinynet/train.py @@ -0,0 +1,250 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Training Interface""" +import sys +import os +import argparse +import copy + +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.model import ParallelMode, Model +from mindspore.train.callback import TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.nn import SGD, RMSProp, Loss, Top1CategoricalAccuracy, \ + Top5CategoricalAccuracy +from mindspore import context, Tensor + +from src.dataset import create_dataset, create_dataset_val +from src.utils import add_weight_decay, count_params, str2bool, get_lr +from src.callback import EmaEvalCallBack, LossMonitor +from src.loss import LabelSmoothingCrossEntropy +from src.tinynet import tinynet + +parser = argparse.ArgumentParser(description='Training') + +# training parameters +parser.add_argument('--data_path', type=str, default="", metavar="DIR", + help='path to dataset') +parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL', + help='Name of model to train (default: "tinynet_c"') +parser.add_argument('--num-classes', type=int, default=1000, metavar='N', + help='number of label classes (default: 1000)') +parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') +parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', + help='Dropout rate (default: 0.)') +parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP', + help='Drop connect rate (default: 0.)') +parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "sgd"') +parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') +parser.add_argument('--weight-decay', type=float, default=0.0001, + help='weight decay (default: 0.0001)') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', + help='warmup learning rate (default: 0.0001)') +parser.add_argument('--epochs', type=int, default=200, metavar='N', + help='number of epochs to train (default: 2)') +parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') +parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', + help='epochs to warmup LR, if scheduler supports') +parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') +parser.add_argument('--smoothing', type=float, default=0.1, + help='label smoothing (default: 0.1)') +parser.add_argument('--ema-decay', type=float, default=0, + help='decay factor for model weights moving average \ + (default: 0.999)') +parser.add_argument('--amp_level', type=str, default='O0') +parser.add_argument('--per_print_times', type=int, default=100) + +# batch norm parameters +parser.add_argument('--bn-tf', action='store_true', default=False, + help='Use Tensorflow BatchNorm defaults for models that \ + support it (default: False)') +parser.add_argument('--bn-momentum', type=float, default=None, + help='BatchNorm momentum override (if not None)') +parser.add_argument('--bn-eps', type=float, default=None, + help='BatchNorm epsilon override (if not None)') + +# parallel parameters +parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', + help='how many training processes to use (default: 1)') +parser.add_argument('--distributed', action='store_true', default=False) +parser.add_argument('--dataset_sink', action='store_true', default=True) + +# checkpoint config +parser.add_argument('--ckpt', type=str, default=None) +parser.add_argument('--ckpt_save_epoch', type=int, default=1) +parser.add_argument('--loss_scale', type=int, + default=1024, help='static loss scale') +parser.add_argument('--train', type=str2bool, default=1, help='train or eval') +parser.add_argument('--GPU', action='store_true', default=False, + help='Use GPU for training (default: False)') + + +def main(): + """Main entrance for training""" + args = parser.parse_args() + print(sys.argv) + devid, args.rank_id, args.rank_size = 0, 0, 1 + + context.set_context(mode=context.GRAPH_MODE) + + if args.distributed: + if args.GPU: + init("nccl") + context.set_context(device_target='GPU') + else: + init() + devid = int(os.getenv('DEVICE_ID')) + context.set_context(device_target='Ascend', + device_id=devid, + reserve_class_name_in_scope=False) + context.reset_auto_parallel_context() + args.rank_id = get_rank() + args.rank_size = get_group_size() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, + device_num=args.rank_size) + else: + if args.GPU: + context.set_context(device_target='GPU') + + is_master = not args.distributed or (args.rank_id == 0) + + # parse model argument + assert args.model.startswith( + "tinynet"), "Only Tinynet models are supported." + _, sub_name = args.model.split("_") + net = tinynet(sub_model=sub_name, + num_classes=args.num_classes, + drop_rate=args.drop, + drop_connect_rate=args.drop_connect, + global_pool="avg", + bn_tf=args.bn_tf, + bn_momentum=args.bn_momentum, + bn_eps=args.bn_eps) + + if is_master: + print("Total number of parameters:", count_params(net)) + # input image size of the network + input_size = net.default_cfg['input_size'][1] + + train_dataset = val_dataset = None + train_data_url = os.path.join(args.data_path, 'train') + val_data_url = os.path.join(args.data_path, 'val') + val_dataset = create_dataset_val(args.batch_size, + val_data_url, + workers=args.workers, + distributed=False, + input_size=input_size) + + if args.train: + train_dataset = create_dataset(args.batch_size, + train_data_url, + workers=args.workers, + distributed=args.distributed, + input_size=input_size) + batches_per_epoch = train_dataset.get_dataset_size() + + loss = LabelSmoothingCrossEntropy( + smooth_factor=args.smoothing, num_classes=args.num_classes) + time_cb = TimeMonitor(data_size=batches_per_epoch) + loss_scale_manager = FixedLossScaleManager( + args.loss_scale, drop_overflow_update=False) + + lr_array = get_lr(base_lr=args.lr, + total_epochs=args.epochs, + steps_per_epoch=batches_per_epoch, + decay_epochs=args.decay_epochs, + decay_rate=args.decay_rate, + warmup_epochs=args.warmup_epochs, + warmup_lr_init=args.warmup_lr, + global_epoch=0) + lr = Tensor(lr_array) + + loss_cb = LossMonitor(lr_array, + args.epochs, + per_print_times=args.per_print_times, + start_epoch=0) + + param_group = add_weight_decay(net, weight_decay=args.weight_decay) + + if args.opt == 'sgd': + if is_master: + print('Using SGD optimizer') + optimizer = SGD(param_group, + learning_rate=lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + + elif args.opt == 'rmsprop': + if is_master: + print('Using rmsprop optimizer') + optimizer = RMSProp(param_group, + learning_rate=lr, + decay=0.9, + weight_decay=args.weight_decay, + momentum=args.momentum, + epsilon=args.opt_eps, + loss_scale=args.loss_scale) + + loss.add_flags_recursive(fp32=True, fp16=False) + eval_metrics = {'Validation-Loss': Loss(), + 'Top1-Acc': Top1CategoricalAccuracy(), + 'Top5-Acc': Top5CategoricalAccuracy()} + + if args.ckpt: + ckpt = load_checkpoint(args.ckpt) + load_param_into_net(net, ckpt) + net.set_train(False) + + model = Model(net, loss, optimizer, metrics=eval_metrics, + loss_scale_manager=loss_scale_manager, + amp_level=args.amp_level) + + net_ema = copy.deepcopy(net) + net_ema.set_train(False) + assert args.ema_decay > 0, "EMA should be used in tinynet training." + + ema_cb = EmaEvalCallBack(model=model, + ema_network=net_ema, + loss_fn=loss, + eval_dataset=val_dataset, + decay=args.ema_decay, + save_epoch=args.ckpt_save_epoch, + dataset_sink_mode=args.dataset_sink, + start_epoch=0) + + callbacks = [loss_cb, ema_cb, time_cb] if is_master else [] + + if is_master: + print("Training on " + args.model + + " with " + str(args.num_classes) + " classes") + + model.train(args.epochs, train_dataset, callbacks=callbacks, + dataset_sink_mode=args.dataset_sink) + + +if __name__ == '__main__': + main()