From: @Gogery
Reviewed-by: @guoqi1024
Signed-off-by: @guoqi1024
This commit is contained in:
mindspore-ci-bot 2020-12-11 18:26:30 +08:00 committed by Gitee
commit 45f07683e7
11 changed files with 985 additions and 0 deletions

View File

@ -0,0 +1,234 @@
# Contents
- [Contents](#contents)
- [Xception Description](#xception-description)
- [Model architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision(Ascend)](#mixed-precisionascend)
- [Environment Requirements](#environment-requirements)
- [Script description](#script-description)
- [Script and sample code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training process](#training-process)
- [Usage](#usage)
- [Launch](#launch)
- [Result](#result)
- [Eval process](#eval-process)
- [Usage](#usage-1)
- [Launch](#launch-1)
- [Result](#result-1)
- [Model description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Xception Description](#contents)
Xception by Google is extreme version of Inception. With a modified depthwise separable convolution, it is even better than Inception-v3. This paper was published in 2017.
[Paper](https://arxiv.org/pdf/1610.02357v3.pdf) Franois Chollet. Xception: Deep Learning with Depthwise Separable Convolutions. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) IEEE, 2017.
# [Model architecture](#contents)
The overall network architecture of Xception is show below:
[Link](https://arxiv.org/pdf/1610.02357v3.pdf)
# [Dataset](#contents)
Dataset used can refer to paper.
- Dataset size: 125G, 1250k colorful images in 1000 classes
- Train: 120G, 1200k images
- Test: 5G, 50k images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
# [Features](#contents)
## [Mixed Precision(Ascend)](#contents)
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```shell
.
└─Xception
├─README.md
├─scripts
├─run_standalone_train.sh # launch standalone training with ascend platform(1p)
├─run_distribute_train.sh # launch distributed training with ascend platform(8p)
└─run_eval.sh # launch evaluating with ascend platform
├─src
├─config.py # parameter configuration
├─dataset.py # data preprocessing
├─Xception.py # network definition
├─CrossEntropySmooth.py # Customized CrossEntropy loss function
└─lr_generator.py # learning rate generator
├─train.py # train net
└─eval.py # eval net
```
## [Script Parameters](#contents)
```python
Major parameters in train.py and config.py are:
'num_classes': 1000 # dataset class numbers
'batch_size': 128 # input batchsize
'loss_scale': 1024 # loss scale
'momentum': 0.9 # momentum
'weight_decay': 1e-4 # weight decay
'epoch_size': 250 # total epoch numbers
'save_checkpoint': True # save checkpoint
'save_checkpoint_epochs': 1 # save checkpoint epochs
'keep_checkpoint_max': 5 # max numbers to keep checkpoints
'save_checkpoint_path': "./" # save checkpoint path
'warmup_epochs': 1 # warmup epoch numbers
'lr_decay_mode': "liner" # lr decay mode
'use_label_smooth': True # use label smooth
'finish_epoch': 0 # finished epochs numbers
'label_smooth_factor': 0.1 # label smoothing factor
'lr_init': 0.00004 # initiate learning rate
'lr_max': 0.4 # max bound of learning rate
'lr_end': 0.00004 # min bound of learning rate
"weight_init": 'xavier_uniform' # Weight initialization mode
```
## [Training process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend:
```shell
# distribute training example(8p)
sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
# standalone training
sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
```
> Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
### Launch
``` shell
# training example
python:
Ascend:
python train.py --device_target Ascend --dataset_path /dataset/train
shell:
Ascend:
# distribute training example(8p)
sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
# standalone training
sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
```
### Result
Training result will be stored in the example path. Checkpoints will be stored at `. /model_0` by default, and training log will be redirected to `log.txt` like followings.
``` shell
epoch: [ 0/250], step:[ 1250/ 1251], loss:[4.761/5.613], time:[529.305], lr:[0.400]
epoch time: 1128662.862, per step time: 902.209, avg loss: 5.609
epoch: [ 1/250], step:[ 1250/ 1251], loss:[4.164/4.318], time:[503.708], lr:[0.398]
epoch time: 889163.081, per step time: 710.762, avg loss: 4.312
```
## [Eval process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
```shell
sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
```
### Launch
```shell
# eval example
python:
Ascend: python eval.py --device_target Ascend --checkpoint_path PATH_CHECKPOINT --dataset_path DATA_DIR
shell:
Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
```
> checkpoint can be produced in training process.
### Result
Evaluation result will be stored in the example path, you can find result like the followings in `eval.log`.
```shell
result: {'Loss': 1.7797744848789312, 'Top_1_Acc': 0.7985777243589743, 'Top_5_Acc': 0.9485777243589744}
```
# [Model description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Ascend |
| -------------------------- | ---------------------------------------------- |
| Model Version | Xception |
| Resource | HUAWEI CLOUD Modelarts |
| uploaded Date | 11/15/2020 |
| MindSpore Version | 1.0.0 |
| Dataset | 1200k images |
| Batch_size | 128 |
| Training Parameters | src/config.py |
| Optimizer | Momentum |
| Loss Function | CrossEntropySmooth |
| Loss | 1.78 |
| Accuracy (8p) | Top1[79.9%] Top5[94.9%] |
| Total time (8p) | 63h |
| Params (M) | 180M |
| Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/Xception) |
#### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | Xception |
| Resource | HUAWEI CLOUD Modelarts |
| Uploaded Date | 11/15/2020 |
| MindSpore Version | 1.0.0 |
| Dataset | 50k images |
| Batch_size | 128 |
| Accuracy | Top1[79.9%] Top5[94.9%] |
| Total time | 3mins |
# [Description of Random Situation](#contents)
In `dataset.py`, we set the seed inside `create_dataset` function. We also use random seed in `train.py`.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,63 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval Xception."""
import argparse
from mindspore import context, nn
from mindspore.train.model import Model
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.Xception import xception
from src.config import config
from src.dataset import create_dataset
from src.loss import CrossEntropySmooth
set_seed(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--device_id', type=int, default=0, help='Device id')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
context.set_context(device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
# create dataset
dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=config.batch_size, device_num=1, rank=0)
step_size = dataset.get_dataset_size()
# define net
net = xception(class_num=config.class_num)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
# define model
eval_metrics = {'Loss': nn.Loss(),
'Top_1_Acc': nn.Top1CategoricalAccuracy(),
'Top_5_Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss_fn=loss, metrics=eval_metrics)
# eval model
res = model.eval(dataset, dataset_sink_mode=False)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,50 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$2
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python ../train.py \
--is_distributed \
--device_target=Ascend \
--dataset_path=$DATA_DIR > log.txt 2>&1 &
cd ../
done

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
python ./eval.py \
--device_target=Ascend \
--device_id=$DEVICE_ID \
--checkpoint_path=$PATH_CHECKPOINT \
--dataset_path=$DATA_DIR > eval.log 2>&1 &

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
python ./train.py \
--device_target=Ascend \
--dataset_path=$DATA_DIR > log.txt 2>&1 &

View File

@ -0,0 +1,186 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Xception."""
import mindspore.nn as nn
import mindspore.ops.operations as P
from src.config import config
class SeparableConv2d(nn.Cell):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, group=in_channels, pad_mode='pad',
padding=padding, weight_init=config.weight_init)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='valid',
weight_init=config.weight_init)
def construct(self, x):
x = self.conv1(x)
x = self.pointwise(x)
return x
class Block(nn.Cell):
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
super(Block, self).__init__()
if out_filters != in_filters or strides != 1:
self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, pad_mode='valid', has_bias=False,
weight_init=config.weight_init)
self.skipbn = nn.BatchNorm2d(out_filters, momentum=0.9)
else:
self.skip = None
self.relu = nn.ReLU()
rep = []
filters = in_filters
if grow_first:
rep.append(nn.ReLU())
rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
rep.append(nn.BatchNorm2d(out_filters, momentum=0.9))
filters = out_filters
for _ in range(reps - 1):
rep.append(nn.ReLU())
rep.append(SeparableConv2d(filters, filters, kernel_size=3, stride=1, padding=1))
rep.append(nn.BatchNorm2d(filters, momentum=0.9))
if not grow_first:
rep.append(nn.ReLU())
rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
rep.append(nn.BatchNorm2d(out_filters, momentum=0.9))
if not start_with_relu:
rep = rep[1:]
else:
rep[0] = nn.ReLU()
if strides != 1:
rep.append(nn.MaxPool2d(3, strides, pad_mode="same"))
self.rep = nn.SequentialCell(*rep)
self.add = P.TensorAdd()
def construct(self, inp):
x = self.rep(inp)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x = self.add(x, skip)
return x
class Xception(nn.Cell):
"""
Xception optimized for the ImageNet dataset, as specified in
https://arxiv.org/abs/1610.02357.pdf
"""
def __init__(self, num_classes=1000):
""" Constructor
Args:
num_classes: number of classes.
"""
super(Xception, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 32, 3, 2, pad_mode='valid', weight_init=config.weight_init)
self.bn1 = nn.BatchNorm2d(32, momentum=0.9)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='valid', weight_init=config.weight_init)
self.bn2 = nn.BatchNorm2d(64, momentum=0.9)
# Entry flow
self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
# Middle flow
self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
# Exit flow
self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
self.bn3 = nn.BatchNorm2d(1536, momentum=0.9)
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
self.bn4 = nn.BatchNorm2d(2048, momentum=0.9)
self.avg_pool = nn.AvgPool2d(10)
self.dropout = nn.Dropout()
self.fc = nn.Dense(2048, num_classes)
def construct(self, x):
shape = P.Shape()
reshape = P.Reshape()
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.avg_pool(x)
x = self.dropout(x)
x = reshape(x, (shape(x)[0], -1))
x = self.fc(x)
return x
def xception(class_num=1000):
"""
Get Xception neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of Xception neural network.
Examples:
>>> net = xception(1000)
"""
return Xception(class_num)

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
# config for Xception, imagenet2012.
config = ed({
"class_num": 1000,
"batch_size": 128,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 250,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 5,
"save_checkpoint_path": "./",
"warmup_epochs": 1,
"lr_decay_mode": "liner",
"use_label_smooth": True,
"finish_epoch": 0,
"label_smooth_factor": 0.1,
"lr_init": 0.00004,
"lr_max": 0.4,
"lr_end": 0.00004,
"weight_init": 'xavier_uniform'
})

View File

@ -0,0 +1,66 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
def create_dataset(dataset_path, do_train, batch_size=16, device_num=1, rank=0):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
batch_size(int): the batch size of dataset. Default: 16.
device_num (int): Number of shards that the dataset should be divided into (default=1).
rank (int): The shard ID within num_shards (default=0).
Returns:
dataset
"""
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank)
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(299),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
]
else:
trans = [
C.Decode(),
C.Resize(320),
C.CenterCrop(299)
]
trans += [
C.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
C.HWC2CHW(),
C2.TypeCast(mstype.float32)
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step

View File

@ -0,0 +1,173 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train Xception."""
import os
import time
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import dtype as mstype
from mindspore.common import set_seed
from src.lr_generator import get_lr
from src.Xception import xception
from src.config import config
from src.dataset import create_dataset
from src.loss import CrossEntropySmooth
set_seed(1)
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> Monitor(lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
cb_params.cur_epoch_num - 1 + config.finish_epoch, cb_params.epoch_num + config.finish_epoch,
cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification training')
parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training')
parser.add_argument('--device_target', type=str, default='Ascend', help='run platform')
parser.add_argument('--dataset_path', type=str, default=None, help='dataset path')
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
args_opt = parser.parse_args()
# init distributed
if args_opt.is_distributed:
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
rank = get_rank()
group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
init()
else:
rank = 0
group_size = 1
context.set_context(device_id=0)
if args_opt.device_target == "Ascend":
#train on Ascend
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
# define network
net = xception(class_num=config.class_num)
net.to_float(mstype.float16)
# define loss
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
# define dataset
dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size,
device_num=group_size, rank=rank)
step_size = dataset.get_dataset_size()
# resume
if args_opt.resume:
ckpt = load_checkpoint(args_opt.resume)
load_param_into_net(net, ckpt)
# get learning rate
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode))
# define optimization
opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale)
# define model
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level='O3', keep_batchnorm_fp32=True)
# define callbacks
cb = [Monitor(lr_init=lr.asnumpy())]
if config.save_checkpoint:
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(rank) + '/')
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck)
# begin train
if args_opt.is_distributed:
if rank == 0:
cb += [ckpt_cb]
model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=False)
else:
cb += [ckpt_cb]
model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=False)
print("train success")
else:
raise ValueError("Unsupported device_target.")