add mobilenet v2 quant and resnet50 quant to model_zoo

This commit is contained in:
chenzomi 2020-06-30 14:04:15 +08:00
parent 622b97f3b6
commit c530e15e09
32 changed files with 2659 additions and 16 deletions

View File

@ -33,7 +33,7 @@ Then you will get the following display
```bash
>>> Found existing installation: mindspore-ascend
>>> Uninstalling mindspore-ascend:
>>> Successfully uninstalled mindspore-ascend.
>>> Successfully uninstalled mindspore-ascend.
```
### Prepare Dataset
@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
### train quantization aware model
Also, you can just run this command instread.
Also, you can just run this command instead.
```python
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
Here are some optional parameters:
```bash
--device_target {Ascend,GPU,CPU}
--device_target {Ascend,GPU}
device where the code will be implemented (default: Ascend)
--data_path DATA_PATH
path where the dataset is saved

View File

@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')

View File

@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
@ -61,7 +61,7 @@ if __name__ == "__main__":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
print("============== Starting Testing ==============")

View File

@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
@ -56,8 +56,7 @@ if __name__ == "__main__":
# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max,
model_type=network.type)
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model

View File

@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
@ -50,11 +50,13 @@ if __name__ == "__main__":
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
@ -64,8 +66,7 @@ if __name__ == "__main__":
# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max,
model_type="quant")
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model

View File

@ -30,7 +30,7 @@ run_ascend()
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ];
if [ -d "../train" ];
then
rm -rf ../train
fi

View File

@ -0,0 +1,142 @@
# MobileNetV2 Quantization Aware Training
MobileNetV2 is a significant improvement over MobileNetV1 and pushes the state of the art for mobile visual recognition including classification, object detection and semantic segmentation.
MobileNetV2 builds upon the ideas from MobileNetV1, using depthwise separable convolution as efficient building blocks. However, V2 introduces two new features to the architecture: 1) linear bottlenecks between the layers, and 2) shortcut connections between the bottlenecks1.
Training MobileNetV2 with ImageNet dataset in MindSpore with quantization aware training.
This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware.
In this readme tutorial, you will:
1. Train a MindSpore fusion MobileNetV2 model for ImageNet from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
[Paper](https://arxiv.org/pdf/1801.04381) Sandler, Mark, et al. "Mobilenetv2: Inverted residuals and linear bottlenecks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.
# Dataset
Dataset use: ImageNet
- Dataset size: about 125G
- Train: 120G, 1281167 images: 1000 directories
- Test: 5G, 50000 images: images should be classified into 1000 directories firstly, just like train images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
# Environment Requirements
- HardwareAscend)
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# Script description
## Script and sample code
```python
├── mobilenetv2_quant
├── Readme.md
├── scripts
│ ├──run_train.sh
│ ├──run_infer.sh
│ ├──run_train_quant.sh
│ ├──run_infer_quant.sh
├── src
│ ├──config.py
│ ├──dataset.py
│ ├──luanch.py
│ ├──lr_generator.py
│ ├──mobilenetV2.py
├── train.py
├── eval.py
```
## Training process
### Train MobileNetV2 model
Train a MindSpore fusion MobileNetV2 model for ImageNet, like:
- sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]
You can just run this command instead.
``` bash
>>> sh run_train.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt
```
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
```
>>> epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100]
>>> epoch time: 140522.500, per step time: 224.836, avg loss: 5.258
>>> epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200]
>>> epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
```
### Evaluate MobileNetV2 model
Evaluate a MindSpore fusion MobileNetV2 model for ImageNet, like:
- sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH]
You can just run this command instead.
``` bash
>>> sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
```
Inference result will be stored in the example path, you can find result like the followings in `val.log`.
```
>>> result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt
```
### Fine-tune for quantization aware training
Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
- sh run_train_quant.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]
You can just run this command instead.
``` bash
>>> sh run_train_quant.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt
```
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
```
>>> epoch: [ 0/60], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100]
>>> epoch time: 140522.500, per step time: 224.836, avg loss: 5.258
>>> epoch: [ 1/60], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200]
>>> epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
```
### Evaluate quantization aware training model
Evaluate a MindSpore fusion MobileNetV2 model for ImageNet by applying the quantization aware training, like:
- sh run_infer_quant.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH]
You can just run this command instead.
``` bash
>>> sh run_infer_quant.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_625.ckpt
```
Inference result will be stored in the example path, you can find result like the followings in `val.log`.
```
>>> result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-60_625.ckpt
```
# ModelZoo Homepage
[Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo)

