forked from mindspore-Ecosystem/mindspore
optimize the vgg script
This commit is contained in:
parent
12a150bb5d
commit
abbd7b50db
|
@ -8,7 +8,9 @@ This example is for VGG16 model training and evaluation.
|
|||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Download the CIFAR-10 binary version dataset.
|
||||
- Download the dataset CIFAR-10 or ImageNet2012.
|
||||
|
||||
CIFAR-10
|
||||
|
||||
> Unzip the CIFAR-10 dataset to any path you want and the folder structure should be as follows:
|
||||
> ```
|
||||
|
@ -17,16 +19,89 @@ This example is for VGG16 model training and evaluation.
|
|||
> └── cifar-10-verify-bin # infer dataset
|
||||
> ```
|
||||
|
||||
ImageNet2012
|
||||
|
||||
> Unzip the ImageNet2012 dataset to any path you want and the folder should include train and eval dataset as follows:
|
||||
>
|
||||
> ```
|
||||
> .
|
||||
> └─dataset
|
||||
> ├─ilsvrc # train dataset
|
||||
> └─validation_preprocess # evaluate dataset
|
||||
> ```
|
||||
|
||||
## Parameter configuration
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py.
|
||||
|
||||
- config for vgg16, CIFAR-10 dataset
|
||||
|
||||
```
|
||||
"num_classes": 10, # dataset class num
|
||||
"lr": 0.01, # learning rate
|
||||
"lr_init": 0.01, # initial learning rate
|
||||
"lr_max": 0.1, # max learning rate
|
||||
"lr_epochs": '30,60,90,120', # lr changing based epochs
|
||||
"lr_scheduler": "step", # learning rate mode
|
||||
"warmup_epochs": 5, # number of warmup epoch
|
||||
"batch_size": 64, # batch size of input tensor
|
||||
"max_epoch": 70, # only valid for taining, which is always 1 for inference
|
||||
"momentum": 0.9, # momentum
|
||||
"weight_decay": 5e-4, # weight decay
|
||||
"loss_scale": 1.0, # loss scale
|
||||
"label_smooth": 0, # label smooth
|
||||
"label_smooth_factor": 0, # label smooth factor
|
||||
"buffer_size": 10, # shuffle buffer size
|
||||
"image_size": '224,224', # image size
|
||||
"pad_mode": 'same', # pad mode for conv2d
|
||||
"padding": 0, # padding value for conv2d
|
||||
"has_bias": False, # whether has bias in conv2d
|
||||
"batch_norm": True, # wether has batch_norm in conv2d
|
||||
"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint
|
||||
"initialize_mode": "XavierUniform", # conv2d init mode
|
||||
"has_dropout": True # wether using Dropout layer
|
||||
```
|
||||
|
||||
- config for vgg16, ImageNet2012 dataset
|
||||
|
||||
```
|
||||
"num_classes": 1000, # dataset class num
|
||||
"lr": 0.01, # learning rate
|
||||
"lr_init": 0.01, # initial learning rate
|
||||
"lr_max": 0.1, # max learning rate
|
||||
"lr_epochs": '30,60,90,120', # lr changing based epochs
|
||||
"lr_scheduler": "cosine_annealing", # learning rate mode
|
||||
"warmup_epochs": 0, # number of warmup epoch
|
||||
"batch_size": 32, # batch size of input tensor
|
||||
"max_epoch": 150, # only valid for taining, which is always 1 for inference
|
||||
"momentum": 0.9, # momentum
|
||||
"weight_decay": 1e-4, # weight decay
|
||||
"loss_scale": 1024, # loss scale
|
||||
"label_smooth": 1, # label smooth
|
||||
"label_smooth_factor": 0.1, # label smooth factor
|
||||
"buffer_size": 10, # shuffle buffer size
|
||||
"image_size": '224,224', # image size
|
||||
"pad_mode": 'pad', # pad mode for conv2d
|
||||
"padding": 1, # padding value for conv2d
|
||||
"has_bias": True, # whether has bias in conv2d
|
||||
"batch_norm": False, # wether has batch_norm in conv2d
|
||||
"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint
|
||||
"initialize_mode": "KaimingNormal", # conv2d init mode
|
||||
"has_dropout": True # wether using Dropout layer
|
||||
```
|
||||
|
||||
## Running the Example
|
||||
|
||||
### Training
|
||||
**Run vgg16, using CIFAR-10 dataset**
|
||||
|
||||
- Training using single device(1p)
|
||||
```
|
||||
python train.py --data_path=your_data_path --device_id=6 > out.train.log 2>&1 &
|
||||
```
|
||||
The python command above will run in the background, you can view the results through the file `out.train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the script folder by default.
|
||||
After training, you'll get some checkpoint files in specified ckpt_path, default in ./output directory.
|
||||
|
||||
You will get the loss value as following:
|
||||
```
|
||||
|
@ -36,22 +111,7 @@ epcoh: 2 step: 781, loss is 1.827582
|
|||
...
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
- Do eval as follows, need to specify dataset type as "cifar10" or "imagenet2012"
|
||||
```
|
||||
python eval.py --data_path=your_data_path --dataset="cifar10" --pre_trained=./train_vgg_cifar10-70-781.ckpt > out.eval.log 2>&1 &
|
||||
```
|
||||
- If the using dataset is
|
||||
The above python command will run in the background, you can view the results through the file `out.eval.log`.
|
||||
|
||||
You will get the accuracy as following:
|
||||
```
|
||||
# grep "result: " out.eval.log
|
||||
result: {'acc': 0.92}
|
||||
```
|
||||
|
||||
### Distribute Training
|
||||
- Distribute Training
|
||||
```
|
||||
sh run_distribute_train.sh rank_table.json your_data_path
|
||||
```
|
||||
|
@ -70,18 +130,63 @@ train_parallel1/log:epcoh: 2 step: 97, loss is 1.7133579
|
|||
```
|
||||
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
|
||||
|
||||
|
||||
**Run vgg16, using imagenet2012 dataset**
|
||||
|
||||
- Training using single device(1p)
|
||||
```
|
||||
python train.py --device_target="GPU" --dataset="imagenet2012" --is_distributed=0 --data_path=$DATA_PATH > output.train.log 2>&1 &
|
||||
```
|
||||
|
||||
- Distribute Training
|
||||
```
|
||||
# distributed training(8p)
|
||||
bash scripts/run_distribute_train_gpu.sh /path/ImageNet2012/train"
|
||||
```
|
||||
|
||||
|
||||
### Evaluation
|
||||
|
||||
- Do eval as follows, need to specify dataset type as "cifar10" or "imagenet2012"
|
||||
```
|
||||
# when using cifar10 dataset
|
||||
python eval.py --data_path=your_data_path --dataset="cifar10" --device_target="Ascend" --pre_trained=./*-70-781.ckpt > out.eval.log 2>&1 &
|
||||
|
||||
# when using imagenet2012 dataset
|
||||
python eval.py --data_path=your_data_path --dataset="imagenet2012" --device_target="GPU" --pre_trained=./*-150-5004.ckpt > out.eval.log 2>&1 &
|
||||
```
|
||||
- If the using dataset is
|
||||
The above python command will run in the background, you can view the results through the file `out.eval.log`.
|
||||
|
||||
You will get the accuracy as following:
|
||||
```
|
||||
# when using cifar10 dataset
|
||||
# grep "result: " out.eval.log
|
||||
result: {'acc': 0.92}
|
||||
|
||||
# when using the imagenet2012 dataset
|
||||
after allreduce eval: top1_correct=36636, tot=50000, acc=73.27%
|
||||
after allreduce eval: top5_correct=45582, tot=50000, acc=91.16%
|
||||
```
|
||||
|
||||
## Usage:
|
||||
|
||||
### Training
|
||||
```
|
||||
usage: train.py [--device_target TARGET][--data_path DATA_PATH]
|
||||
[--dataset DATASET_TYPE][--is_distributed VALUE]
|
||||
[--device_id DEVICE_ID][--pre_trained PRE_TRAINED]
|
||||
[--ckpt_path CHECKPOINT_PATH][--ckpt_interval INTERVAL_STEP]
|
||||
|
||||
parameters/options:
|
||||
--device_target the training backend type, default is Ascend.
|
||||
--device_target the training backend type, Ascend or GPU, default is Ascend.
|
||||
--dataset the dataset type, cifar10 or imagenet2012.
|
||||
--is_distributed the way of traing, whether do distribute traing, value can be 0 or 1.
|
||||
--data_path the storage path of dataset
|
||||
--device_id the device which used to train model.
|
||||
--pre_trained the pretrained checkpoint file path.
|
||||
--ckpt_path the path to save checkpoint.
|
||||
--ckpt_interval the epoch interval for saving checkpoint.
|
||||
|
||||
```
|
||||
|
||||
|
@ -89,16 +194,19 @@ parameters/options:
|
|||
|
||||
```
|
||||
usage: eval.py [--device_target TARGET][--data_path DATA_PATH]
|
||||
[--device_id DEVICE_ID][--checkpoint_path CKPT_PATH]
|
||||
[--dataset DATASET_TYPE][--pre_trained PRE_TRAINED]
|
||||
[--device_id DEVICE_ID]
|
||||
|
||||
parameters/options:
|
||||
--device_target the evaluation backend type, default is Ascend.
|
||||
--data_path the storage path of datasetd
|
||||
--device_target the evaluation backend type, Ascend or GPU, default is Ascend.
|
||||
--dataset the dataset type, cifar10 or imagenet2012.
|
||||
--data_path the storage path of dataset.
|
||||
--device_id the device which used to evaluate model.
|
||||
--pre_trained the checkpoint file path used to evaluate model.
|
||||
--pre_trained the checkpoint file path used to evaluate model.
|
||||
```
|
||||
|
||||
### Distribute Training
|
||||
- Train on Ascend.
|
||||
|
||||
```
|
||||
Usage: sh script/run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH]
|
||||
|
@ -107,3 +215,11 @@ parameters/options:
|
|||
MINDSPORE_HCCL_CONFIG_PATH HCCL configuration file path.
|
||||
DATA_PATH the storage path of dataset.
|
||||
```
|
||||
|
||||
- Train on GPU.
|
||||
```
|
||||
Usage: bash run_distribute_train_gpu.sh [DATA_PATH]
|
||||
|
||||
parameters/options:
|
||||
DATA_PATH the storage path of dataset.
|
||||
```
|
|
@ -86,6 +86,8 @@ def parse_args(cloud_args=None):
|
|||
args_opt.padding = cfg.padding
|
||||
args_opt.has_bias = cfg.has_bias
|
||||
args_opt.batch_norm = cfg.batch_norm
|
||||
args_opt.initialize_mode = cfg.initialize_mode
|
||||
args_opt.has_dropout = cfg.has_dropout
|
||||
|
||||
args_opt.image_size = list(map(int, args_opt.image_size.split(',')))
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_distribute_train_gpu.sh DATA_PATH"
|
||||
echo "for example: bash run_distribute_train_gpu.sh /path/ImageNet2012/train"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DATA_PATH=$1
|
||||
|
||||
mpirun -n 8 python train.py \
|
||||
--device_target="GPU" \
|
||||
--dataset="imagenet2012" \
|
||||
--is_distributed=1 \
|
||||
--data_path=$DATA_PATH > output.train.log 2>&1 &
|
|
@ -19,50 +19,54 @@ from easydict import EasyDict as edict
|
|||
|
||||
# config for vgg16, cifar10
|
||||
cifar_cfg = edict({
|
||||
'num_classes': 10,
|
||||
"num_classes": 10,
|
||||
"lr": 0.01,
|
||||
'lr_init': 0.01,
|
||||
'lr_max': 0.1,
|
||||
"lr_init": 0.01,
|
||||
"lr_max": 0.1,
|
||||
"lr_epochs": '30,60,90,120',
|
||||
"lr_scheduler": "step",
|
||||
'warmup_epochs': 5,
|
||||
'batch_size': 64,
|
||||
'max_epoch': 70,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
"warmup_epochs": 5,
|
||||
"batch_size": 64,
|
||||
"max_epoch": 70,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 5e-4,
|
||||
"loss_scale": 1.0,
|
||||
"label_smooth": 0,
|
||||
"label_smooth_factor": 0,
|
||||
'buffer_size': 10,
|
||||
"buffer_size": 10,
|
||||
"image_size": '224,224',
|
||||
'pad_mode': 'same',
|
||||
'padding': 0,
|
||||
'has_bias': False,
|
||||
"pad_mode": 'same',
|
||||
"padding": 0,
|
||||
"has_bias": False,
|
||||
"batch_norm": True,
|
||||
'keep_checkpoint_max': 10
|
||||
"keep_checkpoint_max": 10,
|
||||
"initialize_mode": "XavierUniform",
|
||||
"has_dropout": False
|
||||
})
|
||||
|
||||
# config for vgg16, imagenet2012
|
||||
imagenet_cfg = edict({
|
||||
'num_classes': 1000,
|
||||
"num_classes": 1000,
|
||||
"lr": 0.01,
|
||||
'lr_init': 0.01,
|
||||
'lr_max': 0.1,
|
||||
"lr_init": 0.01,
|
||||
"lr_max": 0.1,
|
||||
"lr_epochs": '30,60,90,120',
|
||||
"lr_scheduler": 'cosine_annealing',
|
||||
'warmup_epochs': 0,
|
||||
'batch_size': 32,
|
||||
'max_epoch': 150,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 1e-4,
|
||||
"warmup_epochs": 0,
|
||||
"batch_size": 32,
|
||||
"max_epoch": 150,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"loss_scale": 1024,
|
||||
"label_smooth": 1,
|
||||
"label_smooth_factor": 0.1,
|
||||
'buffer_size': 10,
|
||||
"buffer_size": 10,
|
||||
"image_size": '224,224',
|
||||
'pad_mode': 'pad',
|
||||
'padding': 1,
|
||||
'has_bias': True,
|
||||
"pad_mode": 'pad',
|
||||
"padding": 1,
|
||||
"has_bias": True,
|
||||
"batch_norm": False,
|
||||
'keep_checkpoint_max': 10
|
||||
"keep_checkpoint_max": 10,
|
||||
"initialize_mode": "KaimingNormal",
|
||||
"has_dropout": True
|
||||
})
|
||||
|
|
|
@ -33,7 +33,7 @@ def _make_layer(base, args, batch_norm):
|
|||
else:
|
||||
weight_shape = (v, in_channels, 3, 3)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
|
||||
if args.dataset == "imagenet2012":
|
||||
if args.initialize_mode == "KaimingNormal":
|
||||
weight = 'normal'
|
||||
conv2d = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=v,
|
||||
|
@ -74,7 +74,7 @@ class Vgg(nn.Cell):
|
|||
self.layers = _make_layer(base, args, batch_norm=batch_norm)
|
||||
self.flatten = nn.Flatten()
|
||||
dropout_ratio = 0.5
|
||||
if args.dataset == "cifar10" or phase == "test":
|
||||
if not args.has_dropout or phase == "test":
|
||||
dropout_ratio = 1.0
|
||||
self.classifier = nn.SequentialCell([
|
||||
nn.Dense(512 * 7 * 7, 4096),
|
||||
|
@ -84,7 +84,7 @@ class Vgg(nn.Cell):
|
|||
nn.ReLU(),
|
||||
nn.Dropout(dropout_ratio),
|
||||
nn.Dense(4096, num_classes)])
|
||||
if args.dataset == "imagenet2012":
|
||||
if args.initialize_mode == "KaimingNormal":
|
||||
default_recurisive_init(self)
|
||||
self.custom_init_weight()
|
||||
|
||||
|
@ -128,14 +128,14 @@ def vgg16(num_classes=1000, args=None, phase="train"):
|
|||
|
||||
Args:
|
||||
num_classes (int): Class numbers. Default: 1000.
|
||||
args(dict): param for net init.
|
||||
args(namespace): param for net init.
|
||||
phase(str): train or test mode.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of Vgg16 neural network with batch normalization.
|
||||
|
||||
Examples:
|
||||
>>> vgg16(num_classes=1000)
|
||||
>>> vgg16(num_classes=1000, args=args)
|
||||
"""
|
||||
|
||||
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
|
||||
|
|
|
@ -161,6 +161,8 @@ def parse_args(cloud_args=None):
|
|||
args_opt.padding = cfg.padding
|
||||
args_opt.has_bias = cfg.has_bias
|
||||
args_opt.batch_norm = cfg.batch_norm
|
||||
args_opt.initialize_mode = cfg.initialize_mode
|
||||
args_opt.has_dropout = cfg.has_dropout
|
||||
|
||||
args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(',')))
|
||||
args_opt.image_size = list(map(int, cfg.image_size.split(',')))
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test_vgg"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from model_zoo.official.cv.vgg16.src.vgg import vgg16
|
||||
from model_zoo.official.cv.vgg16.src.config import cifar_cfg as cfg
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_vgg16():
|
||||
inputs = Tensor(np.random.rand(1, 3, 112, 112).astype(np.float32))
|
||||
net = vgg16(args=cfg)
|
||||
with pytest.raises(ValueError):
|
||||
print(net.construct(inputs))
|
Loading…
Reference in New Issue