add TinyNet-A, B, D, E

This commit is contained in:
yanglinfeng 2020-11-05 14:30:07 +08:00 committed by wangrao
parent 882301f4b5
commit 357bf44bf7
5 changed files with 77 additions and 73 deletions

View File

@ -5,14 +5,12 @@
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
@ -22,7 +20,6 @@ TinyNets are a series of lightweight models obtained by twisting resolution, dep
[Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020.
Note: We have only released TinyNet-C for now, and will release other TinyNets soon.
# [Model architecture](#contents)
The overall network architecture of TinyNet is show below:
@ -33,53 +30,56 @@ The overall network architecture of TinyNet is show below:
Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/)
- Dataset size:
- Train: 1.2 million images in 1,000 classes
- Test: 50,000 validation images in 1,000 classes
- Dataset size:
- Train: 1.2 million images in 1,000 classes
- Test: 50,000 validation images in 1,000 classes
- Data format: RGB images.
- Note: Data will be processed in src/dataset/dataset.py
- Note: Data will be processed in src/dataset/dataset.py
# [Environment Requirements](#contents)
- Hardware (GPU)
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
# [Script Description](#contents)
## [Script and sample code](#contents)
## [Script and Sample Code](#contents)
```
```markdown
.tinynet
├── Readme.md # descriptions about tinynet
├── README.md # descriptions about tinynet
├── script
│ ├── eval.sh # evaluation script
│ ├── train_1p_gpu.sh # training script on single GPU
│ └── train_distributed_gpu.sh # distributed training script on multiple GPUs
├── src
│ ├── callback.py # loss and checkpoint callbacks
│ ├── dataset.py # data processing
│ ├── callback.py # loss, ema, and checkpoint callbacks
│ ├── dataset.py # data preprocessing
│ ├── loss.py # label-smoothing cross-entropy loss function
│ ├── tinynet.py # tinynet architecture
│ └── utils.py # utility functions
│ └── utils.py # utility functions
├── eval.py # evaluation interface
└── train.py # training interface
```
## [Training process](#contents)
### Launch
### [Training process](#contents)
```
#### Launch
```bash
# training on single GPU
sh train_1p_gpu.sh
# training on multiple GPUs, the number after -n indicates how many GPUs will be used for training
sh train_distributed_gpu.sh -n 8
```
Inside train.sh, there are hyperparameters that can be adjusted during training, for example:
```
```python
--model tinynet_c model to be used for training
--drop 0.2 dropout rate
--drop-connect 0 drop connect rate
@ -88,51 +88,55 @@ Inside train.sh, there are hyperparameters that can be adjusted during training,
--lr 0.048 learning rate
--batch-size 128 batch size
--decay-epochs 2.4 learning rate decays every 2.4 epoch
--warmup-lr 1e-6 warm up learning rate
--warmup-lr 1e-6 warm up learning rate
--warmup-epochs 3 learning rate warm up epoch
--decay-rate 0.97 learning rate decay rate
--ema-decay 0.9999 decay factor for model weights moving average
--weight-decay 1e-5 optimizer's weight decay
--epochs 450 number of epochs to be trained
--ckpt_save_epoch 1 checkpoint saving interval
--ckpt_save_epoch 1 checkpoint saving interval
--workers 8 number of processes for loading data
--amp_level O0 training auto-mixed precision
--opt rmsprop optimizers, currently we support SGD and RMSProp
--data_path /path_to_ImageNet/
--data_path /path_to_ImageNet/
--GPU using GPU for training
--dataset_sink using sink mode
```
The config above was used to train tinynets on ImageNet (change drop-connect to 0.2 for training tinynet-b)
The config above was used to train tinynets on ImageNet (change drop-connect to 0.1 for training tinynet_b)
> checkpoints will be saved in the ./device_{rank_id} folder (single GPU)
or ./device_parallel folder (multiple GPUs)
## [Eval process](#contents)
### [Evaluation Process](#contents)
### Launch
#### Launch
```
```bash
# infer example
sh eval.sh
```
Inside the eval.sh, there are configs that can be adjusted during inference, for example:
```
--num-classes 1000
--batch-size 128
--workers 8
--data_path /path_to_ImageNet/
--GPU
--ckpt /path_to_EMA_checkpoint/
```python
--num-classes 1000
--batch-size 128
--workers 8
--data_path /path_to_ImageNet/
--GPU
--ckpt /path_to_EMA_checkpoint/
--dataset_sink > tinynet_c_eval.log 2>&1 &
```
> checkpoint can be produced in training process.
# [Model Description](#contents)
## [Performance](#contents)
#### Evaluation Performance
### Evaluation Performance
| Model | FLOPs | Latency* | ImageNet Top-1 |
| ------------------- | ----- | -------- | -------------- |
@ -149,6 +153,6 @@ Inside the eval.sh, there are configs that can be adjusted during inference, for
We set the seed inside dataset.py. We also use random seed in train.py.
# [Model Zoo Homepage](#contents)
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -36,7 +36,7 @@ def load_nparray_into_net(net, array_dict):
for _, param in net.parameters_and_names():
if param.name in array_dict:
new_param = array_dict[param.name]
param.set_data(Parameter(new_param.copy(), name=param.name))
param.set_data(Parameter(Tensor(deepcopy(new_param)), name=param.name))
else:
param_not_load.append(param.name)
return param_not_load
@ -48,8 +48,8 @@ class EmaEvalCallBack(Callback):
the end of training epoch.
Args:
model: Mindspore model instance.
ema_network: step-wise exponential moving average for ema_network.
network: tinynet network instance.
ema_network: step-wise exponential moving average of network.
eval_dataset: the evaluation daatset.
decay (float): ema decay.
save_epoch (int): defines how often to save checkpoint.
@ -57,9 +57,9 @@ class EmaEvalCallBack(Callback):
start_epoch (int): which epoch to start/resume training.
"""
def __init__(self, model, ema_network, eval_dataset, loss_fn, decay=0.999,
def __init__(self, network, ema_network, eval_dataset, loss_fn, decay=0.999,
save_epoch=1, dataset_sink_mode=True, start_epoch=0):
self.model = model
self.network = network
self.ema_network = ema_network
self.eval_dataset = eval_dataset
self.loss_fn = loss_fn
@ -80,14 +80,12 @@ class EmaEvalCallBack(Callback):
def begin(self, run_context):
"""Initialize the EMA parameters """
cb_params = run_context.original_args()
for _, param in cb_params.network.parameters_and_names():
for _, param in self.network.parameters_and_names():
self.shadow[param.name] = deepcopy(param.data.asnumpy())
def step_end(self, run_context):
"""Update the EMA parameters"""
cb_params = run_context.original_args()
for _, param in cb_params.network.parameters_and_names():
for _, param in self.network.parameters_and_names():
new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \
self.decay * self.shadow[param.name]
self.shadow[param.name] = new_average
@ -98,24 +96,20 @@ class EmaEvalCallBack(Callback):
cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1
save_ckpt = (cur_epoch % self.save_epoch == 0)
acc = self.model.eval(
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
print("Model Accuracy:", acc)
load_nparray_into_net(self.ema_network, self.shadow)
self.ema_network.set_train(False)
model = Model(self.network, loss_fn=self.loss_fn, metrics=self.eval_metrics)
model_ema = Model(self.ema_network, loss_fn=self.loss_fn,
metrics=self.eval_metrics)
acc = model.eval(
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
ema_acc = model_ema.eval(
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
print("Model Accuracy:", acc)
print("EMA-Model Accuracy:", ema_acc)
self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"]
output = [{"name": k, "data": Tensor(v)}
for k, v in self.shadow.items()]
self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"]
if self.best_ema_accuracy < ema_acc["Top1-Acc"]:
self.best_ema_accuracy = ema_acc["Top1-Acc"]
self.best_ema_epoch = cur_epoch

View File

@ -65,12 +65,12 @@ def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
contrast=adjust_range,
saturation=adjust_range)
to_tensor = py_vision.ToTensor()
nromlize_op = py_vision.Normalize(
normalize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
# assemble all the transforms
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
random_horizontal_flip_op, random_color_jitter_op, to_tensor, nromlize_op])
random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op])
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
@ -125,11 +125,11 @@ def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=F
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
center_crop = py_vision.CenterCrop(size=input_size)
to_tensor = py_vision.ToTensor()
nromlize_op = py_vision.Normalize(
normalize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
to_tensor, nromlize_op])
to_tensor, normalize_op])
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
num_parallel_workers=workers)

View File

@ -18,10 +18,12 @@ import re
from copy import deepcopy
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform
from mindspore import context, ms_function
from mindspore.common.parameter import Parameter
from mindspore import Tensor
# Imagenet constant values
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
@ -29,12 +31,14 @@ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# model structure configurations for TinyNets, values are
# (resolution multiplier, channel multiplier, depth multiplier)
# only tinynet-c is availiable for now, we will release other tinynet
# models soon
# codes are inspired and partially adapted from
# https://github.com/rwightman/gen-efficientnet-pytorch
TINYNET_CFG = {"c": (0.825, 0.54, 0.85)}
TINYNET_CFG = {"a": (0.86, 1.0, 1.2),
"b": (0.84, 0.75, 1.1),
"c": (0.825, 0.54, 0.85),
"d": (0.68, 0.54, 0.695),
"e": (0.475, 0.51, 0.60)}
relu = P.ReLU()
sigmoid = P.Sigmoid()
@ -524,13 +528,15 @@ class DropConnect(nn.Cell):
self.dtype = P.DType()
self.keep_prob = 1 - drop_connect_rate
self.dropout = P.Dropout(keep_prob=self.keep_prob)
self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32)
def construct(self, x):
shape = self.shape(x)
dtype = self.dtype(x)
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
_, mask_ = self.dropout(ones_tensor)
x = x * mask_
_, mask = self.dropout(ones_tensor)
x = x * mask
x = x / self.keep_prob_tensor
return x

View File

@ -227,7 +227,7 @@ def main():
net_ema.set_train(False)
assert args.ema_decay > 0, "EMA should be used in tinynet training."
ema_cb = EmaEvalCallBack(model=model,
ema_cb = EmaEvalCallBack(network=net,
ema_network=net_ema,
loss_fn=loss,
eval_dataset=val_dataset,