!9571 Add Inceptionv4 net to model_zoo/official/cv/

From: @zhanghuiyao
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-09 12:53:48 +08:00 committed by Gitee
commit b2a164b1c2
12 changed files with 1115 additions and 0 deletions

View File

@ -0,0 +1,240 @@
# InceptionV4 for Ascend
- [InceptionV4 Description](#InceptionV4-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [InceptionV4 Description](#contents)
Inception-v4 is a convolutional neural network architecture that builds on previous iterations of the Inception family by simplifying the architecture and using more inception modules than Inception-v3. This idea was proposed in the paper Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning, published in 2016.
[Paper](https://arxiv.org/pdf/1602.07261.pdf) Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi. Computer Vision and Pattern Recognition[J]. 2016.
# [Model architecture](#contents)
The overall network architecture of InceptionV4 is show below:
[Link](https://arxiv.org/pdf/1602.07261.pdf)
# [Dataset](#contents)
Dataset used can refer to paper.
- Dataset size: 125G, 1250k colorful images in 1000 classes
- Train: 120G, 1200k images
- Test: 5G, 50k images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
# [Features](#contents)
## [Mixed Precision(Ascend)](#contents)
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```shell
.
└─Inception-v4
├─README.md
├─scripts
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
└─run_eval_ascend.sh # launch evaluating with ascend platform
├─src
├─config.py # parameter configuration
├─dataset.py # data preprocessing
├─inceptionv4.py # network definition
└─callback.py # eval callback function
├─eval.py # eval net
├─export.py # export checkpoint, surpport .onnx, .air, .mindir convert
└─train.py # train net
```
## [Script Parameters](#contents)
```python
Major parameters in train.py and config.py are:
'is_save_on_master' # save checkpoint only on master device
'batch_size' # input batchsize
'epoch_size' # total epoch numbers
'num_classes' # dataset class numbers
'work_nums' # number of workers to read data
'loss_scale' # loss scale
'smooth_factor' # label smoothing factor
'weight_decay' # weight decay
'momentum' # momentum
'amp_level' # precision training, Supports [O0, O2, O3]
'decay' # decay used in optimize function
'epsilon' # epsilon used in iptimize function
'keep_checkpoint_max' # max numbers to keep checkpoints
'save_checkpoint_epochs' # save checkpoints per n epoch
'lr_init' # init leaning rate
'lr_end' # end of learning rate
'lr_max' # max bound of learning rate
'warmup_epochs' # warmup epoch numbers
'start_epoch' # number of start epoch range[1, epoch_size]
```
## [Training process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend:
```bash
# distribute training example(8p)
sh scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE DATA_PATH DATA_DIR
# standalone training
sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
```
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
>
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
### Launch
```bash
# training example
shell:
Ascend:
# distribute training example(8p)
sh scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE DATA_PATH DATA_DIR
# standalone training
sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
```
### Result
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log.txt` like followings.
```python
epoch: 1 step: 1251, loss is 5.861846
Epoch time: 701416.649, per step time: 560.685
epoch: 2 step: 1251, loss is 4.295785
Epoch time: 472524.154, per step time: 377.717
epoch: 3 step: 1251, loss is 3.691987
Epoch time: 472505.767, per step time: 377.702
```
## [Eval process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend:
```bash
sh scripts/run_eval_ascend.sh DEVICE_ID DATA_DIR CHECKPOINT_PATH
```
### Launch
```bash
# eval example
shell:
Ascend:
sh scripts/run_eval_ascend.sh DEVICE_ID DATA_DIR CHECKPOINT_PATH
```
> checkpoint can be produced in training process.
### Result
Evaluation result will be stored in the example path, you can find result like the followings in `eval.log`.
```python
metric: {'Loss': 0.9849, 'Top1-Acc':0.7985, 'Top5-Acc':0.9460}
```
# [Model description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | InceptionV4 |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
| uploaded Date | 11/04/2020 |
| MindSpore Version | 1.0.0 |
| Dataset | 1200k images |
| Batch_size | 128 |
| Training Parameters | src/config.py |
| Optimizer | RMSProp |
| Loss Function | SoftmaxCrossEntropyWithLogits |
| Outputs | probability |
| Loss | 0.98486 |
| Accuracy (8p) | ACC1[79.85%] ACC5[94.60%] |
| Total time (8p) | 33h |
| Params (M) | 153M |
| Checkpoint for Fine tuning | 2135M |
| Scripts | [inceptionv4 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4) |
#### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | InceptionV4 |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
| Uploaded Date | 11/04/2020 |
| MindSpore Version | 1.0.0 |
| Dataset | 50k images |
| Batch_size | 128 |
| Outputs | probability |
| Accuracy | ACC1[79.85%] ACC5[94.60%] |
| Total time | 2mins |
| Model for inference | 2135M (.ckpt file) |
#### Training performance results
| **Ascend** | train performance |
| :--------: | :---------------: |
| 1p | 345 img/s |
| **Ascend** | train performance |
| :--------: | :---------------: |
| 8p | 2708img/s |
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluate_imagenet"""
import argparse
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from src.dataset import create_dataset
from src.inceptionv4 import Inceptionv4
from src.config import config_ascend as config
def parse_args():
'''parse_args'''
parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4')
args_opt = parser.parse_args()
return args_opt
if __name__ == '__main__':
args = parse_args()
if args.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
net = Inceptionv4(classes=config.num_classes)
ckpt = load_checkpoint(args.checkpoint_path)
load_param_into_net(net, ckpt)
net.set_train(False)
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False,
repeat_num=1, batch_size=config.batch_size)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
print('='*20, 'Evalute start', '='*20)
metrics = model.eval(dataset)
print("metric: ", metrics)

View File

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air and onnx models#################
"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config_ascend as config
from src.inceptionv4 import Inceptionv4
def parse_args():
'''parse_args'''
parser = argparse.ArgumentParser(description='checkpoint export')
parser.add_argument('--model_name', type=str, default='inceptionV4.air', help='convert model name of inceptionv4')
parser.add_argument('--format', type=str, default='AIR', help='convert model name of inceptionv4')
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inceptionv4')
_args_opt = parser.parse_args()
return _args_opt
if __name__ == '__main__':
args_opt = parse_args()
net = Inceptionv4(classes=config.num_classes)
param_dict = load_checkpoint(args_opt.checkpoint)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 299, 299]), ms.float32)
export(net, input_arr, file_name=args_opt.model_name, file_format=args_opt.format)

View File

@ -0,0 +1,49 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export RANK_TABLE_FILE=$1
DATA_DIR=$2
export RANK_SIZE=8
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $i, device $DEVICE_ID rank_id $RANK_ID"
env > env.log
taskset -c $cmdopt python -u ../train.py \
--device_id $i \
--dataset_path=$DATA_DIR > log.txt 2>&1 &
cd ../
done

View File

@ -0,0 +1,28 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
CHECKPOINT_PATH=$3
export RANK_SIZE=1
rm -rf evaluation_ascend
mkdir ./evaluation_ascend
cd ./evaluation_ascend || exit
echo "start training for device id $DEVICE_ID"
env > env.log
python ../eval.py --platform=Ascend --dataset_path=$DATA_DIR --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 &
cd ../

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export RANK_SIZE=1
export DEVICE_ID=$1
DATA_DIR=$2
rm -rf train_standalone
mkdir ./train_standalone
cd ./train_standalone || exit
echo "start training for device id $DEVICE_ID"
env > env.log
python -u ../train.py \
--device_id=$1 \
--dataset_path=$DATA_DIR > log.txt 2>&1 &
cd ../

View File

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""callback function"""
from mindspore.train.callback import Callback
class EvaluateCallBack(Callback):
"""EvaluateCallBack"""
def __init__(self, model, eval_dataset, per_print_time=1000):
super(EvaluateCallBack, self).__init__()
self.model = model
self.per_print_time = per_print_time
self.eval_dataset = eval_dataset
def step_end(self, run_context):
cb_params = run_context.original_args()
if cb_params.cur_step_num % self.per_print_time == 0:
result = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
print('cur epoch {}, cur_step {}, top1 accuracy {}, top5 accuracy {}.'.format(cb_params.cur_epoch_num,
cb_params.cur_step_num,
result['top_1_accuracy'],
result['top_5_accuracy']))
def epoch_end(self, run_context):
cb_params = run_context.original_args()
result = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
print('cur epoch {}, cur_step {}, top1 accuracy {}, top5 accuracy {}.'.format(cb_params.cur_epoch_num,
cb_params.cur_step_num,
result['top_1_accuracy'],
result['top_5_accuracy']))

View File

@ -0,0 +1,47 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
config_ascend = edict({
'is_save_on_master': False,
'batch_size': 128,
'epoch_size': 250,
'num_classes': 1000,
'work_nums': 8,
'loss_scale': 1024,
'smooth_factor': 0.1,
'weight_decay': 0.00004,
'momentum': 0.9,
'amp_level': 'O3',
'decay': 0.9,
'epsilon': 1.0,
'keep_checkpoint_max': 10,
'save_checkpoint_epochs': 10,
'lr_init': 0.00004,
'lr_end': 0.000004,
'lr_max': 0.4,
'warmup_epochs': 1,
'start_epoch': 1,
'onnx_filename': 'inceptionv4.onnx',
'air_filename': 'inceptionv4.air'
})

View File

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create train or eval dataset."""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from src.config import config_ascend as config
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
Create a train or eval dataset.
Args:
dataset_path (str): The path of dataset.
do_train (bool): Whether dataset is used for train or eval.
repeat_num (int): The repeat times of dataset. Default: 1.
batch_size (int): The batch size of dataset. Default: 32.
Returns:
Dataset.
"""
do_shuffle = bool(do_train)
if device_num == 1 or not do_train:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums,
shuffle=do_shuffle, num_shards=device_num, shard_id=device_id)
image_length = 299
if do_train:
trans = [
C.RandomCropDecodeResize(image_length, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
]
else:
trans = [
C.Decode(),
C.Resize(image_length),
C.CenterCrop(image_length)
]
trans += [
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,328 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InceptionV4"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import Initializer
class Avginitializer(Initializer):
"""
Initialize the weight to 1/m*n, (m, n) is the shape of kernel.
"""
def _initialize(self, arr):
arr[:] = 0
for i in range(arr.shape[0]):
for j in range(arr.shape[2]):
for k in range(arr.shape[3]):
arr[i][i][j][k] = 1/(arr.shape[2]*arr.shape[3])
class Avgpool(nn.Cell):
"""
Average pooling for temporal data.
Using a custom initializer to turn conv2d into avgpool2d. The weights won't be trained.
"""
def __init__(self, channel, kernel_size, stride=1, pad_mode='same'):
super(Avgpool, self).__init__()
self.init = Avginitializer()
self.conv = nn.Conv2d(channel, channel, kernel_size,
stride=stride, pad_mode=pad_mode, weight_init=self.init)
self.conv.set_train(False)
def construct(self, x):
x = self.conv(x)
return x
class Conv2d(nn.Cell):
"""
Set the default configuration for Conv2dBnAct
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='valid', padding=0,
has_bias=False, weight_init="XavierUniform", bias_init='zeros'):
super(Conv2d, self).__init__()
self.conv = nn.Conv2dBnAct(in_channels, out_channels, kernel_size, stride=stride, pad_mode=pad_mode,
padding=padding, weight_init=weight_init, bias_init=bias_init, has_bias=has_bias,
has_bn=True, activation="relu")
def construct(self, x):
x = self.conv(x)
return x
class Stem(nn.Cell):
"""
Inceptionv4 stem
"""
def __init__(self, in_channels):
super(Stem, self).__init__()
self.conv2d_1a_3x3 = Conv2d(
in_channels, 32, 3, stride=2, padding=0, has_bias=False)
self.conv2d_2a_3x3 = Conv2d(
32, 32, 3, stride=1, padding=0, has_bias=False)
self.conv2d_2b_3x3 = Conv2d(
32, 64, 3, stride=1, pad_mode='pad', padding=1, has_bias=False)
self.mixed_3a_branch_0 = nn.MaxPool2d(3, stride=2)
self.mixed_3a_branch_1 = Conv2d(
64, 96, 3, stride=2, padding=0, has_bias=False)
self.mixed_4a_branch_0 = nn.SequentialCell([
Conv2d(160, 64, 1, stride=1, padding=0, has_bias=False),
Conv2d(64, 96, 3, stride=1, padding=0, pad_mode='valid', has_bias=False)])
self.mixed_4a_branch_1 = nn.SequentialCell([
Conv2d(160, 64, 1, stride=1, padding=0, has_bias=False),
Conv2d(64, 64, (1, 7), pad_mode='same', stride=1, has_bias=False),
Conv2d(64, 64, (7, 1), pad_mode='same', stride=1, has_bias=False),
Conv2d(64, 96, 3, stride=1, padding=0, pad_mode='valid', has_bias=False)])
self.mixed_5a_branch_0 = Conv2d(
192, 192, 3, stride=2, padding=0, has_bias=False)
self.mixed_5a_branch_1 = nn.MaxPool2d(3, stride=2)
self.concat0 = P.Concat(1)
self.concat1 = P.Concat(1)
self.concat2 = P.Concat(1)
def construct(self, x):
"""construct"""
x = self.conv2d_1a_3x3(x) # 149 x 149 x 32
x = self.conv2d_2a_3x3(x) # 147 x 147 x 32
x = self.conv2d_2b_3x3(x) # 147 x 147 x 64
x0 = self.mixed_3a_branch_0(x)
x1 = self.mixed_3a_branch_1(x)
x = self.concat0((x0, x1)) # 73 x 73 x 160
x0 = self.mixed_4a_branch_0(x)
x1 = self.mixed_4a_branch_1(x)
x = self.concat1((x0, x1)) # 71 x 71 x 192
x0 = self.mixed_5a_branch_0(x)
x1 = self.mixed_5a_branch_1(x)
x = self.concat2((x0, x1)) # 35 x 35 x 384
return x
class InceptionA(nn.Cell):
"""InceptionA"""
def __init__(self, in_channels):
super(InceptionA, self).__init__()
self.branch_0 = Conv2d(
in_channels, 96, 1, stride=1, padding=0, has_bias=False)
self.branch_1 = nn.SequentialCell([
Conv2d(in_channels, 64, 1, stride=1, padding=0, has_bias=False),
Conv2d(64, 96, 3, stride=1, pad_mode='pad', padding=1, has_bias=False)])
self.branch_2 = nn.SequentialCell([
Conv2d(in_channels, 64, 1, stride=1, padding=0, has_bias=False),
Conv2d(64, 96, 3, stride=1, pad_mode='pad',
padding=1, has_bias=False),
Conv2d(96, 96, 3, stride=1, pad_mode='pad', padding=1, has_bias=False)])
self.branch_3 = nn.SequentialCell([
Avgpool(384, kernel_size=3, stride=1, pad_mode='same'),
Conv2d(384, 96, 1, stride=1, padding=0, has_bias=False)])
self.concat = P.Concat(1)
def construct(self, x):
x0 = self.branch_0(x)
x1 = self.branch_1(x)
x2 = self.branch_2(x)
x3 = self.branch_3(x)
x4 = self.concat((x0, x1, x2, x3))
return x4
class InceptionB(nn.Cell):
"""InceptionB"""
def __init__(self, in_channels):
super(InceptionB, self).__init__()
self.branch_0 = Conv2d(in_channels, 384, 1,
stride=1, padding=0, has_bias=False)
self.branch_1 = nn.SequentialCell([
Conv2d(in_channels, 192, 1, stride=1, padding=0, has_bias=False),
Conv2d(192, 224, (1, 7), pad_mode='same',
stride=1, has_bias=False),
Conv2d(224, 256, (7, 1), pad_mode='same',
stride=1, has_bias=False),
])
self.branch_2 = nn.SequentialCell([
Conv2d(in_channels, 192, 1, stride=1, padding=0, has_bias=False),
Conv2d(192, 192, (7, 1), pad_mode='same',
stride=1, has_bias=False),
Conv2d(192, 224, (1, 7), pad_mode='same',
stride=1, has_bias=False),
Conv2d(224, 224, (7, 1), pad_mode='same',
stride=1, has_bias=False),
Conv2d(224, 256, (1, 7), pad_mode='same', stride=1, has_bias=False)
])
self.branch_3 = nn.SequentialCell([
Avgpool(in_channels, kernel_size=3, stride=1, pad_mode='same'),
Conv2d(in_channels, 128, 1, stride=1, padding=0, has_bias=False)
])
self.concat = P.Concat(1)
def construct(self, x):
x0 = self.branch_0(x)
x1 = self.branch_1(x)
x2 = self.branch_2(x)
x3 = self.branch_3(x)
x4 = self.concat((x0, x1, x2, x3))
return x4
class ReductionA(nn.Cell):
"""ReductionA"""
def __init__(self, in_channels, k, l, m, n):
super(ReductionA, self).__init__()
self.branch_0 = Conv2d(in_channels, n, 3, stride=2, padding=0)
self.branch_1 = nn.SequentialCell([
Conv2d(in_channels, k, 1, stride=1, padding=0, has_bias=False),
Conv2d(k, l, 3, stride=1, pad_mode='pad',
padding=1, has_bias=False),
Conv2d(l, m, 3, stride=2, padding=0, has_bias=False),
])
self.branch_2 = nn.MaxPool2d(3, stride=2)
self.concat = P.Concat(1)
def construct(self, x):
x0 = self.branch_0(x)
x1 = self.branch_1(x)
x2 = self.branch_2(x)
x3 = self.concat((x0, x1, x2))
return x3 # 17 x 17 x 1024
class ReductionB(nn.Cell):
"""ReductionB"""
def __init__(self, in_channels):
super(ReductionB, self).__init__()
self.branch_0 = nn.SequentialCell([
Conv2d(in_channels, 192, 1, stride=1, padding=0, has_bias=False),
Conv2d(192, 192, 3, stride=2, padding=0, has_bias=False),
])
self.branch_1 = nn.SequentialCell([
Conv2d(in_channels, 256, 1, stride=1, padding=0, has_bias=False),
Conv2d(256, 256, (1, 7), pad_mode='same',
stride=1, has_bias=False),
Conv2d(256, 320, (7, 1), pad_mode='same',
stride=1, has_bias=False),
Conv2d(320, 320, 3, stride=2, padding=0, has_bias=False)
])
self.branch_2 = nn.MaxPool2d(3, stride=2)
self.concat = P.Concat(1)
def construct(self, x):
x0 = self.branch_0(x)
x1 = self.branch_1(x)
x2 = self.branch_2(x)
x3 = self.concat((x0, x1, x2))
return x3 # 8 x 8 x 1536
class InceptionC(nn.Cell):
"""InceptionC"""
def __init__(self, in_channels):
super(InceptionC, self).__init__()
self.branch_0 = Conv2d(in_channels, 256, 1,
stride=1, padding=0, has_bias=False)
self.branch_1 = Conv2d(in_channels, 384, 1,
stride=1, padding=0, has_bias=False)
self.branch_1_1 = Conv2d(
384, 256, (1, 3), pad_mode='same', stride=1, has_bias=False)
self.branch_1_2 = Conv2d(
384, 256, (3, 1), pad_mode='same', stride=1, has_bias=False)
self.branch_2 = nn.SequentialCell([
Conv2d(in_channels, 384, 1, stride=1, padding=0, has_bias=False),
Conv2d(384, 448, (3, 1), pad_mode='same',
stride=1, has_bias=False),
Conv2d(448, 512, (1, 3), pad_mode='same',
stride=1, has_bias=False),
])
self.branch_2_1 = Conv2d(
512, 256, (1, 3), pad_mode='same', stride=1, has_bias=False)
self.branch_2_2 = Conv2d(
512, 256, (3, 1), pad_mode='same', stride=1, has_bias=False)
self.branch_3 = nn.SequentialCell([
Avgpool(in_channels, kernel_size=3, stride=1, pad_mode='same'),
Conv2d(in_channels, 256, 1, stride=1, padding=0, has_bias=False)
])
self.concat0 = P.Concat(1)
self.concat1 = P.Concat(1)
self.concat2 = P.Concat(1)
def construct(self, x):
"""construct"""
x0 = self.branch_0(x)
x1 = self.branch_1(x)
x1_1 = self.branch_1_1(x1)
x1_2 = self.branch_1_2(x1)
x1 = self.concat0((x1_1, x1_2))
x2 = self.branch_2(x)
x2_1 = self.branch_2_1(x2)
x2_2 = self.branch_2_2(x2)
x2 = self.concat1((x2_1, x2_2))
x3 = self.branch_3(x)
return self.concat2((x0, x1, x2, x3)) # 8 x 8 x 1536
class Inceptionv4(nn.Cell):
"""
Inceptionv4 architecture
Args.
is_train : in train mode, turn on the dropout.
"""
def __init__(self, in_channels=3, classes=1000, k=192, l=224, m=256, n=384, is_train=True):
super(Inceptionv4, self).__init__()
blocks = []
blocks.append(Stem(in_channels))
for _ in range(4):
blocks.append(InceptionA(384))
blocks.append(ReductionA(384, k, l, m, n))
for _ in range(7):
blocks.append(InceptionB(1024))
blocks.append(ReductionB(1024))
for _ in range(3):
blocks.append(InceptionC(1536))
self.features = nn.SequentialCell(blocks)
self.avgpool = P.ReduceMean(keep_dims=False)
self.softmax = nn.DenseBnAct(
1536, classes, weight_init="XavierUniform", has_bias=True, has_bn=True, activation="logsoftmax")
if is_train:
self.dropout = nn.Dropout(0.20)
else:
self.dropout = nn.Dropout(1)
self.bn0 = nn.BatchNorm1d(1536, eps=0.001, momentum=0.1)
def construct(self, x):
x = self.features(x)
x = self.avgpool(x, (2, 3))
x = self.bn0(x)
x = self.dropout(x)
x = self.softmax(x)
return x

View File

@ -0,0 +1,167 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train imagenet"""
import os
import argparse
import math
import numpy as np
from mindspore.communication import init, get_rank
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.model import ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn import RMSProp
from mindspore import Tensor
from mindspore import context
from mindspore.common import set_seed
from mindspore.common.initializer import XavierUniform, initializer
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.inceptionv4 import Inceptionv4
from src.dataset import create_dataset, device_num
from src.config import config_ascend as config
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
set_seed(1)
def generate_cosine_lr(steps_per_epoch, total_epochs,
lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs):
"""
Applies cosine decay to generate learning rate array.
Args:
steps_per_epoch(int): steps number per epoch
total_epochs(int): all epoch in training.
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
lr = (lr_max - lr_end) * cosine_decay + lr_end
lr_each_step.append(lr)
learning_rate = np.array(lr_each_step).astype(np.float32)
current_step = steps_per_epoch * (config.start_epoch - 1)
learning_rate = learning_rate[current_step:]
return learning_rate
def inception_v4_train():
"""
Train Inceptionv4 in data parallelism
"""
print('epoch_size: {} batch_size: {} class_num {}'.format(config.epoch_size, config.batch_size, config.num_classes))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=args.device_id)
context.set_context(enable_graph_kernel=False)
rank = 0
if device_num > 1:
init(backend_name='hccl')
rank = get_rank()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
all_reduce_fusion_config=[200, 400])
# create dataset
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True,
repeat_num=1, batch_size=config.batch_size)
train_step_size = train_dataset.get_dataset_size()
# create model
net = Inceptionv4(classes=config.num_classes)
# loss
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# learning rate
lr = Tensor(generate_cosine_lr(steps_per_epoch=train_step_size, total_epochs=config.epoch_size))
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
param.set_data(initializer(XavierUniform(), param.data.shape, param.data.dtype))
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
opt = RMSProp(group_params, lr, decay=config.decay, epsilon=config.epsilon, weight_decay=config.weight_decay,
momentum=config.momentum, loss_scale=config.loss_scale)
if args.device_id == 0:
print(lr)
print(train_step_size)
if args.resume:
ckpt = load_checkpoint(args.resume)
load_param_into_net(net, ckpt)
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={
'acc', 'top_1_accuracy', 'top_5_accuracy'}, loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
# define callbacks
performance_cb = TimeMonitor(data_size=train_step_size)
loss_cb = LossMonitor(per_print_times=train_step_size)
ckp_save_step = config.save_checkpoint_epochs * train_step_size
config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{rank}",
directory='ckpts_rank_' + str(rank), config=config_ck)
callbacks = [performance_cb, loss_cb]
if device_num > 1 and config.is_save_on_master:
if args.device_id == 0:
callbacks.append(ckpoint_cb)
else:
callbacks.append(ckpoint_cb)
# train model
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
def parse_args():
'''parse_args'''
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
args_opt = arg_parser.parse_args()
return args_opt
if __name__ == '__main__':
args = parse_args()
inception_v4_train()
print('Inceptionv4 training success!')

View File

@ -0,0 +1 @@
# recommend