forked from mindspore-Ecosystem/mindspore
!9571 Add Inceptionv4 net to model_zoo/official/cv/
From: @zhanghuiyao Reviewed-by: Signed-off-by:
This commit is contained in:
commit
b2a164b1c2
|
@ -0,0 +1,240 @@
|
||||||
|
# InceptionV4 for Ascend
|
||||||
|
|
||||||
|
- [InceptionV4 Description](#InceptionV4-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Mixed Precision](#mixed-precision)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Training Performance](#evaluation-performance)
|
||||||
|
- [Inference Performance](#evaluation-performance)
|
||||||
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
# [InceptionV4 Description](#contents)
|
||||||
|
|
||||||
|
Inception-v4 is a convolutional neural network architecture that builds on previous iterations of the Inception family by simplifying the architecture and using more inception modules than Inception-v3. This idea was proposed in the paper Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning, published in 2016.
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/pdf/1602.07261.pdf) Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi. Computer Vision and Pattern Recognition[J]. 2016.
|
||||||
|
|
||||||
|
# [Model architecture](#contents)
|
||||||
|
|
||||||
|
The overall network architecture of InceptionV4 is show below:
|
||||||
|
|
||||||
|
[Link](https://arxiv.org/pdf/1602.07261.pdf)
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
Dataset used can refer to paper.
|
||||||
|
|
||||||
|
- Dataset size: 125G, 1250k colorful images in 1000 classes
|
||||||
|
- Train: 120G, 1200k images
|
||||||
|
- Test: 5G, 50k images
|
||||||
|
- Data format: RGB images.
|
||||||
|
- Note: Data will be processed in src/dataset.py
|
||||||
|
|
||||||
|
# [Features](#contents)
|
||||||
|
|
||||||
|
## [Mixed Precision(Ascend)](#contents)
|
||||||
|
|
||||||
|
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||||
|
|
||||||
|
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware(Ascend)
|
||||||
|
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
|
# [Script description](#contents)
|
||||||
|
|
||||||
|
## [Script and sample code](#contents)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
.
|
||||||
|
└─Inception-v4
|
||||||
|
├─README.md
|
||||||
|
├─scripts
|
||||||
|
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
|
||||||
|
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
|
||||||
|
└─run_eval_ascend.sh # launch evaluating with ascend platform
|
||||||
|
├─src
|
||||||
|
├─config.py # parameter configuration
|
||||||
|
├─dataset.py # data preprocessing
|
||||||
|
├─inceptionv4.py # network definition
|
||||||
|
└─callback.py # eval callback function
|
||||||
|
├─eval.py # eval net
|
||||||
|
├─export.py # export checkpoint, surpport .onnx, .air, .mindir convert
|
||||||
|
└─train.py # train net
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
```python
|
||||||
|
Major parameters in train.py and config.py are:
|
||||||
|
'is_save_on_master' # save checkpoint only on master device
|
||||||
|
'batch_size' # input batchsize
|
||||||
|
'epoch_size' # total epoch numbers
|
||||||
|
'num_classes' # dataset class numbers
|
||||||
|
'work_nums' # number of workers to read data
|
||||||
|
'loss_scale' # loss scale
|
||||||
|
'smooth_factor' # label smoothing factor
|
||||||
|
'weight_decay' # weight decay
|
||||||
|
'momentum' # momentum
|
||||||
|
'amp_level' # precision training, Supports [O0, O2, O3]
|
||||||
|
'decay' # decay used in optimize function
|
||||||
|
'epsilon' # epsilon used in iptimize function
|
||||||
|
'keep_checkpoint_max' # max numbers to keep checkpoints
|
||||||
|
'save_checkpoint_epochs' # save checkpoints per n epoch
|
||||||
|
'lr_init' # init leaning rate
|
||||||
|
'lr_end' # end of learning rate
|
||||||
|
'lr_max' # max bound of learning rate
|
||||||
|
'warmup_epochs' # warmup epoch numbers
|
||||||
|
'start_epoch' # number of start epoch range[1, epoch_size]
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Training process](#contents)
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||||
|
|
||||||
|
- Ascend:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# distribute training example(8p)
|
||||||
|
sh scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE DATA_PATH DATA_DIR
|
||||||
|
# standalone training
|
||||||
|
sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
|
||||||
|
```
|
||||||
|
|
||||||
|
> Notes:
|
||||||
|
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||||
|
>
|
||||||
|
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
|
||||||
|
|
||||||
|
### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# training example
|
||||||
|
shell:
|
||||||
|
Ascend:
|
||||||
|
# distribute training example(8p)
|
||||||
|
sh scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE DATA_PATH DATA_DIR
|
||||||
|
# standalone training
|
||||||
|
sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
|
||||||
|
```
|
||||||
|
|
||||||
|
### Result
|
||||||
|
|
||||||
|
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log.txt` like followings.
|
||||||
|
|
||||||
|
```python
|
||||||
|
epoch: 1 step: 1251, loss is 5.861846
|
||||||
|
Epoch time: 701416.649, per step time: 560.685
|
||||||
|
epoch: 2 step: 1251, loss is 4.295785
|
||||||
|
Epoch time: 472524.154, per step time: 377.717
|
||||||
|
epoch: 3 step: 1251, loss is 3.691987
|
||||||
|
Epoch time: 472505.767, per step time: 377.702
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Eval process](#contents)
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||||
|
|
||||||
|
- Ascend:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sh scripts/run_eval_ascend.sh DEVICE_ID DATA_DIR CHECKPOINT_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# eval example
|
||||||
|
shell:
|
||||||
|
Ascend:
|
||||||
|
sh scripts/run_eval_ascend.sh DEVICE_ID DATA_DIR CHECKPOINT_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
> checkpoint can be produced in training process.
|
||||||
|
|
||||||
|
### Result
|
||||||
|
|
||||||
|
Evaluation result will be stored in the example path, you can find result like the followings in `eval.log`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
metric: {'Loss': 0.9849, 'Top1-Acc':0.7985, 'Top5-Acc':0.9460}
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Model description](#contents)
|
||||||
|
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Training Performance
|
||||||
|
|
||||||
|
| Parameters | Ascend |
|
||||||
|
| -------------------------- | ------------------------------------------------------------ |
|
||||||
|
| Model Version | InceptionV4 |
|
||||||
|
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||||
|
| uploaded Date | 11/04/2020 |
|
||||||
|
| MindSpore Version | 1.0.0 |
|
||||||
|
| Dataset | 1200k images |
|
||||||
|
| Batch_size | 128 |
|
||||||
|
| Training Parameters | src/config.py |
|
||||||
|
| Optimizer | RMSProp |
|
||||||
|
| Loss Function | SoftmaxCrossEntropyWithLogits |
|
||||||
|
| Outputs | probability |
|
||||||
|
| Loss | 0.98486 |
|
||||||
|
| Accuracy (8p) | ACC1[79.85%] ACC5[94.60%] |
|
||||||
|
| Total time (8p) | 33h |
|
||||||
|
| Params (M) | 153M |
|
||||||
|
| Checkpoint for Fine tuning | 2135M |
|
||||||
|
| Scripts | [inceptionv4 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4) |
|
||||||
|
|
||||||
|
#### Inference Performance
|
||||||
|
|
||||||
|
| Parameters | Ascend |
|
||||||
|
| ------------------- | --------------------------- |
|
||||||
|
| Model Version | InceptionV4 |
|
||||||
|
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||||
|
| Uploaded Date | 11/04/2020 |
|
||||||
|
| MindSpore Version | 1.0.0 |
|
||||||
|
| Dataset | 50k images |
|
||||||
|
| Batch_size | 128 |
|
||||||
|
| Outputs | probability |
|
||||||
|
| Accuracy | ACC1[79.85%] ACC5[94.60%] |
|
||||||
|
| Total time | 2mins |
|
||||||
|
| Model for inference | 2135M (.ckpt file) |
|
||||||
|
|
||||||
|
#### Training performance results
|
||||||
|
|
||||||
|
| **Ascend** | train performance |
|
||||||
|
| :--------: | :---------------: |
|
||||||
|
| 1p | 345 img/s |
|
||||||
|
|
||||||
|
| **Ascend** | train performance |
|
||||||
|
| :--------: | :---------------: |
|
||||||
|
| 8p | 2708img/s |
|
||||||
|
|
||||||
|
# [Description of Random Situation](#contents)
|
||||||
|
|
||||||
|
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""evaluate_imagenet"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
|
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.inceptionv4 import Inceptionv4
|
||||||
|
from src.config import config_ascend as config
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
'''parse_args'''
|
||||||
|
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||||
|
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.platform == 'Ascend':
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||||
|
net = Inceptionv4(classes=config.num_classes)
|
||||||
|
ckpt = load_checkpoint(args.checkpoint_path)
|
||||||
|
load_param_into_net(net, ckpt)
|
||||||
|
net.set_train(False)
|
||||||
|
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False,
|
||||||
|
repeat_num=1, batch_size=config.batch_size)
|
||||||
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||||
|
eval_metrics = {'Loss': nn.Loss(),
|
||||||
|
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||||
|
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||||
|
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||||
|
print('='*20, 'Evalute start', '='*20)
|
||||||
|
metrics = model.eval(dataset)
|
||||||
|
print("metric: ", metrics)
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
##############export checkpoint file into air and onnx models#################
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||||
|
|
||||||
|
from src.config import config_ascend as config
|
||||||
|
from src.inceptionv4 import Inceptionv4
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
'''parse_args'''
|
||||||
|
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||||
|
parser.add_argument('--model_name', type=str, default='inceptionV4.air', help='convert model name of inceptionv4')
|
||||||
|
parser.add_argument('--format', type=str, default='AIR', help='convert model name of inceptionv4')
|
||||||
|
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inceptionv4')
|
||||||
|
_args_opt = parser.parse_args()
|
||||||
|
return _args_opt
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
args_opt = parse_args()
|
||||||
|
|
||||||
|
net = Inceptionv4(classes=config.num_classes)
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 299, 299]), ms.float32)
|
||||||
|
export(net, input_arr, file_name=args_opt.model_name, file_format=args_opt.format)
|
|
@ -0,0 +1,49 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export RANK_TABLE_FILE=$1
|
||||||
|
DATA_DIR=$2
|
||||||
|
export RANK_SIZE=8
|
||||||
|
|
||||||
|
|
||||||
|
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||||
|
echo "the number of logical core" $cores
|
||||||
|
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||||
|
core_gap=`expr $avg_core_per_rank \- 1`
|
||||||
|
echo "avg_core_per_rank" $avg_core_per_rank
|
||||||
|
echo "core_gap" $core_gap
|
||||||
|
for((i=0;i<RANK_SIZE;i++))
|
||||||
|
do
|
||||||
|
start=`expr $i \* $avg_core_per_rank`
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEPLOY_MODE=0
|
||||||
|
export GE_USE_STATIC_MEMORY=1
|
||||||
|
end=`expr $start \+ $core_gap`
|
||||||
|
cmdopt=$start"-"$end
|
||||||
|
|
||||||
|
rm -rf train_parallel$i
|
||||||
|
mkdir ./train_parallel$i
|
||||||
|
cp *.py ./train_parallel$i
|
||||||
|
cd ./train_parallel$i || exit
|
||||||
|
echo "start training for rank $i, device $DEVICE_ID rank_id $RANK_ID"
|
||||||
|
|
||||||
|
env > env.log
|
||||||
|
taskset -c $cmdopt python -u ../train.py \
|
||||||
|
--device_id $i \
|
||||||
|
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||||
|
cd ../
|
||||||
|
done
|
|
@ -0,0 +1,28 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
DATA_DIR=$2
|
||||||
|
CHECKPOINT_PATH=$3
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
rm -rf evaluation_ascend
|
||||||
|
mkdir ./evaluation_ascend
|
||||||
|
cd ./evaluation_ascend || exit
|
||||||
|
echo "start training for device id $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python ../eval.py --platform=Ascend --dataset_path=$DATA_DIR --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 &
|
||||||
|
cd ../
|
|
@ -0,0 +1,29 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export RANK_SIZE=1
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
DATA_DIR=$2
|
||||||
|
|
||||||
|
rm -rf train_standalone
|
||||||
|
mkdir ./train_standalone
|
||||||
|
cd ./train_standalone || exit
|
||||||
|
echo "start training for device id $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python -u ../train.py \
|
||||||
|
--device_id=$1 \
|
||||||
|
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||||
|
cd ../
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""callback function"""
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluateCallBack(Callback):
|
||||||
|
"""EvaluateCallBack"""
|
||||||
|
def __init__(self, model, eval_dataset, per_print_time=1000):
|
||||||
|
super(EvaluateCallBack, self).__init__()
|
||||||
|
self.model = model
|
||||||
|
self.per_print_time = per_print_time
|
||||||
|
self.eval_dataset = eval_dataset
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
if cb_params.cur_step_num % self.per_print_time == 0:
|
||||||
|
result = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
|
||||||
|
print('cur epoch {}, cur_step {}, top1 accuracy {}, top5 accuracy {}.'.format(cb_params.cur_epoch_num,
|
||||||
|
cb_params.cur_step_num,
|
||||||
|
result['top_1_accuracy'],
|
||||||
|
result['top_5_accuracy']))
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
result = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
|
||||||
|
print('cur epoch {}, cur_step {}, top1 accuracy {}, top5 accuracy {}.'.format(cb_params.cur_epoch_num,
|
||||||
|
cb_params.cur_step_num,
|
||||||
|
result['top_1_accuracy'],
|
||||||
|
result['top_5_accuracy']))
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in main.py
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
config_ascend = edict({
|
||||||
|
'is_save_on_master': False,
|
||||||
|
|
||||||
|
'batch_size': 128,
|
||||||
|
'epoch_size': 250,
|
||||||
|
'num_classes': 1000,
|
||||||
|
'work_nums': 8,
|
||||||
|
|
||||||
|
'loss_scale': 1024,
|
||||||
|
'smooth_factor': 0.1,
|
||||||
|
'weight_decay': 0.00004,
|
||||||
|
'momentum': 0.9,
|
||||||
|
'amp_level': 'O3',
|
||||||
|
'decay': 0.9,
|
||||||
|
'epsilon': 1.0,
|
||||||
|
|
||||||
|
'keep_checkpoint_max': 10,
|
||||||
|
'save_checkpoint_epochs': 10,
|
||||||
|
|
||||||
|
'lr_init': 0.00004,
|
||||||
|
'lr_end': 0.000004,
|
||||||
|
'lr_max': 0.4,
|
||||||
|
'warmup_epochs': 1,
|
||||||
|
'start_epoch': 1,
|
||||||
|
|
||||||
|
'onnx_filename': 'inceptionv4.onnx',
|
||||||
|
'air_filename': 'inceptionv4.air'
|
||||||
|
})
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Create train or eval dataset."""
|
||||||
|
import os
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as de
|
||||||
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C2
|
||||||
|
from src.config import config_ascend as config
|
||||||
|
|
||||||
|
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||||
|
"""
|
||||||
|
Create a train or eval dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path (str): The path of dataset.
|
||||||
|
do_train (bool): Whether dataset is used for train or eval.
|
||||||
|
repeat_num (int): The repeat times of dataset. Default: 1.
|
||||||
|
batch_size (int): The batch size of dataset. Default: 32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
do_shuffle = bool(do_train)
|
||||||
|
|
||||||
|
if device_num == 1 or not do_train:
|
||||||
|
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle)
|
||||||
|
else:
|
||||||
|
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums,
|
||||||
|
shuffle=do_shuffle, num_shards=device_num, shard_id=device_id)
|
||||||
|
|
||||||
|
image_length = 299
|
||||||
|
if do_train:
|
||||||
|
trans = [
|
||||||
|
C.RandomCropDecodeResize(image_length, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||||
|
C.RandomHorizontalFlip(prob=0.5),
|
||||||
|
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
trans = [
|
||||||
|
C.Decode(),
|
||||||
|
C.Resize(image_length),
|
||||||
|
C.CenterCrop(image_length)
|
||||||
|
]
|
||||||
|
trans += [
|
||||||
|
C.Rescale(1.0 / 255.0, 0.0),
|
||||||
|
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
|
||||||
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
|
||||||
|
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
|
||||||
|
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
|
||||||
|
|
||||||
|
# apply batch operations
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
# apply dataset repeat operation
|
||||||
|
ds = ds.repeat(repeat_num)
|
||||||
|
return ds
|
|
@ -0,0 +1,328 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""InceptionV4"""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.initializer import Initializer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Avginitializer(Initializer):
|
||||||
|
"""
|
||||||
|
Initialize the weight to 1/m*n, (m, n) is the shape of kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize(self, arr):
|
||||||
|
arr[:] = 0
|
||||||
|
for i in range(arr.shape[0]):
|
||||||
|
for j in range(arr.shape[2]):
|
||||||
|
for k in range(arr.shape[3]):
|
||||||
|
arr[i][i][j][k] = 1/(arr.shape[2]*arr.shape[3])
|
||||||
|
|
||||||
|
|
||||||
|
class Avgpool(nn.Cell):
|
||||||
|
"""
|
||||||
|
Average pooling for temporal data.
|
||||||
|
|
||||||
|
Using a custom initializer to turn conv2d into avgpool2d. The weights won't be trained.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, channel, kernel_size, stride=1, pad_mode='same'):
|
||||||
|
super(Avgpool, self).__init__()
|
||||||
|
self.init = Avginitializer()
|
||||||
|
self.conv = nn.Conv2d(channel, channel, kernel_size,
|
||||||
|
stride=stride, pad_mode=pad_mode, weight_init=self.init)
|
||||||
|
self.conv.set_train(False)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d(nn.Cell):
|
||||||
|
"""
|
||||||
|
Set the default configuration for Conv2dBnAct
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='valid', padding=0,
|
||||||
|
has_bias=False, weight_init="XavierUniform", bias_init='zeros'):
|
||||||
|
super(Conv2d, self).__init__()
|
||||||
|
self.conv = nn.Conv2dBnAct(in_channels, out_channels, kernel_size, stride=stride, pad_mode=pad_mode,
|
||||||
|
padding=padding, weight_init=weight_init, bias_init=bias_init, has_bias=has_bias,
|
||||||
|
has_bn=True, activation="relu")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Stem(nn.Cell):
|
||||||
|
"""
|
||||||
|
Inceptionv4 stem
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(Stem, self).__init__()
|
||||||
|
self.conv2d_1a_3x3 = Conv2d(
|
||||||
|
in_channels, 32, 3, stride=2, padding=0, has_bias=False)
|
||||||
|
|
||||||
|
self.conv2d_2a_3x3 = Conv2d(
|
||||||
|
32, 32, 3, stride=1, padding=0, has_bias=False)
|
||||||
|
self.conv2d_2b_3x3 = Conv2d(
|
||||||
|
32, 64, 3, stride=1, pad_mode='pad', padding=1, has_bias=False)
|
||||||
|
|
||||||
|
self.mixed_3a_branch_0 = nn.MaxPool2d(3, stride=2)
|
||||||
|
self.mixed_3a_branch_1 = Conv2d(
|
||||||
|
64, 96, 3, stride=2, padding=0, has_bias=False)
|
||||||
|
|
||||||
|
self.mixed_4a_branch_0 = nn.SequentialCell([
|
||||||
|
Conv2d(160, 64, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(64, 96, 3, stride=1, padding=0, pad_mode='valid', has_bias=False)])
|
||||||
|
|
||||||
|
self.mixed_4a_branch_1 = nn.SequentialCell([
|
||||||
|
Conv2d(160, 64, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(64, 64, (1, 7), pad_mode='same', stride=1, has_bias=False),
|
||||||
|
Conv2d(64, 64, (7, 1), pad_mode='same', stride=1, has_bias=False),
|
||||||
|
Conv2d(64, 96, 3, stride=1, padding=0, pad_mode='valid', has_bias=False)])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.mixed_5a_branch_0 = Conv2d(
|
||||||
|
192, 192, 3, stride=2, padding=0, has_bias=False)
|
||||||
|
self.mixed_5a_branch_1 = nn.MaxPool2d(3, stride=2)
|
||||||
|
self.concat0 = P.Concat(1)
|
||||||
|
self.concat1 = P.Concat(1)
|
||||||
|
self.concat2 = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.conv2d_1a_3x3(x) # 149 x 149 x 32
|
||||||
|
x = self.conv2d_2a_3x3(x) # 147 x 147 x 32
|
||||||
|
x = self.conv2d_2b_3x3(x) # 147 x 147 x 64
|
||||||
|
|
||||||
|
x0 = self.mixed_3a_branch_0(x)
|
||||||
|
x1 = self.mixed_3a_branch_1(x)
|
||||||
|
x = self.concat0((x0, x1)) # 73 x 73 x 160
|
||||||
|
|
||||||
|
x0 = self.mixed_4a_branch_0(x)
|
||||||
|
x1 = self.mixed_4a_branch_1(x)
|
||||||
|
x = self.concat1((x0, x1)) # 71 x 71 x 192
|
||||||
|
|
||||||
|
x0 = self.mixed_5a_branch_0(x)
|
||||||
|
x1 = self.mixed_5a_branch_1(x)
|
||||||
|
x = self.concat2((x0, x1)) # 35 x 35 x 384
|
||||||
|
return x
|
||||||
|
|
||||||
|
class InceptionA(nn.Cell):
|
||||||
|
"""InceptionA"""
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(InceptionA, self).__init__()
|
||||||
|
self.branch_0 = Conv2d(
|
||||||
|
in_channels, 96, 1, stride=1, padding=0, has_bias=False)
|
||||||
|
self.branch_1 = nn.SequentialCell([
|
||||||
|
Conv2d(in_channels, 64, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(64, 96, 3, stride=1, pad_mode='pad', padding=1, has_bias=False)])
|
||||||
|
|
||||||
|
self.branch_2 = nn.SequentialCell([
|
||||||
|
Conv2d(in_channels, 64, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(64, 96, 3, stride=1, pad_mode='pad',
|
||||||
|
padding=1, has_bias=False),
|
||||||
|
Conv2d(96, 96, 3, stride=1, pad_mode='pad', padding=1, has_bias=False)])
|
||||||
|
|
||||||
|
self.branch_3 = nn.SequentialCell([
|
||||||
|
Avgpool(384, kernel_size=3, stride=1, pad_mode='same'),
|
||||||
|
Conv2d(384, 96, 1, stride=1, padding=0, has_bias=False)])
|
||||||
|
|
||||||
|
self.concat = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x0 = self.branch_0(x)
|
||||||
|
x1 = self.branch_1(x)
|
||||||
|
x2 = self.branch_2(x)
|
||||||
|
x3 = self.branch_3(x)
|
||||||
|
x4 = self.concat((x0, x1, x2, x3))
|
||||||
|
return x4
|
||||||
|
|
||||||
|
class InceptionB(nn.Cell):
|
||||||
|
"""InceptionB"""
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(InceptionB, self).__init__()
|
||||||
|
self.branch_0 = Conv2d(in_channels, 384, 1,
|
||||||
|
stride=1, padding=0, has_bias=False)
|
||||||
|
self.branch_1 = nn.SequentialCell([
|
||||||
|
Conv2d(in_channels, 192, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(192, 224, (1, 7), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
Conv2d(224, 256, (7, 1), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
])
|
||||||
|
self.branch_2 = nn.SequentialCell([
|
||||||
|
Conv2d(in_channels, 192, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(192, 192, (7, 1), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
Conv2d(192, 224, (1, 7), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
Conv2d(224, 224, (7, 1), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
Conv2d(224, 256, (1, 7), pad_mode='same', stride=1, has_bias=False)
|
||||||
|
])
|
||||||
|
self.branch_3 = nn.SequentialCell([
|
||||||
|
Avgpool(in_channels, kernel_size=3, stride=1, pad_mode='same'),
|
||||||
|
Conv2d(in_channels, 128, 1, stride=1, padding=0, has_bias=False)
|
||||||
|
])
|
||||||
|
self.concat = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x0 = self.branch_0(x)
|
||||||
|
x1 = self.branch_1(x)
|
||||||
|
x2 = self.branch_2(x)
|
||||||
|
x3 = self.branch_3(x)
|
||||||
|
x4 = self.concat((x0, x1, x2, x3))
|
||||||
|
return x4
|
||||||
|
|
||||||
|
class ReductionA(nn.Cell):
|
||||||
|
"""ReductionA"""
|
||||||
|
def __init__(self, in_channels, k, l, m, n):
|
||||||
|
super(ReductionA, self).__init__()
|
||||||
|
self.branch_0 = Conv2d(in_channels, n, 3, stride=2, padding=0)
|
||||||
|
self.branch_1 = nn.SequentialCell([
|
||||||
|
Conv2d(in_channels, k, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(k, l, 3, stride=1, pad_mode='pad',
|
||||||
|
padding=1, has_bias=False),
|
||||||
|
Conv2d(l, m, 3, stride=2, padding=0, has_bias=False),
|
||||||
|
])
|
||||||
|
self.branch_2 = nn.MaxPool2d(3, stride=2)
|
||||||
|
self.concat = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x0 = self.branch_0(x)
|
||||||
|
x1 = self.branch_1(x)
|
||||||
|
x2 = self.branch_2(x)
|
||||||
|
x3 = self.concat((x0, x1, x2))
|
||||||
|
return x3 # 17 x 17 x 1024
|
||||||
|
|
||||||
|
class ReductionB(nn.Cell):
|
||||||
|
"""ReductionB"""
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(ReductionB, self).__init__()
|
||||||
|
self.branch_0 = nn.SequentialCell([
|
||||||
|
Conv2d(in_channels, 192, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(192, 192, 3, stride=2, padding=0, has_bias=False),
|
||||||
|
])
|
||||||
|
self.branch_1 = nn.SequentialCell([
|
||||||
|
Conv2d(in_channels, 256, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(256, 256, (1, 7), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
Conv2d(256, 320, (7, 1), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
Conv2d(320, 320, 3, stride=2, padding=0, has_bias=False)
|
||||||
|
])
|
||||||
|
self.branch_2 = nn.MaxPool2d(3, stride=2)
|
||||||
|
self.concat = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x0 = self.branch_0(x)
|
||||||
|
x1 = self.branch_1(x)
|
||||||
|
x2 = self.branch_2(x)
|
||||||
|
x3 = self.concat((x0, x1, x2))
|
||||||
|
return x3 # 8 x 8 x 1536
|
||||||
|
|
||||||
|
class InceptionC(nn.Cell):
|
||||||
|
"""InceptionC"""
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(InceptionC, self).__init__()
|
||||||
|
self.branch_0 = Conv2d(in_channels, 256, 1,
|
||||||
|
stride=1, padding=0, has_bias=False)
|
||||||
|
|
||||||
|
self.branch_1 = Conv2d(in_channels, 384, 1,
|
||||||
|
stride=1, padding=0, has_bias=False)
|
||||||
|
self.branch_1_1 = Conv2d(
|
||||||
|
384, 256, (1, 3), pad_mode='same', stride=1, has_bias=False)
|
||||||
|
self.branch_1_2 = Conv2d(
|
||||||
|
384, 256, (3, 1), pad_mode='same', stride=1, has_bias=False)
|
||||||
|
|
||||||
|
self.branch_2 = nn.SequentialCell([
|
||||||
|
Conv2d(in_channels, 384, 1, stride=1, padding=0, has_bias=False),
|
||||||
|
Conv2d(384, 448, (3, 1), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
Conv2d(448, 512, (1, 3), pad_mode='same',
|
||||||
|
stride=1, has_bias=False),
|
||||||
|
])
|
||||||
|
self.branch_2_1 = Conv2d(
|
||||||
|
512, 256, (1, 3), pad_mode='same', stride=1, has_bias=False)
|
||||||
|
self.branch_2_2 = Conv2d(
|
||||||
|
512, 256, (3, 1), pad_mode='same', stride=1, has_bias=False)
|
||||||
|
|
||||||
|
self.branch_3 = nn.SequentialCell([
|
||||||
|
Avgpool(in_channels, kernel_size=3, stride=1, pad_mode='same'),
|
||||||
|
Conv2d(in_channels, 256, 1, stride=1, padding=0, has_bias=False)
|
||||||
|
])
|
||||||
|
self.concat0 = P.Concat(1)
|
||||||
|
self.concat1 = P.Concat(1)
|
||||||
|
self.concat2 = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x0 = self.branch_0(x)
|
||||||
|
x1 = self.branch_1(x)
|
||||||
|
x1_1 = self.branch_1_1(x1)
|
||||||
|
x1_2 = self.branch_1_2(x1)
|
||||||
|
x1 = self.concat0((x1_1, x1_2))
|
||||||
|
x2 = self.branch_2(x)
|
||||||
|
x2_1 = self.branch_2_1(x2)
|
||||||
|
x2_2 = self.branch_2_2(x2)
|
||||||
|
x2 = self.concat1((x2_1, x2_2))
|
||||||
|
x3 = self.branch_3(x)
|
||||||
|
return self.concat2((x0, x1, x2, x3)) # 8 x 8 x 1536
|
||||||
|
|
||||||
|
class Inceptionv4(nn.Cell):
|
||||||
|
"""
|
||||||
|
Inceptionv4 architecture
|
||||||
|
|
||||||
|
Args.
|
||||||
|
is_train : in train mode, turn on the dropout.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels=3, classes=1000, k=192, l=224, m=256, n=384, is_train=True):
|
||||||
|
super(Inceptionv4, self).__init__()
|
||||||
|
blocks = []
|
||||||
|
blocks.append(Stem(in_channels))
|
||||||
|
for _ in range(4):
|
||||||
|
blocks.append(InceptionA(384))
|
||||||
|
blocks.append(ReductionA(384, k, l, m, n))
|
||||||
|
for _ in range(7):
|
||||||
|
blocks.append(InceptionB(1024))
|
||||||
|
blocks.append(ReductionB(1024))
|
||||||
|
for _ in range(3):
|
||||||
|
blocks.append(InceptionC(1536))
|
||||||
|
self.features = nn.SequentialCell(blocks)
|
||||||
|
|
||||||
|
self.avgpool = P.ReduceMean(keep_dims=False)
|
||||||
|
self.softmax = nn.DenseBnAct(
|
||||||
|
1536, classes, weight_init="XavierUniform", has_bias=True, has_bn=True, activation="logsoftmax")
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
self.dropout = nn.Dropout(0.20)
|
||||||
|
else:
|
||||||
|
self.dropout = nn.Dropout(1)
|
||||||
|
self.bn0 = nn.BatchNorm1d(1536, eps=0.001, momentum=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = self.avgpool(x, (2, 3))
|
||||||
|
x = self.bn0(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.softmax(x)
|
||||||
|
return x
|
|
@ -0,0 +1,167 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""train imagenet"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.communication import init, get_rank
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||||
|
from mindspore.train.model import ParallelMode
|
||||||
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
|
from mindspore import Model
|
||||||
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
|
from mindspore.nn import RMSProp
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.common.initializer import XavierUniform, initializer
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.inceptionv4 import Inceptionv4
|
||||||
|
from src.dataset import create_dataset, device_num
|
||||||
|
|
||||||
|
from src.config import config_ascend as config
|
||||||
|
|
||||||
|
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
def generate_cosine_lr(steps_per_epoch, total_epochs,
|
||||||
|
lr_init=config.lr_init,
|
||||||
|
lr_end=config.lr_end,
|
||||||
|
lr_max=config.lr_max,
|
||||||
|
warmup_epochs=config.warmup_epochs):
|
||||||
|
"""
|
||||||
|
Applies cosine decay to generate learning rate array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
steps_per_epoch(int): steps number per epoch
|
||||||
|
total_epochs(int): all epoch in training.
|
||||||
|
lr_init(float): init learning rate.
|
||||||
|
lr_end(float): end learning rate
|
||||||
|
lr_max(float): max learning rate.
|
||||||
|
warmup_steps(int): all steps in warmup epochs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array, learning rate array.
|
||||||
|
"""
|
||||||
|
total_steps = steps_per_epoch * total_epochs
|
||||||
|
warmup_steps = steps_per_epoch * warmup_epochs
|
||||||
|
decay_steps = total_steps - warmup_steps
|
||||||
|
lr_each_step = []
|
||||||
|
for i in range(total_steps):
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||||
|
lr = float(lr_init) + lr_inc * (i + 1)
|
||||||
|
else:
|
||||||
|
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
|
||||||
|
lr = (lr_max - lr_end) * cosine_decay + lr_end
|
||||||
|
lr_each_step.append(lr)
|
||||||
|
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||||
|
current_step = steps_per_epoch * (config.start_epoch - 1)
|
||||||
|
learning_rate = learning_rate[current_step:]
|
||||||
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
def inception_v4_train():
|
||||||
|
"""
|
||||||
|
Train Inceptionv4 in data parallelism
|
||||||
|
"""
|
||||||
|
print('epoch_size: {} batch_size: {} class_num {}'.format(config.epoch_size, config.batch_size, config.num_classes))
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
context.set_context(enable_graph_kernel=False)
|
||||||
|
rank = 0
|
||||||
|
if device_num > 1:
|
||||||
|
init(backend_name='hccl')
|
||||||
|
rank = get_rank()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num,
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True,
|
||||||
|
all_reduce_fusion_config=[200, 400])
|
||||||
|
|
||||||
|
# create dataset
|
||||||
|
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True,
|
||||||
|
repeat_num=1, batch_size=config.batch_size)
|
||||||
|
train_step_size = train_dataset.get_dataset_size()
|
||||||
|
|
||||||
|
# create model
|
||||||
|
net = Inceptionv4(classes=config.num_classes)
|
||||||
|
# loss
|
||||||
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||||
|
# learning rate
|
||||||
|
lr = Tensor(generate_cosine_lr(steps_per_epoch=train_step_size, total_epochs=config.epoch_size))
|
||||||
|
|
||||||
|
decayed_params = []
|
||||||
|
no_decayed_params = []
|
||||||
|
for param in net.trainable_params():
|
||||||
|
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||||
|
decayed_params.append(param)
|
||||||
|
else:
|
||||||
|
no_decayed_params.append(param)
|
||||||
|
for param in net.trainable_params():
|
||||||
|
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||||
|
param.set_data(initializer(XavierUniform(), param.data.shape, param.data.dtype))
|
||||||
|
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||||
|
{'params': no_decayed_params},
|
||||||
|
{'order_params': net.trainable_params()}]
|
||||||
|
|
||||||
|
opt = RMSProp(group_params, lr, decay=config.decay, epsilon=config.epsilon, weight_decay=config.weight_decay,
|
||||||
|
momentum=config.momentum, loss_scale=config.loss_scale)
|
||||||
|
|
||||||
|
if args.device_id == 0:
|
||||||
|
print(lr)
|
||||||
|
print(train_step_size)
|
||||||
|
if args.resume:
|
||||||
|
ckpt = load_checkpoint(args.resume)
|
||||||
|
load_param_into_net(net, ckpt)
|
||||||
|
|
||||||
|
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={
|
||||||
|
'acc', 'top_1_accuracy', 'top_5_accuracy'}, loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
|
||||||
|
|
||||||
|
# define callbacks
|
||||||
|
performance_cb = TimeMonitor(data_size=train_step_size)
|
||||||
|
loss_cb = LossMonitor(per_print_times=train_step_size)
|
||||||
|
ckp_save_step = config.save_checkpoint_epochs * train_step_size
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{rank}",
|
||||||
|
directory='ckpts_rank_' + str(rank), config=config_ck)
|
||||||
|
callbacks = [performance_cb, loss_cb]
|
||||||
|
if device_num > 1 and config.is_save_on_master:
|
||||||
|
if args.device_id == 0:
|
||||||
|
callbacks.append(ckpoint_cb)
|
||||||
|
else:
|
||||||
|
callbacks.append(ckpoint_cb)
|
||||||
|
|
||||||
|
# train model
|
||||||
|
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
'''parse_args'''
|
||||||
|
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
|
||||||
|
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||||
|
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
|
||||||
|
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||||
|
args_opt = arg_parser.parse_args()
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
inception_v4_train()
|
||||||
|
print('Inceptionv4 training success!')
|
|
@ -0,0 +1 @@
|
||||||
|
# recommend
|
Loading…
Reference in New Issue