From 1c44d8500fb3ceb15c336c1053fd05b438b53737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=BD=AC?= Date: Thu, 12 Aug 2021 15:44:18 +0800 Subject: [PATCH] gpu for squeezenet --- model_zoo/official/cv/squeezenet/README.md | 282 ++++++++++-------- .../scripts/run_distribute_train_gpu.sh | 107 +++++++ .../cv/squeezenet/scripts/run_eval_gpu.sh | 99 ++++++ .../scripts/run_standalone_train_gpu.sh | 109 +++++++ .../squeezenet/squeezenet_cifar10_config.yaml | 4 +- .../squeezenet_imagenet_config.yaml | 4 +- .../squeezenet_residual_cifar10_config.yaml | 4 +- .../squeezenet_residual_imagenet_config.yaml | 4 +- model_zoo/official/cv/squeezenet/train.py | 24 +- 9 files changed, 486 insertions(+), 151 deletions(-) create mode 100644 model_zoo/official/cv/squeezenet/scripts/run_distribute_train_gpu.sh create mode 100644 model_zoo/official/cv/squeezenet/scripts/run_eval_gpu.sh create mode 100644 model_zoo/official/cv/squeezenet/scripts/run_standalone_train_gpu.sh diff --git a/model_zoo/official/cv/squeezenet/README.md b/model_zoo/official/cv/squeezenet/README.md index abdb58694df..0bfb9caaf08 100644 --- a/model_zoo/official/cv/squeezenet/README.md +++ b/model_zoo/official/cv/squeezenet/README.md @@ -92,6 +92,19 @@ After installing MindSpore via the official website, you can start training and Usage: bash scripts/run_eval.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATA_PATH] [CHECKPOINT_PATH] ``` +- running on GPU + + ```bash + # distributed training + Usage: bash scripts/run_distribute_train_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) + + # standalone training + Usage: bash scripts/run_standalone_train.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATA_PATH] [PRETRAINED_CKPT_PATH](optional) + + # run evaluation example + Usage: bash scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH] + ``` + - running on CPU ```bash @@ -139,46 +152,53 @@ After installing MindSpore via the official website, you can start training and ## [Script and Sample Code](#contents) -```shell +```text . └── squeezenet ├── README.md - ├── ascend310_infer # application for 310 inference + ├── ascend310_infer # application for 310 inference ├── scripts - ├── run_distribute_train.sh # launch ascend distributed training(8 pcs) - ├── run_standalone_train.sh # launch ascend standalone training(1 pcs) - ├── run_eval.sh # launch ascend evaluation - ├── run_infer_310.sh # shell script for 310 infer + ├── run_distribute_train.sh # launch ascend distributed training(8 pcs) + ├── run_distribute_train_gpu.sh # launch GPU distributed training(8 pcs) + ├── run_standalone_train.sh # launch ascend standalone training(1 pcs) + ├── run_standalone_train_gpu.sh # launch GPU standalone training(1 pcs) + ├── run_train_cpu.sh # launch CPU training + ├── run_eval.sh # launch ascend evaluation + ├── run_eval_gpu.sh # launch GPU evaluation + ├── run_eval_cpu.sh # launch CPU evaluation + ├── run_infer_310.sh # shell script for 310 infer ├── src - ├── dataset.py # data preprocessing - ├── CrossEntropySmooth.py # loss definition for ImageNet dataset - ├── lr_generator.py # generate learning rate for each step - └── squeezenet.py # squeezenet architecture, including squeezenet and squeezenet_residual + ├── dataset.py # data preprocessing + ├── CrossEntropySmooth.py # loss definition for ImageNet dataset + ├── lr_generator.py # generate learning rate for each step + └── squeezenet.py # squeezenet architecture, including squeezenet and squeezenet_residual ├── model_utils - │ ├── device_adapter.py # device adapter - │ ├── local_adapter.py # local adapter - │ ├── moxing_adapter.py # moxing adapter - │ ├── config.py # parameter analysis + │ ├── device_adapter.py # device adapter + │ ├── local_adapter.py # local adapter + │ ├── moxing_adapter.py # moxing adapter + │ └── config.py # parameter analysis ├── squeezenet_cifar10_config.yaml # parameter configuration ├── squeezenet_imagenet_config.yaml # parameter configuration ├── squeezenet_residual_cifar10_config.yaml # parameter configuration ├── squeezenet_residual_imagenet_config.yaml # parameter configuration ├── train.py # train net ├── eval.py # eval net - └── export.py # export checkpoint files into geir/onnx - ├── postprocess.py # postprocess script - ├── preprocess.py # preprocess script + ├── export.py # export checkpoint files into geir/onnx + ├── postprocess.py # postprocess script + ├── preprocess.py # preprocess script + ├── requirements.txt + └── mindspore_hub_conf.py # mindspore hub interface ``` ## [Script Parameters](#contents) -Parameters for both training and evaluation can be set in config.py +Parameters for both training and evaluation can be set in *.yaml - config for SqueezeNet, CIFAR-10 dataset ```py "class_num": 10, # dataset class num - "batch_size": 32, # batch size of input tensor + "global_batch_size": 32, # the total batch_size for training and evaluation "loss_scale": 1024, # loss scale "momentum": 0.9, # momentum "weight_decay": 1e-4, # weight decay @@ -199,7 +219,7 @@ Parameters for both training and evaluation can be set in config.py ```py "class_num": 1000, # dataset class num - "batch_size": 32, # batch size of input tensor + "global_batch_size": 256, # the total batch_size for training and evaluation "loss_scale": 1024, # loss scale "momentum": 0.9, # momentum "weight_decay": 7e-5, # weight decay @@ -222,7 +242,7 @@ Parameters for both training and evaluation can be set in config.py ```py "class_num": 10, # dataset class num - "batch_size": 32, # batch size of input tensor + "global_batch_size": 32, # the total batch_size for training and evaluation "loss_scale": 1024, # loss scale "momentum": 0.9, # momentum "weight_decay": 1e-4, # weight decay @@ -243,7 +263,7 @@ Parameters for both training and evaluation can be set in config.py ```py "class_num": 1000, # dataset class num - "batch_size": 32, # batch size of input tensor + "global_batch_size": 256, # The total batch_size for training and evaluation "loss_scale": 1024, # loss scale "momentum": 0.9, # momentum "weight_decay": 7e-5, # weight decay @@ -262,7 +282,7 @@ Parameters for both training and evaluation can be set in config.py "lr_max": 0.01, # maximum learning rate ``` -For more configuration details, please refer the script `config.py`. +For more configuration details, please refer the file `*.yaml`. ## [Training Process](#contents) @@ -469,137 +489,137 @@ Inference result is saved in current path, you can find result like this in acc. #### SqueezeNet on CIFAR-10 -| Parameters | Ascend | -| -------------------------- | ----------------------------------------------------------- | -| Model Version | SqueezeNet | -| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | -| uploaded Date | 11/06/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | -| Dataset | CIFAR-10 | -| Training Parameters | epoch=120, steps=195, batch_size=32, lr=0.01 | -| Optimizer | Momentum | -| Loss Function | Softmax Cross Entropy | -| outputs | probability | -| Loss | 0.0496 | -| Speed | 1pc: 16.7 ms/step; 8pcs: 17.0 ms/step | -| Total time | 1pc: 55.5 mins; 8pcs: 15.0 mins | -| Parameters (M) | 4.8 | -| Checkpoint for Fine tuning | 6.4M (.ckpt file) | -| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | +| Parameters | Ascend | GPU | +| -------------------------- | ----------------------------------------------------------- | --- | +| Model Version | SqueezeNet | SqueezeNet | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G | +| uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.4.0 | +| Dataset | CIFAR-10 | CIFAR-10 | +| Training Parameters | epoch=120, steps=195, batch_size=32, lr=0.01 | 1pc:epoch=120, steps=1562, batch_size=32, lr=0.01; 8pcs:epoch=120, steps=1562, batch_size=4, lr=0.01| +| Optimizer | Momentum | Momentum | +| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy | +| outputs | probability | probability | +| Loss | 0.0496 | 1pc:0.0892, 8pcs:0.0130 | +| Speed | 1pc: 16.7 ms/step; 8pcs: 17.0 ms/step | 1pc: 28.6 ms/step; 8pcs: 10.8 ms/step | +| Total time | 1pc: 55.5 mins; 8pcs: 15.0 mins | 1pc: 90mins; 8pcs: 34mins | +| Parameters (M) | 4.8 | 0.74 | +| Checkpoint for Fine tuning | 6.4M (.ckpt file) | 6.4M (.ckpt file)| +| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | #### SqueezeNet on ImageNet -| Parameters | Ascend | -| -------------------------- | ----------------------------------------------------------- | -| Model Version | SqueezeNet | -| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | -| uploaded Date | 11/06/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | -| Dataset | ImageNet | -| Training Parameters | epoch=200, steps=5004, batch_size=32, lr=0.01 | -| Optimizer | Momentum | -| Loss Function | Softmax Cross Entropy | -| outputs | probability | -| Loss | 2.9150 | -| Speed | 8pcs: 19.9 ms/step | -| Total time | 8pcs: 5.2 hours | -| Parameters (M) | 4.8 | -| Checkpoint for Fine tuning | 13.3M (.ckpt file) | -| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | +| Parameters | Ascend | GPU | +| -------------------------- | ----------------------------------------------------------- | --- | +| Model Version | SqueezeNet | SqueezeNet | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G | +| uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.4.0 | +| Dataset | ImageNet | ImageNet | +| Training Parameters | epoch=200, steps=5004, batch_size=32, lr=0.01 | epoch=200, steps=5004, batch_size=32, lr=0.01 | +| Optimizer | Momentum | Momentum | +| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy | +| outputs | probability | probability | +| Loss | 2.9150 | 3.009 | +| Speed | 8pcs: 19.9 ms/step | 8pcs: 43.5ms/step| +| Total time | 8pcs: 5.2 hours | 8pcs: 12.1 hours | +| Parameters (M) | 4.8 | 1.25 | +| Checkpoint for Fine tuning | 13.3M (.ckpt file) | 13.3M (.ckpt file) | +| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | #### SqueezeNet_Residual on CIFAR-10 -| Parameters | Ascend | -| -------------------------- | ----------------------------------------------------------- | -| Model Version | SqueezeNet_Residual | -| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | -| uploaded Date | 11/06/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | -| Dataset | CIFAR-10 | -| Training Parameters | epoch=150, steps=195, batch_size=32, lr=0.01 | -| Optimizer | Momentum | -| Loss Function | Softmax Cross Entropy | -| outputs | probability | -| Loss | 0.0641 | -| Speed | 1pc: 16.9 ms/step; 8pcs: 17.3 ms/step | -| Total time | 1pc: 68.6 mins; 8pcs: 20.9 mins | -| Parameters (M) | 4.8 | -| Checkpoint for Fine tuning | 6.5M (.ckpt file) | -| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | +| Parameters | Ascend | GPU | +| -------------------------- | ----------------------------------------------------------- | --- | +| Model Version | SqueezeNet_Residual | SqueezeNet_Residual | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G | +| uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.4.0 | +| Dataset | CIFAR-10 | CIFAR-10 | +| Training Parameters | epoch=150, steps=195, batch_size=32, lr=0.01 | 1pc:epoch=150, steps=1562, batch_size=32, lr=0.01; 8pcs: epoch=150, steps=1562, batch_size=4| +| Optimizer | Momentum | Momentum +| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy +| outputs | probability | probability +| Loss | 0.0641 | 1pc: 0.0402; 8pcs:0.004 | +| Speed | 1pc: 16.9 ms/step; 8pcs: 17.3 ms/step | 1pc: 29.4 ms/step; 8pcs:11.0 ms/step | +| Total time | 1pc: 68.6 mins; 8pcs: 20.9 mins | 1pc: 115 mins; 8pcs: 43.5 mins | +| Parameters (M) | 4.8 | 0.74 | +| Checkpoint for Fine tuning | 6.5M (.ckpt file) | 6.5M (.ckpt file) | +| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | #### SqueezeNet_Residual on ImageNet -| Parameters | Ascend | -| -------------------------- | ----------------------------------------------------------- | -| Model Version | SqueezeNet_Residual | -| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | -| uploaded Date | 11/06/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | -| Dataset | ImageNet | -| Training Parameters | epoch=300, steps=5004, batch_size=32, lr=0.01 | -| Optimizer | Momentum | -| Loss Function | Softmax Cross Entropy | -| outputs | probability | -| Loss | 2.9040 | -| Speed | 8pcs: 20.2 ms/step | -| Total time | 8pcs: 8.0 hours | -| Parameters (M) | 4.8 | -| Checkpoint for Fine tuning | 15.3M (.ckpt file) | -| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | +| Parameters | Ascend | GPU | +| -------------------------- | ----------------------------------------------------------- | --- | +| Model Version | SqueezeNet_Residual | SqueezeNet_Residual | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G | +| uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.4.0 | +| Dataset | ImageNet | ImageNet | +| Training Parameters | epoch=300, steps=5004, batch_size=32, lr=0.01 | epoch=300, steps=5004, batch_size=32, lr=0.01 | +| Optimizer | Momentum | Momentum | +| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy | +| outputs | probability | probability | +| Loss | 2.9040 | 2.969 | +| Speed | 8pcs: 20.2 ms/step | 8pcs: 44.1 ms/step | +| Total time | 8pcs: 8.0 hours | 8pcs: 18.4 hours | +| Parameters (M) | 4.8 | 1.25 | +| Checkpoint for Fine tuning | 15.3M (.ckpt file) | 15.3M (.ckpt file) | +| Scripts | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | [squeezenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/squeezenet) | ### Inference Performance #### SqueezeNet on CIFAR-10 -| Parameters | Ascend | -| ------------------- | --------------------------- | -| Model Version | SqueezeNet | -| Resource | Ascend 910; OS Euler2.8 | -| Uploaded Date | 11/06/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | -| Dataset | CIFAR-10 | -| batch_size | 32 | -| outputs | probability | -| Accuracy | 1pc: 89.0%; 8pcs: 84.4% | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------- | --- | +| Model Version | SqueezeNet | SqueezeNet | +| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G | +| Uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.4.0 | +| Dataset | CIFAR-10 | CIFAR-10 | +| batch_size | 32 | 1pc:32; 8pcs:4 | +| outputs | probability | probability | +| Accuracy | 1pc: 89.0%; 8pcs: 84.4% | 1pc: 89.0%; 8pcs: 88.8%| #### SqueezeNet on ImageNet -| Parameters | Ascend | -| ------------------- | --------------------------- | -| Model Version | SqueezeNet | -| Resource | Ascend 910; OS Euler2.8 | -| Uploaded Date | 11/06/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | -| Dataset | ImageNet | -| batch_size | 32 | -| outputs | probability | -| Accuracy | 8pcs: 58.5%(TOP1), 81.1%(TOP5) | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------- | --- | +| Model Version | SqueezeNet | SqueezeNet | +| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G | +| Uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.4.0 | +| Dataset | ImageNet | ImageNet | +| batch_size | 32 | 32 | +| outputs | probability | probability | +| Accuracy | 8pcs: 58.5%(TOP1), 81.1%(TOP5) | 8pcs: 58.5%(TOP1), 80.7%(TOP5) | #### SqueezeNet_Residual on CIFAR-10 -| Parameters | Ascend | -| ------------------- | --------------------------- | -| Model Version | SqueezeNet_Residual | -| Resource | Ascend 910; OS Euler2.8 | -| Uploaded Date | 11/06/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | -| Dataset | CIFAR-10 | -| batch_size | 32 | -| outputs | probability | -| Accuracy | 1pc: 90.8%; 8pcs: 87.4% | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------- | --- | +| Model Version | SqueezeNet_Residual | SqueezeNet_Residual | +| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G | +| Uploaded Date | 11/06/2020 (month/day/year) | 8/26/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.4.0 | +| Dataset | CIFAR-10 | CIFAR-10 | +| batch_size | 32 | 1pc:32; 8pcs:4 | +| outputs | probability | probability | +| Accuracy | 1pc: 90.8%; 8pcs: 87.4% | 1pc: 90.7%; 8pcs: 90.5% | #### SqueezeNet_Residual on ImageNet -| Parameters | Ascend | -| ------------------- | --------------------------- | -| Model Version | SqueezeNet_Residual | -| Resource | Ascend 910; OS Euler2.8 | -| Uploaded Date | 11/06/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | -| Dataset | ImageNet | -| batch_size | 32 | -| outputs | probability | -| Accuracy | 8pcs: 60.9%(TOP1), 82.6%(TOP5) | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------- | --- | +| Model Version | SqueezeNet_Residual | SqueezeNet_Residual | +| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G | +| Uploaded Date | 11/06/2020 (month/day/year) | 8/24/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.4.0 | +| Dataset | ImageNet | ImageNet | +| batch_size | 32 | 32 | +| outputs | probability | probability | +| Accuracy | 8pcs: 60.9%(TOP1), 82.6%(TOP5) | 8pcs: 60.2%(TOP1), 82.3%(TOP5)| ### 310 Inference Performance diff --git a/model_zoo/official/cv/squeezenet/scripts/run_distribute_train_gpu.sh b/model_zoo/official/cv/squeezenet/scripts/run_distribute_train_gpu.sh new file mode 100644 index 00000000000..01fca7257dc --- /dev/null +++ b/model_zoo/official/cv/squeezenet/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] && [ $# != 4 ] +then + echo "Usage: bash scripts/run_distribute_train_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" + exit 1 +fi + +if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ] +then + echo "error: the selected net is neither squeezenet nor squeezenet_residual" + exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet" + exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $3) + +if [ $# == 4 ] +then + PATH2=$(get_real_path $4) +fi + + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" + exit 1 +fi + +if [ $# == 4 ] && [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" + exit 1 +fi + +ulimit -u unlimited +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export RANK_SIZE=8 +BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")") +CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml" +if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml" +elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml" +elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml" +elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml" +else + echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}" +exit 1 +fi + +TRAIN_OUTPUT=$BASE_PATH/train_parallel_$1_$2 +if [ -d $TRAIN_OUTPUT ]; then + rm -rf $TRAIN_OUTPUT +fi +mkdir $TRAIN_OUTPUT +cp ./train.py $TRAIN_OUTPUT +cp -r ./src $TRAIN_OUTPUT +cp -r ./model_utils $TRAIN_OUTPUT +cp $CONFIG_FILE $TRAIN_OUTPUT +cd $TRAIN_OUTPUT || exit + +if [ $# == 3 ] +then + mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \ + python train.py --net_name=$1 --dataset=$2 --run_distribute=True --output_path='./output'\ + --device_target="GPU" --data_path=$PATH1 \ + --config_path=${CONFIG_FILE##*/} &> log & +fi + +if [ $# == 4 ] +then + mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \ + python train.py --net_name=$1 --dataset=$2 --run_distribute=True --output_path='./output'\ + --device_target="GPU" --data_path=$PATH1 --pre_trained=$PATH2 \ + --config_path=${CONFIG_FILE##*/} &> log & +fi +cd .. \ No newline at end of file diff --git a/model_zoo/official/cv/squeezenet/scripts/run_eval_gpu.sh b/model_zoo/official/cv/squeezenet/scripts/run_eval_gpu.sh new file mode 100644 index 00000000000..58d7f349b72 --- /dev/null +++ b/model_zoo/official/cv/squeezenet/scripts/run_eval_gpu.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 5 ] +then + echo "Usage: bash scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ] +then + echo "error: the selected net is neither squeezenet nor squeezenet_residual" +exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $4) +PATH2=$(get_real_path $5) + + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +expr $3 + 0 &>/dev/null +if [ $? != 0 ]; then + echo "DEVICE_ID=$3 is not an integer!" +exit 1 +fi + +ulimit -u unlimited +export CUDA_VISIBLE_DEVICES=$3 + +BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")") +CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml" + +if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml" +elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml" +elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml" +elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml" +else + echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}" +exit 1 +fi + +EVAL_OUTPUT=$BASE_PATH/eval_$3_$1_$2 +if [ -d $EVAL_OUTPUT ]; +then + rm -rf $EVAL_OUTPUT +fi +mkdir $EVAL_OUTPUT +cp ./eval.py $EVAL_OUTPUT +cp -r ./src $EVAL_OUTPUT +cp -r ./model_utils $EVAL_OUTPUT +cp $CONFIG_FILE $EVAL_OUTPUT +cd $EVAL_OUTPUT || exit +env > env.log +echo "start evaluation for device $3" +python eval.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --checkpoint_file_path=$PATH2 --device_target="GPU" \ +--config_path=${CONFIG_FILE##*/} --output_path='./output' &> log & +cd .. diff --git a/model_zoo/official/cv/squeezenet/scripts/run_standalone_train_gpu.sh b/model_zoo/official/cv/squeezenet/scripts/run_standalone_train_gpu.sh new file mode 100644 index 00000000000..de27e1d2d80 --- /dev/null +++ b/model_zoo/official/cv/squeezenet/scripts/run_standalone_train_gpu.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 4 ] && [ $# != 5 ] +then + echo "Usage: bash scripts/run_standalone_train.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATA_PATH] [PRETRAINED_CKPT_PATH](optional)" +exit 1 +fi + +if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ] +then + echo "error: the selected net is neither squeezenet nor squeezenet_residual" +exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $4) + +if [ $# == 5 ] +then + PATH2=$(get_real_path $5) +fi + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ $# == 5 ] && [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" +exit 1 +fi + +expr $3 + 0 &>/dev/null +if [ $? = 2 ]; then + echo "DEVICE_ID=$3 is not an integer!" +exit 1 +fi + +ulimit -u unlimited +export CUDA_VISIBLE_DEVICES=$3 +BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")") +CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml" +if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml" +elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml" +elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml" +elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then + CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml" +else + echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}" +exit 1 +fi + +TRAIN_OUTPUT=$BASE_PATH/train_standalone$3_$1_$2 +if [ -d $TRAIN_OUTPUT ]; +then + rm -rf $TRAIN_OUTPUT +fi +mkdir $TRAIN_OUTPUT +cp ./train.py $TRAIN_OUTPUT +cp -r ./src $TRAIN_OUTPUT +cp -r ./model_utils $TRAIN_OUTPUT +cp $CONFIG_FILE $TRAIN_OUTPUT +cd $TRAIN_OUTPUT || exit +echo "start training for device $3" +env > env.log +if [ $# == 4 ] +then + python train.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --config_path=${CONFIG_FILE##*/} \ + --output_path='./output' --device_target='GPU' &> log & +fi + +if [ $# == 5 ] +then + python train.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --pre_trained=$PATH2 \ + --config_path=${CONFIG_FILE##*/} --output_path='./output' --device_target='GPU' &> log & +fi +cd .. diff --git a/model_zoo/official/cv/squeezenet/squeezenet_cifar10_config.yaml b/model_zoo/official/cv/squeezenet/squeezenet_cifar10_config.yaml index 5bc7187786f..30ce52e4b6f 100644 --- a/model_zoo/official/cv/squeezenet/squeezenet_cifar10_config.yaml +++ b/model_zoo/official/cv/squeezenet/squeezenet_cifar10_config.yaml @@ -21,7 +21,7 @@ checkpoint_file_path: "suqeezenet_cifar10-120_195.ckpt" net_name: "suqeezenet" dataset : "cifar10" class_num: 10 -batch_size: 32 +global_batch_size: 32 loss_scale: 1024 momentum: 0.9 weight_decay: 0.0001 @@ -55,7 +55,7 @@ load_path: "The location of checkpoint for obs" device_target: "Target device type, available: [Ascend, GPU, CPU]" enable_profiling: "Whether enable profiling while training, default: False" num_classes: "Class for dataset" -batch_size: "Batch size for training and evaluation" +global_batch_size: "The total batch_size for training and evaluation" epoch_size: "Total training epochs." keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" checkpoint_path: "The location of the checkpoint file." diff --git a/model_zoo/official/cv/squeezenet/squeezenet_imagenet_config.yaml b/model_zoo/official/cv/squeezenet/squeezenet_imagenet_config.yaml index 1c2ed4f34da..0e203981f3e 100644 --- a/model_zoo/official/cv/squeezenet/squeezenet_imagenet_config.yaml +++ b/model_zoo/official/cv/squeezenet/squeezenet_imagenet_config.yaml @@ -21,7 +21,7 @@ checkpoint_file_path: "suqeezenet_imagenet-200_5004.ckpt" net_name: "suqeezenet" dataset : "imagenet" class_num: 1000 -batch_size: 32 +global_batch_size: 256 loss_scale: 1024 momentum: 0.9 weight_decay: 0.00007 @@ -57,7 +57,7 @@ load_path: 'The location of checkpoint for obs' device_target: 'Target device type, available: [Ascend, GPU, CPU]' enable_profiling: 'Whether enable profiling while training, default: False' num_classes: 'Class for dataset' -batch_size: "Batch size for training and evaluation" +global_batch_size: "The total batch_size for training and evaluation" epoch_size: "Total training epochs." keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" checkpoint_path: "The location of the checkpoint file." diff --git a/model_zoo/official/cv/squeezenet/squeezenet_residual_cifar10_config.yaml b/model_zoo/official/cv/squeezenet/squeezenet_residual_cifar10_config.yaml index 2f00f6811dc..95adc46da04 100644 --- a/model_zoo/official/cv/squeezenet/squeezenet_residual_cifar10_config.yaml +++ b/model_zoo/official/cv/squeezenet/squeezenet_residual_cifar10_config.yaml @@ -21,7 +21,7 @@ checkpoint_file_path: "suqeezenet_residual_cifar10-150_195.ckpt" net_name: "suqeezenet_residual" dataset : "cifar10" class_num: 10 -batch_size: 32 +global_batch_size: 32 loss_scale: 1024 momentum: 0.9 weight_decay: 0.0001 @@ -55,7 +55,7 @@ load_path: "The location of checkpoint for obs" device_target: "Target device type, available: [Ascend, GPU, CPU]" enable_profiling: "Whether enable profiling while training, default: False" num_classes: "Class for dataset" -batch_size: "Batch size for training and evaluation" +global_batch_size: "The total batch_size for training and evaluation." epoch_size: "Total training epochs." keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" checkpoint_path: "The location of the checkpoint file." diff --git a/model_zoo/official/cv/squeezenet/squeezenet_residual_imagenet_config.yaml b/model_zoo/official/cv/squeezenet/squeezenet_residual_imagenet_config.yaml index 9b4ea654da0..8aa84960ff7 100644 --- a/model_zoo/official/cv/squeezenet/squeezenet_residual_imagenet_config.yaml +++ b/model_zoo/official/cv/squeezenet/squeezenet_residual_imagenet_config.yaml @@ -21,7 +21,7 @@ checkpoint_file_path: "suqeezenet_residual_imagenet-300_5004.ckpt" net_name: "suqeezenet_residual" dataset : "imagenet" class_num: 1000 -batch_size: 32 +global_batch_size: 256 loss_scale: 1024 momentum: 0.9 weight_decay: 0.00007 @@ -57,7 +57,7 @@ load_path: "The location of checkpoint for obs" device_target: "Target device type, available: [Ascend, GPU, CPU]" enable_profiling: "Whether enable profiling while training, default: False" num_classes: "Class for dataset" -batch_size: "Batch size for training and evaluation" +global_batch_size: "The total batch_size for training and evaluation" epoch_size: "Total training epochs." keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" checkpoint_path: "The location of the checkpoint file." diff --git a/model_zoo/official/cv/squeezenet/train.py b/model_zoo/official/cv/squeezenet/train.py index ad6a751331a..d58d5b5b0cb 100755 --- a/model_zoo/official/cv/squeezenet/train.py +++ b/model_zoo/official/cv/squeezenet/train.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ """train squeezenet.""" -import os from mindspore import context from mindspore import Tensor from mindspore.nn.optim.momentum import Momentum @@ -27,6 +26,7 @@ from mindspore.communication.management import init, get_rank, get_group_size from mindspore.common import set_seed from model_utils.config import config from model_utils.moxing_adapter import moxing_wrapper +from model_utils.device_adapter import get_device_id from src.lr_generator import get_lr from src.CrossEntropySmooth import CrossEntropySmooth @@ -54,33 +54,37 @@ def train_net(): # init context context.set_context(mode=context.GRAPH_MODE, device_target=target) + device_num = 1 if config.run_distribute: if target == "Ascend": - device_id = int(os.getenv('DEVICE_ID')) + device_id = get_device_id() + device_num = config.device_num context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_auto_parallel_context( - device_num=config.device_num, + device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) init() # GPU target else: - print("Squeezenet training on GPU performs badly now, and it is still in research..." - "See model_zoo/research/cv/squeezenet to get up-to-date details.") init() + device_num = get_group_size() context.set_auto_parallel_context( - device_num=get_group_size(), + device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str( get_rank()) + "/" - + # obtain the actual batch_size + if not hasattr(config, "global_batch_size"): + raise AttributeError("'config' object has no attribute 'global_batch_size', please check the yaml file.") + batch_size = max(config.global_batch_size // device_num, 1) # create dataset dataset = create_dataset(dataset_path=config.data_path, do_train=True, repeat_num=1, - batch_size=config.batch_size, + batch_size=batch_size, target=target) step_size = dataset.get_dataset_size() @@ -132,10 +136,6 @@ def train_net(): amp_level="O2", keep_batchnorm_fp32=False) else: - if target == "GPU": - # GPU target - print("Squeezenet training on GPU performs badly now, and it is still in research..." - "See model_zoo/research/cv/squeezenet to get up-to-date details.") opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,