!19072 EfficientNet supports training and inference for CPU and CIFAR-10 dataset

Merge pull request !19072 from liangxhao/efficientnet
This commit is contained in:
i-robot 2021-06-30 07:06:04 +00:00 committed by Gitee
commit e19e6b00b8
13 changed files with 466 additions and 166 deletions

View File

@ -28,18 +28,26 @@ The overall network architecture of EfficientNet-B0 is show below:
# [Dataset](#contents)
Dataset used: [imagenet](http://www.image-net.org/)
Dataset used:
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
1. [ImageNet](http://www.image-net.org/)
- Dataset size: ~125G, 133W colorful images in 1000 classes
- Train: 120G, 128W images
- Test: 5G, 5W images
- Data format: RGB images
2. [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)
- Dataset size: ~180MB, 6W colorful images in 10 classes
- Train: 150MB, 5W images
- Test: 30MB, 1W images
- Data format: RGB imagesBinary Version
Note: Data will be processed in src/dataset.py
# [Environment Requirements](#contents)
- Hardware GPU
- Prepare hardware environment with GPU processor.
- Hardware CPU/GPU
- Prepare hardware environment with CPU/GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
@ -50,14 +58,16 @@ Dataset used: [imagenet](http://www.image-net.org/)
## [Script and sample code](#contents)
```python
```text
.
└─efficientnet
├─README.md
├─scripts
├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p)
├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p)
└─run_eval_for_gpu.sh # launch evaluating with gpu platform
├─run_train_cpu.sh # launch training with cpu platform
├─run_standalone_train_gpu.sh # launch standalone training with gpu platform
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform
├─run_eval_cpu.sh # launch evaluating with cpu platform
└─run_eval_gpu.sh # launch evaluating with gpu platform
├─src
├─config.py # parameter configuration
├─dataset.py # data preprocessing
@ -74,6 +84,8 @@ Dataset used: [imagenet](http://www.image-net.org/)
Parameters for both training and evaluating can be set in config.py.
1. ImageNet Config for GPU:
```python
'random_seed': 1, # fix random seed
'model': 'efficientnet_b0', # model name
@ -101,27 +113,68 @@ Parameters for both training and evaluating can be set in config.py.
'resume_start_epoch': 0, # resume start epoch
```
2. CIFAR-10 Config for CPU/GPU
```python
'random_seed': 1, # fix random seed
'model': 'efficientnet_b0', # model name
'drop': 0.2, # dropout rate
'drop_connect': 0.2, # drop connect rate
'opt_eps': 0.0001, # optimizer epsilon
'lr': 0.0002, # learning rate LR
'batch_size': 32, # batch size
'decay_epochs': 2.4, # epoch interval to decay LR
'warmup_epochs': 5, # epochs to warmup LR
'decay_rate': 0.97, # LR decay rate
'weight_decay': 1e-5, # weight decay
'epochs': 150, # number of epochs to train
'workers': 8, # number of data processing processes
'amp_level': 'O0', # amp level
'opt': 'rmsprop', # optimizer
'num_classes': 10, # number of classes
'gp': 'avg', # type of global pool, "avg", "max", "avgmax", "avgmaxc"
'momentum': 0.9, # optimizer momentum
'warmup_lr_init': 0.0001, # init warmup LR
'smoothing': 0.1, # label smoothing factor
'bn_tf': False, # use Tensorflow BatchNorm defaults
'keep_checkpoint_max': 10, # max number ckpts to keep
'loss_scale': 1024, # loss scale
'resume_start_epoch': 0, # resume start epoch
```
## [Training Process](#contents)
### Usage
```python
GPU:
# distribute training example(8p)
sh run_distribute_train_for_gpu.sh
1. GPU
```bash
# distribute training
sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_TYPE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# standalone training
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
sh run_standalone_train_gpu.sh [DATASET_TYPE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
```
### Launch
2. CPU
```bash
sh run_train_cpu.sh [DATASET_TYPE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
```
### Launch Example
```bash
# distributed training example(8p) for GPU
cd scripts
sh run_distribute_train_for_gpu.sh 8 0,1,2,3,4,5,6,7 /dataset/train
sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 ImageNet /dataset/train
# standalone training example for GPU
cd scripts
sh run_standalone_train_for_gpu.sh 0 /dataset/train
sh run_standalone_train_gpu.sh ImageNet /dataset/train
# training example for CPU
cd scripts
sh run_train_cpu.sh ImageNet /dataset/train
```
You can find checkpoint file together with result in log.
@ -130,30 +183,41 @@ You can find checkpoint file together with result in log.
### Usage
1. CPU
```bash
# Evaluation
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
sh run_eval_cpu.sh [DATASET_TYPE] [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Launch
2. GPU
```bash
# Evaluation with checkpoint
sh run_eval_gpu.sh [DATASET_TYPE] [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Launch Example
```bash
# Evaluation with checkpoint for GPU
cd scripts
sh run_eval_for_gpu.sh /dataset/eval ./checkpoint/efficientnet_b0-600_1251.ckpt
sh run_eval_gpu.sh ImageNet /dataset/eval ./checkpoint/efficientnet_b0-600_1251.ckpt
# Evaluation with checkpoint for CPU
cd scripts
sh run_eval_cpu.sh ImageNet /dataset/eval ./checkpoint/efficientnet_b0-600_1251.ckpt
```
#### Result
Evaluation result will be stored in the scripts path. Under this, you can find result like the following in log.
```python
```text
acc=76.96%(TOP1)
```
# [Model description](#contents)
## [Performance](#contents)
## [Performance in ImageNet](#contents)
### Training Performance
@ -165,23 +229,53 @@ acc=76.96%(TOP1)
| Dataset | ImageNet |
| Training Parameters | src/config.py |
| Optimizer | rmsprop |
| Loss Function | LabelSmoothingCrossEntropy |
| Loss Function | LabelSmoothingCrossEntropy|
| Loss | 1.8886 |
| Accuracy | 76.96%(TOP1) |
| Accuracy | 76.96%(TOP1) |
| Total time | 132 h 8ps |
| Checkpoint for Fine tuning | 64 M(.ckpt file) |
| Checkpoint for Fine tuning | 64 M(.ckpt file) |
### Inference Performance
| Parameters | |
| -------------------------- | ------------------------- |
| Resource | NV SMX2 V100-32G |
| uploaded Date | 10/26/2020 |
| MindSpore Version | 1.0.0 |
| Dataset | ImageNet, 1.2W |
| batch_size | 128 |
| outputs | probability |
| Accuracy | acc=76.96%(TOP1) |
| Parameters | |
| ----------------- | ---------------- |
| Resource | NV SMX2 V100-32G |
| uploaded Date | 10/26/2020 |
| MindSpore Version | 1.0.0 |
| Dataset | ImageNet |
| batch_size | 128 |
| outputs | probability |
| Accuracy | acc=76.96%(TOP1) |
## [Performance in CIFAR-10](#contents)
### Training Performance
| Parameters | efficientnet_b0 |
| -------------------------- | -------------------------- |
| Resource | NV GTX 1080Ti-12G |
| uploaded Date | 06/28/2021 |
| MindSpore Version | 1.3.0 |
| DataseCIFAR | CIFAR-10 |
| Training Parameters | src/config.py |
| Optimizer | rmsprop |
| Loss Function | LabelSmoothingCrossEntropy |
| Loss | 1.2773 |
| Accuracy | 97.75%(TOP1) |
| Total time | 2 h 4ps |
| Checkpoint for Fine tuning | 47 M(.ckpt file) |
### Inference Performance
| Parameters | |
| ----------------- | ---------------- |
| Resource | NV GTX 1080Ti-12G|
| uploaded Date | 06/28/2021 |
| MindSpore Version | 1.3.0 |
| Dataset | CIFAR-10 |
| batch_size | 128 |
| outputs | probability |
| Accuracy | acc=93.12%(TOP1) |
# [ModelZoo Homepage](#contents)

View File

@ -20,24 +20,28 @@ from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import efficientnet_b0_config_gpu as cfg
from src.config import dataset_config
from src.dataset import create_dataset_val
from src.efficientnet import efficientnet_b0
from src.loss import LabelSmoothingCrossEntropy
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of efficientnet (Default: None)')
parser.add_argument('--data_path', type=str, default='', help='Dataset path')
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
parser.add_argument('--checkpoint', type=str, required=True, help='checkpoint of efficientnet (Default: None)')
parser.add_argument('--data_path', type=str, required=True, help='Dataset path')
parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet', 'CIFAR10'],
help='ImageNet or CIFAR10')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'), help='run platform')
args_opt = parser.parse_args()
if args_opt.platform != 'GPU':
raise ValueError("Only supported GPU training.")
print(args_opt)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
dataset_type = args_opt.dataset.lower()
cfg = dataset_config[dataset_type].cfg
net = efficientnet_b0(num_classes=cfg.num_classes,
cfg=dataset_config[dataset_type],
drop_rate=cfg.drop,
drop_connect_rate=cfg.drop_connect,
global_pool=cfg.gp,
@ -48,12 +52,15 @@ if __name__ == '__main__':
load_param_into_net(net, ckpt)
net.set_train(False)
val_data_url = args_opt.data_path
dataset = create_dataset_val(cfg.batch_size, val_data_url, workers=cfg.workers, distributed=False)
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
dataset = create_dataset_val(dataset_type, val_data_url, cfg.batch_size, workers=cfg.workers, distributed=False)
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing, num_classes=cfg.num_classes)
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
metrics = model.eval(dataset)
dataset_sink_mode = args_opt.platform != "CPU"
metrics = model.eval(dataset, dataset_sink_mode=dataset_sink_mode)
print("metric: ", metrics)

View File

@ -18,27 +18,30 @@ import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.efficientnet import efficientnet_b0
from src.config import efficientnet_b0_config_gpu as cfg
from src.config import dataset_config
parser = argparse.ArgumentParser(description="efficientnet export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--width", type=int, default=224, help="input width")
parser.add_argument("--height", type=int, default=224, help="input height")
parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet', 'CIFAR10'],
help='ImageNet or CIFAR10')
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="efficientnet", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"],
default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="GPU",
parser.add_argument("--device_target", type=str, choices=["GPU", "CPU"], default="GPU",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == "__main__":
if args.device_target != "GPU":
raise ValueError("Only supported GPU now.")
if args.device_target not in ("GPU", "CPU"):
raise ValueError("Only supported CPU and GPU now.")
dataset_type = args.dataset.lower()
cfg = dataset_config[dataset_type].cfg
net = efficientnet_b0(num_classes=cfg.num_classes,
drop_rate=cfg.drop,

View File

@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ] && [ $# != 4 ]
if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage:
sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_TYPE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
"
exit 1
fi
@ -27,10 +27,17 @@ then
exit 1
fi
# check dataset file
if [ ! -d $3 ]
# check dataset type
if [[ $3 != "ImageNet" ]] && [[ $3 != "CIFAR10" ]]
then
echo "error: DATASET_PATH=$3 is not a directory"
echo "error: Only supported for ImageNet and CIFAR10, but DATASET_TYPE=$3."
exit 1
fi
# check dataset file
if [ ! -d $4 ]
then
echo "error: DATASET_PATH=$4 is not a directory"
exit 1
fi
@ -48,22 +55,24 @@ cd ../train || exit
export CUDA_VISIBLE_DEVICES="$2"
if [ $# == 3 ]
then
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python ${BASEPATH}/../train.py \
--GPU \
--distributed \
--data_path $3 > train.log 2>&1 &
fi
if [ $# == 4 ]
then
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python ${BASEPATH}/../train.py \
--GPU \
--distributed \
--data_path $3 \
--resume $4 > train.log 2>&1 &
--platform GPU \
--dataset $3 \
--data_path $4 > train.log 2>&1 &
fi
if [ $# == 5 ]
then
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python ${BASEPATH}/../train.py \
--platform GPU \
--distributed \
--dataset $3 \
--data_path $4 \
--resume $5 > train.log 2>&1 &
fi

View File

@ -0,0 +1,56 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage:
sh run_eval_cpu.sh [DATASET_TYPE] [DATASET_PATH] [CHECKPOINT_PATH]
"
exit 1
fi
# check dataset type
if [[ $1 != "ImageNet" ]] && [[ $1 != "CIFAR10" ]]
then
echo "error: Only supported for ImageNet and CIFAR10, but DATASET_TYPE=$1."
exit 1
fi
# check dataset file
if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory."
exit 1
fi
# check checkpoint file
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$3 is not a file"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit
python ${BASEPATH}/../eval.py --dataset $1 --data_path $2 --platform CPU --checkpoint=$3 > ./eval.log 2>&1 &

View File

@ -13,23 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
if [ $# != 3 ]
then
echo "GPU: sh run_eval_for_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]"
echo "Usage:
sh run_eval_gpu.sh [DATASET_TYPE] [DATASET_PATH] [CHECKPOINT_PATH]
"
exit 1
fi
# check dataset type
if [[ $1 != "ImageNet" ]] && [[ $1 != "CIFAR10" ]]
then
echo "error: Only supported for ImageNet and CIFAR10, but DATASET_TYPE=$1."
exit 1
fi
# check dataset file
if [ ! -d $1 ]
if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
echo "error: DATASET_PATH=$2 is not a directory."
exit 1
fi
# check checkpoint file
if [ ! -f $2 ]
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$2 is not a file"
echo "error: CHECKPOINT_PATH=$3 is not a file"
exit 1
fi
@ -43,4 +53,4 @@ fi
mkdir ../eval
cd ../eval || exit
python ${BASEPATH}/../eval.py --platform 'GPU' --data_path $1 --checkpoint=$2 > ./eval.log 2>&1 &
python ${BASEPATH}/../eval.py --dataset $1 --data_path $2 --platform GPU --checkpoint=$3 > ./eval.log 2>&1 &

View File

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage:
sh run_standalone_train_gpu.sh [DATASET_TYPE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
"
exit 1
fi
# check dataset type
if [[ $1 != "ImageNet" ]] && [[ $1 != "CIFAR10" ]]
then
echo "error: Only supported for ImageNet and CIFAR10, but DATASET_TYPE=$1."
exit 1
fi
# check dataset file
if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory."
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
if [ $# == 2 ]
then
python ${BASEPATH}/../train.py --dataset $1 --data_path $2 --platform GPU > train.log 2>&1 &
fi
if [ $# == 3 ]
then
python ${BASEPATH}/../train.py --dataset $1 --data_path $2 --platform GPU --resume $3 > train.log 2>&1 &
fi

View File

@ -16,18 +16,27 @@
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage:
sh run_standalone_train_for_gpu.sh [DEVICE_ID] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
sh run_train_cpu.sh [DATASET_TYPE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
"
exit 1
fi
# check dataset type
if [[ $1 != "ImageNet" ]] && [[ $1 != "CIFAR10" ]]
then
echo "error: Only supported for ImageNet and CIFAR10, but DATASET_TYPE=$1."
exit 1
fi
# check dataset file
if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory"
echo "error: DATASET_PATH=$2 is not a directory."
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
@ -37,14 +46,13 @@ fi
mkdir ../train
cd ../train || exit
export CUDA_VISIBLE_DEVICES=$1
if [ $# == 2 ]
then
python ${BASEPATH}/../train.py --GPU --data_path $2 > train.log 2>&1 &
python ${BASEPATH}/../train.py --dataset $1 --data_path $2 --platform CPU > train.log 2>&1 &
fi
if [ $# == 3 ]
then
python ${BASEPATH}/../train.py --GPU --data_path $2 --resume $3 > train.log 2>&1 &
python ${BASEPATH}/../train.py --dataset $1 --data_path $2 --platform CPU --resume $3 > train.log 2>&1 &
fi

View File

@ -15,10 +15,45 @@
"""
network config setting
"""
from easydict import EasyDict as edict
from easydict import EasyDict
efficientnet_b0_config_gpu = edict({
'random_seed': 1,
resize_value = 224 # image resize
basic_config = EasyDict({
'random_seed': 1
})
efficientnet_b0_config_cifar10 = EasyDict({
'model': 'efficientnet_b0',
'drop': 0.2,
'drop_connect': 0.2,
'opt_eps': 0.0001,
'lr': 0.0002,
'batch_size': 32,
'decay_epochs': 2.4,
'warmup_epochs': 5,
'decay_rate': 0.97,
'weight_decay': 1e-5,
'epochs': 150,
'workers': 8,
'amp_level': 'O0',
'opt': 'rmsprop',
'num_classes': 10,
#'Type of global pool, "avg", "max", "avgmax", "avgmaxc"
'gp': 'avg',
'momentum': 0.9,
'warmup_lr_init': 0.0001,
'smoothing': 0.1,
#Use Tensorflow BatchNorm defaults for models that support it
'bn_tf': False,
'save_checkpoint': True,
'keep_checkpoint_max': 10,
'loss_scale': 1024,
'resume_start_epoch': 0,
})
efficientnet_b0_config_imagenet = EasyDict({
'model': 'efficientnet_b0',
'drop': 0.2,
'drop_connect': 0.2,
@ -33,7 +68,7 @@ efficientnet_b0_config_gpu = edict({
'workers': 8,
'amp_level': 'O0',
'opt': 'rmsprop',
'num_classes': 1000,
'num_classes': 10,
#'Type of global pool, "avg", "max", "avgmax", "avgmaxc"
'gp': 'avg',
'momentum': 0.9,
@ -46,3 +81,19 @@ efficientnet_b0_config_gpu = edict({
'loss_scale': 1024,
'resume_start_epoch': 0,
})
dataset_config = {
"imagenet": EasyDict({
"size": (224, 224),
"mean": (0.485, 0.456, 0.406),
"std": (0.229, 0.224, 0.225),
"cfg": efficientnet_b0_config_imagenet
}),
"cifar10": EasyDict({
"size": (32, 32),
"mean": (0.4914, 0.4822, 0.4465),
"std": (0.247, 0.2435, 0.2616),
"cfg": efficientnet_b0_config_cifar10
})
}

View File

@ -25,21 +25,16 @@ import mindspore.dataset.vision.c_transforms as C
from mindspore.communication.management import get_group_size, get_rank
from mindspore.dataset.vision import Inter
from src.config import efficientnet_b0_config_gpu as cfg
from src.config import basic_config, dataset_config, resize_value
from src.transform import RandAugment
ds.config.set_seed(cfg.random_seed)
ds.config.set_seed(basic_config.random_seed)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
img_size = (224, 224)
crop_pct = 0.875
rescale = 1.0 / 255.0
shift = 0.0
inter_method = 'bilinear'
resize_value = 224 # img_size
scale = (0.08, 1.0)
ratio = (3./4., 4./3.)
inter_str = 'bicubic'
@ -51,9 +46,10 @@ def str2MsInter(method):
return Inter.NEAREST
return Inter.BILINEAR
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False):
def create_dataset(datatype_type, train_data_url, batch_size, workers=8, distributed=False):
if not os.path.exists(train_data_url):
raise ValueError('Path not exists')
interpolation = str2MsInter(inter_str)
c_decode_op = C.Decode()
@ -61,19 +57,31 @@ def create_dataset(batch_size, train_data_url='', workers=8, distributed=False):
random_resize_crop_op = C.RandomResizedCrop(size=(resize_value, resize_value), scale=scale, ratio=ratio,
interpolation=interpolation)
random_horizontal_flip_op = C.RandomHorizontalFlip(0.5)
efficient_rand_augment = RandAugment(dataset_config[datatype_type])
efficient_rand_augment = RandAugment()
image_ops = [c_decode_op, random_resize_crop_op, random_horizontal_flip_op]
# load dataset
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset_train = ds.ImageFolderDataset(train_data_url,
if datatype_type.lower() == 'imagenet':
dataset_train = ds.ImageFolderDataset(train_data_url,
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
image_ops = [c_decode_op, random_resize_crop_op, random_horizontal_flip_op]
elif datatype_type.lower() == 'cifar10':
dataset_train = ds.Cifar10Dataset(train_data_url,
usage="train",
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
image_ops = [random_resize_crop_op, random_horizontal_flip_op]
else:
raise NotImplementedError("Only supported for ImageNet or CIFAR10 dataset")
# build dataset
dataset_train = dataset_train.map(input_columns=["image"],
operations=image_ops,
num_parallel_workers=workers)
@ -83,21 +91,18 @@ def create_dataset(batch_size, train_data_url='', workers=8, distributed=False):
ds_train = dataset_train.batch(batch_size,
per_batch_map=efficient_rand_augment,
input_columns=["image", "label"],
num_parallel_workers=2,
num_parallel_workers=workers,
drop_remainder=True)
return ds_train
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False):
def create_dataset_val(datatype_type, val_data_url, batch_size=128, workers=8, distributed=False):
if not os.path.exists(val_data_url):
raise ValueError('Path not exists')
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id, shuffle=False)
scale_size = None
interpolation = str2MsInter(inter_method)
img_size = resize_value
if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
@ -110,12 +115,25 @@ def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=F
type_cast_op = C2.TypeCast(mstype.int32)
decode_op = C.Decode()
resize_op = C.Resize(size=scale_size, interpolation=interpolation)
center_crop = C.CenterCrop(size=224)
center_crop = C.CenterCrop(size=resize_value)
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
normalize_op = C.Normalize(dataset_config[datatype_type].mean, dataset_config[datatype_type].std)
changeswap_op = C.HWC2CHW()
ctrans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, changeswap_op]
# load dataset
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
if datatype_type.lower() == 'imagenet':
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id, shuffle=False)
ctrans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, changeswap_op]
elif datatype_type.lower() == 'cifar10':
dataset = ds.Cifar10Dataset(val_data_url, usage="test", num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id, shuffle=False)
ctrans = [resize_op, center_crop, rescale_op, normalize_op, changeswap_op]
else:
raise NotImplementedError("Only supported for ImageNet or CIFAR10 dataset")
dataset = dataset.map(input_columns=["label"], operations=type_cast_op, num_parallel_workers=workers)
dataset = dataset.map(input_columns=["image"], operations=ctrans, num_parallel_workers=workers)

View File

@ -708,9 +708,12 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
return model
def efficientnet_b0(num_classes=1000, in_chans=3, **kwargs):
def efficientnet_b0(num_classes=1000, in_chans=3, cfg=None, **kwargs):
""" EfficientNet-B0 """
default_cfg = default_cfgs['efficientnet_b0']
default_cfg["num_classes"] = num_classes
if cfg:
default_cfg.update({k: v for k, v in cfg.items() if k in default_cfg})
model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.0,
num_classes=num_classes, in_chans=in_chans, **kwargs)

View File

@ -19,16 +19,17 @@ import numpy as np
import mindspore.dataset.vision.py_transforms as P
from src import transform_utils
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
class RandAugment:
# img_info have property mean、std
# config_str belongs to str
# hparams belongs to dict
def __init__(self, config_str="rand-m9-mstd0.5", hparams=None):
def __init__(self, img_info, config_str="rand-m9-mstd0.5", hparams=None):
hparams = hparams if hparams is not None else {}
self.config_str = config_str
self.hparams = hparams
self.mean = img_info.mean
self.std = img_info.std
def __call__(self, imgs, labels, batchInfo):
# assert the imgs object are pil_images
@ -36,7 +37,7 @@ class RandAugment:
ret_labels = []
py_to_pil_op = P.ToPIL()
to_tensor = P.ToTensor()
normalize_op = P.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
normalize_op = P.Normalize(self.mean, self.std)
rand_augment_ops = transform_utils.rand_augment_transform(self.config_str, self.hparams)
for i, image in enumerate(imgs):
img_pil = py_to_pil_op(image)

View File

@ -15,7 +15,6 @@
"""train imagenet."""
import argparse
import math
import os
import random
import numpy as np
@ -28,14 +27,14 @@ from mindspore.train.callback import (CheckpointConfig, LossMonitor,
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import efficientnet_b0_config_gpu as cfg
from src.config import basic_config, dataset_config
from src.dataset import create_dataset
from src.efficientnet import efficientnet_b0
from src.loss import LabelSmoothingCrossEntropy
mindspore.common.set_seed(cfg.random_seed)
random.seed(cfg.random_seed)
np.random.seed(cfg.random_seed)
mindspore.common.set_seed(basic_config.random_seed)
random.seed(basic_config.random_seed)
np.random.seed(basic_config.random_seed)
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1,
@ -60,60 +59,48 @@ def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1,
return lr_each_step
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
if not os.path.exists(outdir):
os.makedirs(outdir)
elif inc:
count = 1
outdir_inc = outdir + '-' + str(count)
while os.path.exists(outdir_inc):
count = count + 1
outdir_inc = outdir + '-' + str(count)
assert count < 100
outdir = outdir_inc
os.makedirs(outdir)
return outdir
parser = argparse.ArgumentParser(
description='Training configuration', add_help=False)
parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR',
help='path to dataset')
parser = argparse.ArgumentParser(description='Training configuration', add_help=False)
parser.add_argument('--data_path', type=str, metavar='DIR', required=True, help='path to dataset')
parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet', 'CIFAR10'],
help='ImageNet or CIFAR10')
parser.add_argument('--distributed', action='store_true', default=False)
parser.add_argument('--GPU', action='store_true', default=False,
help='Use GPU for training (default: False)')
parser.add_argument('--cur_time', type=str,
default='19701010-000000', help='current time')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'), help='run platform')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
def main():
args, _ = parser.parse_known_args()
rank_id, rank_size = 0, 1
print(args)
rank_id, rank_size = 0, 1
context.set_context(mode=context.GRAPH_MODE)
if args.platform == "GPU":
dataset_sink_mode = True
context.set_context(device_target='GPU', enable_graph_kernel=True)
elif args.platform == "CPU":
dataset_sink_mode = False
context.set_context(device_target='CPU')
else:
raise NotImplementedError("Training only supported for CPU and GPU.")
if args.distributed:
if args.GPU:
if args.platform == "GPU":
init("nccl")
context.set_context(device_target='GPU')
else:
raise ValueError("Only supported GPU training.")
raise NotImplementedError("Distributed Training only supported for GPU.")
context.reset_auto_parallel_context()
rank_id = get_rank()
rank_size = get_group_size()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, device_num=rank_size)
else:
if args.GPU:
context.set_context(device_target='GPU')
else:
raise ValueError("Only supported GPU training.")
dataset_type = args.dataset.lower()
cfg = dataset_config[dataset_type].cfg
net = efficientnet_b0(num_classes=cfg.num_classes,
cfg=dataset_config[dataset_type],
drop_rate=cfg.drop,
drop_connect_rate=cfg.drop_connect,
global_pool=cfg.gp,
@ -122,11 +109,12 @@ def main():
train_data_url = args.data_path
train_dataset = create_dataset(
cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed)
dataset_type, train_data_url, cfg.batch_size, workers=cfg.workers, distributed=args.distributed)
batches_per_epoch = train_dataset.get_dataset_size()
print("Batches_per_epoch: ", batches_per_epoch)
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
loss_cb = LossMonitor(per_print_times=1 if args.platform == "CPU" else batches_per_epoch)
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing, num_classes=cfg.num_classes)
time_cb = TimeMonitor(data_size=batches_per_epoch)
loss_scale_manager = FixedLossScaleManager(
cfg.loss_scale, drop_overflow_update=False)
@ -165,18 +153,13 @@ def main():
amp_level=cfg.amp_level
)
if args.GPU:
context.set_context(enable_graph_kernel=True)
# callbacks = callbacks if is_master else []
if args.resume:
real_epoch = cfg.epochs - cfg.resume_start_epoch
model.train(real_epoch, train_dataset,
callbacks=callbacks, dataset_sink_mode=True)
callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
else:
model.train(cfg.epochs, train_dataset,
callbacks=callbacks, dataset_sink_mode=True)
callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
if __name__ == '__main__':