From d7e2bf974f70ecbe0a29b9811321eef593da89a7 Mon Sep 17 00:00:00 2001 From: Payne Date: Thu, 3 Sep 2020 22:24:20 +0800 Subject: [PATCH] add incremental learn fun --- model_zoo/official/cv/mobilenetv2/Readme.md | 124 ++++--- model_zoo/official/cv/mobilenetv2/eval.py | 71 ++-- .../cv/mobilenetv2/scripts/run_infer.sh | 137 ++++++-- .../cv/mobilenetv2/scripts/run_train.sh | 47 ++- model_zoo/official/cv/mobilenetv2/src/args.py | 65 ++++ .../official/cv/mobilenetv2/src/config.py | 108 ++++-- .../official/cv/mobilenetv2/src/dataset.py | 47 ++- .../official/cv/mobilenetv2/src/launch.py | 41 +-- .../cv/mobilenetv2/src/mobilenetV2.py | 150 ++++++-- .../official/cv/mobilenetv2/src/models.py | 138 ++++++++ .../official/cv/mobilenetv2/src/utils.py | 93 +++++ model_zoo/official/cv/mobilenetv2/train.py | 325 +++++------------- 12 files changed, 872 insertions(+), 474 deletions(-) create mode 100644 model_zoo/official/cv/mobilenetv2/src/args.py create mode 100644 model_zoo/official/cv/mobilenetv2/src/models.py create mode 100644 model_zoo/official/cv/mobilenetv2/src/utils.py diff --git a/model_zoo/official/cv/mobilenetv2/Readme.md b/model_zoo/official/cv/mobilenetv2/Readme.md index 430d5b75865..cc92dde8780 100644 --- a/model_zoo/official/cv/mobilenetv2/Readme.md +++ b/model_zoo/official/cv/mobilenetv2/Readme.md @@ -4,23 +4,22 @@ - [Model Architecture](#model-architecture) - [Dataset](#dataset) - [Features](#features) - - [Mixed Precision](#mixed-precision) + - [Mixed Precision](#mixed-precision) - [Environment Requirements](#environment-requirements) - [Script Description](#script-description) - - [Script and Sample Code](#script-and-sample-code) + - [Script and Sample Code](#script-and-sample-code) - [Training Process](#training-process) - [Evaluation Process](#evaluation-process) - - [Evaluation](#evaluation) + - [Evaluation](#evaluation) - [Model Description](#model-description) - - [Performance](#performance) - - [Training Performance](#evaluation-performance) - - [Inference Performance](#evaluation-performance) + - [Performance](#performance) + - [Training Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) - [Description of Random Situation](#description-of-random-situation) - [ModelZoo Homepage](#modelzoo-homepage) # [MobileNetV2 Description](#contents) - MobileNetV2 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for MobileNetV2." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. @@ -36,78 +35,103 @@ The overall network architecture of MobileNetV2 is show below: Dataset used: [imagenet](http://www.image-net.org/) - Dataset size: ~125G, 1.2W colorful images in 1000 classes - - Train: 120G, 1.2W images - - Test: 5G, 50000 images + - Train: 120G, 1.2W images + - Test: 5G, 50000 images - Data format: RGB images. - - Note: Data will be processed in src/dataset.py - + - Note: Data will be processed in src/dataset.py # [Features](#contents) ## [Mixed Precision(Ascend)](#contents) -The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. +The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. # [Environment Requirements](#contents) -- Hardware(Ascend/GPU) - - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Hardware(Ascend/GPU/CPU) + - Prepare hardware environment with Ascend、GPU or CPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Framework - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) - For more information, please check the resources below: - - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - # [Script description](#contents) ## [Script and sample code](#contents) ```python -├── MobileNetV2 - ├── Readme.md # descriptions about MobileNetV2 - ├── scripts - │ ├──run_train.sh # shell script for train - │ ├──run_eval.sh # shell script for evaluation - ├── src - │ ├──config.py # parameter configuration +├── MobileNetV2 + ├── Readme.md # descriptions about MobileNetV2 + ├── scripts + │ ├──run_train.sh # shell script for train, fine_tune or incremental learn with CPU, GPU or Ascend + │ ├──run_eval.sh # shell script for evaluation with CPU, GPU or Ascend + ├── src + │ ├──args.py # parse args + │ ├──config.py # parameter configuration │ ├──dataset.py # creating dataset │ ├──launch.py # start python script - │ ├──lr_generator.py # learning rate config + │ ├──lr_generator.py # learning rate config │ ├──mobilenetV2.py # MobileNetV2 architecture + │ ├──models.py # contain define_net and Loss, Monitor + │ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn ├── train.py # training script - ├── eval.py # evaluation script + ├── eval.py # evaluation script ``` ## [Training process](#contents) ### Usage - You can start training using python or shell scripts. The usage of shell scripts as follows: -- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] -- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] +- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] +- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] +- CPU: sh run_trian.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] ### Launch -``` +``` # training example python: - Ascend: python train.py --dataset_path ~/imagenet/train/ --device_targe Ascend - GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU + Ascend: python train.py --dataset_path ~/imagenet/train/ --platform Ascend --train_method train + GPU: python train.py --dataset_path ~/imagenet/train/ --platform GPU --train_method train + CPU: python train.py --dataset_path ~/imagenet/train/ --platform CPU --train_method train shell: - Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ mobilenet_199.ckpt - GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ + Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ train + GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ train + CPU: sh run_train.sh CPU ~/imagenet/train/ train + +# fine tune example + python: + Ascend: python train.py --dataset_path ~/imagenet/train/ --platform Ascend --train_method fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt + GPU: python train.py --dataset_path ~/imagenet/train/ --platform GPU --train_method fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt + CPU: python train.py --dataset_path ~/imagenet/train/ --platform CPU --train_method fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt + + shell: + Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt + GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt + CPU: sh run_train.sh CPU ~/imagenet/train/ fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt + +# incremental learn example + python: + Ascend: python train.py --dataset_path ~/imagenet/train/ --platform Ascend --train_method incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt + GPU: python train.py --dataset_path ~/imagenet/train/ --platform GPU --train_method incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt + CPU: python train.py --dataset_path ~/imagenet/train/ --platform CPU --train_method incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt + + shell: + Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt + GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt + CPU: sh run_train.sh CPU ~/imagenet/train/ incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt ``` ### Result -Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. -``` +``` epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] @@ -120,29 +144,32 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 You can start training using python or shell scripts. The usage of shell scripts as follows: -- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] -- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] +- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] [HEAD_CKPT_PATH] +- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] [HEAD_CKPT_PATH] +- CPU: sh run_infer.sh CPU [DATASET_PATH] [BACKBONE_CKPT_PATH] [HEAD_CKPT_PATH] ### Launch -``` +``` # infer example python: - Ascend: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe Ascend - GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU + Ascend: python eval.py --dataset_path ~/imagenet/val/ --pretrain_ckpt ~/train/mobilenet-200_625.ckpt --platform Ascend --head_ckpt ./checkpoint/mobilenetv2_199.ckpt + GPU: python eval.py --dataset_path ~/imagenet/val/ --pretrain_ckpt ~/train/mobilenet-200_625.ckpt --platform GPU --head_ckpt ./checkpoint/mobilenetv2_199.ckpt + CPU: python eval.py --dataset_path ~/imagenet/val/ --pretrain_ckpt ~/train/mobilenet-200_625.ckpt --platform CPU --head_ckpt ./checkpoint/mobilenetv2_199.ckpt shell: - Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt - GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt + Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt ./checkpoint/mobilenetv2_199.ckpt + GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt ./checkpoint/mobilenetv2_199.ckpt + CPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt ./checkpoint/mobilenetv2_199.ckpt ``` -> checkpoint can be produced in training process. +> checkpoint can be produced in training process. ### Result -Inference result will be stored in the example path, you can find result like the followings in `val.log`. +Inference result will be stored in the example path, you can find result like the followings in `val.log`. -``` +``` result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt ``` @@ -177,7 +204,7 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625. | Model Version | V1 | | | | Resource | Ascend 910 | NV SMX2 V100-32G | Ascend 310 | | uploaded Date | 05/06/2020 | 05/22/2020 | | -| MindSpore Version | 0.2.0 | 0.2.0 | 0.2.0 | +| MindSpore Version | 0.2.0 | 0.2.0 | 0.2.0 | | Dataset | ImageNet, 1.2W | ImageNet, 1.2W | ImageNet, 1.2W | | batch_size | | 130(8P) | | | outputs | | | | @@ -191,6 +218,5 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625. In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. # [ModelZoo Homepage](#contents) - -Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). - \ No newline at end of file + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/mobilenetv2/eval.py b/model_zoo/official/cv/mobilenetv2/eval.py index e4ac99013ca..967eae9b9a3 100644 --- a/model_zoo/official/cv/mobilenetv2/eval.py +++ b/model_zoo/official/cv/mobilenetv2/eval.py @@ -15,62 +15,43 @@ """ eval. """ -import os -import argparse -from mindspore import context from mindspore import nn from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import dtype as mstype + from src.dataset import create_dataset -from src.config import config_ascend, config_gpu -from src.mobilenetV2 import mobilenet_v2 - - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--device_target', type=str, default=None, help='run device_target') -args_opt = parser.parse_args() - +from src.config import set_config +from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2 +from src.args import eval_parse_args +from src.models import load_ckpt +from src.utils import switch_precision, set_context if __name__ == '__main__': - config = None - net = None - if args_opt.device_target == "Ascend": - config = config_ascend - device_id = int(os.getenv('DEVICE_ID', '0')) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", - device_id=device_id, save_graphs=False) - net = mobilenet_v2(num_classes=config.num_classes, device_target="Ascend") - elif args_opt.device_target == "GPU": - config = config_gpu - context.set_context(mode=context.GRAPH_MODE, - device_target="GPU", save_graphs=False) - net = mobilenet_v2(num_classes=config.num_classes, device_target="GPU") + args_opt = eval_parse_args() + config = set_config(args_opt) + + backbone_net = MobileNetV2Backbone(platform=args_opt.platform) + head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) + net = mobilenet_v2(feature_net, head_net) + + #load the trained checkpoint file to the net for evaluation + if args_opt.head_ckpt: + load_ckpt(backbone_net, args_opt.pretrain_ckpt) + load_ckpt(head_net, args_opt.head_ckpt) else: - raise ValueError("Unsupported device_target.") + load_ckpt(net, args_opt.pretrain_ckpt) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + set_context(config) + switch_precision(net, mstype.float16, config) - if args_opt.device_target == "Ascend": - net.to_float(mstype.float16) - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.to_float(mstype.float32) - - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=False, - config=config, - device_target=args_opt.device_target, - batch_size=config.batch_size) + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config) step_size = dataset.get_dataset_size() - - if args_opt.checkpoint_path: - param_dict = load_checkpoint(args_opt.checkpoint_path) - load_param_into_net(net, param_dict) net.set_train(False) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') model = Model(net, loss_fn=loss, metrics={'acc'}) + res = model.eval(dataset) - print("result:", res, "ckpt=", args_opt.checkpoint_path) + print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}") + if args_opt.head_ckpt: + print(f"head_ckpt={args_opt.head_ckpt}") diff --git a/model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh index c596e33a23e..5f175578f43 100644 --- a/model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh +++ b/model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh @@ -13,10 +13,106 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -if [ $# != 3 ] + + + +run_ascend() +{ + # check checkpoint file + if [ ! -f $3 ] + then + echo "error: CHECKPOINT_PATH=$3 is not a file" + exit 1 + fi + + # set environment + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + export DEVICE_ID=0 + export RANK_ID=0 + export RANK_SIZE=1 + if [ -d "../eval" ]; + then + rm -rf ../eval + fi + mkdir ../eval + cd ../eval || exit + + # launch + python ${BASEPATH}/../eval.py \ + --platform=$1 \ + --dataset_path=$2 \ + --pretrain_ckpt=$3 \ + --head_ckpt=$4 \ + &> ../infer.log & # dataset val folder path +} + +run_gpu() +{ + # check checkpoint file + if [ ! -f $3 ] + then + echo "error: CHECKPOINT_PATH=$3 is not a file" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../eval" ]; + then + rm -rf ../eval + fi + mkdir ../eval + cd ../eval || exit + + python ${BASEPATH}/../eval.py \ + --platform=$1 \ + --dataset_path=$2 \ + --pretrain_ckpt=$3 \ + --head_ckpt=$4 \ + &> ../infer.log & # dataset train folder +} + +run_cpu() +{ + # check checkpoint file + if [ ! -f $3 ] + then + echo "error: BACKBONE_CKPT=$3 is not a file" + exit 1 + fi + + # check checkpoint file + if [ ! -f $4 ] + then + echo "error: HEAD_CKPT=$4 is not a file" + exit 1 + fi + + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../eval" ]; + then + rm -rf ../eval + fi + mkdir ../eval + cd ../eval || exit + + python ${BASEPATH}/../eval.py \ + --platform=$1 \ + --dataset_path=$2 \ + --pretrain_ckpt=$3 \ + --head_ckpt=$4 \ + &> ../infer.log & # dataset train folder +} + + +if [ $# -gt 4 ] || [ $# -lt 3 ] then - echo "Ascend: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH] \ - GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" + echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] \ + GPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] + CPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]" exit 1 fi @@ -27,29 +123,12 @@ then exit 1 fi -# check checkpoint file -if [ ! -f $3 ] -then - echo "error: CHECKPOINT_PATH=$3 is not a file" -exit 1 -fi - -# set environment -BASEPATH=$(cd "`dirname $0`" || exit; pwd) -export PYTHONPATH=${BASEPATH}:$PYTHONPATH -export DEVICE_ID=0 -export RANK_ID=0 -export RANK_SIZE=1 -if [ -d "../eval" ]; -then - rm -rf ../eval -fi -mkdir ../eval -cd ../eval || exit - -# launch -python ${BASEPATH}/../eval.py \ - --device_target=$1 \ - --dataset_path=$2 \ - --checkpoint_path=$3 \ - &> ../infer.log & # dataset val folder path +if [ $1 = "CPU" ] ; then + run_cpu "$@" +elif [ $1 = "GPU" ] ; then + run_gpu "$@" +elif [ $1 = "Ascend" ] ; then + run_ascend "$@" +else + echo "Unsupported device_target." +fi; diff --git a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh index 0cd24a0f2fe..5d4f7544569 100644 --- a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh @@ -38,12 +38,14 @@ run_ascend() mkdir ../train cd ../train || exit python ${BASEPATH}/../src/launch.py \ + --platform=$1 \ --nproc_per_node=$2 \ --visible_devices=$3 \ --training_script=${BASEPATH}/../train.py \ --dataset_path=$5 \ - --pre_trained=$6 \ - --device_target=$1 &> ../train.log & # dataset train folder + --train_method=$6 \ + --pretrain_ckpt=$7 \ + &> ../train.log & # dataset train folder } run_gpu() @@ -72,17 +74,45 @@ run_gpu() export CUDA_VISIBLE_DEVICES="$3" mpirun -n $2 --allow-run-as-root \ python ${BASEPATH}/../train.py \ + --platform=$1 \ --dataset_path=$4 \ - --pre_trained=$5 \ - --device_target=$1 \ + --train_method=$5 \ + --pretrain_ckpt=$6 \ &> ../train.log & # dataset train folder } -if [ $# -gt 6 ] || [ $# -lt 4 ] +run_cpu() +{ + + if [ ! -d $4 ] + then + echo "error: DATASET_PATH=$4 is not a directory" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + + python ${BASEPATH}/../train.py \ + --platform=$1 \ + --dataset_path=$2 \ + --train_method=$3 \ + --pretrain_ckpt=$4 \ + &> ../train.log & # dataset train folder +} + +if [ $# -gt 7 ] || [ $# -lt 4 ] then echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]\n \ - GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] \n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \ + CPU: sh run_train.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \ " exit 1 fi @@ -91,7 +121,8 @@ if [ $1 = "Ascend" ] ; then run_ascend "$@" elif [ $1 = "GPU" ] ; then run_gpu "$@" +elif [ $1 = "CPU" ] ; then + run_cpu "$@" else echo "Unsupported device_target." fi; - diff --git a/model_zoo/official/cv/mobilenetv2/src/args.py b/model_zoo/official/cv/mobilenetv2/src/args.py new file mode 100644 index 00000000000..184b65a6326 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2/src/args.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse + + +def launch_parse_args(): + + launch_parser = argparse.ArgumentParser(description="mindspore distributed training launch helper utilty \ + that will spawn up multiple distributed processes") + launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \ + help='run platform, only support GPU, CPU and Ascend') + launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(0, 1, 2, 3, 4, 5, 6, 7), \ + help="The number of processes to launch on each node, for D training, this is recommended to be set \ + to the number of D in your system so that each process can be bound to a single D.") + launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \ + visible devices sequentially") + launch_parser.add_argument("--training_script", type=str, default="./train.py", help="The full path to \ + the single D training program/script to be launched in parallel, followed by all the arguments for \ + the training script") + + launch_args, unknown = launch_parser.parse_known_args() + launch_args.train_script_args = unknown + launch_args.training_script_args += ["--platform", launch_args.platform] + return launch_args + +def train_parse_args(): + train_parser = argparse.ArgumentParser(description='Image classification trian') + train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') + train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \ + help='run platform, only support CPU, GPU and Ascend') + train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ + for fine tune or incremental learning') + train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \ + help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \ + train from initialization model") + + train_args = train_parser.parse_args() + return train_args + +def eval_parse_args(): + eval_parser = argparse.ArgumentParser(description='Image classification eval') + eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') + eval_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \ + help='run platform, only support GPU, CPU and Ascend') + eval_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ + for fine tune or incremental learning') + eval_parser.add_argument('--head_ckpt', type=str, default=None, help='Pretrained checkpoint path \ + for fine tune or incremental learning') + eval_args = eval_parser.parse_args() + + return eval_args + \ No newline at end of file diff --git a/model_zoo/official/cv/mobilenetv2/src/config.py b/model_zoo/official/cv/mobilenetv2/src/config.py index 98e0aef0ec6..07c94a8db4d 100644 --- a/model_zoo/official/cv/mobilenetv2/src/config.py +++ b/model_zoo/official/cv/mobilenetv2/src/config.py @@ -15,40 +15,80 @@ """ network config setting, will be used in train.py and eval.py """ +import os from easydict import EasyDict as ed -config_ascend = ed({ - "num_classes": 1000, - "image_height": 224, - "image_width": 224, - "batch_size": 256, - "epoch_size": 200, - "warmup_epochs": 4, - "lr": 0.4, - "momentum": 0.9, - "weight_decay": 4e-5, - "label_smooth": 0.1, - "loss_scale": 1024, - "save_checkpoint": True, - "save_checkpoint_epochs": 1, - "keep_checkpoint_max": 200, - "save_checkpoint_path": "./checkpoint", -}) +def set_config(args): + config_cpu = ed({ + "num_classes": 26, + "image_height": 224, + "image_width": 224, + "batch_size": 150, + "epoch_size": 15, + "warmup_epochs": 0, + "lr_max": 0.03, + "lr_end": 0.03, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 20, + "save_checkpoint_path": "./checkpoint", + "platform": args.platform + }) + config_gpu = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 150, + "epoch_size": 200, + "warmup_epochs": 0, + "lr": 0.8, + "lr_max": 0.03, + "lr_end": 0.03, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 200, + "save_checkpoint_path": "./checkpoint", + "platform": args.platform, + "ccl": "nccl", + }) + config_ascend = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 256, + "epoch_size": 200, + "warmup_epochs": 4, + "lr": 0.4, + "lr_max": 0.03, + "lr_end": 0.03, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 200, + "save_checkpoint_path": "./checkpoint", + "platform": args.platform, + "ccl": "hccl", + "device_id": int(os.getenv('DEVICE_ID', '0')), + "rank_id": int(os.getenv('RANK_ID', '0')), + "rank_size": int(os.getenv('RANK_SIZE', '1')), + "run_distribute": int(os.getenv('RANK_SIZE', '1')) > 1. + }) + config = ed({"CPU": config_cpu, + "GPU": config_gpu, + "Ascend": config_ascend}) -config_gpu = ed({ - "num_classes": 1000, - "image_height": 224, - "image_width": 224, - "batch_size": 150, - "epoch_size": 200, - "warmup_epochs": 0, - "lr": 0.8, - "momentum": 0.9, - "weight_decay": 4e-5, - "label_smooth": 0.1, - "loss_scale": 1024, - "save_checkpoint": True, - "save_checkpoint_epochs": 1, - "keep_checkpoint_max": 200, - "save_checkpoint_path": "./checkpoint", -}) + if args.platform not in config.keys(): + raise ValueError("Unsupport platform.") + + return config[args.platform] diff --git a/model_zoo/official/cv/mobilenetv2/src/dataset.py b/model_zoo/official/cv/mobilenetv2/src/dataset.py index 528ead053cf..1180d7b021a 100644 --- a/model_zoo/official/cv/mobilenetv2/src/dataset.py +++ b/model_zoo/official/cv/mobilenetv2/src/dataset.py @@ -16,25 +16,31 @@ create train or eval dataset. """ import os +from tqdm import tqdm +import numpy as np + +from mindspore import Tensor +from mindspore.train.model import Model import mindspore.common.dtype as mstype import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 -def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32): +def create_dataset(dataset_path, do_train, config, repeat_num=1): """ create a train or eval dataset Args: dataset_path(string): the path of dataset. do_train(bool): whether dataset is used for train or eval. + config(struct): the config of train and eval in diffirent platform. repeat_num(int): the repeat times of dataset. Default: 1. - batch_size(int): the batch size of dataset. Default: 32. + Returns: dataset """ - if device_target == "Ascend": + if config.platform == "Ascend": rank_size = int(os.getenv("RANK_SIZE", '1')) rank_id = int(os.getenv("RANK_ID", '0')) if rank_size == 1: @@ -42,15 +48,16 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, else: ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=rank_size, shard_id=rank_id) - elif device_target == "GPU": + elif config.platform == "GPU": if do_train: from mindspore.communication.management import get_rank, get_group_size ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=get_group_size(), shard_id=get_rank()) else: ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) - else: - raise ValueError("Unsupported device_target.") + elif config.platform == "CPU": + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + resize_height = config.image_height resize_width = config.image_width @@ -81,9 +88,35 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, ds = ds.shuffle(buffer_size=buffer_size) # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.batch(config.batch_size, drop_remainder=True) # apply dataset repeat operation ds = ds.repeat(repeat_num) return ds + +def extract_features(net, dataset_path, config): + features_folder = dataset_path + '_features' + if not os.path.exists(features_folder): + os.makedirs(features_folder) + dataset = create_dataset(dataset_path=dataset_path, + do_train=False, + config=config, + repeat_num=1) + step_size = dataset.get_dataset_size() + pbar = tqdm(list(dataset.create_dict_iterator())) + model = Model(net) + i = 0 + for data in pbar: + features_path = os.path.join(features_folder, f"feature_{i}.npy") + label_path = os.path.join(features_folder, f"label_{i}.npy") + if not(os.path.exists(features_path) and os.path.exists(label_path)): + image = data["image"] + label = data["label"] + features = model.predict(Tensor(image)) + np.save(features_path, features.asnumpy()) + np.save(label_path, label) + pbar.set_description("Process dataset batch: %d"%(i+1)) + i += 1 + + return step_size diff --git a/model_zoo/official/cv/mobilenetv2/src/launch.py b/model_zoo/official/cv/mobilenetv2/src/launch.py index f5c97b0bd70..0b42a5d753c 100644 --- a/model_zoo/official/cv/mobilenetv2/src/launch.py +++ b/model_zoo/official/cv/mobilenetv2/src/launch.py @@ -17,44 +17,11 @@ import os import sys import subprocess import shutil -from argparse import ArgumentParser - -def parse_args(): - """ - parse args . - - Args: - - Returns: - args. - - Examples: - >>> parse_args() - """ - parser = ArgumentParser(description="mindspore distributed training launch " - "helper utilty that will spawn up " - "multiple distributed processes") - parser.add_argument("--nproc_per_node", type=int, default=1, - help="The number of processes to launch on each node, " - "for D training, this is recommended to be set " - "to the number of D in your system so that " - "each process can be bound to a single D.") - parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", - help="will use the visible devices sequentially") - parser.add_argument("--training_script", type=str, - help="The full path to the single D training " - "program/script to be launched in parallel, " - "followed by all the arguments for the " - "training script") - # rest from the training program - args, unknown = parser.parse_known_args() - args.training_script_args = unknown - return args - +from args import launch_parse_args def main(): print("start", __file__) - args = parse_args() + args = launch_parse_args() print(args) visible_devices = args.visible_devices.split(',') assert os.path.isfile(args.training_script) @@ -79,8 +46,8 @@ def main(): os.mkdir(device_dir) os.chdir(device_dir) cmd = [sys.executable, '-u'] - cmd.append(args.training_script) - cmd.extend(args.training_script_args) + cmd.append(args.train_script) + cmd.extend(args.train_script_args) log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) processes.append(process) diff --git a/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py b/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py index 6d0b6f38e0f..a0ecc250349 100644 --- a/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py +++ b/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py @@ -20,7 +20,7 @@ from mindspore.ops.operations import TensorAdd from mindspore import Parameter, Tensor from mindspore.common.initializer import initializer -__all__ = ['mobilenet_v2'] +__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2'] def _make_divisible(v, divisor, min_value=None): @@ -119,17 +119,19 @@ class ConvBNReLU(nn.Cell): >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) """ - def __init__(self, device_target, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 if groups == 1: conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) else: - if device_target == "Ascend": + if platform in ("CPU", "GPU"): + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, group=in_planes, pad_mode='pad', \ + padding=padding) + elif platform == "Ascend": conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - elif device_target == "GPU": - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, - group=in_planes, pad_mode='pad', padding=padding) + else: + raise ValueError("Unsupported Device, only support CPU, GPU and Ascend.") layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] self.features = nn.SequentialCell(layers) @@ -156,7 +158,7 @@ class InvertedResidual(nn.Cell): >>> ResidualBlock(3, 256, 1, 1) """ - def __init__(self, device_target, inp, oup, stride, expand_ratio): + def __init__(self, platform, inp, oup, stride, expand_ratio): super(InvertedResidual, self).__init__() assert stride in [1, 2] @@ -165,10 +167,10 @@ class InvertedResidual(nn.Cell): layers = [] if expand_ratio != 1: - layers.append(ConvBNReLU(device_target, inp, hidden_dim, kernel_size=1)) + layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) layers.extend([ # dw - ConvBNReLU(device_target, hidden_dim, hidden_dim, + ConvBNReLU(platform, hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), # pw-linear nn.Conv2d(hidden_dim, oup, kernel_size=1, @@ -186,8 +188,7 @@ class InvertedResidual(nn.Cell): return self.add(identity, x) return x - -class MobileNetV2(nn.Cell): +class MobileNetV2Backbone(nn.Cell): """ MobileNetV2 architecture. @@ -204,12 +205,10 @@ class MobileNetV2(nn.Cell): >>> MobileNetV2(num_classes=1000) """ - def __init__(self, device_target, num_classes=1000, width_mult=1., - has_dropout=False, inverted_residual_setting=None, round_nearest=8): - super(MobileNetV2, self).__init__() + def __init__(self, platform, width_mult=1., inverted_residual_setting=None, round_nearest=8, + input_channel=32, last_channel=1280): + super(MobileNetV2Backbone, self).__init__() block = InvertedResidual - input_channel = 32 - last_channel = 1280 # setting of inverted residual blocks self.cfgs = inverted_residual_setting if inverted_residual_setting is None: @@ -227,28 +226,22 @@ class MobileNetV2(nn.Cell): # building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features = [ConvBNReLU(device_target, 3, input_channel, stride=2)] + features = [ConvBNReLU(platform, 3, input_channel, stride=2)] # building inverted residual blocks for t, c, n, s in self.cfgs: output_channel = _make_divisible(c * width_mult, round_nearest) for i in range(n): stride = s if i == 0 else 1 - features.append(block(device_target, input_channel, output_channel, stride, expand_ratio=t)) + features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) input_channel = output_channel # building last several layers - features.append(ConvBNReLU(device_target, input_channel, self.out_channels, kernel_size=1)) + features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) # make it nn.CellList self.features = nn.SequentialCell(features) - # mobilenet head - head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else - [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) - self.head = nn.SequentialCell(head) - self._initialize_weights() def construct(self, x): x = self.features(x) - x = self.head(x) return x def _initialize_weights(self): @@ -277,16 +270,115 @@ class MobileNetV2(nn.Cell): Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) m.beta.set_parameter_data( Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) - elif isinstance(m, nn.Dense): + + @property + def get_features(self): + return self.features + +class MobileNetV2Head(nn.Cell): + """ + MobileNetV2 architecture. + + Args: + class_num (Cell): number of classes. + has_dropout (bool): Is dropout used. Default is false + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(num_classes=1000) + """ + + def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False): + super(MobileNetV2Head, self).__init__() + # mobilenet head + head = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else + [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)]) + self.head = nn.SequentialCell(head) + self._initialize_weights() + + def construct(self, x): + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, nn.Dense): m.weight.set_parameter_data(Tensor(np.random.normal( 0, 0.01, m.weight.data.shape).astype("float32"))) if m.bias is not None: m.bias.set_parameter_data( Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + @property + def get_head(self): + return self.head +class MobileNetV2(nn.Cell): + """ + MobileNetV2 architecture. -def mobilenet_v2(**kwargs): + Args: + backbone(nn.Cell): + head(nn.Cell): + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(backbone, head) """ - Constructs a MobileNet V2 model + + def __init__(self, platform, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None, \ + round_nearest=8, input_channel=32, last_channel=1280): + super(MobileNetV2, self).__init__() + self.backbone = MobileNetV2Backbone(platform=platform, width_mult=width_mult, \ + inverted_residual_setting=inverted_residual_setting, \ + round_nearest=round_nearest, input_channel=input_channel, last_channel=last_channel).get_features + self.head = MobileNetV2Head(input_channel=self.backbone.out_channel, num_classes=num_classes, \ + has_dropout=has_dropout).get_head + + def construct(self, x): + x = self.backbone(x) + x = self.head(x) + return x + +class MobileNetV2Combine(nn.Cell): """ - return MobileNetV2(**kwargs) + MobileNetV2 architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(num_classes=1000) + """ + + def __init__(self, backbone, head): + super(MobileNetV2Combine, self).__init__() + self.backbone = backbone + self.head = head + + def construct(self, x): + x = self.backbone(x) + x = self.head(x) + return x + +def mobilenet_v2(backbone, head): + return MobileNetV2Combine(backbone, head) diff --git a/model_zoo/official/cv/mobilenetv2/src/models.py b/model_zoo/official/cv/mobilenetv2/src/models.py new file mode 100644 index 00000000000..fcbbc1bf47e --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2/src/models.py @@ -0,0 +1,138 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import time +import numpy as np +from mindspore import Tensor +from mindspore import nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindspore.nn.loss.loss import _Loss +from mindspore.train.callback import Callback +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2 + +class CrossEntropyWithLabelSmooth(_Loss): + """ + CrossEntropyWith LabelSmooth. + + Args: + smooth_factor (float): smooth factor, default=0. + num_classes (int): num classes + + Returns: + None. + + Examples: + >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) + """ + + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropyWithLabelSmooth, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logit, label): + one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], + self.on_value, self.off_value) + out_loss = self.ce(logit, one_hot_label) + out_loss = self.mean(out_loss, 0) + return out_loss + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( + cb_params.cur_epoch_num - + 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) + +def load_ckpt(network, pretrain_ckpt_path, trainable=True): + """ + incremental_learning or not + """ + param_dict = load_checkpoint(pretrain_ckpt_path) + load_param_into_net(network, param_dict) + if not trainable: + for param in network.get_parameters(): + param.requires_grad = False + +def define_net(args, config): + backbone_net = MobileNetV2Backbone(platform=args.platform) + head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) + net = mobilenet_v2(backbone_net, head_net) + + # load the ckpt file to the network for fine tune or incremental leaning + if args.pretrain_ckpt: + if args.train_method == "fine_tune": + load_ckpt(net, args.pretrain_ckpt) + elif args.train_method == "incremental_learn": + load_ckpt(backbone_net, args.pretrain_ckpt, trainable=False) + elif args.train_method == "train": + pass + else: + raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None") + + return backbone_net, head_net, net diff --git a/model_zoo/official/cv/mobilenetv2/src/utils.py b/model_zoo/official/cv/mobilenetv2/src/utils.py new file mode 100644 index 00000000000..9f54ab5d2d4 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2/src/utils.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import random +import numpy as np + +from mindspore import context +from mindspore import nn +from mindspore.common import dtype as mstype +from mindspore.train.model import ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.communication.management import get_rank, init +from mindspore.dataset import engine as de + +from src.models import Monitor + +def switch_precision(net, data_type, config): + if config.platform == "Ascend": + net.to_float(data_type) + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.to_float(mstype.float32) + +def context_device_init(config): + + if config.platform == "CPU": + context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) + + elif config.platform == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) + init("nccl") + context.set_auto_parallel_context(device_num=get_group_size(), + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + + elif config.platform == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id, + save_graphs=False) + if config.run_distribute: + context.set_auto_parallel_context(device_num=config.rank_size, + parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + init() + else: + raise ValueError("Only support CPU, GPU and Ascend.") + +def set_context(config): + if config.platform == "CPU": + context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, + save_graphs=False) + elif config.platform == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, + device_id=config.device_id, save_graphs=False) + elif config.platform == "GPU": + context.set_context(mode=context.GRAPH_MODE, + device_target=args_opt.platform, save_graphs=False) + +def config_ckpoint(config, lr, step_size): + cb = None + if config.platform in ("CPU", "GPU") or config.rank_id == 0: + cb = [Monitor(lr_init=lr.asnumpy())] + + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_save_dir = config.save_checkpoint_path + + if config.platform == "GPU": + ckpt_save_dir += "ckpt_" + str(get_rank()) + "/" + + ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + return cb + + + +def set_random_seed(seed=1): + random.seed(seed) + np.random.seed(seed) + de.config.set_seed(seed) diff --git a/model_zoo/official/cv/mobilenetv2/train.py b/model_zoo/official/cv/mobilenetv2/train.py index 8c433392d01..0381f8cb7b0 100644 --- a/model_zoo/official/cv/mobilenetv2/train.py +++ b/model_zoo/official/cv/mobilenetv2/train.py @@ -16,263 +16,116 @@ import os import time -import argparse import random import numpy as np -from mindspore import context from mindspore import Tensor -from mindspore import nn -from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.nn import WithLossCell, TrainOneStepCell from mindspore.nn.optim.momentum import Momentum from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits -from mindspore.nn.loss.loss import _Loss -from mindspore.ops import operations as P -from mindspore.ops import functional as F from mindspore.common import dtype as mstype from mindspore.train.model import Model -from mindspore.context import ParallelMode -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.communication.management import init, get_group_size, get_rank -import mindspore.dataset.engine as de +from mindspore.train.serialization import _exec_save_checkpoint -from src.dataset import create_dataset +from src.dataset import create_dataset, extract_features from src.lr_generator import get_lr -from src.config import config_gpu, config_ascend -from src.mobilenetV2 import mobilenet_v2 +from src.config import set_config -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') -parser.add_argument('--device_target', type=str, default=None, help='run device_target') -args_opt = parser.parse_args() - -if args_opt.device_target == "Ascend": - device_id = int(os.getenv('DEVICE_ID', '0')) - rank_id = int(os.getenv('RANK_ID', '0')) - rank_size = int(os.getenv('RANK_SIZE', '1')) - run_distribute = rank_size > 1 - context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - device_id=device_id, save_graphs=False) -elif args_opt.device_target == "GPU": - context.set_context(mode=context.GRAPH_MODE, - device_target="GPU", - save_graphs=False) - init() - context.set_auto_parallel_context(device_num=get_group_size(), - parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) -else: - raise ValueError("Unsupported device target.") - - -class CrossEntropyWithLabelSmooth(_Loss): - """ - CrossEntropyWith LabelSmooth. - - Args: - smooth_factor (float): smooth factor, default=0. - num_classes (int): num classes - - Returns: - None. - - Examples: - >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) - """ - - def __init__(self, smooth_factor=0., num_classes=1000): - super(CrossEntropyWithLabelSmooth, self).__init__() - self.onehot = P.OneHot() - self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) - self.off_value = Tensor(1.0 * smooth_factor / - (num_classes - 1), mstype.float32) - self.ce = nn.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean(False) - self.cast = P.Cast() - - def construct(self, logit, label): - one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], - self.on_value, self.off_value) - out_loss = self.ce(logit, one_hot_label) - out_loss = self.mean(out_loss, 0) - return out_loss - - -class Monitor(Callback): - """ - Monitor loss and time. - - Args: - lr_init (numpy array): train lr - - Returns: - None - - Examples: - >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) - """ - - def __init__(self, lr_init=None): - super(Monitor, self).__init__() - self.lr_init = lr_init - self.lr_init_len = len(lr_init) - - def epoch_begin(self, run_context): - self.losses = [] - self.epoch_time = time.time() - - def epoch_end(self, run_context): - cb_params = run_context.original_args() - - epoch_mseconds = (time.time() - self.epoch_time) * 1000 - per_step_mseconds = epoch_mseconds / cb_params.batch_num - print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, - per_step_mseconds, - np.mean(self.losses))) - - def step_begin(self, run_context): - self.step_time = time.time() - - def step_end(self, run_context): - cb_params = run_context.original_args() - step_mseconds = (time.time() - self.step_time) * 1000 - step_loss = cb_params.net_outputs - - if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): - step_loss = step_loss[0] - if isinstance(step_loss, Tensor): - step_loss = np.mean(step_loss.asnumpy()) - - self.losses.append(step_loss) - cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num - - print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( - cb_params.cur_epoch_num - - 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, - np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) +from src.args import train_parse_args +from src.utils import set_random_seed, context_device_init, switch_precision, config_ckpoint +from src.models import CrossEntropyWithLabelSmooth, define_net +set_random_seed(1) if __name__ == '__main__': - if args_opt.device_target == "GPU": - # train on gpu - print("train args: ", args_opt) - print("cfg: ", config_gpu) + args_opt = train_parse_args() + config = set_config(args_opt) + start = time.time() - # define network - net = mobilenet_v2(num_classes=config_gpu.num_classes, device_target="GPU") - # define loss - if config_gpu.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth(smooth_factor=config_gpu.label_smooth, - num_classes=config_gpu.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - # define dataset - epoch_size = config_gpu.epoch_size - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config_gpu, - device_target=args_opt.device_target, - repeat_num=1, - batch_size=config_gpu.batch_size) + print(f"train args: {args_opt}\ncfg: {config}") + + #set context and device init + context_device_init(config) + + # define network + backbone_net, head_net, net = define_net(args_opt, config) + + # CPU only support "incremental_learn" + if args_opt.train_method == "incremental_learn": + step_size = extract_features(backbone_net, args_opt.dataset_path, config) + net = head_net + + elif args_opt.train_method in ("train", "fine_tune"): + if args_opt.platform == "CPU": + raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".") + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config) step_size = dataset.get_dataset_size() - # resume - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - # get learning rate - loss_scale = FixedLossScaleManager( - config_gpu.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_gpu.lr, - warmup_epochs=config_gpu.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) + # Currently, only Ascend support switch precision. + switch_precision(net, mstype.float16, config) - # define optimization - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, - config_gpu.weight_decay, config_gpu.loss_scale) - # define model + # define loss + if config.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth( + smooth_factor=config.label_smooth, num_classes=config.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + epoch_size = config.epoch_size + + # get learning rate + lr = Tensor(get_lr(global_step=0, + lr_init=0, + lr_end=config.lr_end, + lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, + total_epochs=epoch_size, + steps_per_epoch=step_size)) + + if args_opt.train_method == "incremental_learn": + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) + + network = WithLossCell(net, loss) + network = TrainOneStepCell(net, opt) + network.set_train() + + features_path = args_opt.dataset_path + '_features' + idx_list = list(range(step_size)) + + if os.path.isdir(config.save_checkpoint_path): + os.rename(config.save_checkpoint_path, "{}_{}".format(config.save_checkpoint_path, time.time())) + os.mkdir(config.save_checkpoint_path) + + for epoch in range(epoch_size): + random.shuffle(idx_list) + epoch_start = time.time() + losses = [] + for j in idx_list: + feature = Tensor(np.load(os.path.join(features_path, f"feature_{j}.npy"))) + label = Tensor(np.load(os.path.join(features_path, f"label_{j}.npy"))) + losses.append(network(feature, label).asnumpy()) + epoch_mseconds = (time.time()-epoch_start) * 1000 + per_step_mseconds = epoch_mseconds / step_size + # lr cause to pynative, but cpu doesn't support this mode + # print("\r epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}, lr: {}"\ + # .format(epoch + 1, step_step, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses)), \ + # lr[(epoch+1)*step_size - 1]), end="") + print("\r epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\ + .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ + end="") + if (epoch + 1) % config.save_checkpoint_epochs == 0: + _exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \ + f"mobilenetv2_head_{epoch+1}.ckpt")) + print("total cost {:5.4f} s".format(time.time() - start)) + + elif args_opt.train_method in ("train", "fine_tune"): + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \ + config.weight_decay, config.loss_scale) model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) + cb = config_ckpoint(config, lr, step_size) print("============== Starting Training ==============") - cb = [Monitor(lr_init=lr.asnumpy())] - ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" - if config_gpu.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_gpu.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) - cb += [ckpt_cb] - # begin train model.train(epoch_size, dataset, callbacks=cb) print("============== End Training ==============") - elif args_opt.device_target == "Ascend": - # train on ascend - print("train args: ", args_opt, "\ncfg: ", config_ascend, - "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) - - if run_distribute: - context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, - parameter_broadcast=True, mirror_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) - init() - - epoch_size = config_ascend.epoch_size - net = mobilenet_v2(num_classes=config_ascend.num_classes, device_target="Ascend") - net.to_float(mstype.float16) - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.to_float(mstype.float32) - if config_ascend.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config_ascend, - device_target=args_opt.device_target, - repeat_num=1, - batch_size=config_ascend.batch_size) - step_size = dataset.get_dataset_size() - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - - loss_scale = FixedLossScaleManager( - config_ascend.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_ascend.lr, - warmup_epochs=config_ascend.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum, - config_ascend.weight_decay, config_ascend.loss_scale) - - model = Model(net, loss_fn=loss, optimizer=opt, - loss_scale_manager=loss_scale) - - cb = None - if rank_id == 0: - cb = [Monitor(lr_init=lr.asnumpy())] - if config_ascend.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_ascend.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint( - prefix="mobilenetV2", directory=config_ascend.save_checkpoint_path, config=config_ck) - cb += [ckpt_cb] - model.train(epoch_size, dataset, callbacks=cb) - else: - raise ValueError("Unsupported device_target.")