!20310 Added GPU support for Model Zoo SE-Net

Merge pull request !20310 from alashkari/se-net
This commit is contained in:
i-robot 2021-07-16 04:39:12 +00:00 committed by Gitee
commit b97a9ba5a7
8 changed files with 220 additions and 40 deletions

View File

@ -58,8 +58,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
@ -88,6 +88,22 @@ export DEVICE_ID=0
python eval.py --net=se-resnet50 --dataset=imagenet2012 --checkpoint_path=[CHECKPOINT_PATH] --dataset_path=[DATASET_PATH]
```
- Running on GPU
```bash
# distributed training
Usage:
sh run_distribute_train_gpu.sh se-resnet50 imagenet2012 [DATASET_PATH]
# standalone training
Usage:
sh run_standalone_train_gpu.sh se-resnet50 imagenet2012 [DATASET_PATH]
# run evaluation example
Usage:
sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
@ -99,8 +115,11 @@ python eval.py --net=se-resnet50 --dataset=imagenet2012 --checkpoint_path=[CHECK
├── ascend310_infer # application for 310 inference
├── scripts
├── run_distribute_train.sh # launch ascend distributed training(8 pcs)
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
├── run_eval.sh # launch ascend evaluation
├── run_eval_gpu.sh # launch gpu evaluation
├── run_standalone_train.sh # launch ascend standalone training(1 pcs)
├── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
└─ run_infer_310.sh # shell script for 310inference on ascend
├── src
├── config.py # parameter configuration
@ -159,6 +178,18 @@ export DEVICE_ID=0
bash run_standalone_train.sh se-resnet50 imagenet2012 /data/imagenet/train/
```
#### Running on GPU
```bash
# distributed training
Usage:
sh run_distribute_train_gpu.sh se-resnet50 imagenet2012 [DATASET_PATH]
# standalone training
Usage:
sh run_standalone_train_gpu.sh se-resnet50 imagenet2012 [DATASET_PATH]
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link [hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
@ -195,6 +226,12 @@ export DEVICE_ID=0
bash run_eval.sh /imagenet/val/ /path/to/resnet-90_625.ckpt
```
#### Running on GPU
```bash
sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
### Result
- Evaluating SE-ResNet50 with ImageNet2012 dataset
@ -243,38 +280,38 @@ result: {'top_5_accuracy': 93.86%, 'top_1_accuracy': 77.80%}
#### SE-ResNet50 on ImageNet2012
| Parameters | Ascend 910
| -------------------------- | ------------------------------------------------------------------------ |
| Model Version | SE-ResNet50 |
| Resource | CentOs 8.2, Ascend 910CPU 2.60GHz 192coresMemory 755G |
| uploaded Date | 03/19/2021 (month/day/year) |
| MindSpore Version | 0.7.0-alpha |
| Dataset | ImageNet2012 |
| Training Parameters | epoch=90, steps per epoch=5004, batch_size = 256 |
| Optimizer | Momentum |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 1.5931969 |
| Speed | # ms/step8pcs |
| Total time | # mins |
| Parameters (M) | 285M |
| Checkpoint for Fine tuning | # M (.ckpt file) |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/SE-Net> |
| Parameters | Ascend | GPU
| -------------------------- | ---------------------------------------------------------- | ---------------------------------------------------------- |
| Model Version | SE-ResNet50 | SE-ResNet50 |
| Resource | CentOs 8.2, Ascend 910CPU 2.60GHz 192coresMemory 755G | V100-PCIE 32G |
| uploaded Date | 03/19/2021 (month/day/year) | 07/14/2021 (month/day/year) |
| MindSpore Version | 0.7.0-alpha | 1.3.0 |
| Dataset | ImageNet2012 | ImageNet2012 |
| Training Parameters | epoch=90, steps per epoch=5004, batch_size = 256 | epoch=90, steps per epoch=5004, batch_size = 256 |
| Optimizer | Momentum | Momentum |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| Loss | 1.5931969 | 1.6664593 |
| Speed | # ms/step8pcs | 8pcs: 1016.9 ms/step |
| Total time | # mins | 8pcs: 15.9 hours |
| Parameters (M) | 285M | 285M |
| Checkpoint for Fine tuning | # M (.ckpt file) | # M (.ckpt file) |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/SE-Net> |<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/SE-Net> |
### Inference Performance
#### SE-ResNet50 on ImageNet2012
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | SE-ResNet50 |
| Resource | Ascend 910 |
| Uploaded Date | 03/19/2021 (month/day/year) |
| MindSpore Version | 0.7.0-alpha |
| Dataset | ImageNet2012 |
| batch_size | 256 |
| Accuracy | 77.74% |
| Model for inference | # (.air file) |
| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | --------------------------- |
| Model Version | SE-ResNet50 | SE-ResNet50 |
| Resource | Ascend 910 | V100-PCIE 32G |
| Uploaded Date | 03/19/2021 (month/day/year) | 07/14/2021 (month/day/year) |
| MindSpore Version | 0.7.0-alpha | 1.3.0 |
| Dataset | ImageNet2012 | ImageNet2012 |
| batch_size | 256 | 256 |
| Accuracy | 77.74% | 77.66% |
| Model for inference | # (.air file) | ------ |
### 310Inference Performance

View File

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#Usage: sh run_distribute_train.sh [se-resnet50] [imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export NET=$1
export DATASET=$2
export DATASET_PATH=$3
rm -rf ./train_parallel
mkdir ./train_parallel
cp ../*.py ./train_parallel
cp *.sh ./train_parallel
cp -r ../src ./train_parallel
cd ./train_parallel || exit
echo "start distributed training with $DEVICE_NUM GPUs."
mpirun --allow-run-as-root -n $DEVICE_NUM \
python train.py \
--device_target="GPU" \
--net=$NET \
--dataset=$DATASET \
--run_distribute=True \
--device_num=$DEVICE_NUM \
--dataset_path=$DATASET_PATH > log 2>&1 &

View File

@ -0,0 +1,45 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
export DATA_PATH=$1
export CKPT_PATH=$2
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --device_target="GPU" --dataset_path=$DATA_PATH --checkpoint_path=$CKPT_PATH > log 2>&1 &
cd ..

View File

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_standalone_train.sh [NET] [DATASET_NAME] [DATASET_PATH]"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
export NET=$1
export DATASET=$2
export DATASET_PATH=$3
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for GPU device $DEVICE_ID"
env > env.log
python train.py --device_target="GPU" --net=$NET --dataset=$DATASET --dataset_path=$DATASET_PATH > log 2>&1 &
cd ..

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(LossBase):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -27,7 +27,7 @@ config2 = ed({
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10,
"keep_checkpoint_max": 90,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "linear",

View File

@ -116,7 +116,7 @@ class Se_ResidualBlock(nn.Cell):
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])#use_se=self.use_se
self.add = P.TensorAdd()
self.add = P.Add()
self.se = SELayer(out_channel, reduction)
def construct(self, x):

View File

@ -24,7 +24,7 @@ from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed
from mindspore.parallel import set_algo_parameters
import mindspore.nn as nn
@ -67,12 +67,19 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
if args_opt.net == "se-resnet50":
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
else:
context.set_auto_parallel_context(all_reduce_fusion_config=[180, 313])
init()
elif target == "GPU":
init('nccl')
context.reset_auto_parallel_context()
rank = get_rank()
device_num = get_group_size()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
if args_opt.net == "se-resnet50":
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
else:
context.set_auto_parallel_context(all_reduce_fusion_config=[180, 313])
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
@ -120,7 +127,7 @@ if __name__ == '__main__':
{'order_params': net.trainable_params()}]
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
# define loss, model
if target == "Ascend":
if target in ["Ascend", "GPU"]:
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
@ -137,6 +144,8 @@ if __name__ == '__main__':
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
if target == "GPU" and args_opt.run_distribute:
ckpt_save_dir = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]