diff --git a/model_zoo/official/cv/dpn/README.md b/model_zoo/official/cv/dpn/README.md index 19e1d0aa4eb..ac4960f9881 100644 --- a/model_zoo/official/cv/dpn/README.md +++ b/model_zoo/official/cv/dpn/README.md @@ -89,13 +89,13 @@ The DPN models use ImageNet-1K dataset to train and validate in this repository. To train the DPNs, run the shell script `scripts/train_standalone.sh` with the format below: ```shell -sh scripts/train_standalone.sh [device_id] [dataset_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)] +sh scripts/train_standalone.sh [device_id] [train_data_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)] ``` To validate the DPNs, run the shell script `scripts/eval.sh` with the format below: ```shell -sh scripts/eval.sh [device_id] [dataset_dir] [pretrained_ckpt] +sh scripts/eval.sh [device_id] [eval_data_dir] [checkpoint_path] ``` # [Script Description](#contents) @@ -116,6 +116,11 @@ The structure of the files in this repository is shown below. │ ├─ dpn.py // dpns implementation │ ├─ imagenet_dataset.py // dataset processor and provider │ └─ lr_scheduler.py // dpn learning rate scheduler + ├── model_utils + ├──config.py // Parameter config + ├──moxing_adapter.py // modelarts device configuration + ├──device_adapter.py // Device Config + ├──local_adapter.py // local device config ├─ eval.py // evaluation script ├─ train.py // training script ├─ export.py // export model @@ -124,11 +129,11 @@ The structure of the files in this repository is shown below. ## [Script Parameters](#contents) -Parameters for both training and evaluation can be set in `src/config.py` +Parameters for both training and evaluation and export can be set in `default_config.yaml` - Configurations for DPN92 with ImageNet-1K dataset -```python +```default_config.yaml # model config config.image_size = (224,224) # inpute image size config.num_classes = 1000 # dataset class number @@ -174,7 +179,7 @@ config.keep_checkpoint_max = 3 # only keep the last keep_checkpoint Run `scripts/train_standalone.sh` to train the model standalone. The usage of the script is: ```shell -sh scripts/train_standalone.sh [device_id] [dataset_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)] +sh scripts/train_standalone.sh [device_id] [train_data_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)] ``` For example, you can run the shell command below to launch the training procedure. @@ -212,10 +217,16 @@ The model checkpoint will be saved into `[ckpt_path_to_save]`. #### Running on Ascend + For distributed training, a hccl configuration file with JSON format needs to be created in advance. + + Please follow the instructions in the link below: + + . + Run `scripts/train_distributed.sh` to train the model distributed. The usage of the script is: ```text -sh scripts/train_distributed.sh [rank_table] [dataset_dir] [ckpt_path_to_save] [rank_size] [eval_each_epoch] [pretrained_ckpt(optional)] +sh scripts/train_distributed.sh [rank_table] [train_data_dir] [ckpt_path_to_save] [rank_size] [eval_each_epoch] [pretrained_ckpt(optional)] ``` For example, you can run the shell command below to launch the training procedure. @@ -243,7 +254,7 @@ The model checkpoint will be saved into `[ckpt_path_to_save]`. Run `scripts/eval.sh` to evaluate the model with one Ascend processor. The usage of the script is: ```text -sh scripts/eval.sh [device_id] [dataset_dir] [pretrained_ckpt] +sh scripts/eval.sh [device_id] [eval_data_dir] [checkpoint_path] ``` For example, you can run the shell command below to launch the validation procedure. @@ -259,6 +270,58 @@ Evaluation result: {'top_5_accuracy': 0.9449223751600512, 'top_1_accuracy': 0.79 DPN evaluate success! ``` +- running on ModelArts +- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/) + +```python +# Example of using distributed training dpn on modelarts : +# Data set storage method + +# ├── ImageNet_Original # dir +# ├── train # train dir +# ├── train_dataset # train_dataset dir +# ├── train_predtrained # predtrained dir if exists +# ├── eval # eval dir +# ├── eval_dataset # eval dataset dir +# ├── checkpoint # ckpt files dir + +# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。 +# a. set "enable_modelarts=True" 。 +# set "is_distributed=1" +# set "ckpt_path=/cache/train/outputs_imagenet/" +# set "train_data_dir=/cache/data/train/train_dataset/" +# set "pretrained=/cache/data/train/train_predtrained/pred file name" Without pre-training weights train_pretrained="" + +# b. add "enable_modelarts=True" Parameters are on the interface of modearts。 +# Set the parameters required by method a on the modelarts interface +# Note: The path parameter does not need to be quoted + +# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/" +# (3) Set the code path on the modelarts interface "/path/dpn"。 +# (4) Set the model's startup file on the modelarts interface "train.py" 。 +# (5) Set the data path of the model on the modelarts interface ".../ImageNet_Original"(choices ImageNet_Original Folder path) , +# The output path of the model "Output file path" and the log path of the model "Job log path" 。 +# (6) start trainning the model。 + +# Example of using model inference on modelarts +# (1) Place the trained model to the corresponding position of the bucket。 +# (2) chocie a or b。 +# a. set "enable_modelarts=True" 。 +# set "eval_data_dir=/cache/data/eval/eval_dataset/" +# set "checkpoint_path=/cache/data/eval/checkpoint/" + +# b. Add "enable_modelarts=True" parameter on the interface of modearts。 +# Set the parameters required by method a on the modelarts interface +# Note: The path parameter does not need to be quoted + +# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/" +# (4) Set the code path on the modelarts interface "/path/dpn"。 +# (5) Set the model's startup file on the modelarts interface "eval.py" 。 +# (6) Set the data path of the model on the modelarts interface ".../ImageNet_Original"(choices ImageNet_Original Folder path) , +# The output path of the model "Output file path" and the log path of the model "Job log path" 。 +# (7) Start model inference。 +``` + # [Model Description](#contents) ## [Performance](#contents) diff --git a/model_zoo/official/cv/dpn/default_config.yaml b/model_zoo/official/cv/dpn/default_config.yaml new file mode 100644 index 00000000000..1cfc07d28c1 --- /dev/null +++ b/model_zoo/official/cv/dpn/default_config.yaml @@ -0,0 +1,88 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing) +enable_modelarts: False +# url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path" +device_target: "Ascend" +enable_profiling: False + +# ====================================================================================== +# common options +is_distributed: 0 +image_size: [224, 224] +batch_size: 32 +num_parallel_workers: 4 +rank: 0 +group_size: 1 +num_classes: 1000 +label_smooth: False +label_smooth_factor: 0.0 + +# ====================================================================================== +# Training options +backbone: 'dpn92' +is_save_on_master: True + +# training config +pretrained: "" +ckpt_path: "./" +eval_each_epoch: 0 +global_step: 0 +epoch_size: 180 +loss_scale_num: 1024 +momentum: 0.9 +weight_decay: 1e-4 + +# learning rate config +lr_schedule: "warmup" +lr_init: 0.01 +lr_max: 0.1 +factor: 0.1 +epoch_number_to_drop: [5, 15] +warmup_epochs: 5 + +# dataset config +train_data_dir: "" +dataset: "imagenet-1K" +keep_checkpoint_max: 3 + +# ====================================================================================== +# Eval options +eval_data_dir: "" +checkpoint_path: "" + + +# export options +device_id: 0 +width: 224 +height: 224 +file_name: "dpn" +file_format: "MINDIR" + + +--- +# Help description for each configuration +enable_modelarts: "Whether training on modelarts default: False" +data_url: "Url for modelarts" +train_url: "Url for modelarts" +data_path: "The location of input data" +output_pah: "The location of the output file" +device_target: "device id of GPU or Ascend. (Default: None)" +enable_profiling: "Whether enable profiling while training default: False" +train_data_dir: "Imagenet train data dir" +pretrained: "ckpt path to load" +is_distributed: "if multi device" +ckpt_path: "ckpt path to save" +eval_each_epoch: "evaluate on each epoch" +eval_data_dir: "eval data dir" +checkpoint_path: "ckpt path to load" +device_id: "device id" +width: "input width" +height: "input height" +file_name: "dpn output file name" +file_format: "choices [AIR, ONNX, MINDIR]" diff --git a/model_zoo/official/cv/dpn/eval.py b/model_zoo/official/cv/dpn/eval.py index 83593d2b553..4813c6ca324 100644 --- a/model_zoo/official/cv/dpn/eval.py +++ b/model_zoo/official/cv/dpn/eval.py @@ -13,70 +13,53 @@ # limitations under the License. # ============================================================================ """DPN model eval with MindSpore""" -import os -import argparse - from mindspore import context from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.train.model import Model from mindspore.common import set_seed from mindspore.train.serialization import load_checkpoint, load_param_into_net - -from src.dpn import dpns -from src.config import config from src.imagenet_dataset import classification_dataset +from src.dpn import dpns +from src.crossentropy import CrossEntropy +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.device_adapter import get_device_id + + set_seed(1) + + # set context -device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", save_graphs=False, device_id=device_id) + device_target=config.device_target, save_graphs=False, device_id=get_device_id()) -def parse_args(): - """parameters""" - parser = argparse.ArgumentParser('dpn evaluating') - # dataset related - parser.add_argument('--data_dir', type=str, default='', help='eval data dir') - # network related - parser.add_argument('--pretrained', type=str, default='', help='ckpt path to load') - args, _ = parser.parse_known_args() - args.image_size = config.image_size - args.num_classes = config.num_classes - args.batch_size = config.batch_size - args.num_parallel_workers = config.num_parallel_workers - args.backbone = config.backbone - args.loss_scale_num = config.loss_scale_num - args.rank = config.rank - args.group_size = config.group_size - args.dataset = config.dataset - return args - - -def dpn_evaluate(args): +@moxing_wrapper(pre_process=None) +def dpn_evaluate(): # create evaluate dataset - eval_path = os.path.join(args.data_dir, 'val') - eval_dataset = classification_dataset(eval_path, - image_size=args.image_size, - num_parallel_workers=args.num_parallel_workers, - per_batch_size=args.batch_size, + # eval_path = os.path.join(config.eval_data_dir, 'val') + eval_dataset = classification_dataset(config.eval_data_dir, + image_size=config.image_size, + num_parallel_workers=config.num_parallel_workers, + per_batch_size=config.batch_size, max_epoch=1, - rank=args.rank, + rank=config.rank, shuffle=False, - group_size=args.group_size, + group_size=config.group_size, mode='eval') # create network - net = dpns[args.backbone](num_classes=args.num_classes) + net = dpns[config.backbone](num_classes=config.num_classes) # load checkpoint - load_param_into_net(net, load_checkpoint(args.pretrained)) - print("load checkpoint from [{}].".format(args.pretrained)) + load_param_into_net(net, load_checkpoint(config.checkpoint_path)) + print("load checkpoint from [{}].".format(config.checkpoint_path)) # loss - if args.dataset == "imagenet-1K": + if config.dataset == "imagenet-1K": loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') else: - if not args.label_smooth: - args.label_smooth_factor = 0.0 - loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes) + if not config.label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes) # create model model = Model(net, amp_level="O2", keep_batchnorm_fp32=False, loss_fn=loss, @@ -87,5 +70,5 @@ def dpn_evaluate(args): if __name__ == '__main__': - dpn_evaluate(parse_args()) + dpn_evaluate() print('DPN evaluate success!') diff --git a/model_zoo/official/cv/dpn/export.py b/model_zoo/official/cv/dpn/export.py index be8a1d96fda..6d3a9f2bf27 100644 --- a/model_zoo/official/cv/dpn/export.py +++ b/model_zoo/official/cv/dpn/export.py @@ -12,30 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Export DPN""" -import argparse +"""Export DPN +suggest run as python export.py --file_name [filename] --file_format [file format] --checkpoint_path [ckpt path] +""" + import numpy as np - from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export - from src.dpn import dpns -from src.config import config +from src.model_utils.config import config -parser = argparse.ArgumentParser(description="export dpn") -parser.add_argument("--device_id", type=int, default=0, help="device id") -parser.add_argument("--ckpt_file", type=str, required=True, help="dpn ckpt file.") -parser.add_argument("--width", type=int, default=224, help="input width") -parser.add_argument("--height", type=int, default=224, help="input height") -parser.add_argument("--file_name", type=str, default="dpn", help="dpn output file name.") -parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], - default="MINDIR", help="file format") -parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", - help="device target") -args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) -if args.device_target == "Ascend": - context.set_context(device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) +if config.device_target == "Ascend": + context.set_context(device_id=config.device_id) if __name__ == "__main__": # define net @@ -44,9 +33,9 @@ if __name__ == "__main__": net = dpns[backbone](num_classes=num_classes) # load checkpoint - param_dict = load_checkpoint(args.ckpt_file) + param_dict = load_checkpoint(config.checkpoint_path) load_param_into_net(net, param_dict) net.set_train(False) - image = Tensor(np.zeros([config.batch_size, 3, args.height, args.width], np.float32)) - export(net, image, file_name=args.file_name, file_format=args.file_format) + image = Tensor(np.zeros([config.batch_size, 3, config.height, config.width], np.float32)) + export(net, image, file_name=config.file_name, file_format=config.file_format) diff --git a/model_zoo/official/cv/dpn/scripts/eval.sh b/model_zoo/official/cv/dpn/scripts/eval.sh index ea8c415f1b4..58b3b2d26de 100644 --- a/model_zoo/official/cv/dpn/scripts/eval.sh +++ b/model_zoo/official/cv/dpn/scripts/eval.sh @@ -18,5 +18,5 @@ DATA_DIR=$2 PATH_CHECKPOINT=$3 python eval.py \ - --pretrained=$PATH_CHECKPOINT \ - --data_dir=$DATA_DIR > eval_log.txt 2>&1 & + --checkpoint_path=$PATH_CHECKPOINT \ + --eval_data_dir=$DATA_DIR > eval_log.txt 2>&1 & diff --git a/model_zoo/official/cv/dpn/scripts/train_distributed.sh b/model_zoo/official/cv/dpn/scripts/train_distributed.sh index 000134d6547..7e7b90f1c77 100644 --- a/model_zoo/official/cv/dpn/scripts/train_distributed.sh +++ b/model_zoo/official/cv/dpn/scripts/train_distributed.sh @@ -36,6 +36,7 @@ do rm -rf ./train_parallel$i mkdir ./train_parallel$i cp -r ./src ./train_parallel$i + cp ./*yaml ./train_parallel$i cp ./train.py ./train_parallel$i echo "start training for rank $i, device $DEVICE_ID" @@ -47,12 +48,12 @@ do --is_distributed=1 \ --ckpt_path=$SAVE_PATH \ --eval_each_epoch=$EVAL_EACH_EPOCH\ - --data_dir=$DATA_DIR > log.txt 2>&1 & + --train_data_dir=$DATA_DIR > log.txt 2>&1 & echo "python train.py \ --is_distributed=1 \ --ckpt_path=$SAVE_PATH \ --eval_each_epoch=$EVAL_EACH_EPOCH\ - --data_dir=$DATA_DIR > log.txt 2>&1 &" + --train_data_dir=$DATA_DIR > log.txt 2>&1 &" fi if [ $# == 6 ] @@ -62,7 +63,7 @@ do --eval_each_epoch=$EVAL_EACH_EPOCH\ --ckpt_path=$SAVE_PATH \ --pretrained=$PATH_CHECKPOINT \ - --data_dir=$DATA_DIR > log.txt 2>&1 & + --train_data_dir=$DATA_DIR > log.txt 2>&1 & fi cd ../ diff --git a/model_zoo/official/cv/dpn/scripts/train_standalone.sh b/model_zoo/official/cv/dpn/scripts/train_standalone.sh index 883a074b19e..c2ae33552a3 100644 --- a/model_zoo/official/cv/dpn/scripts/train_standalone.sh +++ b/model_zoo/official/cv/dpn/scripts/train_standalone.sh @@ -30,12 +30,12 @@ then --is_distributed=0 \ --ckpt_path=$SAVE_CKPT_PATH\ --eval_each_epoch=$EVAL_EACH_EPOCH\ - --data_dir=$DATA_DIR > train_log.txt 2>&1 & + --train_data_dir=$DATA_DIR > train_log.txt 2>&1 & echo " python train.py \ --is_distributed=0 \ --ckpt_path=$SAVE_CKPT_PATH\ --eval_each_epoch=$EVAL_EACH_EPOCH\ - --data_dir=$DATA_DIR > train_log.txt 2>&1 &" + --train_data_dir=$DATA_DIR > train_log.txt 2>&1 &" fi if [ $# == 5 ] then @@ -43,6 +43,6 @@ then --is_distributed=0 \ --ckpt_path=$SAVE_CKPT_PATH\ --pretrained=$PATH_CHECKPOINT \ - --data_dir=$DATA_DIR\ + --train_data_dir=$DATA_DIR\ --eval_each_epoch=$EVAL_EACH_EPOCH > train_log.txt 2>&1 & fi \ No newline at end of file diff --git a/model_zoo/official/cv/dpn/src/config.py b/model_zoo/official/cv/dpn/src/config.py deleted file mode 100644 index 8d92ab749c2..00000000000 --- a/model_zoo/official/cv/dpn/src/config.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config setting, will be used in train.py and eval.py -""" -from easydict import EasyDict as edict - -# config for dpn,imagenet-1K -config = edict() - -# model config -config.image_size = (224, 224) # inpute image size -config.num_classes = 1000 # dataset class number -config.backbone = 'dpn92' # backbone network -config.is_save_on_master = True - -# parallel config -config.num_parallel_workers = 4 # number of workers to read the data -config.rank = 0 # local rank of distributed -config.group_size = 1 # group size of distributed - -# training config -config.batch_size = 32 # batch_size -config.global_step = 0 # start step of learning rate -config.epoch_size = 180 # epoch_size -config.loss_scale_num = 1024 # loss scale -# optimizer config -config.momentum = 0.9 # momentum (SGD) -config.weight_decay = 1e-4 # weight_decay (SGD) -# learning rate config -config.lr_schedule = 'warmup' # learning rate schedule -config.lr_init = 0.01 # init learning rate -config.lr_max = 0.1 # max learning rate -config.factor = 0.1 # factor of lr to drop -config.epoch_number_to_drop = [5, 15] # learing rate will drop after these epochs -config.warmup_epochs = 5 # warmup epochs in learning rate schedule - -# dataset config -config.dataset = "imagenet-1K" # dataset -config.label_smooth = False # label_smooth -config.label_smooth_factor = 0.0 # label_smooth_factor - -# parameter save config -config.keep_checkpoint_max = 3 # only keep the last keep_checkpoint_max checkpoint diff --git a/model_zoo/official/cv/dpn/src/model_utils/__init__.py b/model_zoo/official/cv/dpn/src/model_utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/official/cv/dpn/src/model_utils/config.py b/model_zoo/official/cv/dpn/src/model_utils/config.py new file mode 100644 index 00000000000..efc856cf0cf --- /dev/null +++ b/model_zoo/official/cv/dpn/src/model_utils/config.py @@ -0,0 +1,130 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License Version 2.0(the "License"); +# you may not use this file except in compliance with the License. +# you may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0# +# +# Unless required by applicable law or agreed to in writing software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==================================================================================== + +"""Parse arguments""" +import os +import ast +import argparse +from pprint import pprint, pformat +import yaml + + +_config_path = '../../default_config.yaml' + + +class Config: + """ + Configuration namespace. Convert dictionary to members + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'): + """ + Parse command line arguments to the configuration according to the default yaml + + Args: + parser: Parent parser + cfg: Base configuration + helper: Helper description + cfg_path: Path to the default yaml config + """ + parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]', + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice, + help=help_description) + else: + parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file + + Args: + yaml_path: Path to the yaml config + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + cfg_choices = {} + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + cfg_choices = {} + elif len(cfgs) == 3: + cfg, cfg_helper, cfg_choices = cfgs + else: + raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml') + print(cfg_helper) + except: + raise ValueError('Failed to parse yaml') + return cfg, cfg_helper, cfg_choices + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments + + Args: + args: command line arguments + cfg: Base configuration + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments + """ + parser = argparse.ArgumentParser(description='default name', add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path), + help='Config file path') + path_args, _ = parser.parse_known_args() + default, helper, choices = parse_yaml(path_args.config_path) + pprint(default) + args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) + final_config = merge(args, default) + return Config(final_config) + +config = get_config() diff --git a/model_zoo/official/cv/dpn/src/model_utils/device_adapter.py b/model_zoo/official/cv/dpn/src/model_utils/device_adapter.py new file mode 100644 index 00000000000..ad8415af0f6 --- /dev/null +++ b/model_zoo/official/cv/dpn/src/model_utils/device_adapter.py @@ -0,0 +1,26 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License Version 2.0(the "License"); +# you may not use this file except in compliance with the License. +# you may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0# +# +# Unless required by applicable law or agreed to in writing software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==================================================================================== + +"""Device adapter for ModelArts""" + +from .config import config +if config.enable_modelarts: + from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + 'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id' +] diff --git a/model_zoo/official/cv/dpn/src/model_utils/local_adapter.py b/model_zoo/official/cv/dpn/src/model_utils/local_adapter.py new file mode 100644 index 00000000000..4ff88c4fba5 --- /dev/null +++ b/model_zoo/official/cv/dpn/src/model_utils/local_adapter.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License Version 2.0(the "License"); +# you may not use this file except in compliance with the License. +# you may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0# +# +# Unless required by applicable law or agreed to in writing software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==================================================================================== + +"""Local adapter""" + +import os + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return 'Local Job' diff --git a/model_zoo/official/cv/dpn/src/model_utils/moxing_adapter.py b/model_zoo/official/cv/dpn/src/model_utils/moxing_adapter.py new file mode 100644 index 00000000000..c2d2282402b --- /dev/null +++ b/model_zoo/official/cv/dpn/src/model_utils/moxing_adapter.py @@ -0,0 +1,124 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License Version 2.0(the "License"); +# you may not use this file except in compliance with the License. +# you may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0# +# +# Unless required by applicable law or agreed to in writing software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==================================================================================== + +"""Moxing adapter for ModelArts""" + +import os +import functools +from mindspore import context +from .config import config + + +_global_syn_count = 0 + + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local + Uploca data from local directory to remote obs in contrast + """ + import moxing as mox + import time + global _global_syn_count + sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count) + _global_syn_count += 1 + + # Each server contains 8 devices as most + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print('from path: ', from_path) + print('to path: ', to_path) + mox.file.copy_parallel(from_path, to_path) + print('===finished data synchronization===') + try: + os.mknod(sync_lock) + except IOError: + pass + print('===save flag===') + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + print('Finish sync data from {} to {}'.format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs + """ + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print('Dataset downloaded: ', os.listdir(config.data_path)) + if config.checkpoint_url: + if not os.path.exists(config.load_path): + # os.makedirs(config.load_path) + print('=' * 20 + 'makedirs') + if os.path.isdir(config.load_path): + print('=' * 20 + 'makedirs success') + else: + print('=' * 20 + 'makedirs fail') + sync_data(config.checkpoint_url, config.load_path) + print('Preload downloaded: ', os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print('Workspace downloaded: ', os.listdir(config.output_path)) + + context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + run_func(*args, **kwargs) + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print('Start to copy output directory') + sync_data(config.output_path, config.train_url) + return wrapped_func + return wrapper diff --git a/model_zoo/official/cv/dpn/train.py b/model_zoo/official/cv/dpn/train.py index 7a5e57eeafb..a356925c3c7 100644 --- a/model_zoo/official/cv/dpn/train.py +++ b/model_zoo/official/cv/dpn/train.py @@ -14,8 +14,7 @@ # ============================================================================ """DPN model train with MindSpore""" import os -import argparse - +from ast import literal_eval from mindspore import context from mindspore import Tensor from mindspore.nn import SGD @@ -24,145 +23,112 @@ from mindspore.train.model import Model from mindspore.context import ParallelMode from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.communication.management import init, get_group_size, get_rank +from mindspore.communication.management import init, get_group_size from mindspore.common import set_seed from mindspore.train.serialization import load_checkpoint, load_param_into_net - from src.imagenet_dataset import classification_dataset from src.dpn import dpns -from src.config import config from src.lr_scheduler import get_lr_drop, get_lr_warmup from src.crossentropy import CrossEntropy from src.callbacks import SaveCallback +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num -device_id = int(os.getenv('DEVICE_ID')) set_seed(1) -def parse_args(): - """parameters""" - parser = argparse.ArgumentParser('dpn training') - - # dataset related - parser.add_argument('--data_dir', type=str, default='', help='Imagenet data dir') - # network related - parser.add_argument('--pretrained', default='', type=str, help='ckpt path to load') - # distributed related - parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') - parser.add_argument('--ckpt_path', type=str, default='', help='ckpt path to save') - parser.add_argument('--eval_each_epoch', type=int, default=0, help='evaluate on each epoch') - args, _ = parser.parse_known_args() - args.image_size = config.image_size - args.num_classes = config.num_classes - args.lr_init = config.lr_init - args.lr_max = config.lr_max - args.factor = config.factor - args.global_step = config.global_step - args.epoch_number_to_drop = config.epoch_number_to_drop - args.epoch_size = config.epoch_size - args.warmup_epochs = config.warmup_epochs - args.weight_decay = config.weight_decay - args.momentum = config.momentum - args.batch_size = config.batch_size - args.num_parallel_workers = config.num_parallel_workers - args.backbone = config.backbone - args.loss_scale_num = config.loss_scale_num - args.is_save_on_master = config.is_save_on_master - args.rank = config.rank - args.group_size = config.group_size - args.dataset = config.dataset - args.label_smooth = config.label_smooth - args.label_smooth_factor = config.label_smooth_factor - args.keep_checkpoint_max = config.keep_checkpoint_max - args.lr_schedule = config.lr_schedule - return args +def modelarts_pre_process(): + pass -def dpn_train(args): +@moxing_wrapper(pre_process=modelarts_pre_process) +def dpn_train(): # init context + device_id = get_device_id() context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", save_graphs=False, device_id=device_id) + device_target=config.device_target, save_graphs=False, device_id=device_id) # init distributed - if args.is_distributed: + if config.is_distributed: init() - args.rank = get_rank() - args.group_size = get_group_size() - context.set_auto_parallel_context(device_num=args.group_size, parallel_mode=ParallelMode.DATA_PARALLEL, + config.rank = get_rank_id() + config.group_size = get_group_size() + config.device_num = get_device_num() + context.set_auto_parallel_context(device_num=config.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) # select for master rank save ckpt or all rank save, compatible for model parallel - args.rank_save_ckpt_flag = 0 - if args.is_save_on_master: - if args.rank == 0: - args.rank_save_ckpt_flag = 1 + config.rank_save_ckpt_flag = 0 + if config.is_save_on_master: + if config.rank == 0: + config.rank_save_ckpt_flag = 1 else: - args.rank_save_ckpt_flag = 1 + config.rank_save_ckpt_flag = 1 # create dataset - args.train_dir = os.path.join(args.data_dir, 'train') - args.eval_dir = os.path.join(args.data_dir, 'val') - train_dataset = classification_dataset(args.train_dir, - image_size=args.image_size, - per_batch_size=args.batch_size, + train_dataset = classification_dataset(config.train_data_dir, + image_size=config.image_size, + per_batch_size=config.batch_size, max_epoch=1, - num_parallel_workers=args.num_parallel_workers, + num_parallel_workers=config.num_parallel_workers, shuffle=True, - rank=args.rank, - group_size=args.group_size) - if args.eval_each_epoch: + rank=config.rank, + group_size=config.group_size) + if config.eval_each_epoch: print("create eval_dataset") - eval_dataset = classification_dataset(args.eval_dir, - image_size=args.image_size, - per_batch_size=args.batch_size, + eval_dataset = classification_dataset(config.eval_data_dir, + image_size=config.image_size, + per_batch_size=config.batch_size, max_epoch=1, - num_parallel_workers=args.num_parallel_workers, + num_parallel_workers=config.num_parallel_workers, shuffle=False, - rank=args.rank, - group_size=args.group_size, + rank=config.rank, + group_size=config.group_size, mode='eval') train_step_size = train_dataset.get_dataset_size() # choose net - net = dpns[args.backbone](num_classes=args.num_classes) + net = dpns[config.backbone](num_classes=config.num_classes) # load checkpoint - if os.path.isfile(args.pretrained): + if os.path.isfile(config.pretrained): print("load ckpt") - load_param_into_net(net, load_checkpoint(args.pretrained)) + load_param_into_net(net, load_checkpoint(config.pretrained)) # learing rate schedule - if args.lr_schedule == 'drop': + if config.lr_schedule == 'drop': print("lr_schedule:drop") - lr = Tensor(get_lr_drop(global_step=args.global_step, - total_epochs=args.epoch_size, + lr = Tensor(get_lr_drop(global_step=config.global_step, + total_epochs=config.epoch_size, steps_per_epoch=train_step_size, - lr_init=args.lr_init, - factor=args.factor)) - elif args.lr_schedule == 'warmup': + lr_init=config.lr_init, + factor=config.factor)) + elif config.lr_schedule == 'warmup': print("lr_schedule:warmup") - lr = Tensor(get_lr_warmup(global_step=args.global_step, - total_epochs=args.epoch_size, + lr = Tensor(get_lr_warmup(global_step=config.global_step, + total_epochs=config.epoch_size, steps_per_epoch=train_step_size, - lr_init=args.lr_init, - lr_max=args.lr_max, - warmup_epochs=args.warmup_epochs)) + lr_init=config.lr_init, + lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs)) # optimizer + config.weight_decay = literal_eval(config.weight_decay) opt = SGD(net.trainable_params(), lr, - momentum=args.momentum, - weight_decay=args.weight_decay, - loss_scale=args.loss_scale_num) + momentum=config.momentum, + weight_decay=config.weight_decay, + loss_scale=config.loss_scale_num) # loss scale - loss_scale = FixedLossScaleManager(args.loss_scale_num, False) + loss_scale = FixedLossScaleManager(config.loss_scale_num, False) # loss function - if args.dataset == "imagenet-1K": + if config.dataset == "imagenet-1K": print("Use SoftmaxCrossEntropyWithLogits") loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') else: - if not args.label_smooth: - args.label_smooth_factor = 0.0 + if not config.label_smooth: + config.label_smooth_factor = 0.0 print("Use Label_smooth CrossEntropy") - loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes) + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes) # create model model = Model(net, amp_level="O2", keep_batchnorm_fp32=False, @@ -175,19 +141,19 @@ def dpn_train(args): loss_cb = LossMonitor() time_cb = TimeMonitor(data_size=train_step_size) cb = [loss_cb, time_cb] - if args.rank_save_ckpt_flag: - if args.eval_each_epoch: - save_cb = SaveCallback(model, eval_dataset, args.ckpt_path) + if config.rank_save_ckpt_flag: + if config.eval_each_epoch: + save_cb = SaveCallback(model, eval_dataset, config.ckpt_path) cb += [save_cb] else: config_ck = CheckpointConfig(save_checkpoint_steps=train_step_size, - keep_checkpoint_max=args.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="dpn", directory=args.ckpt_path, config=config_ck) + keep_checkpoint_max=config.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="dpn", directory=config.ckpt_path, config=config_ck) cb.append(ckpoint_cb) # train model - model.train(args.epoch_size, train_dataset, callbacks=cb) + model.train(config.epoch_size, train_dataset, callbacks=cb) if __name__ == '__main__': - dpn_train(parse_args()) + dpn_train() print('DPN training success!')