mobilenetv2 mobilenetv3 readme normalize, delete mobilenetv3 ascend

This commit is contained in:
zhaoting 2020-08-19 17:30:20 +08:00
parent 29808eb128
commit 37f78ec3e7
8 changed files with 172 additions and 121 deletions

View File

@ -1,17 +1,37 @@
# MobileNetV2 Description # Contents
- [MobileNetV2 Description](#mobilenetv2-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [MobileNetV2 Description](#contents)
MobileNetV2 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. MobileNetV2 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019.
[Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for MobileNetV2." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for MobileNetV2." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019.
# Model architecture # [Model architecture](#contents)
The overall network architecture of MobileNetV2 is show below: The overall network architecture of MobileNetV2 is show below:
[Link](https://arxiv.org/pdf/1905.02244) [Link](https://arxiv.org/pdf/1905.02244)
# Dataset # [Dataset](#contents)
Dataset used: [imagenet](http://www.image-net.org/) Dataset used: [imagenet](http://www.image-net.org/)
@ -22,10 +42,14 @@ Dataset used: [imagenet](http://www.image-net.org/)
- Note: Data will be processed in src/dataset.py - Note: Data will be processed in src/dataset.py
# Features # [Features](#contents)
## [Mixed Precision(Ascend)](#contents)
# Environment Requirements The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareAscend/GPU - HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
@ -36,30 +60,33 @@ Dataset used: [imagenet](http://www.image-net.org/)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# Script description # [Script description](#contents)
## Script and sample code ## [Script and sample code](#contents)
```python ```python
├── MobileNetV2 ├── MobileNetV2
├── Readme.md ├── Readme.md # descriptions about MobileNetV2
├── scripts ├── scripts
│ ├──run_train.sh │ ├──run_train.sh # shell script for train
│ ├──run_eval.sh │ ├──run_eval.sh # shell script for evaluation
├── src ├── src
│ ├──config.py │ ├──config.py # parameter configuration
│ ├──dataset.py │ ├──dataset.py # creating dataset
│ ├──luanch.py │ ├──launch.py # start python script
│ ├──lr_generator.py │ ├──lr_generator.py # learning rate config
│ ├──mobilenetV2.py │ ├──mobilenetV2.py # MobileNetV2 architecture
├── train.py ├── train.py # training script
├── eval.py ├── eval.py # evaluation script
``` ```
## Training process ## [Training process](#contents)
### Usage ### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
@ -67,8 +94,13 @@ Dataset used: [imagenet](http://www.image-net.org/)
``` ```
# training example # training example
Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ mobilenet_199.ckpt python:
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ Ascend: python train.py --dataset_path ~/imagenet/train/ --device_targe Ascend
GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU
shell:
Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ mobilenet_199.ckpt
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
``` ```
### Result ### Result
@ -82,10 +114,12 @@ epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:
epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
``` ```
## Eval process ## [Eval process](#contents)
### Usage ### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] - Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH]
- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH]
@ -93,8 +127,13 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
``` ```
# infer example # infer example
Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt python:
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt Ascend: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe Ascend
GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU
shell:
Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
``` ```
> checkpoint can be produced in training process. > checkpoint can be produced in training process.
@ -107,9 +146,9 @@ Inference result will be stored in the example path, you can find result like th
result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt
``` ```
# Model description # [Model description](#contents)
## Performance ## [Performance](#contents)
### Training Performance ### Training Performance
@ -147,5 +186,11 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.
| Total time | | | | | Total time | | | |
| Model for inference | | | | | Model for inference | | | |
# ModelZoo Homepage # [Description of Random Situation](#contents)
[Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -39,7 +39,7 @@ if __name__ == '__main__':
net = None net = None
if args_opt.device_target == "Ascend": if args_opt.device_target == "Ascend":
config = config_ascend config = config_ascend
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)
net = mobilenet_v2(num_classes=config.num_classes, device_target="Ascend") net = mobilenet_v2(num_classes=config.num_classes, device_target="Ascend")

View File

@ -47,7 +47,7 @@ fi
mkdir ../eval mkdir ../eval
cd ../eval || exit cd ../eval || exit
# luanch # launch
python ${BASEPATH}/../eval.py \ python ${BASEPATH}/../eval.py \
--device_target=$1 \ --device_target=$1 \
--dataset_path=$2 \ --dataset_path=$2 \

View File

@ -35,8 +35,8 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
dataset dataset
""" """
if device_target == "Ascend": if device_target == "Ascend":
rank_size = int(os.getenv("RANK_SIZE")) rank_size = int(os.getenv("RANK_SIZE", '1'))
rank_id = int(os.getenv("RANK_ID")) rank_id = int(os.getenv("RANK_ID", '0'))
if rank_size == 1: if rank_size == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else: else:

View File

@ -53,11 +53,10 @@ parser.add_argument('--device_target', type=str, default=None, help='run device_
args_opt = parser.parse_args() args_opt = parser.parse_args()
if args_opt.device_target == "Ascend": if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID', '0'))
rank_id = int(os.getenv('RANK_ID')) rank_id = int(os.getenv('RANK_ID', '0'))
rank_size = int(os.getenv('RANK_SIZE')) rank_size = int(os.getenv('RANK_SIZE', '1'))
run_distribute = rank_size > 1 run_distribute = rank_size > 1
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", device_target="Ascend",
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)

View File

@ -1,17 +1,35 @@
# MobileNetV3 Description # Contents
- [MobileNetV3 Description](#mobilenetv3-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [MobileNetV3 Description](#contents)
MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019.
[Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019.
# Model architecture # [Model architecture](#contents)
The overall network architecture of MobileNetV3 is show below: The overall network architecture of MobileNetV3 is show below:
[Link](https://arxiv.org/pdf/1905.02244) [Link](https://arxiv.org/pdf/1905.02244)
# Dataset # [Dataset](#contents)
Dataset used: [imagenet](http://www.image-net.org/) Dataset used: [imagenet](http://www.image-net.org/)
@ -22,10 +40,7 @@ Dataset used: [imagenet](http://www.image-net.org/)
- Note: Data will be processed in src/dataset.py - Note: Data will be processed in src/dataset.py
# Features # [Environment Requirements](#contents)
# Environment Requirements
- HardwareGPU - HardwareGPU
- Prepare hardware environment with GPU processor. - Prepare hardware environment with GPU processor.
@ -36,37 +51,42 @@ Dataset used: [imagenet](http://www.image-net.org/)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# Script description # [Script description](#contents)
## Script and sample code ## [Script and sample code](#contents)
```python ```python
├── MobilenetV3 ├── MobileNetV3
├── Readme.md ├── Readme.md # descriptions about MobileNetV3
├── scripts ├── scripts
│ ├──run_train.sh │ ├──run_train.sh # shell script for train
│ ├──run_eval.sh │ ├──run_eval.sh # shell script for evaluation
├── src ├── src
│ ├──config.py │ ├──config.py # parameter configuration
│ ├──dataset.py │ ├──dataset.py # creating dataset
│ ├──luanch.py │ ├──launch.py # start python script
│ ├──lr_generator.py │ ├──lr_generator.py # learning rate config
│ ├──mobilenetV2.py │ ├──mobilenetV3.py # MobileNetV3 architecture
├── train.py ├── train.py # training script
├── eval.py ├── eval.py # evaluation script
``` ```
## Training process ## [Training process](#contents)
### Usage ### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
### Launch ### Launch
``` ```
# training example # training example
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ python:
GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU
shell:
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
``` ```
### Result ### Result
@ -80,16 +100,22 @@ epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:
epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
``` ```
## Eval process ## [Eval process](#contents)
### Usage ### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH]
### Launch ### Launch
``` ```
# infer example # infer example
python:
GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU
shell:
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
``` ```
@ -103,46 +129,50 @@ Inference result will be stored in the example path, you can find result like th
result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt
``` ```
# Model description # [Model description](#contents)
## Performance ## [Performance](#contents)
### Training Performance ### Training Performance
| Parameters | MobilenetV3 | | | Parameters | MobilenetV3 |
| -------------------------- | ---------------------------------------------------------- | ------------------------- | | -------------------------- | ------------------------- |
| Model Version | | large | | Model Version | large |
| Resource | Ascend 910, cpu:2.60GHz 56cores, memory:314G | NV SMX2 V100-32G | | Resource | NV SMX2 V100-32G |
| uploaded Date | 05/06/2020 | 05/06/2020 | | uploaded Date | 05/06/2020 |
| MindSpore Version | 0.3.0 | 0.3.0 | | MindSpore Version | 0.3.0 |
| Dataset | ImageNet | ImageNet | | Dataset | ImageNet |
| Training Parameters | src/config.py | src/config.py | | Training Parameters | src/config.py |
| Optimizer | Momentum | Momentum | | Optimizer | Momentum |
| Loss Function | SoftmaxCrossEntropy | SoftmaxCrossEntropy | | Loss Function | SoftmaxCrossEntropy |
| outputs | | | | outputs | |
| Loss | | 1.913 | | Loss | 1.913 |
| Accuracy | | ACC1[77.57%] ACC5[92.51%] | | Accuracy | ACC1[77.57%] ACC5[92.51%] |
| Total time | | | | Total time | |
| Params (M) | | | | Params (M) | |
| Checkpoint for Fine tuning | | | | Checkpoint for Fine tuning | |
| Model for inference | | | | Model for inference | |
#### Inference Performance #### Inference Performance
| Parameters | | | | | Parameters | |
| -------------------------- | ----------------------------- | ------------------------- | -------------------- | | -------------------------- | -------------------- |
| Model Version | V1 | | | | Model Version | |
| Resource | Huawei 910 | NV SMX2 V100-32G | Huawei 310 | | Resource | NV SMX2 V100-32G |
| uploaded Date | 05/06/2020 | 05/22/2020 | | | uploaded Date | 05/22/2020 |
| MindSpore Version | 0.2.0 | 0.2.0 | 0.2.0 | | MindSpore Version | 0.2.0 |
| Dataset | ImageNet, 1.2W | ImageNet, 1.2W | ImageNet, 1.2W | | Dataset | ImageNet, 1.2W |
| batch_size | | 130(8P) | | | batch_size | 130(8P) |
| outputs | | | | | outputs | |
| Accuracy | | ACC1[75.43%] ACC5[92.51%] | | | Accuracy | ACC1[75.43%] ACC5[92.51%] |
| Speed | | | | | Speed | |
| Total time | | | | | Total time | |
| Model for inference | | | | | Model for inference | |
# [Description of Random Situation](#contents)
# ModelZoo Homepage In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
[Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo)
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -15,33 +15,26 @@
""" """
eval. eval.
""" """
import os
import argparse import argparse
from mindspore import context from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype
from src.dataset import create_dataset from src.dataset import create_dataset
from src.config import config_ascend, config_gpu from src.config import config_gpu
from src.mobilenetV3 import mobilenet_v3_large from src.mobilenetV3 import mobilenet_v3_large
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default=None, help='run device_target') parser.add_argument('--device_target', type=str, default="GPU", help='run device_target')
args_opt = parser.parse_args() args_opt = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
config = None config = None
if args_opt.device_target == "Ascend": if args_opt.device_target == "GPU":
config = config_ascend
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
elif args_opt.device_target == "GPU":
config = config_gpu config = config_gpu
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False) device_target="GPU", save_graphs=False)
@ -52,12 +45,6 @@ if __name__ == '__main__':
is_grad=False, sparse=True, reduction='mean') is_grad=False, sparse=True, reduction='mean')
net = mobilenet_v3_large(num_classes=config.num_classes) net = mobilenet_v3_large(num_classes=config.num_classes)
if args_opt.device_target == "Ascend":
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
dataset = create_dataset(dataset_path=args_opt.dataset_path, dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False, do_train=False,
config=config, config=config,

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""train_imagenet.""" """train_imagenet."""
import os
import time import time
import argparse import argparse
import random import random
@ -47,20 +47,10 @@ de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--device_target', type=str, default=None, help='run device_target') parser.add_argument('--device_target', type=str, default="GPU", help='run device_target')
args_opt = parser.parse_args() args_opt = parser.parse_args()
if args_opt.device_target == "Ascend": if args_opt.device_target == "GPU":
device_id = int(os.getenv('DEVICE_ID'))
rank_id = int(os.getenv('RANK_ID'))
rank_size = int(os.getenv('RANK_SIZE'))
run_distribute = rank_size > 1
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id,
save_graphs=False)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", device_target="GPU",
save_graphs=False) save_graphs=False)