!12995 MindSpore社区网络模型征集活动——Inceptionv4
From: @yyyzzzhao Reviewed-by: Signed-off-by:
This commit is contained in:
commit
6aacbc7221
|
@ -1,4 +1,4 @@
|
|||
# InceptionV4 for Ascend
|
||||
# InceptionV4 for Ascend/GPU
|
||||
|
||||
- [InceptionV4 Description](#InceptionV4-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
|
@ -12,7 +12,7 @@
|
|||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
|
@ -50,8 +50,9 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Hardware(Ascend/GPU)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- or prepare GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
|
@ -67,6 +68,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
└─Inception-v4
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p)
|
||||
├─run_eval_gpu.sh # launch evaluating with gpu platform
|
||||
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
|
||||
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
|
||||
└─run_eval_ascend.sh # launch evaluating with ascend platform
|
||||
|
@ -125,6 +128,13 @@ sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
|
|||
>
|
||||
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
|
||||
|
||||
- GPU:
|
||||
|
||||
```bash
|
||||
# distribute training example(8p)
|
||||
sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
||||
```
|
||||
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
|
@ -135,11 +145,16 @@ sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
|
|||
sh scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE DATA_PATH DATA_DIR
|
||||
# standalone training
|
||||
sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
|
||||
GPU:
|
||||
# distribute training example(8p)
|
||||
sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log.txt` like following.
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log.txt` like followings.
|
||||
|
||||
- Ascend
|
||||
|
||||
```python
|
||||
epoch: 1 step: 1251, loss is 5.4833196
|
||||
|
@ -150,6 +165,17 @@ epoch: 3 step: 1251, loss is 3.6242008
|
|||
Epoch time: 288507.506, per step time: 230.622
|
||||
```
|
||||
|
||||
- GPU
|
||||
|
||||
```python
|
||||
epoch: 1 step: 1251, loss is 6.49775
|
||||
Epoch time: 1487493.604, per step time: 1189.044
|
||||
epoch: 2 step: 1251, loss is 5.6884665
|
||||
Epoch time: 1421838.433, per step time: 1136.561
|
||||
epoch: 3 step: 1251, loss is 5.5168786
|
||||
Epoch time: 1423009.501, per step time: 1137.498
|
||||
```
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Usage
|
||||
|
@ -162,6 +188,12 @@ You can start training using python or shell scripts. The usage of shell scripts
|
|||
sh scripts/run_eval_ascend.sh DEVICE_ID DATA_DIR CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
- GPU
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval_gpu.sh DATA_DIR CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
|
@ -169,57 +201,67 @@ You can start training using python or shell scripts. The usage of shell scripts
|
|||
shell:
|
||||
Ascend:
|
||||
sh scripts/run_eval_ascend.sh DEVICE_ID DATA_DIR CHECKPOINT_PATH
|
||||
GPU:
|
||||
sh scripts/run_eval_gpu.sh DATA_DIR CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
### Result
|
||||
|
||||
Evaluation result will be stored in the example path, you can find result like the following in `eval.log`.
|
||||
Evaluation result will be stored in the example path, you can find result like the followings in `eval.log`.
|
||||
|
||||
- Ascend
|
||||
|
||||
```python
|
||||
metric: {'Loss': 0.9849, 'Top1-Acc':0.7985, 'Top5-Acc':0.9460}
|
||||
```
|
||||
|
||||
- GPU(8p)
|
||||
|
||||
```python
|
||||
metric: {'Loss': 0.8144, 'Top1-Acc': 0.8009, 'Top5-Acc': 0.9457}
|
||||
```
|
||||
|
||||
# [Model description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------ |
|
||||
| Model Version | InceptionV4 |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||
| uploaded Date | 11/04/2020 |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | 1200k images |
|
||||
| Batch_size | 128 |
|
||||
| Training Parameters | src/config.py |
|
||||
| Optimizer | RMSProp |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits |
|
||||
| Outputs | probability |
|
||||
| Loss | 0.98486 |
|
||||
| Accuracy (8p) | ACC1[79.85%] ACC5[94.60%] |
|
||||
| Total time (8p) | 20h |
|
||||
| Params (M) | 153M |
|
||||
| Checkpoint for Fine tuning | 2135M |
|
||||
| Scripts | [inceptionv4 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | --------------------------------------------- | -------------------------------- |
|
||||
| Model Version | InceptionV4 | InceptionV4 |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | NV SMX2 V100-32G |
|
||||
| uploaded Date | 11/04/2020 | 03/05/2021 |
|
||||
| MindSpore Version | 1.0.0 | 1.0.0 |
|
||||
| Dataset | 1200k images | 1200K images |
|
||||
| Batch_size | 128 | 128 |
|
||||
| Training Parameters | src/config.py (Ascend) | src/config.py (GPU) |
|
||||
| Optimizer | RMSProp | RMSProp |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits | SoftmaxCrossEntropyWithLogits |
|
||||
| Outputs | probability | probability |
|
||||
| Loss | 0.98486 | 0.8144 |
|
||||
| Accuracy (8p) | ACC1[79.85%] ACC5[94.60%] | ACC1[80.09%] ACC5[94.57%] |
|
||||
| Total time (8p) | 20h | 95h |
|
||||
| Params (M) | 153M | 153M |
|
||||
| Checkpoint for Fine tuning | 2135M | 489M |
|
||||
| Scripts | [inceptionv4 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4) | [inceptionv4 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4) |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | InceptionV4 |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||
| Uploaded Date | 11/04/2020 |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | 50k images |
|
||||
| Batch_size | 128 |
|
||||
| Outputs | probability |
|
||||
| Accuracy | ACC1[79.85%] ACC5[94.60%] |
|
||||
| Total time | 2mins |
|
||||
| Model for inference | 2135M (.ckpt file) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| ------------------- | --------------------------------------------- | ---------------------------------- |
|
||||
| Model Version | InceptionV4 | InceptionV4 |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | NV SMX2 V100-32G |
|
||||
| Uploaded Date | 11/04/2020 | 03/05/2021 |
|
||||
| MindSpore Version | 1.0.0 | 1.0.0 |
|
||||
| Dataset | 50k images | 50K images |
|
||||
| Batch_size | 128 | 128 |
|
||||
| Outputs | probability | probability |
|
||||
| Accuracy | ACC1[79.85%] ACC5[94.60%] | ACC1[80.09%] ACC5[94.57%] |
|
||||
| Total time | 2mins | 2mins |
|
||||
| Model for inference | 2135M (.ckpt file) | 489M (.ckpt file) |
|
||||
|
||||
#### Training performance results
|
||||
|
||||
|
@ -229,7 +271,11 @@ metric: {'Loss': 0.9849, 'Top1-Acc':0.7985, 'Top5-Acc':0.9460}
|
|||
|
||||
| **Ascend** | train performance |
|
||||
| :--------: | :---------------: |
|
||||
| 8p | 4430 img/s |
|
||||
| 8p | 4430 img/s |
|
||||
|
||||
| **GPU** | train performance |
|
||||
| :--------: | :---------------: |
|
||||
| 8p | 906 img/s |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
|
@ -237,4 +283,4 @@ In dataset.py, we set the seed inside “create_dataset" function. We also use r
|
|||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
|||
|
||||
from src.dataset import create_dataset
|
||||
from src.inceptionv4 import Inceptionv4
|
||||
from src.config import config_ascend as config
|
||||
from src.config import config
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
|
@ -39,7 +39,7 @@ if __name__ == '__main__':
|
|||
args = parse_args()
|
||||
|
||||
if args.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
|
|
|
@ -20,7 +20,7 @@ import mindspore as ms
|
|||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.config import config_ascend as config
|
||||
from src.config import config
|
||||
from src.inceptionv4 import Inceptionv4
|
||||
|
||||
parser = argparse.ArgumentParser(description='inceptionv4 export')
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -rf device
|
||||
mkdir device
|
||||
cp ./*.py ./device
|
||||
cp -r ./src ./device
|
||||
cd ./device
|
||||
|
||||
DATA_DIR=$1
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=8
|
||||
|
||||
echo "start training"
|
||||
|
||||
mpirun -n $RANK_SIZE --allow-run-as-root python train.py --dataset_path=$DATA_DIR --platform='GPU' > train.log 2>&1 &
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -rf evaluation
|
||||
mkdir evaluation
|
||||
cp ./*.py ./evaluation
|
||||
cp -r ./src ./evaluation
|
||||
cd ./evaluation
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
DATA_DIR=$1
|
||||
CKPT_DIR=$2
|
||||
|
||||
echo "start evaluation"
|
||||
|
||||
python eval.py --dataset_path=$DATA_DIR --checkpoint_path=$CKPT_DIR --platform='GPU' > eval.log 2>&1 &
|
|
@ -26,4 +26,4 @@ env > env.log
|
|||
python -u ../train.py \
|
||||
--device_id=$1 \
|
||||
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
||||
cd ../
|
||||
|
|
|
@ -17,7 +17,7 @@ network config setting, will be used in main.py
|
|||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
config_ascend = edict({
|
||||
config = edict({
|
||||
'is_save_on_master': False,
|
||||
|
||||
'batch_size': 128,
|
||||
|
|
|
@ -18,14 +18,14 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from src.config import config_ascend as config
|
||||
from src.config import config
|
||||
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shard_id=0):
|
||||
"""
|
||||
Create a train or eval dataset.
|
||||
|
||||
|
@ -45,7 +45,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle)
|
||||
else:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums,
|
||||
shuffle=do_shuffle, num_shards=device_num, shard_id=device_id)
|
||||
shuffle=do_shuffle, num_shards=device_num, shard_id=shard_id)
|
||||
|
||||
image_length = 299
|
||||
if do_train:
|
||||
|
|
|
@ -286,7 +286,6 @@ class Inceptionv4(nn.Cell):
|
|||
self.avgpool = P.ReduceMean(keep_dims=False)
|
||||
self.softmax = nn.DenseBnAct(
|
||||
1536, classes, weight_init="XavierUniform", has_bias=True, has_bn=True, activation="logsoftmax")
|
||||
|
||||
if is_train:
|
||||
self.dropout = nn.Dropout(0.20)
|
||||
else:
|
||||
|
|
|
@ -34,7 +34,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
from src.inceptionv4 import Inceptionv4
|
||||
from src.dataset import create_dataset, device_num
|
||||
|
||||
from src.config import config_ascend as config
|
||||
from src.config import config
|
||||
|
||||
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
||||
set_seed(1)
|
||||
|
@ -82,12 +82,20 @@ def inception_v4_train():
|
|||
"""
|
||||
print('epoch_size: {} batch_size: {} class_num {}'.format(config.epoch_size, config.batch_size, config.num_classes))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(enable_graph_kernel=False)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
if args.platform == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(enable_graph_kernel=False)
|
||||
|
||||
rank = 0
|
||||
if device_num > 1:
|
||||
init(backend_name='hccl')
|
||||
if args.platform == "Ascend":
|
||||
init(backend_name='hccl')
|
||||
elif args.platform == "GPU":
|
||||
init()
|
||||
else:
|
||||
raise ValueError("Unsupported device target.")
|
||||
|
||||
rank = get_rank()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
|
@ -96,7 +104,7 @@ def inception_v4_train():
|
|||
|
||||
# create dataset
|
||||
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True,
|
||||
repeat_num=1, batch_size=config.batch_size)
|
||||
repeat_num=1, batch_size=config.batch_size, shard_id=rank)
|
||||
train_step_size = train_dataset.get_dataset_size()
|
||||
|
||||
# create model
|
||||
|
@ -131,8 +139,16 @@ def inception_v4_train():
|
|||
load_param_into_net(net, ckpt)
|
||||
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={
|
||||
'acc', 'top_1_accuracy', 'top_5_accuracy'}, loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
|
||||
|
||||
|
||||
if args.platform == "Ascend":
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
|
||||
loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
|
||||
elif args.platform == "GPU":
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
|
||||
loss_scale_manager=loss_scale_manager, amp_level='O0')
|
||||
else:
|
||||
raise ValueError("Unsupported device target.")
|
||||
|
||||
# define callbacks
|
||||
performance_cb = TimeMonitor(data_size=train_step_size)
|
||||
|
@ -156,6 +172,8 @@ def parse_args():
|
|||
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
|
||||
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
|
||||
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU"),
|
||||
help='Platform, support Ascend, GPU.')
|
||||
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
args_opt = arg_parser.parse_args()
|
||||
return args_opt
|
||||
|
|
Loading…
Reference in New Issue