View File

@ -0,0 +1,76 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluate MobilenetV2 on ImageNet"""
import os
import argparse
from mindspore import context
from mindspore import nn
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.quant import quant
from src.mobilenetV2 import mobilenetV2
from src.dataset import create_dataset
from src.config import config_ascend
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training')
args_opt = parser.parse_args()
if __name__ == '__main__':
config_device_target = None
if args_opt.device_target == "Ascend":
config_device_target = config_ascend
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
else:
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
# define fusion network
network = mobilenetV2(num_classes=config_device_target.num_classes)
if args_opt.quantization_aware:
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# define network loss
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
# define dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,
config=config_device_target,
device_target=args_opt.device_target,
batch_size=config_device_target.batch_size)
step_size = dataset.get_dataset_size()
# load checkpoint
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(network, param_dict)
network.set_train(False)
# define model
model = Model(network, loss_fn=loss, metrics={'acc'})
print("============== Starting Validation ==============")
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)
print("============== End Validation ==============")

View File

@ -0,0 +1,53 @@
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
# check dataset path
if [ ! -d $2 ] && [ ! -f $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory or file"
exit 1
fi
# check checkpoint file
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$3 is not a file"
exit 1
fi
# set environment
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit
# launch
python ${BASEPATH}/../eval.py \
--device_target=$1 \
--dataset_path=$2 \
--checkpoint_path=$3 \
&> infer.log & # dataset val folder path

View File

@ -0,0 +1,54 @@
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
# check dataset path
if [ ! -d $2 ] && [ ! -f $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory or file"
exit 1
fi
# check checkpoint file
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$3 is not a file"
exit 1
fi
# set environment
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit
# launch
python ${BASEPATH}/../eval.py \
--device_target=$1 \
--dataset_path=$2 \
--checkpoint_path=$3 \
--quantization_aware=True \
&> infer.log & # dataset val folder path

View File

@ -0,0 +1,62 @@
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
run_ascend()
{
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then
echo "error: DEVICE_NUM=$2 is not in (1-9)"
exit 1
fi
if [ ! -d $5 ] && [ ! -f $5 ]
then
echo "error: DATASET_PATH=$5 is not a directory or file"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
python ${BASEPATH}/../src/launch.py \
--nproc_per_node=$2 \
--visible_devices=$4 \
--server_id=$3 \
--training_script=${BASEPATH}/../train.py \
--dataset_path=$5 \
--pre_trained=$6 \
--device_target=$1 &> train.log & # dataset train folder
}
if [ $# -gt 6 ] || [ $# -lt 4 ]
then
echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
"
exit 1
fi
if [ $1 = "Ascend" ] ; then
run_ascend "$@"
else
echo "Unsupported device target."
fi;

View File

@ -0,0 +1,63 @@
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
run_ascend()
{
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then
echo "error: DEVICE_NUM=$2 is not in (1-9)"
exit 1
fi
if [ ! -d $5 ] && [ ! -f $5 ]
then
echo "error: DATASET_PATH=$5 is not a directory or file"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
python ${BASEPATH}/../src/launch.py \
--nproc_per_node=$2 \
--visible_devices=$4 \
--server_id=$3 \
--training_script=${BASEPATH}/../train.py \
--dataset_path=$5 \
--pre_trained=$6 \
--quantization_aware=True \
--device_target=$1 &> train.log & # dataset train folder
}
if [ $# -gt 6 ] || [ $# -lt 4 ]
then
echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
"
exit 1
fi
if [ $1 = "Ascend" ] ; then
run_ascend "$@"
else
echo "Unsupported device target."
fi;

View File

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config_ascend = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 256,
"data_load_mode": "mindrecord",
"epoch_size": 200,
"start_epoch": 0,
"warmup_epochs": 4,
"lr": 0.4,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
"quantization_aware": False,
})
config_ascend_quant = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 192,
"data_load_mode": "mindrecord",
"epoch_size": 60,
"start_epoch": 200,
"warmup_epochs": 1,
"lr": 0.3,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
"quantization_aware": True,
})

View File

@ -0,0 +1,156 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
from functools import partial
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.py_transforms as P
from src.config import config_ascend
def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1.
batch_size(int): the batch size of dataset. Default: 32.
Returns:
dataset
"""
if device_target == "Ascend":
rank_size = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
columns_list = ['image', 'label']
if config_ascend.data_load_mode == "mindrecord":
load_func = partial(de.MindDataset, dataset_path, columns_list)
else:
load_func = partial(de.ImageFolderDatasetV2, dataset_path)
if do_train:
if rank_size == 1:
ds = load_func(num_parallel_workers=8, shuffle=True)
else:
ds = load_func(num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
else:
ds = load_func(num_parallel_workers=8, shuffle=False)
else:
raise ValueError("Unsupport device_target.")
resize_height = config.image_height
if do_train:
buffer_size = 20480
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
# define map operations
decode_op = C.Decode()
resize_crop_decode_op = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
resize_op = C.Resize(256)
center_crop = C.CenterCrop(resize_height)
normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
change_swap_op = C.HWC2CHW()
if do_train:
trans = [resize_crop_decode_op, horizontal_flip_op, normalize_op, change_swap_op]
else:
trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=16)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def create_dataset_py(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1.
batch_size(int): the batch size of dataset. Default: 32.
Returns:
dataset
"""
if device_target == "Ascend":
rank_size = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
if do_train:
if rank_size == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
else:
raise ValueError("Unsupported device target.")
resize_height = config.image_height
if do_train:
buffer_size = 20480
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
# define map operations
decode_op = P.Decode()
resize_crop_op = P.RandomResizedCrop(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5)
resize_op = P.Resize(256)
center_crop = P.CenterCrop(resize_height)
to_tensor = P.ToTensor()
normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if do_train:
trans = [decode_op, resize_crop_op, horizontal_flip_op, to_tensor, normalize_op]
else:
trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op]
compose = P.ComposeOp(trans)
ds = ds.map(input_columns="image", operations=compose(), num_parallel_workers=8, python_multiprocessing=True)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,166 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""launch train script"""
import os
import sys
import json
import subprocess
import shutil
import platform
from argparse import ArgumentParser
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="mindspore distributed training launch "
"helper utilty that will spawn up "
"multiple distributed processes")
parser.add_argument("--nproc_per_node", type=int, default=1,
help="The number of processes to launch on each node, "
"for D training, this is recommended to be set "
"to the number of D in your system so that "
"each process can be bound to a single D.")
parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7",
help="will use the visible devices sequentially")
parser.add_argument("--server_id", type=str, default="",
help="server ip")
parser.add_argument("--training_script", type=str,
help="The full path to the single D training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script")
# rest from the training program
args, unknown = parser.parse_known_args()
args.training_script_args = unknown
return args
def main():
print("start", __file__)
args = parse_args()
print(args)
visible_devices = args.visible_devices.split(',')
assert os.path.isfile(args.training_script)
assert len(visible_devices) >= args.nproc_per_node
print('visible_devices:{}'.format(visible_devices))
if not args.server_id:
print('pleaser input server ip!!!')
exit(0)
print('server_id:{}'.format(args.server_id))
# construct hccn_table
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
device_ips = {}
for hccn_item in hccn_configs:
hccn_item = hccn_item.strip()
if hccn_item.startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {}
arch = platform.processor()
hccn_table['board_id'] = {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch]
hccn_table['chip_info'] = '910'
hccn_table['deploy_mode'] = 'lab'
hccn_table['group_count'] = '1'
hccn_table['group_list'] = []
instance_list = []
usable_dev = ''
for instance_id in range(args.nproc_per_node):
instance = {}
instance['devices'] = []
device_id = visible_devices[instance_id]
device_ip = device_ips[device_id]
usable_dev += str(device_id)
instance['devices'].append({
'device_id': device_id,
'device_ip': device_ip,
})
instance['rank_id'] = str(instance_id)
instance['server_id'] = args.server_id
instance_list.append(instance)
hccn_table['group_list'].append({
'device_num': str(args.nproc_per_node),
'server_num': '1',
'group_name': '',
'instance_count': str(args.nproc_per_node),
'instance_list': instance_list,
})
hccn_table['para_plane_nic_location'] = 'device'
hccn_table['para_plane_nic_name'] = []
for instance_id in range(args.nproc_per_node):
eth_id = visible_devices[instance_id]
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
hccn_table['para_plane_nic_num'] = str(args.nproc_per_node)
hccn_table['status'] = 'completed'
# save hccn_table to file
table_path = os.getcwd()
if not os.path.exists(table_path):
os.mkdir(table_path)
table_fn = os.path.join(table_path,
'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id))
with open(table_fn, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
sys.stdout.flush()
# spawn the processes
processes = []
cmds = []
log_files = []
env = os.environ.copy()
env['RANK_SIZE'] = str(args.nproc_per_node)
cur_path = os.getcwd()
for rank_id in range(0, args.nproc_per_node):
os.chdir(cur_path)
device_id = visible_devices[rank_id]
device_dir = os.path.join(cur_path, 'device{}'.format(rank_id))
env['RANK_ID'] = str(rank_id)
env['DEVICE_ID'] = str(device_id)
if args.nproc_per_node > 1:
env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn
env['RANK_TABLE_FILE'] = table_fn
if os.path.exists(device_dir):
shutil.rmtree(device_dir)
os.mkdir(device_dir)
os.chdir(device_dir)
cmd = [sys.executable, '-u']
cmd.append(args.training_script)
cmd.extend(args.training_script_args)
log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
processes.append(process)
cmds.append(cmd)
log_files.append(log_file)
for process, cmd, log_file in zip(processes, cmds, log_files):
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(returncode=process, cmd=cmd)
log_file.close()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

View File

@ -0,0 +1,231 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MobileNetV2 Quant model define"""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
__all__ = ['mobilenetV2']
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class GlobalAvgPooling(nn.Cell):
"""
Global avg pooling definition.
Args:
Returns:
Tensor, output tensor.
Examples:
>>> GlobalAvgPooling()
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(keep_dims=False)
def construct(self, x):
x = self.mean(x, (2, 3))
return x
class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size,
stride=stride,
pad_mode='pad',
padding=padding,
group=groups,
has_bn=True,
activation='relu')
def construct(self, x):
x = self.conv(x)
return x
class InvertedResidual(nn.Cell):
"""
Mobilenetv2 residual block definition.
Args:
inp (int): Input channel.
oup (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
expand_ratio (int): expand ration of input channel
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, 1, 1)
"""
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True)
])
self.conv = nn.SequentialCell(layers)
self.add = P.TensorAdd()
def construct(self, x):
out = self.conv(x)
if self.use_res_connect:
out = self.add(out, x)
return out
class mobilenetV2(nn.Cell):
"""
mobilenetV2 fusion architecture.
Args:
class_num (Cell): number of classes.
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
has_dropout (bool): Is dropout used. Default is false
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to . Default is 8
Returns:
Tensor, output tensor.
Examples:
>>> mobilenetV2(num_classes=1000)
"""
def __init__(self, num_classes=1000, width_mult=1.,
has_dropout=False, inverted_residual_setting=None, round_nearest=8):
super(mobilenetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
# setting of inverted residual blocks
self.cfgs = inverted_residual_setting
if inverted_residual_setting is None:
self.cfgs = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in self.cfgs:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))
# make it nn.CellList
self.features = nn.SequentialCell(features)
# mobilenet head
head = ([GlobalAvgPooling(),
nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False)
] if not has_dropout else
[GlobalAvgPooling(),
nn.Dropout(0.2),
nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False)
])
self.head = nn.SequentialCell(head)
# init weights
self._initialize_weights()
def construct(self, x):
x = self.features(x)
x = self.head(x)
return x
def _initialize_weights(self):
"""
Initialize weights.
Args:
Returns:
None.
Examples:
>>> _initialize_weights()
"""
for _, m in self.cells_and_names():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32"))
m.weight.set_parameter_data(w)
if m.bias is not None:
m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))

