add resnet18

This commit is contained in:
jiangzhenguang 2021-01-25 14:36:18 +08:00
parent 590e889cba
commit 9a07d38d55
16 changed files with 407 additions and 100 deletions

View File

@ -25,7 +25,7 @@
ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network.
These are examples of training ResNet50/ResNet101/SE-ResNet50 with CIFAR-10/ImageNet2012 dataset in MindSpore.ResNet50 and ResNet101 can reference [paper 1](https://arxiv.org/pdf/1512.03385.pdf) below, and SE-ResNet50 is a variant of ResNet50 which reference [paper 2](https://arxiv.org/abs/1709.01507) and [paper 3](https://arxiv.org/abs/1812.01187) below, Training SE-ResNet50 for just 24 epochs using 8 Ascend 910, we can reach top-1 accuracy of 75.9%.(Training ResNet101 with dataset CIFAR-10 and SE-ResNet50 with CIFAR-10 is not supported yet.)
These are examples of training ResNet18/ResNet50/ResNet101/SE-ResNet50 with CIFAR-10/ImageNet2012 dataset in MindSpore.ResNet50 and ResNet101 can reference [paper 1](https://arxiv.org/pdf/1512.03385.pdf) below, and SE-ResNet50 is a variant of ResNet50 which reference [paper 2](https://arxiv.org/abs/1709.01507) and [paper 3](https://arxiv.org/abs/1812.01187) below, Training SE-ResNet50 for just 24 epochs using 8 Ascend 910, we can reach top-1 accuracy of 75.9%.(Training ResNet101 with dataset CIFAR-10 and SE-ResNet50 with CIFAR-10 is not supported yet.)
## Paper
@ -97,30 +97,30 @@ After installing MindSpore via the official website, you can start training and
```bash
# distributed training
Usage: sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
Usage: bash run_distribute_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# standalone training
Usage: sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH]
Usage: bash run_standalone_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH]
[PRETRAINED_CKPT_PATH](optional)
# run evaluation example
Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
Usage: bash run_eval.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
- Running on GPU
```bash
# distributed training example
sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
bash run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# standalone training example
sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
bash run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# infer example
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
# gpu benchmark example
sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional) [DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional)
bash run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional) [DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional)
```
- Running on CPU
@ -170,7 +170,7 @@ python eval.py --net=[resnet50|resnet101] --dataset=[cifar10|imagenet2012] --dat
Parameters for both training and evaluation can be set in config.py.
- Config for ResNet50, CIFAR-10 dataset
- Config for ResNet18 and ResNet50, CIFAR-10 dataset
```bash
"class_num": 10, # dataset class num
@ -191,7 +191,7 @@ Parameters for both training and evaluation can be set in config.py.
"lr_max": 0.1, # maximum learning rate
```
- Config for ResNet50, ImageNet2012 dataset
- Config for ResNet18 and ResNet50, ImageNet2012 dataset
```bash
"class_num": 1001, # dataset class number
@ -267,14 +267,14 @@ Parameters for both training and evaluation can be set in config.py.
```bash
# distributed training
Usage: sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
Usage: bash run_distribute_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# standalone training
Usage: sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH]
Usage: bash run_standalone_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH]
[PRETRAINED_CKPT_PATH](optional)
# run evaluation example
Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
Usage: bash run_eval.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
@ -288,19 +288,19 @@ Training result will be stored in the example path, whose folder name begins wit
```bash
# distributed training example
sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
bash run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# standalone training example
sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
bash run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# infer example
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
# gpu benchmark training example
sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional) [DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional)
bash run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional) [DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional)
# gpu benchmark infer example
sh run_eval_gpu_resnet_benchmark.sh [DATASET_PATH] [CKPT_PATH] [BATCH_SIZE](optional) [DTYPE](optional)
bash run_eval_gpu_resnet_benchmark.sh [DATASET_PATH] [CKPT_PATH] [BATCH_SIZE](optional) [DTYPE](optional)
```
For distributed training, a hostfile configuration needs to be created in advance.
@ -312,17 +312,41 @@ Please follow the instructions in the link [GPU-Multi-Host](https://www.mindspor
- Parameter server training Ascend example
```bash
sh run_parameter_server_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
```
- Parameter server training GPU example
```bash
sh run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
```
### Result
- Training ResNet18 with CIFAR-10 dataset
```bash
# distribute training result(8 pcs)
epoch: 1 step: 195, loss is 1.5783054
epoch: 2 step: 195, loss is 1.0682616
epoch: 3 step: 195, loss is 0.8836588
epoch: 4 step: 195, loss is 0.36090446
epoch: 5 step: 195, loss is 0.80853784
...
```
- Training ResNet18 with ImageNet2012 dataset
```bash
# distribute training result(8 pcs)
epoch: 1 step: 625, loss is 4.757934
epoch: 2 step: 625, loss is 4.0891967
epoch: 3 step: 625, loss is 3.9131956
epoch: 4 step: 625, loss is 3.5302577
epoch: 5 step: 625, loss is 3.597817
...
```
- Training ResNet50 with CIFAR-10 dataset
```bash
@ -391,12 +415,12 @@ epoch: [0/1] step: [100/5004], loss is 6.814013Epoch time: 3437.154 ms, fps: 148
```bash
# evaluation
Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
Usage: bash run_eval.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
```bash
# evaluation example
sh run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt
bash run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt
```
> checkpoint can be produced in training process.
@ -404,13 +428,25 @@ sh run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train
#### Running on GPU
```bash
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
### Result
Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the following in log.
- Evaluating ResNet18 with CIFAR-10 dataset
```bash
result: {'acc': 0.9402043269230769} ckpt=~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt
```
- Evaluating ResNet18 with ImageNet2012 dataset
```bash
result: {'acc': 0.7053685897435897} ckpt=train_parallel0/resnet-90_5004.ckpt
```
- Evaluating ResNet50 with CIFAR-10 dataset
```bash
@ -442,6 +478,46 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499
### Evaluation Performance
#### ResNet18 on CIFAR-10
| Parameters | Ascend 910 |
| -------------------------- | -------------------------------------- |
| Model Version | ResNet18 |
| Resource | Ascend 910CPU 2.60GHz 192coresMemory 755G |
| uploaded Date | 02/25/2021 (month/day/year) |
| MindSpore Version | 1.1.1-alpha |
| Dataset | CIFAR-10 |
| Training Parameters | epoch=90, steps per epoch=195, batch_size = 32 |
| Optimizer | Momentum |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 0.0002519517 |
| Speed | 10 ms/step8pcs |
| Total time | 3 mins |
| Parameters (M) | 11.2 |
| Checkpoint for Fine tuning | 86M (.ckpt file) |
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet) |
#### ResNet18 on ImageNet2012
| Parameters | Ascend 910 |
| -------------------------- | -------------------------------------- |
| Model Version | ResNet18 |
| Resource | Ascend 910CPU 2.60GHz 192coresMemory 755G |
| uploaded Date | 02/25/2021 (month/day/year) |
| MindSpore Version | 1.1.1-alpha |
| Dataset | ImageNet2012 |
| Training Parameters | epoch=90, steps per epoch=626, batch_size = 256 |
| Optimizer | Momentum |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 2.15702 |
| Speed | 140ms/step8pcs |
| Total time | 131 mins |
| Parameters (M) | 11.7 |
| Checkpoint for Fine tuning | 90M (.ckpt file) |
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet) |
#### ResNet50 on CIFAR-10
| Parameters | Ascend 910 | GPU |
@ -524,6 +600,34 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499
### Inference Performance
#### ResNet18 on CIFAR-10
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | ResNet18 |
| Resource | Ascend 910 |
| Uploaded Date | 02/25/2021 (month/day/year) |
| MindSpore Version | 1.1.1-alpha |
| Dataset | CIFAR-10 |
| batch_size | 32 |
| outputs | probability |
| Accuracy | 94.02% |
| Model for inference | 43M (.air file) |
#### ResNet18 on ImageNet2012
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | ResNet18 |
| Resource | Ascend 910 |
| Uploaded Date | 02/25/2021 (month/day/year) |
| MindSpore Version | 1.1.1-alpha |
| Dataset | ImageNet2012 |
| batch_size | 256 |
| outputs | probability |
| Accuracy | 70.53% |
| Model for inference | 45M (.air file) |
#### ResNet50 on CIFAR-10
| Parameters | Ascend | GPU |

View File

@ -28,7 +28,7 @@
残差神经网络ResNet由微软研究院何凯明等五位华人提出通过ResNet单元成功训练152层神经网络赢得了ILSVRC2015冠军。ResNet前五项的误差率为3.57%参数量低于VGGNet因此效果非常显著。传统的卷积网络或全连接网络或多或少存在信息丢失的问题还会造成梯度消失或爆炸导致深度网络训练失败ResNet则在一定程度上解决了这个问题。通过将输入信息传递给输出确保信息完整性。整个网络只需要学习输入和输出的差异部分简化了学习目标和难度。ResNet的结构大幅提高了神经网络训练的速度并且大大提高了模型的准确率。正因如此ResNet十分受欢迎甚至可以直接用于ConceptNet网络。
如下为MindSpore使用CIFAR-10/ImageNet2012数据集对ResNet50/ResNet101/SE-ResNet50进行训练的示例。ResNet50和ResNet101可参考[论文1](https://arxiv.org/pdf/1512.03385.pdf)SE-ResNet50是ResNet50的一个变体可参考[论文2](https://arxiv.org/abs/1709.01507)和[论文3](https://arxiv.org/abs/1812.01187)。使用8卡Ascend 910训练SE-ResNet50仅需24个周期TOP1准确率就达到了75.9%暂不支持用CIFAR-10数据集训练ResNet101以及用用CIFAR-10数据集训练SE-ResNet50
如下为MindSpore使用CIFAR-10/ImageNet2012数据集对ResNet18/ResNet50/ResNet101/SE-ResNet50进行训练的示例。ResNet50和ResNet101可参考[论文1](https://arxiv.org/pdf/1512.03385.pdf)SE-ResNet50是ResNet50的一个变体可参考[论文2](https://arxiv.org/abs/1709.01507)和[论文3](https://arxiv.org/abs/1812.01187)。使用8卡Ascend 910训练SE-ResNet50仅需24个周期TOP1准确率就达到了75.9%暂不支持用CIFAR-10数据集训练ResNet101以及用用CIFAR-10数据集训练SE-ResNet50
## 论文
@ -100,27 +100,27 @@ ResNet的总体网络架构如下
```text
# 分布式训练
用法sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
用法:bash run_distribute_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练
用法sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH]
用法:bash run_standalone_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH]
[PRETRAINED_CKPT_PATH](可选)
# 运行评估示例
用法sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
用法:bash run_eval.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
- GPU处理器环境运行
```text
# 分布式训练示例
sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
bash run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练示例
sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
bash run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 推理示例
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
# 脚本说明
@ -154,7 +154,7 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C
在config.py中可以同时配置训练参数和评估参数。
- 配置ResNet50和CIFAR-10数据集。
- 配置ResNet18、ResNet50和CIFAR-10数据集。
```text
"class_num":10, # 数据集类数
@ -175,7 +175,7 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C
"lr_max":0.1, # 最大学习率
```
- 配置ResNet50和ImageNet2012数据集。
- 配置ResNet18、ResNet50和ImageNet2012数据集。
```text
"class_num":1001, # 数据集类数
@ -251,14 +251,14 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C
```text
# 分布式训练
用法sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
用法:bash run_distribute_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练
用法sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH]
用法:bash run_standalone_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH]
[PRETRAINED_CKPT_PATH](可选)
# 运行评估示例
用法sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
用法:bash run_eval.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
@ -272,13 +272,13 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C
```text
# 分布式训练示例
sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
bash run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练示例
sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
bash run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 推理示例
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
#### 运行参数服务器模式训练
@ -286,17 +286,41 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C
- Ascend参数服务器训练示例
```text
sh run_parameter_server_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
```
- GPU参数服务器训练示例
```text
sh run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
```
### 结果
- 使用CIFAR-10数据集训练ResNet18
```text
# 分布式训练结果8P
epoch: 1 step: 195, loss is 1.5783054
epoch: 2 step: 195, loss is 1.0682616
epoch: 3 step: 195, loss is 0.8836588
epoch: 4 step: 195, loss is 0.36090446
epoch: 5 step: 195, loss is 0.80853784
...
```
- 使用ImageNet2012数据集训练ResNet18
```text
# 分布式训练结果8P
epoch: 1 step: 625, loss is 4.757934
epoch: 2 step: 625, loss is 4.0891967
epoch: 3 step: 625, loss is 3.9131956
epoch: 4 step: 625, loss is 3.5302577
epoch: 5 step: 625, loss is 3.597817
...
```
- 使用CIFAR-10数据集训练ResNet50
```text
@ -358,12 +382,12 @@ epoch:5 step:5004, loss is 3.3501816
```bash
# 评估
Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
Usage: bash run_eval.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
```bash
# 评估示例
sh run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt
bash run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt
```
> 训练过程中可以生成检查点。
@ -371,13 +395,25 @@ sh run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train
#### GPU处理器环境运行
```bash
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
### 结果
评估结果保存在示例路径中文件夹名为“eval”。您可在此路径下的日志找到如下结果
- 使用CIFAR-10数据集评估ResNet18
```bash
result: {'acc': 0.9402043269230769} ckpt=~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt
```
- 使用ImageNet2012数据集评估ResNet18
```bash
result: {'acc': 0.7053685897435897} ckpt=train_parallel0/resnet-90_5004.ckpt
```
- 使用CIFAR-10数据集评估ResNet50
```text
@ -409,6 +445,46 @@ result:{'top_5_accuracy':0.9342589628681178, 'top_1_accuracy':0.768065781049936}
### 评估性能
#### CIFAR-10上的ResNet18
| 参数 | Ascend 910 |
| -------------------------- | -------------------------------------- |
| 模型版本 | ResNet18 |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 | 2021-02-25 |
| MindSpore版本 | 1.1.1-alpha |
| 数据集 | CIFAR-10 |
| 训练参数 | epoch=90, steps per epoch=195, batch_size = 32 |
| 优化器 | Momentum |
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.0002519517 |
| 速度 | 10毫秒/步8卡 |
| 总时长 | 3分钟 |
| 参数(M) | 11.2 |
| 微调检查点 | 86.ckpt文件 |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet) |
#### ImageNet2012上的ResNet18
| 参数 | Ascend 910 |
| -------------------------- | -------------------------------------- |
| 模型版本 | ResNet18 |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 | 2020-04-01 ; |
| MindSpore版本 | 1.1.1-alpha |
| 数据集 | ImageNet2012 |
| 训练参数 | epoch=90, steps per epoch=626, batch_size = 256 |
| 优化器 | Momentum |
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 2.15702 |
| 速度 | 140毫秒/步8卡 |
| 总时长 | 131分钟 |
| 参数(M) | 11.7 |
| 微调检查点| 90M.ckpt文件 |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet) |
#### CIFAR-10上的ResNet50
| 参数 | Ascend 910 | GPU |

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -23,7 +23,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.CrossEntropySmooth import CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet18, '
'resnet50 or resnet101')
parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
@ -34,14 +35,18 @@ args_opt = parser.parse_args()
set_seed(1)
if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet
if args_opt.net in ("resnet18", "resnet50"):
if args_opt.net == "resnet18":
from src.resnet import resnet18 as resnet
if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet
if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
elif args_opt.net == "resnet101":
from src.resnet import resnet101 as resnet
from src.config import config3 as config

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,7 +22,9 @@ import numpy as np
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
parser = argparse.ArgumentParser(description='resnet export')
parser.add_argument('--network_dataset', type=str, default='resnet50_cifar10', choices=['resnet50_cifar10',
parser.add_argument('--network_dataset', type=str, default='resnet50_cifar10', choices=['resnet18_cifar10',
'resnet18_imagenet2012',
'resnet50_cifar10',
'resnet50_imagenet2012',
'resnet101_imagenet2012',
"se-resnet50_imagenet2012"],
@ -44,8 +46,13 @@ if args.device_target == "Ascend":
if __name__ == '__main__':
if args.network_dataset == 'resnet50_cifar10':
if args.network_dataset == 'resnet18_cifar10':
from src.config import config1 as config
from src.resnet import resnet18 as resnet
elif args.network_dataset == 'resnet18_imagenet2012':
from src.config import config2 as config
from src.resnet import resnet18 as resnet
elif args.network_dataset == 'resnet50_cifar10':
from src.config import config1 as config
from src.resnet import resnet50 as resnet
elif args.network_dataset == 'resnet50_imagenet2012':

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,14 +15,14 @@
# ============================================================================
if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
then
echo "Usage: bash run_distribute_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
if [ $1 != "resnet18" ] && [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50"
echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50"
exit 1
fi

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: bash run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,11 +16,11 @@
if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "Usage: bash run_eval.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
if [ $1 != "resnet18" ] && [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101 nor se-resnet50"
exit 1

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
if [ $# != 4 ]
then
echo "Usage: sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "Usage: bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh run_eval_gpu_resnet_benchmark.sh [DATASET_PATH] [CKPT_PATH] [BATCH_SIZE](optional) \
echo "Usage: bash run_eval_gpu_resnet_benchmark.sh [DATASET_PATH] [CKPT_PATH] [BATCH_SIZE](optional) \
[DTYPE](optional)"
echo "Example: sh run_eval_gpu_resnet_benchmark.sh /path/imagenet/train /path/ckpt 256 FP16"
exit 1

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional)\
echo "Usage: bash run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional)\
[DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional)"
echo "Example: sh run_gpu_resnet_benchmark.sh /path/imagenet/train 256 FP16 8 true /path/ckpt"
exit 1

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: bash run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: bash run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,11 +16,11 @@
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: bash run_standalone_train.sh [resnet18|resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
if [ $1 != "resnet18" ] && [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50"
exit 1

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: bash run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,6 +22,7 @@ from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from scipy.stats import truncnorm
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
@ -32,6 +33,7 @@ def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
@ -104,37 +106,49 @@ def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1, use_se=False):
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
if res_base:
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=1, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1, use_se=False):
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
if res_base:
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1, use_se=False):
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if res_base:
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _bn(channel):
def _bn(channel, res_base=False):
if res_base:
return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.1,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
@ -146,7 +160,7 @@ def _bn_last(channel):
def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel)
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
@ -196,8 +210,8 @@ class ResidualBlock(nn.Cell):
self.bn3 = _bn_last(out_channel)
if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se)
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = P.Mul()
self.relu = nn.ReLU()
@ -220,7 +234,6 @@ class ResidualBlock(nn.Cell):
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
self.add = P.Add()
def construct(self, x):
identity = x
@ -249,7 +262,69 @@ class ResidualBlock(nn.Cell):
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = out + identity
out = self.relu(out)
return out
class ResidualBlockBase(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlockBase(3, 256, stride=2)
"""
def __init__(self,
in_channel,
out_channel,
stride=1,
res_base=True,
use_se=False,
se_block=False):
super(ResidualBlockBase, self).__init__()
self.res_base = res_base
self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base)
self.bn1d = _bn(out_channel)
self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base)
self.bn2d = _bn(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=use_se, res_base=self.res_base),
_bn(out_channel, res_base)])
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1d(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2d(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = out + identity
out = self.relu(out)
return out
@ -287,12 +362,14 @@ class ResNet(nn.Cell):
out_channels,
strides,
num_classes,
use_se=False):
use_se=False,
res_base=False):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.use_se = use_se
self.res_base = res_base
self.se_block = False
if self.use_se:
self.se_block = True
@ -304,10 +381,16 @@ class ResNet(nn.Cell):
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base)
self.bn1 = _bn(64, self.res_base)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
if self.res_base:
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
@ -385,6 +468,8 @@ class ResNet(nn.Cell):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.res_base:
x = self.pad(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
@ -399,6 +484,28 @@ class ResNet(nn.Cell):
return out
def resnet18(class_num=10):
"""
Get ResNet18 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet18 neural network.
Examples:
>>> net = resnet18(10)
"""
return ResNet(ResidualBlockBase,
[2, 2, 2, 2],
[64, 64, 128, 256],
[64, 128, 256, 512],
[1, 2, 2, 2],
class_num,
res_base=True)
def resnet50(class_num=10):
"""
Get ResNet50 neural network.
@ -419,6 +526,7 @@ def resnet50(class_num=10):
[1, 2, 2, 2],
class_num)
def se_resnet50(class_num=1001):
"""
Get SE-ResNet50 neural network.
@ -440,6 +548,7 @@ def se_resnet50(class_num=1001):
class_num,
use_se=True)
def resnet101(class_num=1001):
"""
Get ResNet101 neural network.

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -36,7 +36,7 @@ from src.CrossEntropySmooth import CrossEntropySmooth
from src.config import cfg
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, resnet18, resnet50 or resnet101')
parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
@ -50,14 +50,18 @@ args_opt = parser.parse_args()
set_seed(1)
if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet
if args_opt.net in ("resnet18", "resnet50"):
if args_opt.net == "resnet18":
from src.resnet import resnet18 as resnet
if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet
if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
elif args_opt.net == "resnet101":
from src.resnet import resnet101 as resnet
from src.config import config3 as config
@ -94,6 +98,8 @@ if __name__ == '__main__':
set_algo_parameters(elementwise_op_strategy_follow=True)
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
elif args_opt.net == "resnet18":
context.set_auto_parallel_context(all_reduce_fusion_config=[40, 61])
else:
context.set_auto_parallel_context(all_reduce_fusion_config=[180, 313])
init()
@ -136,7 +142,7 @@ if __name__ == '__main__':
from src.lr_generator import get_thor_lr
lr = get_thor_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
else:
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
if args_opt.net in ("resnet18", "resnet50", "se-resnet50"):
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)