!8570 add TinyNet-A, B, D, E to model_zoo
From: @wangrao124 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
16b34e5e58
|
@ -5,14 +5,12 @@
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [Environment Requirements](#environment-requirements)
|
- [Environment Requirements](#environment-requirements)
|
||||||
- [Script Description](#script-description)
|
- [Script Description](#script-description)
|
||||||
- [Script and Sample Code](#script-and-sample-code)
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
- [Training Process](#training-process)
|
- [Training Process](#training-process)
|
||||||
- [Evaluation Process](#evaluation-process)
|
- [Evaluation Process](#evaluation-process)
|
||||||
- [Evaluation](#evaluation)
|
|
||||||
- [Model Description](#model-description)
|
- [Model Description](#model-description)
|
||||||
- [Performance](#performance)
|
- [Performance](#performance)
|
||||||
- [Training Performance](#evaluation-performance)
|
- [Evaluation Performance](#evaluation-performance)
|
||||||
- [Inference Performance](#evaluation-performance)
|
|
||||||
- [Description of Random Situation](#description-of-random-situation)
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
@ -22,7 +20,6 @@ TinyNets are a series of lightweight models obtained by twisting resolution, dep
|
||||||
|
|
||||||
[Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020.
|
[Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020.
|
||||||
|
|
||||||
Note: We have only released TinyNet-C for now, and will release other TinyNets soon.
|
|
||||||
# [Model architecture](#contents)
|
# [Model architecture](#contents)
|
||||||
|
|
||||||
The overall network architecture of TinyNet is show below:
|
The overall network architecture of TinyNet is show below:
|
||||||
|
@ -33,53 +30,56 @@ The overall network architecture of TinyNet is show below:
|
||||||
|
|
||||||
Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/)
|
Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/)
|
||||||
|
|
||||||
- Dataset size:
|
- Dataset size:
|
||||||
- Train: 1.2 million images in 1,000 classes
|
- Train: 1.2 million images in 1,000 classes
|
||||||
- Test: 50,000 validation images in 1,000 classes
|
- Test: 50,000 validation images in 1,000 classes
|
||||||
- Data format: RGB images.
|
- Data format: RGB images.
|
||||||
- Note: Data will be processed in src/dataset/dataset.py
|
- Note: Data will be processed in src/dataset/dataset.py
|
||||||
|
|
||||||
# [Environment Requirements](#contents)
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
- Hardware (GPU)
|
- Hardware (GPU)
|
||||||
- Framework
|
- Framework
|
||||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
- For more information, please check the resources below:
|
- For more information, please check the resources below:
|
||||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
# [Script description](#contents)
|
# [Script Description](#contents)
|
||||||
|
|
||||||
## [Script and sample code](#contents)
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
```
|
```markdown
|
||||||
.tinynet
|
.tinynet
|
||||||
├── Readme.md # descriptions about tinynet
|
├── README.md # descriptions about tinynet
|
||||||
├── script
|
├── script
|
||||||
│ ├── eval.sh # evaluation script
|
│ ├── eval.sh # evaluation script
|
||||||
│ ├── train_1p_gpu.sh # training script on single GPU
|
│ ├── train_1p_gpu.sh # training script on single GPU
|
||||||
│ └── train_distributed_gpu.sh # distributed training script on multiple GPUs
|
│ └── train_distributed_gpu.sh # distributed training script on multiple GPUs
|
||||||
├── src
|
├── src
|
||||||
│ ├── callback.py # loss and checkpoint callbacks
|
│ ├── callback.py # loss, ema, and checkpoint callbacks
|
||||||
│ ├── dataset.py # data processing
|
│ ├── dataset.py # data preprocessing
|
||||||
│ ├── loss.py # label-smoothing cross-entropy loss function
|
│ ├── loss.py # label-smoothing cross-entropy loss function
|
||||||
│ ├── tinynet.py # tinynet architecture
|
│ ├── tinynet.py # tinynet architecture
|
||||||
│ └── utils.py # utility functions
|
│ └── utils.py # utility functions
|
||||||
├── eval.py # evaluation interface
|
├── eval.py # evaluation interface
|
||||||
└── train.py # training interface
|
└── train.py # training interface
|
||||||
```
|
```
|
||||||
## [Training process](#contents)
|
|
||||||
|
|
||||||
### Launch
|
### [Training process](#contents)
|
||||||
|
|
||||||
```
|
#### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
# training on single GPU
|
# training on single GPU
|
||||||
sh train_1p_gpu.sh
|
sh train_1p_gpu.sh
|
||||||
# training on multiple GPUs, the number after -n indicates how many GPUs will be used for training
|
# training on multiple GPUs, the number after -n indicates how many GPUs will be used for training
|
||||||
sh train_distributed_gpu.sh -n 8
|
sh train_distributed_gpu.sh -n 8
|
||||||
```
|
```
|
||||||
|
|
||||||
Inside train.sh, there are hyperparameters that can be adjusted during training, for example:
|
Inside train.sh, there are hyperparameters that can be adjusted during training, for example:
|
||||||
```
|
|
||||||
|
```python
|
||||||
--model tinynet_c model to be used for training
|
--model tinynet_c model to be used for training
|
||||||
--drop 0.2 dropout rate
|
--drop 0.2 dropout rate
|
||||||
--drop-connect 0 drop connect rate
|
--drop-connect 0 drop connect rate
|
||||||
|
@ -88,51 +88,55 @@ Inside train.sh, there are hyperparameters that can be adjusted during training,
|
||||||
--lr 0.048 learning rate
|
--lr 0.048 learning rate
|
||||||
--batch-size 128 batch size
|
--batch-size 128 batch size
|
||||||
--decay-epochs 2.4 learning rate decays every 2.4 epoch
|
--decay-epochs 2.4 learning rate decays every 2.4 epoch
|
||||||
--warmup-lr 1e-6 warm up learning rate
|
--warmup-lr 1e-6 warm up learning rate
|
||||||
--warmup-epochs 3 learning rate warm up epoch
|
--warmup-epochs 3 learning rate warm up epoch
|
||||||
--decay-rate 0.97 learning rate decay rate
|
--decay-rate 0.97 learning rate decay rate
|
||||||
--ema-decay 0.9999 decay factor for model weights moving average
|
--ema-decay 0.9999 decay factor for model weights moving average
|
||||||
--weight-decay 1e-5 optimizer's weight decay
|
--weight-decay 1e-5 optimizer's weight decay
|
||||||
--epochs 450 number of epochs to be trained
|
--epochs 450 number of epochs to be trained
|
||||||
--ckpt_save_epoch 1 checkpoint saving interval
|
--ckpt_save_epoch 1 checkpoint saving interval
|
||||||
--workers 8 number of processes for loading data
|
--workers 8 number of processes for loading data
|
||||||
--amp_level O0 training auto-mixed precision
|
--amp_level O0 training auto-mixed precision
|
||||||
--opt rmsprop optimizers, currently we support SGD and RMSProp
|
--opt rmsprop optimizers, currently we support SGD and RMSProp
|
||||||
--data_path /path_to_ImageNet/
|
--data_path /path_to_ImageNet/
|
||||||
--GPU using GPU for training
|
--GPU using GPU for training
|
||||||
--dataset_sink using sink mode
|
--dataset_sink using sink mode
|
||||||
```
|
```
|
||||||
The config above was used to train tinynets on ImageNet (change drop-connect to 0.2 for training tinynet-b)
|
|
||||||
|
The config above was used to train tinynets on ImageNet (change drop-connect to 0.1 for training tinynet_b)
|
||||||
|
|
||||||
> checkpoints will be saved in the ./device_{rank_id} folder (single GPU)
|
> checkpoints will be saved in the ./device_{rank_id} folder (single GPU)
|
||||||
or ./device_parallel folder (multiple GPUs)
|
or ./device_parallel folder (multiple GPUs)
|
||||||
|
|
||||||
## [Eval process](#contents)
|
### [Evaluation Process](#contents)
|
||||||
|
|
||||||
### Launch
|
#### Launch
|
||||||
|
|
||||||
```
|
```bash
|
||||||
# infer example
|
# infer example
|
||||||
|
|
||||||
sh eval.sh
|
sh eval.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
Inside the eval.sh, there are configs that can be adjusted during inference, for example:
|
Inside the eval.sh, there are configs that can be adjusted during inference, for example:
|
||||||
```
|
|
||||||
--num-classes 1000
|
```python
|
||||||
--batch-size 128
|
--num-classes 1000
|
||||||
--workers 8
|
--batch-size 128
|
||||||
--data_path /path_to_ImageNet/
|
--workers 8
|
||||||
--GPU
|
--data_path /path_to_ImageNet/
|
||||||
--ckpt /path_to_EMA_checkpoint/
|
--GPU
|
||||||
|
--ckpt /path_to_EMA_checkpoint/
|
||||||
--dataset_sink > tinynet_c_eval.log 2>&1 &
|
--dataset_sink > tinynet_c_eval.log 2>&1 &
|
||||||
```
|
```
|
||||||
|
|
||||||
> checkpoint can be produced in training process.
|
> checkpoint can be produced in training process.
|
||||||
|
|
||||||
# [Model Description](#contents)
|
# [Model Description](#contents)
|
||||||
|
|
||||||
## [Performance](#contents)
|
## [Performance](#contents)
|
||||||
|
|
||||||
#### Evaluation Performance
|
### Evaluation Performance
|
||||||
|
|
||||||
| Model | FLOPs | Latency* | ImageNet Top-1 |
|
| Model | FLOPs | Latency* | ImageNet Top-1 |
|
||||||
| ------------------- | ----- | -------- | -------------- |
|
| ------------------- | ----- | -------- | -------------- |
|
||||||
|
@ -149,6 +153,6 @@ Inside the eval.sh, there are configs that can be adjusted during inference, for
|
||||||
|
|
||||||
We set the seed inside dataset.py. We also use random seed in train.py.
|
We set the seed inside dataset.py. We also use random seed in train.py.
|
||||||
|
|
||||||
# [Model Zoo Homepage](#contents)
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||||
|
|
|
@ -36,7 +36,7 @@ def load_nparray_into_net(net, array_dict):
|
||||||
for _, param in net.parameters_and_names():
|
for _, param in net.parameters_and_names():
|
||||||
if param.name in array_dict:
|
if param.name in array_dict:
|
||||||
new_param = array_dict[param.name]
|
new_param = array_dict[param.name]
|
||||||
param.set_data(Parameter(new_param.copy(), name=param.name))
|
param.set_data(Parameter(Tensor(deepcopy(new_param)), name=param.name))
|
||||||
else:
|
else:
|
||||||
param_not_load.append(param.name)
|
param_not_load.append(param.name)
|
||||||
return param_not_load
|
return param_not_load
|
||||||
|
@ -48,8 +48,8 @@ class EmaEvalCallBack(Callback):
|
||||||
the end of training epoch.
|
the end of training epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Mindspore model instance.
|
network: tinynet network instance.
|
||||||
ema_network: step-wise exponential moving average for ema_network.
|
ema_network: step-wise exponential moving average of network.
|
||||||
eval_dataset: the evaluation daatset.
|
eval_dataset: the evaluation daatset.
|
||||||
decay (float): ema decay.
|
decay (float): ema decay.
|
||||||
save_epoch (int): defines how often to save checkpoint.
|
save_epoch (int): defines how often to save checkpoint.
|
||||||
|
@ -57,9 +57,9 @@ class EmaEvalCallBack(Callback):
|
||||||
start_epoch (int): which epoch to start/resume training.
|
start_epoch (int): which epoch to start/resume training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, ema_network, eval_dataset, loss_fn, decay=0.999,
|
def __init__(self, network, ema_network, eval_dataset, loss_fn, decay=0.999,
|
||||||
save_epoch=1, dataset_sink_mode=True, start_epoch=0):
|
save_epoch=1, dataset_sink_mode=True, start_epoch=0):
|
||||||
self.model = model
|
self.network = network
|
||||||
self.ema_network = ema_network
|
self.ema_network = ema_network
|
||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
|
@ -80,14 +80,12 @@ class EmaEvalCallBack(Callback):
|
||||||
|
|
||||||
def begin(self, run_context):
|
def begin(self, run_context):
|
||||||
"""Initialize the EMA parameters """
|
"""Initialize the EMA parameters """
|
||||||
cb_params = run_context.original_args()
|
for _, param in self.network.parameters_and_names():
|
||||||
for _, param in cb_params.network.parameters_and_names():
|
|
||||||
self.shadow[param.name] = deepcopy(param.data.asnumpy())
|
self.shadow[param.name] = deepcopy(param.data.asnumpy())
|
||||||
|
|
||||||
def step_end(self, run_context):
|
def step_end(self, run_context):
|
||||||
"""Update the EMA parameters"""
|
"""Update the EMA parameters"""
|
||||||
cb_params = run_context.original_args()
|
for _, param in self.network.parameters_and_names():
|
||||||
for _, param in cb_params.network.parameters_and_names():
|
|
||||||
new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \
|
new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \
|
||||||
self.decay * self.shadow[param.name]
|
self.decay * self.shadow[param.name]
|
||||||
self.shadow[param.name] = new_average
|
self.shadow[param.name] = new_average
|
||||||
|
@ -98,24 +96,20 @@ class EmaEvalCallBack(Callback):
|
||||||
cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1
|
cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1
|
||||||
|
|
||||||
save_ckpt = (cur_epoch % self.save_epoch == 0)
|
save_ckpt = (cur_epoch % self.save_epoch == 0)
|
||||||
|
|
||||||
acc = self.model.eval(
|
|
||||||
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
|
|
||||||
print("Model Accuracy:", acc)
|
|
||||||
|
|
||||||
load_nparray_into_net(self.ema_network, self.shadow)
|
load_nparray_into_net(self.ema_network, self.shadow)
|
||||||
self.ema_network.set_train(False)
|
model = Model(self.network, loss_fn=self.loss_fn, metrics=self.eval_metrics)
|
||||||
|
|
||||||
model_ema = Model(self.ema_network, loss_fn=self.loss_fn,
|
model_ema = Model(self.ema_network, loss_fn=self.loss_fn,
|
||||||
metrics=self.eval_metrics)
|
metrics=self.eval_metrics)
|
||||||
|
acc = model.eval(
|
||||||
|
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
|
||||||
ema_acc = model_ema.eval(
|
ema_acc = model_ema.eval(
|
||||||
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
|
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
|
||||||
|
print("Model Accuracy:", acc)
|
||||||
print("EMA-Model Accuracy:", ema_acc)
|
print("EMA-Model Accuracy:", ema_acc)
|
||||||
self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"]
|
|
||||||
output = [{"name": k, "data": Tensor(v)}
|
output = [{"name": k, "data": Tensor(v)}
|
||||||
for k, v in self.shadow.items()]
|
for k, v in self.shadow.items()]
|
||||||
|
self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"]
|
||||||
if self.best_ema_accuracy < ema_acc["Top1-Acc"]:
|
if self.best_ema_accuracy < ema_acc["Top1-Acc"]:
|
||||||
self.best_ema_accuracy = ema_acc["Top1-Acc"]
|
self.best_ema_accuracy = ema_acc["Top1-Acc"]
|
||||||
self.best_ema_epoch = cur_epoch
|
self.best_ema_epoch = cur_epoch
|
||||||
|
|
|
@ -65,12 +65,12 @@ def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
|
||||||
contrast=adjust_range,
|
contrast=adjust_range,
|
||||||
saturation=adjust_range)
|
saturation=adjust_range)
|
||||||
to_tensor = py_vision.ToTensor()
|
to_tensor = py_vision.ToTensor()
|
||||||
nromlize_op = py_vision.Normalize(
|
normalize_op = py_vision.Normalize(
|
||||||
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||||
|
|
||||||
# assemble all the transforms
|
# assemble all the transforms
|
||||||
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
|
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
|
||||||
random_horizontal_flip_op, random_color_jitter_op, to_tensor, nromlize_op])
|
random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op])
|
||||||
|
|
||||||
rank_id = get_rank() if distributed else 0
|
rank_id = get_rank() if distributed else 0
|
||||||
rank_size = get_group_size() if distributed else 1
|
rank_size = get_group_size() if distributed else 1
|
||||||
|
@ -125,11 +125,11 @@ def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=F
|
||||||
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
|
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
|
||||||
center_crop = py_vision.CenterCrop(size=input_size)
|
center_crop = py_vision.CenterCrop(size=input_size)
|
||||||
to_tensor = py_vision.ToTensor()
|
to_tensor = py_vision.ToTensor()
|
||||||
nromlize_op = py_vision.Normalize(
|
normalize_op = py_vision.Normalize(
|
||||||
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||||
|
|
||||||
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
|
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
|
||||||
to_tensor, nromlize_op])
|
to_tensor, normalize_op])
|
||||||
|
|
||||||
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
|
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
|
||||||
num_parallel_workers=workers)
|
num_parallel_workers=workers)
|
||||||
|
|
|
@ -18,10 +18,12 @@ import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform
|
from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform
|
||||||
from mindspore import context, ms_function
|
from mindspore import context, ms_function
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
# Imagenet constant values
|
# Imagenet constant values
|
||||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||||
|
@ -29,12 +31,14 @@ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||||
|
|
||||||
# model structure configurations for TinyNets, values are
|
# model structure configurations for TinyNets, values are
|
||||||
# (resolution multiplier, channel multiplier, depth multiplier)
|
# (resolution multiplier, channel multiplier, depth multiplier)
|
||||||
# only tinynet-c is availiable for now, we will release other tinynet
|
|
||||||
# models soon
|
|
||||||
# codes are inspired and partially adapted from
|
# codes are inspired and partially adapted from
|
||||||
# https://github.com/rwightman/gen-efficientnet-pytorch
|
# https://github.com/rwightman/gen-efficientnet-pytorch
|
||||||
|
|
||||||
TINYNET_CFG = {"c": (0.825, 0.54, 0.85)}
|
TINYNET_CFG = {"a": (0.86, 1.0, 1.2),
|
||||||
|
"b": (0.84, 0.75, 1.1),
|
||||||
|
"c": (0.825, 0.54, 0.85),
|
||||||
|
"d": (0.68, 0.54, 0.695),
|
||||||
|
"e": (0.475, 0.51, 0.60)}
|
||||||
|
|
||||||
relu = P.ReLU()
|
relu = P.ReLU()
|
||||||
sigmoid = P.Sigmoid()
|
sigmoid = P.Sigmoid()
|
||||||
|
@ -524,13 +528,15 @@ class DropConnect(nn.Cell):
|
||||||
self.dtype = P.DType()
|
self.dtype = P.DType()
|
||||||
self.keep_prob = 1 - drop_connect_rate
|
self.keep_prob = 1 - drop_connect_rate
|
||||||
self.dropout = P.Dropout(keep_prob=self.keep_prob)
|
self.dropout = P.Dropout(keep_prob=self.keep_prob)
|
||||||
|
self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
shape = self.shape(x)
|
shape = self.shape(x)
|
||||||
dtype = self.dtype(x)
|
dtype = self.dtype(x)
|
||||||
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
|
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
|
||||||
_, mask_ = self.dropout(ones_tensor)
|
_, mask = self.dropout(ones_tensor)
|
||||||
x = x * mask_
|
x = x * mask
|
||||||
|
x = x / self.keep_prob_tensor
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -227,7 +227,7 @@ def main():
|
||||||
net_ema.set_train(False)
|
net_ema.set_train(False)
|
||||||
assert args.ema_decay > 0, "EMA should be used in tinynet training."
|
assert args.ema_decay > 0, "EMA should be used in tinynet training."
|
||||||
|
|
||||||
ema_cb = EmaEvalCallBack(model=model,
|
ema_cb = EmaEvalCallBack(network=net,
|
||||||
ema_network=net_ema,
|
ema_network=net_ema,
|
||||||
loss_fn=loss,
|
loss_fn=loss,
|
||||||
eval_dataset=val_dataset,
|
eval_dataset=val_dataset,
|
||||||
|
|
Loading…
Reference in New Issue