View File

@ -0,0 +1,113 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MobileNetV2 utils"""
import time
import numpy as np
from mindspore.train.callback import Callback
from mindspore import Tensor
from mindspore import nn
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.5f}]".format(
cb_params.cur_epoch_num -
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
class CrossEntropyWithLabelSmooth(_Loss):
"""
CrossEntropyWith LabelSmooth.
Args:
smooth_factor (float): smooth factor, default=0.
num_classes (int): num classes
Returns:
None.
Examples:
>>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000)
"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropyWithLabelSmooth, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()
def construct(self, logit, label):
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1],
self.on_value, self.off_value)
out_loss = self.ce(logit, one_hot_label)
out_loss = self.mean(out_loss, 0)
return out_loss

View File

@ -0,0 +1,131 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train mobilenetV2 on ImageNet"""
import os
import argparse
import random
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore import nn
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.train.quant import quant
import mindspore.dataset.engine as de
from src.dataset import create_dataset
from src.lr_generator import get_lr
from src.utils import Monitor, CrossEntropyWithLabelSmooth
from src.config import config_ascend, config_ascend_quant
from src.mobilenetV2 import mobilenetV2
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path')
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training')
args_opt = parser.parse_args()
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
rank_id = int(os.getenv('RANK_ID'))
rank_size = int(os.getenv('RANK_SIZE'))
run_distribute = rank_size > 1
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id, save_graphs=False)
else:
raise ValueError("Unsupported device target.")
if __name__ == '__main__':
# train on ascend
config = config_ascend_quant if args_opt.quantization_aware else config_ascend
print("training args: {}".format(args_opt))
print("training configure: {}".format(config))
print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
epoch_size = config.epoch_size
# distribute init
if run_distribute:
context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True,
mirror_mean=True)
init()
# define network
network = mobilenetV2(num_classes=config.num_classes)
# define loss
if config.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes)
else:
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
# define dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config,
device_target=args_opt.device_target,
repeat_num=epoch_size,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
# load pre trained ckpt
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network
if config.quantization_aware:
network = quant.convert_quant_network(network,
bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
# get learning rate
lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
lr_init=0,
lr_end=0,
lr_max=config.lr,
warmup_epochs=config.warmup_epochs,
total_epochs=epoch_size + config.start_epoch,
steps_per_epoch=step_size))
# define optimization
opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum,
config.weight_decay)
# define model
model = Model(network, loss_fn=loss, optimizer=opt)
print("============== Starting Training ==============")
callback = None
if rank_id == 0:
callback = [Monitor(lr_init=lr.asnumpy())]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2",
directory=config.save_checkpoint_path,
config=config_ck)
callback += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=callback)
print("============== End Training ==============")

View File

@ -29,7 +29,7 @@ run_ascend()
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ];
if [ -d "../train" ];
then
rm -rf ../train
fi

View File

@ -0,0 +1,122 @@
# ResNet-50_quant Example
## Description
This is an example of training ResNet-50_quant with ImageNet2012 dataset in MindSpore.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the dataset ImageNet2012
> Unzip the ImageNet2012 dataset to any path you want and the folder structure should include train and eval dataset as follows:
> ```
> .
> ├── ilsvrc # train dataset
> └── ilsvrc_eval # infer dataset: images should be classified into 1000 directories firstly, just like train images
> ```
## Example structure
```shell
.
├── Resnet50_quant
├── Readme.md
├── scripts
│ ├──run_train.sh
│ ├──run_eval.sh
├── src
│ ├──config.py
│ ├──crossentropy.py
│ ├──dataset.py
│ ├──luanch.py
│ ├──lr_generator.py
│ ├──utils.py
├── models
│ ├──resnet_quant.py
├── train.py
├── eval.py
```
## Parameter configuration
Parameters for both training and inference can be set in config.py.
```
"class_num": 1001, # dataset class number
"batch_size": 32, # batch size of input tensor
"loss_scale": 1024, # loss scale
"momentum": 0.9, # momentum optimizer
"weight_decay": 1e-4, # weight decay
"epoch_size": 120, # only valid for taining, which is always 1 for inference
"pretrained_epoch_size": 90, # epoch size that model has been trained before load pretrained checkpoint
"buffer_size": 1000, # number of queue size in data preprocessing
"image_height": 224, # image height
"image_width": 224, # image width
"save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch
"keep_checkpoint_max": 50, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
"warmup_epochs": 0, # number of warmup epoch
"lr_decay_mode": "cosine", # decay mode for generating learning rate
"label_smooth": True, # label smooth
"label_smooth_factor": 0.1, # label smooth factor
"lr_init": 0, # initial learning rate
"lr_max": 0.005, # maximum learning rate
```
## Running the example
### Train
### Usage
- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]
### Launch
```
# training example
Ascend: sh run_train.sh Ascend 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet/train/
```
### Result
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
```
epoch: 1 step: 5004, loss is 4.8995576
epoch: 2 step: 5004, loss is 3.9235563
epoch: 3 step: 5004, loss is 3.833077
epoch: 4 step: 5004, loss is 3.2795618
epoch: 5 step: 5004, loss is 3.1978393
```
## Eval process
### Usage
- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH]
### Launch
```
# infer example
Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/checkpoint/resnet50-110_5004.ckpt
```
> checkpoint can be produced in training process.
#### Result
Inference result will be stored in the example path, whose folder name is "infer". Under this, you can find result like the followings in log.
```
result: {'acc': 0.75.252054737516005} ckpt=train_parallel0/resnet-110_5004.ckpt
```

View File

@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluate Resnet50 on ImageNet"""
import os
import argparse
from src.config import quant_set, config_quant, config_noquant
from src.dataset import create_dataset
from src.crossentropy import CrossEntropy
from src.utils import _load_param_into_net
from models.resnet_quant import resnet50_quant
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint
from mindspore.train.quant import quant
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
config = config_quant if quant_set.quantization_aware else config_noquant
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if __name__ == '__main__':
# define fusion network
net = resnet50_quant(class_num=config.class_num)
if quant_set.quantization_aware:
# convert fusion network to quantization aware network
net = quant.convert_quant_network(net,
bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
# define network loss
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor,
num_classes=config.class_num)
# define dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,
batch_size=config.batch_size,
target=args_opt.device_target)
step_size = dataset.get_dataset_size()
# load checkpoint
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
_load_param_into_net(net, param_dict)
net.set_train(False)
# define model
model = Model(net, loss_fn=loss, metrics={'acc'})
print("============== Starting Validation ==============")
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)
print("============== End Validation ==============")

