forked from mindspore-Ecosystem/mindspore
!22323 gpu for squeezenet
Merge pull request !22323 from 郑彬/squeezenet_gpu
This commit is contained in:
commit
b021a58caa
|
@ -92,6 +92,19 @@ After installing MindSpore via the official website, you can start training and
|
|||
Usage: bash scripts/run_eval.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATA_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
- running on GPU
|
||||
|
||||
```bash
|
||||
# distributed training
|
||||
Usage: bash scripts/run_distribute_train_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
|
||||
|
||||
# standalone training
|
||||
Usage: bash scripts/run_standalone_train.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATA_PATH] [PRETRAINED_CKPT_PATH](optional)
|
||||
|
||||
# run evaluation example
|
||||
Usage: bash scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
- running on CPU
|
||||
|
||||
```bash
|
||||
|
@ -139,46 +152,53 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
```text
|
||||
.
|
||||
└── squeezenet
|
||||
├── README.md
|
||||
├── ascend310_infer # application for 310 inference
|
||||
├── ascend310_infer # application for 310 inference
|
||||
├── scripts
|
||||
├── run_distribute_train.sh # launch ascend distributed training(8 pcs)
|
||||
├── run_standalone_train.sh # launch ascend standalone training(1 pcs)
|
||||
├── run_eval.sh # launch ascend evaluation
|
||||
├── run_infer_310.sh # shell script for 310 infer
|
||||
├── run_distribute_train.sh # launch ascend distributed training(8 pcs)
|
||||
├── run_distribute_train_gpu.sh # launch GPU distributed training(8 pcs)
|
||||
├── run_standalone_train.sh # launch ascend standalone training(1 pcs)
|
||||
├── run_standalone_train_gpu.sh # launch GPU standalone training(1 pcs)
|
||||
├── run_train_cpu.sh # launch CPU training
|
||||
├── run_eval.sh # launch ascend evaluation
|
||||
├── run_eval_gpu.sh # launch GPU evaluation
|
||||
├── run_eval_cpu.sh # launch CPU evaluation
|
||||
├── run_infer_310.sh # shell script for 310 infer
|
||||
├── src
|
||||
├── dataset.py # data preprocessing
|
||||
├── CrossEntropySmooth.py # loss definition for ImageNet dataset
|
||||
├── lr_generator.py # generate learning rate for each step
|
||||
└── squeezenet.py # squeezenet architecture, including squeezenet and squeezenet_residual
|
||||
├── dataset.py # data preprocessing
|
||||
├── CrossEntropySmooth.py # loss definition for ImageNet dataset
|
||||
├── lr_generator.py # generate learning rate for each step
|
||||
└── squeezenet.py # squeezenet architecture, including squeezenet and squeezenet_residual
|
||||
├── model_utils
|
||||
│ ├── device_adapter.py # device adapter
|
||||
│ ├── local_adapter.py # local adapter
|
||||
│ ├── moxing_adapter.py # moxing adapter
|
||||
│ ├── config.py # parameter analysis
|
||||
│ ├── device_adapter.py # device adapter
|
||||
│ ├── local_adapter.py # local adapter
|
||||
│ ├── moxing_adapter.py # moxing adapter
|
||||
│ └── config.py # parameter analysis
|
||||
├── squeezenet_cifar10_config.yaml # parameter configuration
|
||||
├── squeezenet_imagenet_config.yaml # parameter configuration
|
||||
├── squeezenet_residual_cifar10_config.yaml # parameter configuration
|
||||
├── squeezenet_residual_imagenet_config.yaml # parameter configuration
|
||||
├── train.py # train net
|
||||
├── eval.py # eval net
|
||||
└── export.py # export checkpoint files into geir/onnx
|
||||
├── postprocess.py # postprocess script
|
||||
├── preprocess.py # preprocess script
|
||||
├── export.py # export checkpoint files into geir/onnx
|
||||
├── postprocess.py # postprocess script
|
||||
├── preprocess.py # preprocess script
|
||||
├── requirements.txt
|
||||
└── mindspore_hub_conf.py # mindspore hub interface
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
Parameters for both training and evaluation can be set in *.yaml
|
||||
|
||||
- config for SqueezeNet, CIFAR-10 dataset
|
||||
|
||||
```py
|
||||
"class_num": 10, # dataset class num
|
||||
"batch_size": 32, # batch size of input tensor
|
||||
"global_batch_size": 32, # the total batch_size for training and evaluation
|
||||
"loss_scale": 1024, # loss scale
|
||||
"momentum": 0.9, # momentum
|
||||
"weight_decay": 1e-4, # weight decay
|
||||
|
@ -199,7 +219,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
```py
|
||||
"class_num": 1000, # dataset class num
|
||||
"batch_size": 32, # batch size of input tensor
|
||||
"global_batch_size": 256, # the total batch_size for training and evaluation
|
||||
"loss_scale": 1024, # loss scale
|
||||
"momentum": 0.9, # momentum
|
||||
"weight_decay": 7e-5, # weight decay
|
||||
|
@ -222,7 +242,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
```py
|
||||
"class_num": 10, # dataset class num
|
||||
"batch_size": 32, # batch size of input tensor
|
||||
"global_batch_size": 32, # the total batch_size for training and evaluation
|
||||
"loss_scale": 1024, # loss scale
|
||||
"momentum": 0.9, # momentum
|
||||
"weight_decay": 1e-4, # weight decay
|
||||
|
@ -243,7 +263,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
```py
|
||||
"class_num": 1000, # dataset class num
|
||||
"batch_size": 32, # batch size of input tensor
|
||||
"global_batch_size": 256, # The total batch_size for training and evaluation
|
||||
"loss_scale": 1024, # loss scale
|
||||
"momentum": 0.9, # momentum
|
||||
"weight_decay": 7e-5, # weight decay
|
||||
|
@ -262,7 +282,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
"lr_max": 0.01, # maximum learning rate
|
||||
```
|
||||
|
||||
For more configuration details, please refer the script `config.py`.
|
||||
For more configuration details, please refer the file `*.yaml`.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
|
@ -469,137 +489,137 @@ Inference result is saved in current path, you can find result like this in acc.
|
|||
|
||||
#### SqueezeNet on CIFAR-10
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | SqueezeNet |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | CIFAR-10 |
|
||||
| Training Parameters | epoch=120, steps=195, batch_size=32, lr=0.01 |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Loss | 0.0496 |
|
||||
| Speed | 1pc: 16.7 ms/step; 8pcs: 17.0 ms/step |
|
||||
| Total time | 1pc: 55.5 mins; 8pcs: 15.0 mins |
|
||||
| Parameters (M) | 4.8 |
|
||||
| Checkpoint for Fine tuning | 6.4M (.ckpt file) |
|
||||
| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- | --- |
|
||||
| Model Version | SqueezeNet | SqueezeNet |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.4.0 |
|
||||
| Dataset | CIFAR-10 | CIFAR-10 |
|
||||
| Training Parameters | epoch=120, steps=195, batch_size=32, lr=0.01 | 1pc:epoch=120, steps=1562, batch_size=32, lr=0.01; 8pcs:epoch=120, steps=1562, batch_size=4, lr=0.01|
|
||||
| Optimizer | Momentum | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||
| outputs | probability | probability |
|
||||
| Loss | 0.0496 | 1pc:0.0892, 8pcs:0.0130 |
|
||||
| Speed | 1pc: 16.7 ms/step; 8pcs: 17.0 ms/step | 1pc: 28.6 ms/step; 8pcs: 10.8 ms/step |
|
||||
| Total time | 1pc: 55.5 mins; 8pcs: 15.0 mins | 1pc: 90mins; 8pcs: 34mins |
|
||||
| Parameters (M) | 4.8 | 0.74 |
|
||||
| Checkpoint for Fine tuning | 6.4M (.ckpt file) | 6.4M (.ckpt file)|
|
||||
| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) |
|
||||
|
||||
#### SqueezeNet on ImageNet
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | SqueezeNet |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | ImageNet |
|
||||
| Training Parameters | epoch=200, steps=5004, batch_size=32, lr=0.01 |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Loss | 2.9150 |
|
||||
| Speed | 8pcs: 19.9 ms/step |
|
||||
| Total time | 8pcs: 5.2 hours |
|
||||
| Parameters (M) | 4.8 |
|
||||
| Checkpoint for Fine tuning | 13.3M (.ckpt file) |
|
||||
| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- | --- |
|
||||
| Model Version | SqueezeNet | SqueezeNet |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.4.0 |
|
||||
| Dataset | ImageNet | ImageNet |
|
||||
| Training Parameters | epoch=200, steps=5004, batch_size=32, lr=0.01 | epoch=200, steps=5004, batch_size=32, lr=0.01 |
|
||||
| Optimizer | Momentum | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||
| outputs | probability | probability |
|
||||
| Loss | 2.9150 | 3.009 |
|
||||
| Speed | 8pcs: 19.9 ms/step | 8pcs: 43.5ms/step|
|
||||
| Total time | 8pcs: 5.2 hours | 8pcs: 12.1 hours |
|
||||
| Parameters (M) | 4.8 | 1.25 |
|
||||
| Checkpoint for Fine tuning | 13.3M (.ckpt file) | 13.3M (.ckpt file) |
|
||||
| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) |
|
||||
|
||||
#### SqueezeNet_Residual on CIFAR-10
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | SqueezeNet_Residual |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | CIFAR-10 |
|
||||
| Training Parameters | epoch=150, steps=195, batch_size=32, lr=0.01 |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Loss | 0.0641 |
|
||||
| Speed | 1pc: 16.9 ms/step; 8pcs: 17.3 ms/step |
|
||||
| Total time | 1pc: 68.6 mins; 8pcs: 20.9 mins |
|
||||
| Parameters (M) | 4.8 |
|
||||
| Checkpoint for Fine tuning | 6.5M (.ckpt file) |
|
||||
| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- | --- |
|
||||
| Model Version | SqueezeNet_Residual | SqueezeNet_Residual |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.4.0 |
|
||||
| Dataset | CIFAR-10 | CIFAR-10 |
|
||||
| Training Parameters | epoch=150, steps=195, batch_size=32, lr=0.01 | 1pc:epoch=150, steps=1562, batch_size=32, lr=0.01; 8pcs: epoch=150, steps=1562, batch_size=4|
|
||||
| Optimizer | Momentum | Momentum
|
||||
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy
|
||||
| outputs | probability | probability
|
||||
| Loss | 0.0641 | 1pc: 0.0402; 8pcs:0.004 |
|
||||
| Speed | 1pc: 16.9 ms/step; 8pcs: 17.3 ms/step | 1pc: 29.4 ms/step; 8pcs:11.0 ms/step |
|
||||
| Total time | 1pc: 68.6 mins; 8pcs: 20.9 mins | 1pc: 115 mins; 8pcs: 43.5 mins |
|
||||
| Parameters (M) | 4.8 | 0.74 |
|
||||
| Checkpoint for Fine tuning | 6.5M (.ckpt file) | 6.5M (.ckpt file) |
|
||||
| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) |
|
||||
|
||||
#### SqueezeNet_Residual on ImageNet
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | SqueezeNet_Residual |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | ImageNet |
|
||||
| Training Parameters | epoch=300, steps=5004, batch_size=32, lr=0.01 |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Loss | 2.9040 |
|
||||
| Speed | 8pcs: 20.2 ms/step |
|
||||
| Total time | 8pcs: 8.0 hours |
|
||||
| Parameters (M) | 4.8 |
|
||||
| Checkpoint for Fine tuning | 15.3M (.ckpt file) |
|
||||
| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- | --- |
|
||||
| Model Version | SqueezeNet_Residual | SqueezeNet_Residual |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.4.0 |
|
||||
| Dataset | ImageNet | ImageNet |
|
||||
| Training Parameters | epoch=300, steps=5004, batch_size=32, lr=0.01 | epoch=300, steps=5004, batch_size=32, lr=0.01 |
|
||||
| Optimizer | Momentum | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||
| outputs | probability | probability |
|
||||
| Loss | 2.9040 | 2.969 |
|
||||
| Speed | 8pcs: 20.2 ms/step | 8pcs: 44.1 ms/step |
|
||||
| Total time | 8pcs: 8.0 hours | 8pcs: 18.4 hours |
|
||||
| Parameters (M) | 4.8 | 1.25 |
|
||||
| Checkpoint for Fine tuning | 15.3M (.ckpt file) | 15.3M (.ckpt file) |
|
||||
| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
#### SqueezeNet on CIFAR-10
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | SqueezeNet |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | CIFAR-10 |
|
||||
| batch_size | 32 |
|
||||
| outputs | probability |
|
||||
| Accuracy | 1pc: 89.0%; 8pcs: 84.4% |
|
||||
| Parameters | Ascend | GPU |
|
||||
| ------------------- | --------------------------- | --- |
|
||||
| Model Version | SqueezeNet | SqueezeNet |
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.4.0 |
|
||||
| Dataset | CIFAR-10 | CIFAR-10 |
|
||||
| batch_size | 32 | 1pc:32; 8pcs:4 |
|
||||
| outputs | probability | probability |
|
||||
| Accuracy | 1pc: 89.0%; 8pcs: 84.4% | 1pc: 89.0%; 8pcs: 88.8%|
|
||||
|
||||
#### SqueezeNet on ImageNet
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | SqueezeNet |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | ImageNet |
|
||||
| batch_size | 32 |
|
||||
| outputs | probability |
|
||||
| Accuracy | 8pcs: 58.5%(TOP1), 81.1%(TOP5) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| ------------------- | --------------------------- | --- |
|
||||
| Model Version | SqueezeNet | SqueezeNet |
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.4.0 |
|
||||
| Dataset | ImageNet | ImageNet |
|
||||
| batch_size | 32 | 32 |
|
||||
| outputs | probability | probability |
|
||||
| Accuracy | 8pcs: 58.5%(TOP1), 81.1%(TOP5) | 8pcs: 58.5%(TOP1), 80.7%(TOP5) |
|
||||
|
||||
#### SqueezeNet_Residual on CIFAR-10
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | SqueezeNet_Residual |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | CIFAR-10 |
|
||||
| batch_size | 32 |
|
||||
| outputs | probability |
|
||||
| Accuracy | 1pc: 90.8%; 8pcs: 87.4% |
|
||||
| Parameters | Ascend | GPU |
|
||||
| ------------------- | --------------------------- | --- |
|
||||
| Model Version | SqueezeNet_Residual | SqueezeNet_Residual |
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.4.0 |
|
||||
| Dataset | CIFAR-10 | CIFAR-10 |
|
||||
| batch_size | 32 | 1pc:32; 8pcs:4 |
|
||||
| outputs | probability | probability |
|
||||
| Accuracy | 1pc: 90.8%; 8pcs: 87.4% | 1pc: 90.7%; 8pcs: 90.5% |
|
||||
|
||||
#### SqueezeNet_Residual on ImageNet
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | SqueezeNet_Residual |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | ImageNet |
|
||||
| batch_size | 32 |
|
||||
| outputs | probability |
|
||||
| Accuracy | 8pcs: 60.9%(TOP1), 82.6%(TOP5) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| ------------------- | --------------------------- | --- |
|
||||
| Model Version | SqueezeNet_Residual | SqueezeNet_Residual |
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) | 8/24/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.4.0 |
|
||||
| Dataset | ImageNet | ImageNet |
|
||||
| batch_size | 32 | 32 |
|
||||
| outputs | probability | probability |
|
||||
| Accuracy | 8pcs: 60.9%(TOP1), 82.6%(TOP5) | 8pcs: 60.2%(TOP1), 82.3%(TOP5)|
|
||||
|
||||
### 310 Inference Performance
|
||||
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ] && [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: bash scripts/run_distribute_train_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
|
||||
then
|
||||
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $3)
|
||||
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
PATH2=$(get_real_path $4)
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 4 ] && [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export RANK_SIZE=8
|
||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
||||
else
|
||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TRAIN_OUTPUT=$BASE_PATH/train_parallel_$1_$2
|
||||
if [ -d $TRAIN_OUTPUT ]; then
|
||||
rm -rf $TRAIN_OUTPUT
|
||||
fi
|
||||
mkdir $TRAIN_OUTPUT
|
||||
cp ./train.py $TRAIN_OUTPUT
|
||||
cp -r ./src $TRAIN_OUTPUT
|
||||
cp -r ./model_utils $TRAIN_OUTPUT
|
||||
cp $CONFIG_FILE $TRAIN_OUTPUT
|
||||
cd $TRAIN_OUTPUT || exit
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py --net_name=$1 --dataset=$2 --run_distribute=True --output_path='./output'\
|
||||
--device_target="GPU" --data_path=$PATH1 \
|
||||
--config_path=${CONFIG_FILE##*/} &> log &
|
||||
fi
|
||||
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py --net_name=$1 --dataset=$2 --run_distribute=True --output_path='./output'\
|
||||
--device_target="GPU" --data_path=$PATH1 --pre_trained=$PATH2 \
|
||||
--config_path=${CONFIG_FILE##*/} &> log &
|
||||
fi
|
||||
cd ..
|
|
@ -0,0 +1,99 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 5 ]
|
||||
then
|
||||
echo "Usage: bash scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
|
||||
then
|
||||
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $4)
|
||||
PATH2=$(get_real_path $5)
|
||||
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
expr $3 + 0 &>/dev/null
|
||||
if [ $? != 0 ]; then
|
||||
echo "DEVICE_ID=$3 is not an integer!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export CUDA_VISIBLE_DEVICES=$3
|
||||
|
||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
||||
|
||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
||||
else
|
||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
EVAL_OUTPUT=$BASE_PATH/eval_$3_$1_$2
|
||||
if [ -d $EVAL_OUTPUT ];
|
||||
then
|
||||
rm -rf $EVAL_OUTPUT
|
||||
fi
|
||||
mkdir $EVAL_OUTPUT
|
||||
cp ./eval.py $EVAL_OUTPUT
|
||||
cp -r ./src $EVAL_OUTPUT
|
||||
cp -r ./model_utils $EVAL_OUTPUT
|
||||
cp $CONFIG_FILE $EVAL_OUTPUT
|
||||
cd $EVAL_OUTPUT || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $3"
|
||||
python eval.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --checkpoint_file_path=$PATH2 --device_target="GPU" \
|
||||
--config_path=${CONFIG_FILE##*/} --output_path='./output' &> log &
|
||||
cd ..
|
|
@ -0,0 +1,109 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ] && [ $# != 5 ]
|
||||
then
|
||||
echo "Usage: bash scripts/run_standalone_train.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATA_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
|
||||
then
|
||||
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $4)
|
||||
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
PATH2=$(get_real_path $5)
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 5 ] && [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
expr $3 + 0 &>/dev/null
|
||||
if [ $? = 2 ]; then
|
||||
echo "DEVICE_ID=$3 is not an integer!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export CUDA_VISIBLE_DEVICES=$3
|
||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
||||
else
|
||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TRAIN_OUTPUT=$BASE_PATH/train_standalone$3_$1_$2
|
||||
if [ -d $TRAIN_OUTPUT ];
|
||||
then
|
||||
rm -rf $TRAIN_OUTPUT
|
||||
fi
|
||||
mkdir $TRAIN_OUTPUT
|
||||
cp ./train.py $TRAIN_OUTPUT
|
||||
cp -r ./src $TRAIN_OUTPUT
|
||||
cp -r ./model_utils $TRAIN_OUTPUT
|
||||
cp $CONFIG_FILE $TRAIN_OUTPUT
|
||||
cd $TRAIN_OUTPUT || exit
|
||||
echo "start training for device $3"
|
||||
env > env.log
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
python train.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --config_path=${CONFIG_FILE##*/} \
|
||||
--output_path='./output' --device_target='GPU' &> log &
|
||||
fi
|
||||
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
python train.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --pre_trained=$PATH2 \
|
||||
--config_path=${CONFIG_FILE##*/} --output_path='./output' --device_target='GPU' &> log &
|
||||
fi
|
||||
cd ..
|
|
@ -21,7 +21,7 @@ checkpoint_file_path: "suqeezenet_cifar10-120_195.ckpt"
|
|||
net_name: "suqeezenet"
|
||||
dataset : "cifar10"
|
||||
class_num: 10
|
||||
batch_size: 32
|
||||
global_batch_size: 32
|
||||
loss_scale: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001
|
||||
|
@ -55,7 +55,7 @@ load_path: "The location of checkpoint for obs"
|
|||
device_target: "Target device type, available: [Ascend, GPU, CPU]"
|
||||
enable_profiling: "Whether enable profiling while training, default: False"
|
||||
num_classes: "Class for dataset"
|
||||
batch_size: "Batch size for training and evaluation"
|
||||
global_batch_size: "The total batch_size for training and evaluation"
|
||||
epoch_size: "Total training epochs."
|
||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
|
|
|
@ -21,7 +21,7 @@ checkpoint_file_path: "suqeezenet_imagenet-200_5004.ckpt"
|
|||
net_name: "suqeezenet"
|
||||
dataset : "imagenet"
|
||||
class_num: 1000
|
||||
batch_size: 32
|
||||
global_batch_size: 256
|
||||
loss_scale: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 0.00007
|
||||
|
@ -57,7 +57,7 @@ load_path: 'The location of checkpoint for obs'
|
|||
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
num_classes: 'Class for dataset'
|
||||
batch_size: "Batch size for training and evaluation"
|
||||
global_batch_size: "The total batch_size for training and evaluation"
|
||||
epoch_size: "Total training epochs."
|
||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
|
|
|
@ -21,7 +21,7 @@ checkpoint_file_path: "suqeezenet_residual_cifar10-150_195.ckpt"
|
|||
net_name: "suqeezenet_residual"
|
||||
dataset : "cifar10"
|
||||
class_num: 10
|
||||
batch_size: 32
|
||||
global_batch_size: 32
|
||||
loss_scale: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001
|
||||
|
@ -55,7 +55,7 @@ load_path: "The location of checkpoint for obs"
|
|||
device_target: "Target device type, available: [Ascend, GPU, CPU]"
|
||||
enable_profiling: "Whether enable profiling while training, default: False"
|
||||
num_classes: "Class for dataset"
|
||||
batch_size: "Batch size for training and evaluation"
|
||||
global_batch_size: "The total batch_size for training and evaluation."
|
||||
epoch_size: "Total training epochs."
|
||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
|
|
|
@ -21,7 +21,7 @@ checkpoint_file_path: "suqeezenet_residual_imagenet-300_5004.ckpt"
|
|||
net_name: "suqeezenet_residual"
|
||||
dataset : "imagenet"
|
||||
class_num: 1000
|
||||
batch_size: 32
|
||||
global_batch_size: 256
|
||||
loss_scale: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 0.00007
|
||||
|
@ -57,7 +57,7 @@ load_path: "The location of checkpoint for obs"
|
|||
device_target: "Target device type, available: [Ascend, GPU, CPU]"
|
||||
enable_profiling: "Whether enable profiling while training, default: False"
|
||||
num_classes: "Class for dataset"
|
||||
batch_size: "Batch size for training and evaluation"
|
||||
global_batch_size: "The total batch_size for training and evaluation"
|
||||
epoch_size: "Total training epochs."
|
||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train squeezenet."""
|
||||
import os
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
|
@ -27,6 +26,7 @@ from mindspore.communication.management import init, get_rank, get_group_size
|
|||
from mindspore.common import set_seed
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id
|
||||
from src.lr_generator import get_lr
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
|
||||
|
@ -54,33 +54,37 @@ def train_net():
|
|||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=target)
|
||||
device_num = 1
|
||||
if config.run_distribute:
|
||||
if target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_id = get_device_id()
|
||||
device_num = config.device_num
|
||||
context.set_context(device_id=device_id,
|
||||
enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(
|
||||
device_num=config.device_num,
|
||||
device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
# GPU target
|
||||
else:
|
||||
print("Squeezenet training on GPU performs badly now, and it is still in research..."
|
||||
"See model_zoo/research/cv/squeezenet to get up-to-date details.")
|
||||
init()
|
||||
device_num = get_group_size()
|
||||
context.set_auto_parallel_context(
|
||||
device_num=get_group_size(),
|
||||
device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(
|
||||
get_rank()) + "/"
|
||||
|
||||
# obtain the actual batch_size
|
||||
if not hasattr(config, "global_batch_size"):
|
||||
raise AttributeError("'config' object has no attribute 'global_batch_size', please check the yaml file.")
|
||||
batch_size = max(config.global_batch_size // device_num, 1)
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=config.data_path,
|
||||
do_train=True,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size,
|
||||
batch_size=batch_size,
|
||||
target=target)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
|
@ -132,10 +136,6 @@ def train_net():
|
|||
amp_level="O2",
|
||||
keep_batchnorm_fp32=False)
|
||||
else:
|
||||
if target == "GPU":
|
||||
# GPU target
|
||||
print("Squeezenet training on GPU performs badly now, and it is still in research..."
|
||||
"See model_zoo/research/cv/squeezenet to get up-to-date details.")
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
lr,
|
||||
config.momentum,
|
||||
|
|
Loading…
Reference in New Issue