!12028 Add MobileNetV3 CPU scripts
From: @wuxuejian Reviewed-by: @c_34,@liangchenghui,@oacjiewen,@liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
07e9c79de8
|
@ -50,8 +50,8 @@ MobileNetV3总体网络架构如下:
|
|||
|
||||
# 环境要求
|
||||
|
||||
- 硬件:GPU
|
||||
- 准备GPU处理器搭建硬件环境。
|
||||
- 硬件:GPU/CPU
|
||||
- 准备GPU/CPU处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
|
@ -86,6 +86,7 @@ MobileNetV3总体网络架构如下:
|
|||
使用python或shell脚本开始训练。shell脚本的使用方法如下:
|
||||
|
||||
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
|
||||
- CPU: sh run_trian.sh CPU [DATASET_PATH]
|
||||
|
||||
### 启动
|
||||
|
||||
|
@ -93,8 +94,10 @@ MobileNetV3总体网络架构如下:
|
|||
# 训练示例
|
||||
python:
|
||||
GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU
|
||||
CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU
|
||||
shell:
|
||||
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
|
||||
CPU: sh run_train.sh CPU ~/cifar10/train/
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
@ -115,6 +118,7 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917
|
|||
使用python或shell脚本开始训练。shell脚本的使用方法如下:
|
||||
|
||||
- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
- CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
|
||||
### 启动
|
||||
|
||||
|
@ -122,9 +126,11 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917
|
|||
# 推理示例
|
||||
python:
|
||||
GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU
|
||||
CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU
|
||||
|
||||
shell:
|
||||
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
|
||||
CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt
|
||||
```
|
||||
|
||||
> 训练过程中可以生成检查点。
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
# [MobileNetV3 Description](#contents)
|
||||
|
||||
|
||||
MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019.
|
||||
|
@ -35,37 +34,35 @@ The overall network architecture of MobileNetV3 is show below:
|
|||
Dataset used: [imagenet](http://www.image-net.org/)
|
||||
|
||||
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
|
||||
- Train: 120G, 1.2W images
|
||||
- Test: 5G, 50000 images
|
||||
- Train: 120G, 1.2W images
|
||||
- Test: 5G, 50000 images
|
||||
- Data format: RGB images.
|
||||
- Note: Data will be processed in src/dataset.py
|
||||
|
||||
- Note: Data will be processed in src/dataset.py
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Hardware(GPU/CPU)
|
||||
- Prepare hardware environment with GPU/CPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
||||
```python
|
||||
├── MobileNetV3
|
||||
├── Readme.md # descriptions about MobileNetV3
|
||||
├── scripts
|
||||
│ ├──run_train.sh # shell script for train
|
||||
│ ├──run_eval.sh # shell script for evaluation
|
||||
├── src
|
||||
│ ├──config.py # parameter configuration
|
||||
├── MobileNetV3
|
||||
├── Readme.md # descriptions about MobileNetV3
|
||||
├── scripts
|
||||
│ ├──run_train.sh # shell script for train
|
||||
│ ├──run_eval.sh # shell script for evaluation
|
||||
├── src
|
||||
│ ├──config.py # parameter configuration
|
||||
│ ├──dataset.py # creating dataset
|
||||
│ ├──lr_generator.py # learning rate config
|
||||
│ ├──lr_generator.py # learning rate config
|
||||
│ ├──mobilenetV3.py # MobileNetV3 architecture
|
||||
├── train.py # training script
|
||||
├── eval.py # evaluation script
|
||||
|
@ -80,22 +77,25 @@ Dataset used: [imagenet](http://www.image-net.org/)
|
|||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
|
||||
- CPU: sh run_trian.sh CPU [DATASET_PATH]
|
||||
|
||||
### Launch
|
||||
|
||||
```
|
||||
```shell
|
||||
# training example
|
||||
python:
|
||||
GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU
|
||||
CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU
|
||||
shell:
|
||||
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
|
||||
CPU: sh run_train.sh CPU ~/cifar10/train/
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
|
||||
|
||||
```
|
||||
```bash
|
||||
epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100]
|
||||
epoch time: 140522.500, per step time: 224.836, avg loss: 5.258
|
||||
epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200]
|
||||
|
@ -109,25 +109,28 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
|
|||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
- CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
|
||||
### Launch
|
||||
|
||||
```
|
||||
```shell
|
||||
# infer example
|
||||
python:
|
||||
GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU
|
||||
CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU
|
||||
|
||||
shell:
|
||||
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
|
||||
CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
### Result
|
||||
|
||||
Inference result will be stored in the example path, you can find result like the followings in `val.log`.
|
||||
Inference result will be stored in the example path, you can find result like the followings in `val.log`.
|
||||
|
||||
```
|
||||
```bash
|
||||
result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt
|
||||
```
|
||||
|
||||
|
@ -135,7 +138,7 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.
|
|||
|
||||
Change the export mode and export file in `src/config.py`, and run `export.py`.
|
||||
|
||||
```
|
||||
```python
|
||||
python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH]
|
||||
```
|
||||
|
||||
|
@ -168,5 +171,5 @@ python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH]
|
|||
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
|
@ -21,7 +21,9 @@ from mindspore import nn
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import create_dataset
|
||||
from src.dataset import create_dataset_cifar
|
||||
from src.config import config_gpu
|
||||
from src.config import config_cpu
|
||||
from src.mobilenetV3 import mobilenet_v3_large
|
||||
|
||||
|
||||
|
@ -38,17 +40,24 @@ if __name__ == '__main__':
|
|||
config = config_gpu
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="GPU", save_graphs=False)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=False,
|
||||
config=config,
|
||||
device_target=args_opt.device_target,
|
||||
batch_size=config.batch_size)
|
||||
elif args_opt.device_target == "CPU":
|
||||
config = config_cpu
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="CPU", save_graphs=False)
|
||||
dataset = create_dataset_cifar(dataset_path=args_opt.dataset_path,
|
||||
do_train=False,
|
||||
batch_size=config.batch_size)
|
||||
else:
|
||||
raise ValueError("Unsupported device_target.")
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax")
|
||||
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=False,
|
||||
config=config,
|
||||
device_target=args_opt.device_target,
|
||||
batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
|
|
|
@ -19,6 +19,7 @@ import argparse
|
|||
import numpy as np
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.config import config_gpu
|
||||
from src.config import config_cpu
|
||||
from src.mobilenetV3 import mobilenet_v3_large
|
||||
|
||||
|
||||
|
@ -32,6 +33,9 @@ if __name__ == '__main__':
|
|||
if args_opt.device_target == "GPU":
|
||||
cfg = config_gpu
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
elif args_opt.device_target == "CPU":
|
||||
cfg = config_cpu
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
else:
|
||||
raise ValueError("Unsupported device_target.")
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
if [ $# != 3 ]
|
||||
then
|
||||
echo "GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
echo "CPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
@ -16,6 +16,14 @@
|
|||
|
||||
run_gpu()
|
||||
{
|
||||
if [ $# -gt 5 ] || [ $# -lt 4 ]
|
||||
then
|
||||
echo "Usage:\n \
|
||||
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
|
||||
CPU: sh run_train.sh CPU [DATASET_PATH]\n \
|
||||
"
|
||||
exit 1
|
||||
fi
|
||||
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
|
||||
then
|
||||
echo "error: DEVICE_NUM=$2 is not in (1-8)"
|
||||
|
@ -45,16 +53,42 @@ run_gpu()
|
|||
&> ../train.log & # dataset train folder
|
||||
}
|
||||
|
||||
if [ $# -gt 5 ] || [ $# -lt 4 ]
|
||||
then
|
||||
echo "Usage:\n \
|
||||
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
|
||||
"
|
||||
exit 1
|
||||
fi
|
||||
run_cpu()
|
||||
{
|
||||
if [ $# -gt 3 ] || [ $# -lt 2 ]
|
||||
then
|
||||
echo "Usage:\n \
|
||||
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
|
||||
CPU: sh run_train.sh CPU [DATASET_PATH]\n \
|
||||
"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $2 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "../train" ];
|
||||
then
|
||||
rm -rf ../train
|
||||
fi
|
||||
mkdir ../train
|
||||
cd ../train || exit
|
||||
|
||||
python ${BASEPATH}/../train.py \
|
||||
--dataset_path=$2 \
|
||||
--device_target=$1 \
|
||||
&> ../train.log & # dataset train folder
|
||||
}
|
||||
|
||||
if [ $1 = "GPU" ] ; then
|
||||
run_gpu "$@"
|
||||
elif [ $1 = "CPU" ] ; then
|
||||
run_cpu "$@"
|
||||
else
|
||||
echo "Unsupported device_target"
|
||||
fi;
|
||||
|
|
|
@ -36,3 +36,23 @@ config_gpu = ed({
|
|||
"export_format": "MINDIR",
|
||||
"export_file": "mobilenetv3"
|
||||
})
|
||||
|
||||
config_cpu = ed({
|
||||
"num_classes": 10,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 32,
|
||||
"epoch_size": 120,
|
||||
"warmup_epochs": 5,
|
||||
"lr": 0.1,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 500,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"export_format": "MINDIR",
|
||||
"export_file": "mobilenetv3"
|
||||
})
|
||||
|
|
|
@ -83,3 +83,60 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
|
|||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
def create_dataset_cifar(dataset_path,
|
||||
do_train,
|
||||
repeat_num=1,
|
||||
batch_size=32,
|
||||
target="CPU"):
|
||||
"""
|
||||
create a train or evaluate cifar10 dataset
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
data_set = ds.Cifar10Dataset(dataset_path,
|
||||
num_parallel_workers=8,
|
||||
shuffle=True)
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
C.Resize((224, 224)),
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||
C.CutOut(112),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Resize((224, 224)),
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=type_cast_op,
|
||||
input_columns="label",
|
||||
num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans,
|
||||
input_columns="image",
|
||||
num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
|
|
@ -37,8 +37,10 @@ from mindspore.common import set_seed
|
|||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.dataset import create_dataset_cifar
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import config_gpu
|
||||
from src.config import config_cpu
|
||||
from src.mobilenetV3 import mobilenet_v3_large
|
||||
|
||||
set_seed(1)
|
||||
|
@ -59,6 +61,10 @@ if args_opt.device_target == "GPU":
|
|||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
elif args_opt.device_target == "CPU":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="CPU",
|
||||
save_graphs=False)
|
||||
else:
|
||||
raise ValueError("Unsupported device_target.")
|
||||
|
||||
|
@ -151,58 +157,71 @@ class Monitor(Callback):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_ = None
|
||||
if args_opt.device_target == "GPU":
|
||||
# train on gpu
|
||||
print("train args: ", args_opt)
|
||||
print("cfg: ", config_gpu)
|
||||
config_ = config_gpu
|
||||
elif args_opt.device_target == "CPU":
|
||||
config_ = config_cpu
|
||||
else:
|
||||
raise ValueError("Unsupported device_target.")
|
||||
# train on device
|
||||
print("train args: ", args_opt)
|
||||
print("cfg: ", config_)
|
||||
|
||||
# define net
|
||||
net = mobilenet_v3_large(num_classes=config_gpu.num_classes)
|
||||
# define loss
|
||||
if config_gpu.label_smooth > 0:
|
||||
loss = CrossEntropyWithLabelSmooth(
|
||||
smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
# define dataset
|
||||
epoch_size = config_gpu.epoch_size
|
||||
# define net
|
||||
net = mobilenet_v3_large(num_classes=config_.num_classes)
|
||||
# define loss
|
||||
if config_.label_smooth > 0:
|
||||
loss = CrossEntropyWithLabelSmooth(
|
||||
smooth_factor=config_.label_smooth, num_classes=config_.num_classes)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
# define dataset
|
||||
epoch_size = config_.epoch_size
|
||||
if args_opt.device_target == "GPU":
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=True,
|
||||
config=config_gpu,
|
||||
config=config_,
|
||||
device_target=args_opt.device_target,
|
||||
repeat_num=1,
|
||||
batch_size=config_gpu.batch_size,
|
||||
run_distribute=args_opt.run_distribute)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# resume
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
# define optimizer
|
||||
loss_scale = FixedLossScaleManager(
|
||||
config_gpu.loss_scale, drop_overflow_update=False)
|
||||
lr = Tensor(get_lr(global_step=0,
|
||||
lr_init=0,
|
||||
lr_end=0,
|
||||
lr_max=config_gpu.lr,
|
||||
warmup_epochs=config_gpu.warmup_epochs,
|
||||
total_epochs=epoch_size,
|
||||
steps_per_epoch=step_size))
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum,
|
||||
config_gpu.weight_decay, config_gpu.loss_scale)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, optimizer=opt,
|
||||
loss_scale_manager=loss_scale)
|
||||
batch_size=config_.batch_size,
|
||||
run_distribute=False)
|
||||
elif args_opt.device_target == "CPU":
|
||||
dataset = create_dataset_cifar(args_opt.dataset_path,
|
||||
do_train=True,
|
||||
batch_size=config_.batch_size)
|
||||
else:
|
||||
raise ValueError("Unsupported device_target.")
|
||||
step_size = dataset.get_dataset_size()
|
||||
# resume
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
# define optimizer
|
||||
loss_scale = FixedLossScaleManager(
|
||||
config_.loss_scale, drop_overflow_update=False)
|
||||
lr = Tensor(get_lr(global_step=0,
|
||||
lr_init=0,
|
||||
lr_end=0,
|
||||
lr_max=config_.lr,
|
||||
warmup_epochs=config_.warmup_epochs,
|
||||
total_epochs=epoch_size,
|
||||
steps_per_epoch=step_size))
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_.momentum,
|
||||
config_.weight_decay, config_.loss_scale)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, optimizer=opt,
|
||||
loss_scale_manager=loss_scale)
|
||||
|
||||
cb = [Monitor(lr_init=lr.asnumpy())]
|
||||
if args_opt.run_distribute:
|
||||
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
else:
|
||||
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/"
|
||||
if config_gpu.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config_gpu.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
# begine train
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
cb = [Monitor(lr_init=lr.asnumpy())]
|
||||
if args_opt.run_distribute:
|
||||
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
else:
|
||||
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/"
|
||||
if config_.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config_.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config_.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
# begine train
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
|
|
Loading…
Reference in New Issue