View File

@ -0,0 +1,251 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import mindspore.nn as nn
from mindspore.ops import operations as P
class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, has_bn=True, activation='relu')
self.features = conv
def construct(self, x):
output = self.features(x)
return output
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', padding=0,
has_bn=True, activation='relu')
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel,
kernel_size=1, stride=stride,
pad_mode='same', padding=0, has_bn=True, activation='relu')
self.add = P.TensorAdd()
self.relu = P.ReLU()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = nn.DenseBnAct(out_channels[3], num_classes, has_bias=True, has_bn=False)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def resnet50_quant(class_num=10001):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50_quant(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def resnet101_quant(class_num=1001):
"""
Get ResNet101 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet101 neural network.
Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

View File

@ -0,0 +1,54 @@
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
# check dataset path
if [ ! -d $2 ] && [ ! -f $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory or file"
exit 1
fi
# check checkpoint file
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$3 is not a file"
exit 1
fi
# set environment
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit
# luanch
python ${BASEPATH}/../eval.py \
--device_target=$1 \
--dataset_path=$2 \
--checkpoint_path=$3 \
&> infer.log & # dataset val folder path

View File

@ -0,0 +1,62 @@
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
run_ascend()
{
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then
echo "error: DEVICE_NUM=$2 is not in (1-8)"
exit 1
fi
if [ ! -d $5 ] && [ ! -f $5 ]
then
echo "error: DATASET_PATH=$5 is not a directory or file"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
python ${BASEPATH}/../src/launch.py \
--nproc_per_node=$2 \
--visible_devices=$4 \
--server_id=$3 \
--training_script=${BASEPATH}/../train.py \
--dataset_path=$5 \
--pre_trained=$6 \
--device_target=$1 &> train.log & # dataset train folder
}
if [ $# -gt 6 ] || [ $# -lt 4 ]
then
echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
"
exit 1
fi
if [ $1 = "Ascend" ] ; then
run_ascend "$@"
else
echo "not support platform"
fi;

View File

@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
quant_set = ed({
"quantization_aware": True,
})
config_noquant = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 90,
"pretrained_epoch_size": 1,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"data_load_mode": "mindrecord",
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 50,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.1,
})
config_quant = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 120,
"pretrained_epoch_size": 90,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"data_load_mode": "mindrecord",
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 50,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.005,
})

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1001):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

View File

@ -0,0 +1,157 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
from functools import partial
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.py_transforms as P
from mindspore.communication.management import init, get_rank, get_group_size
from src.config import quant_set, config_quant, config_noquant
config = config_quant if quant_set.quantization_aware else config_noquant
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
else:
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
columns_list = ['image', 'label']
if config.data_load_mode == "mindrecord":
load_func = partial(de.MindDataset, dataset_path, columns_list)
else:
load_func = partial(de.ImageFolderDatasetV2, dataset_path)
if device_num == 1:
ds = load_func(num_parallel_workers=8, shuffle=True)
else:
ds = load_func(num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = config.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def create_dataset_py(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
else:
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
if do_train:
if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
image_size = 224
# define map operations
decode_op = P.Decode()
resize_crop_op = P.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333))
horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5)
resize_op = P.Resize(256)
center_crop = P.CenterCrop(image_size)
to_tensor = P.ToTensor()
normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# define map operations
if do_train:
trans = [decode_op, resize_crop_op, horizontal_flip_op, to_tensor, normalize_op]
else:
trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op]
compose = P.ComposeOp(trans)
ds = ds.map(input_columns="image", operations=compose(), num_parallel_workers=8, python_multiprocessing=True)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,165 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""launch train script"""
import os
import sys
import json
import subprocess
import shutil
import platform
from argparse import ArgumentParser
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="mindspore distributed training launch "
"helper utilty that will spawn up "
"multiple distributed processes")
parser.add_argument("--nproc_per_node", type=int, default=1,
help="The number of processes to launch on each node, "
"for D training, this is recommended to be set "
"to the number of D in your system so that "
"each process can be bound to a single D.")
parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7",
help="will use the visible devices sequentially")
parser.add_argument("--server_id", type=str, default="",
help="server ip")
parser.add_argument("--training_script", type=str,
help="The full path to the single D training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script")
# rest from the training program
args, unknown = parser.parse_known_args()
args.training_script_args = unknown
return args
def main():
print("start", __file__)
args = parse_args()
print(args)
visible_devices = args.visible_devices.split(',')
assert os.path.isfile(args.training_script)
assert len(visible_devices) >= args.nproc_per_node
print('visible_devices:{}'.format(visible_devices))
if not args.server_id:
print('pleaser input server ip!!!')
exit(0)
print('server_id:{}'.format(args.server_id))
# construct hccn_table
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
device_ips = {}
for hccn_item in hccn_configs:
hccn_item = hccn_item.strip()
if hccn_item.startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {}
arch = platform.processor()
hccn_table['board_id'] = {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch]
hccn_table['chip_info'] = '910'
hccn_table['deploy_mode'] = 'lab'
hccn_table['group_count'] = '1'
hccn_table['group_list'] = []
instance_list = []
usable_dev = ''
for instance_id in range(args.nproc_per_node):
instance = {}
instance['devices'] = []
device_id = visible_devices[instance_id]
device_ip = device_ips[device_id]
usable_dev += str(device_id)
instance['devices'].append({
'device_id': device_id,
'device_ip': device_ip,
})
instance['rank_id'] = str(instance_id)
instance['server_id'] = args.server_id
instance_list.append(instance)
hccn_table['group_list'].append({
'device_num': str(args.nproc_per_node),
'server_num': '1',
'group_name': '',
'instance_count': str(args.nproc_per_node),
'instance_list': instance_list,
})
hccn_table['para_plane_nic_location'] = 'device'
hccn_table['para_plane_nic_name'] = []
for instance_id in range(args.nproc_per_node):
eth_id = visible_devices[instance_id]
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
hccn_table['para_plane_nic_num'] = str(args.nproc_per_node)
hccn_table['status'] = 'completed'
# save hccn_table to file
table_path = os.getcwd()
if not os.path.exists(table_path):
os.mkdir(table_path)
table_fn = os.path.join(table_path,
'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id))
with open(table_fn, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
sys.stdout.flush()
# spawn the processes
processes = []
cmds = []
log_files = []
env = os.environ.copy()
env['RANK_SIZE'] = str(args.nproc_per_node)
cur_path = os.getcwd()
for rank_id in range(0, args.nproc_per_node):
os.chdir(cur_path)
device_id = visible_devices[rank_id]
device_dir = os.path.join(cur_path, 'device{}'.format(rank_id))
env['RANK_ID'] = str(rank_id)
env['DEVICE_ID'] = str(device_id)
if args.nproc_per_node > 1:
env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn
env['RANK_TABLE_FILE'] = table_fn
if os.path.exists(device_dir):
shutil.rmtree(device_dir)
os.mkdir(device_dir)
os.chdir(device_dir)
cmd = [sys.executable, '-u']
cmd.append(args.training_script)
cmd.extend(args.training_script_args)
log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
processes.append(process)
cmds.append(cmd)
log_files.append(log_file)
for process, cmd, log_file in zip(processes, cmds, log_files):
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(returncode=process, cmd=cmd)
log_file.close()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
learning_rate = np.array(lr_each_step).astype(np.float32)
return learning_rate

View File

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""utils script"""
def _load_param_into_net(model, params_dict):
"""
load fp32 model parameters to quantization model.
Args:
model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict = {
'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
'moving_variance': iter(
[item for item in params_dict.items() if item[0].endswith('moving_variance')]),
'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
}
for name, param in model.parameters_and_names():
key_name = name.split(".")[-1]
if key_name not in iterable_dict.keys():
continue
value_param = next(iterable_dict[key_name], None)
if value_param is not None:
param.set_parameter_data(value_param[1].data)
print(f'init model param {name} with checkpoint param {value_param[0]}')

153
model_zoo/resnet50_quant/train.py Executable file
View File

@ -0,0 +1,153 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train Resnet50 on ImageNet"""
import os
import argparse
from mindspore import context
from mindspore import Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint
from mindspore.train.quant import quant
from mindspore.communication.management import init
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from models.resnet_quant import resnet50_quant
from src.dataset import create_dataset
from src.lr_generator import get_lr
from src.config import quant_set, config_quant, config_noquant
from src.crossentropy import CrossEntropy
from src.utils import _load_param_into_net
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path')
args_opt = parser.parse_args()
config = config_quant if quant_set.quantization_aware else config_noquant
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
rank_id = int(os.getenv('RANK_ID'))
rank_size = int(os.getenv('RANK_SIZE'))
run_distribute = rank_size > 1
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id,
enable_auto_mixed_precision=True)
else:
raise ValueError("Unsupported device target.")
if __name__ == '__main__':
# train on ascend
print("training args: {}".format(args_opt))
print("training configure: {}".format(config))
print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
epoch_size = config.epoch_size
# distribute init
if run_distribute:
context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True,
mirror_mean=True)
init()
context.set_auto_parallel_context(device_num=args_opt.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
# define network
net = resnet50_quant(class_num=config.class_num)
net.set_train(True)
# weight init and load checkpoint file
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
_load_param_into_net(net, param_dict)
epoch_size = config.epoch_size - config.pretrained_epoch_size
else:
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
# define dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
repeat_num=epoch_size,
batch_size=config.batch_size,
target=args_opt.device_target)
step_size = dataset.get_dataset_size()
if quant_set.quantization_aware:
# convert fusion network to quantization aware network
net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# get learning rate
lr = get_lr(lr_init=config.lr_init,
lr_end=0.0,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size,
lr_decay_mode='cosine')
if args_opt.pre_trained:
lr = lr[config.pretrained_epoch_size * step_size:]
lr = Tensor(lr)
# define optimization
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
# define model
if quant_set.quantization_aware:
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
else:
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2")
print("============== Starting Training ==============")
time_callback = TimeMonitor(data_size=step_size)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
if rank_id == 0:
if config.save_checkpoint:
config_ckpt = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="ResNet50",
directory=config.save_checkpoint_path,
config=config_ckpt)
callbacks += [ckpt_callback]
model.train(epoch_size, dataset, callbacks=callbacks)
print("============== End Training ==============")