diff --git a/model_zoo/research/cv/centernet/README.md b/model_zoo/research/cv/centernet/README.md index 327e5684e41..a609aa7931c 100644 --- a/model_zoo/research/cv/centernet/README.md +++ b/model_zoo/research/cv/centernet/README.md @@ -97,7 +97,7 @@ Dataset used: [COCO2017](https://cocodataset.org/) pip install mmcv==0.2.14 ``` - And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: + And change the COCO_ROOT and other settings you need in `default_config.yaml`. The directory structure is as follows: ```path . @@ -115,32 +115,90 @@ Dataset used: [COCO2017](https://cocodataset.org/) # [Quick Start](#contents) -After installing MindSpore via the official website, you can start training and evaluation as follows: +- running on local -Note: 1.the first run of training will generate the mindrecord file, which will take a long time. - 2.MINDRECORD_DATASET_PATH is the mindrecord dataset directory. - 3.LOAD_CHECKPOINT_PATH is the pretrained checkpoint file directory, if no just set "" - 4.RUN_MODE support validation and testing, set to be "val"/"test" + After installing MindSpore via the official website, you can start training and evaluation as follows: -```shell -# create dataset in mindrecord format -bash scripts/convert_dataset_to_mindrecord.sh [COCO_DATASET_DIR] [MINDRECORD_DATASET_DIR] + Note: 1.the first run of training will generate the mindrecord file, which will take a long time. + 2.MINDRECORD_DATASET_PATH is the mindrecord dataset directory. + 3.For `train.py`, LOAD_CHECKPOINT_PATH is the pretrained checkpoint file directory, if no just set "". + 4.For `eval.py`, LOAD_CHECKPOINT_PATH is the checkpoint to be evaluated. + 5.RUN_MODE support validation and testing, set to be "val"/"test" -# standalone training on Ascend -bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional) + ```shell + # create dataset in mindrecord format + bash scripts/convert_dataset_to_mindrecord.sh [COCO_DATASET_DIR] [MINDRECORD_DATASET_DIR] -# standalone training on CPU -bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional) + # standalone training on Ascend + bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional) -# distributed training on Ascend -bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE] [LOAD_CHECKPOINT_PATH](optional) + # standalone training on CPU + bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional) -# eval on Ascend -bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH] + # distributed training on Ascend + bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE] [LOAD_CHECKPOINT_PATH](optional) -# eval on CPU -bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH] -``` + # eval on Ascend + bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH] + + # eval on CPU + bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH] + ``` + +- running on ModelArts + + If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows + + - Training with single cards on ModelArts + + ```python + # (1) Upload the code folder to S3 bucket. + # (2) Click to "create training task" on the website UI interface. + # (3) Set the code directory to "/{path}/centernet" on the website UI interface. + # (4) Set the startup file to /{path}/centernet/train.py" on the website UI interface. + # (5) Perform a or b. + # a. setting parameters in /{path}/centernet/default_config.yaml. + # 1. Set ”enable_modelarts: True“ + # 2. Set “epoch_size: 350” + # 3. Set “distribute: 'true'” + # 4. Set “data_sink_steps: 50” + # 5. Set “save_checkpoint_path: ./checkpoints” + # b. adding on the website UI interface. + # 1. Add ”enable_modelarts=True“ + # 2. Add “epoch_size=350” + # 3. Add “distribute=true” + # 4. Add “data_sink_steps=50” + # 5. Add “save_checkpoint_path=./checkpoints” + # (6) Upload the mindrecdrd dataset to S3 bucket. + # (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path. + # (8) Set the "Output file path" and "Job log path" to your path on the website UI interface. + # (9) Under the item "resource pool selection", select the specification of single cards. + # (10) Create your job. + ``` + + - evaluating with single card on ModelArts + + ```python + # (1) Upload the code folder 'centernet' to S3 bucket. + # (2) Git clone https://github.com/xingyizhou/CenterNet.git on local, and put the folder 'CenterNet' under the folder 'centernet' on s3 bucket. + # (3) Click to "create training task" on the website UI interface. + # (4) Set the code directory to "/{path}/centernet" on the website UI interface. + # (5) Set the startup file to /{path}/centernet/eval.py" on the website UI interface. + # (6) Perform a or b. + # a. setting parameters in /{path}/centernet/default_config.yaml. + # 1. Set ”enable_modelarts: True“ + # 2. Set “run_mode: 'val'” + # 3. Set “load_checkpoint_path: ./{path}/*.ckpt”('load_checkpoint_path' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.) + # b. adding on the website UI interface. + # 1. Add ”enable_modelarts=True“ + # 2. Add “run_mode=val” + # 3. Add “load_checkpoint_path=./{path}/*.ckpt”('load_checkpoint_path' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.) + # (7) Upload the dataset(not mindrecord format) to S3 bucket. + # (8) Check the "data storage location" on the website UI interface and set the "Dataset path" path. + # (9) Set the "Output file path" and "Job log path" to your path on the website UI interface. + # (10) Under the item "resource pool selection", select the specification of a single card. + # (11) Create your job. + ``` # [Script Description](#contents) @@ -154,12 +212,13 @@ bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_P ├── eval.py // testing and evaluation outputs ├── export.py // convert mindspore model to air model ├── README.md // descriptions about CenterNet + ├── default_config.yaml // parameter configuration ├── scripts - │ ├── ascend_distributed_launcher - │ │ ├──__init__.py - │ │ ├──hyper_parameter_config.ini // hyper parameter for distributed training - │ │ ├──get_distribute_train_cmd.py // script for distributed training - │ │ ├──README.md + │ ├──ascend_distributed_launcher + │ │ ├──__init__.py + │ │ ├──hyper_parameter_config.ini // hyper parameter for distributed training + │ │ ├──get_distribute_train_cmd.py // script for distributed training + │ │ └──README.md │ ├──convert_dataset_to_mindrecord.sh // shell script for converting coco type dataset to mindrecord │ ├──run_standalone_train_ascend.sh // shell script for standalone training on ascend │ ├──run_distributed_train_ascend.sh // shell script for distributed training on ascend @@ -167,10 +226,14 @@ bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_P │ ├──run_standalone_train_cpu.sh // shell script for standalone training on cpu │ ├──run_standalone_eval_cpu.sh // shell script for standalone evaluation on cpu └── src + ├──model_utils + │ ├──config.py // parsing parameter configuration file of "*.yaml" + │ ├──device_adapter.py // local or ModelArts training + │ ├──local_adapter.py // get related environment variables on local + │ └──moxing_adapter.py // get related environment variables abd transfer data on ModelArts ├──__init__.py ├──centernet_pose.py // centernet networks, training entry ├──dataset.py // generate dataloader and data processing entry - ├──config.py // centernet unique configs ├──dcn_v2.py // deformable convolution operator v2 ├──decode.py // decode the head features ├──backbone_dla.py // deep layer aggregation backbone @@ -255,34 +318,34 @@ options: ### Options and Parameters -Parameters for training and evaluation can be set in file `config.py` and `finetune_eval_config.py` respectively. +Parameters for training and evaluation can be set in file `default_config.yaml`. #### Options -```text -config for training. - batch_size batch size of input dataset: N, default is 32 - loss_scale_value initial value of loss scale: N, default is 1024 - optimizer optimizer used in the network: Adam, default is Adam - lr_schedule schedules to get the learning rate +```python +train_config: + batch_size: 32 // batch size of input dataset: N, default is 32 + loss_scale_value: 1024 // initial value of loss scale: N, default is 1024 + optimizer: 'Adam' // optimizer used in the network: Adam, default is Adam + lr_schedule: 'MultiDecay' // schedules to get the learning rate ``` ```text -config for evaluation. - soft_nms nms after decode: True | False, default is True - keep_res keep original or fix resolution: True | False, default is False - multi_scales use multi-scales of image: List, default is [1.0] - pad pad size when keep original resolution, default is 31 - K number of bboxes to be computed by TopK, default is 100 - score_thresh threshold of score when visualize image and annotation info +eval_config: + soft_nms: True // nms after decode: True | False, default is True + keep_res: c // keep original or fix resolution: True | False, default is False + multi_scales: [1.0] // use multi-scales of image: List, default is [1.0] + pad: 31 // pad size when keep original resolution, default is 31 + K: 100 // number of bboxes to be computed by TopK, default is 100 + score_thresh: 0.3 // threshold of score when visualize image and annotation info ``` ```text config for export. - input_res input resolution of the model air, default is [512, 512] - ckpt_file checkpoint file, default is "./ckkt_file.ckpt" - export_format the exported format of model air, default is MINDIR - export_name the exported file name, default is "CentNet_MultiPose" + input_res: dataset_config.input_res // input resolution of the model air, default is [512, 512] + ckpt_file: "./ckpt_file.ckpt" // checkpoint file, default is "./ckkt_file.ckpt" + export_format: "MINDIR" // the exported format of model air, default is MINDIR + export_name: "CenterNet_MultiPose" // the exported file name, default is "CentNet_MultiPose" ``` #### Parameters @@ -462,9 +525,36 @@ overall performance on coco2017 test-dev dataset If you want to infer the network on Ascend 310, you should convert the model to AIR: -```python -python export.py [DEVICE_ID] -``` +- Export on local + + ```python + python export.py --device_id [DEVICE_ID] --export_format MINDIR --export_load_ckpt [CKPT_FILE__PATH] --export_name [EXPORT_FILE_NAME] + ``` + +- Export on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start as follows) + + ```python + # (1) Upload the code folder to S3 bucket. + # (2) Click to "create training task" on the website UI interface. + # (3) Set the code directory to "/{path}/centernet" on the website UI interface. + # (4) Set the startup file to /{path}/centernet/export.py" on the website UI interface. + # (5) Perform a or b. + # a. setting parameters in /{path}/centernet/default_config.yaml. + # 1. Set ”enable_modelarts: True“ + # 2. Set “export_load_ckpt: ./{path}/*.ckpt”('export_load_ckpt' indicates the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.) + # 3. Set ”export_name: centernet“ + # 4. Set ”export_format:MINDIR“ + # b. adding on the website UI interface. + # 1. Add ”enable_modelarts=True“ + # 2. Add “export_load_ckpt=./{path}/*.ckpt”('export_load_ckpt' indicates the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.) + # 3. Add ”export_name=centernet“ + # 4. Add ”export_format=MINDIR“ + # (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (This step is useless, but necessary.). + # (8) Set the "Output file path" and "Job log path" to your path on the website UI interface. + # (9) Under the item "resource pool selection", select the specification of a single card. + # (10) Create your job. + # You will see centernet.mindir under {Output file path}. + ``` # [Model Description](#contents) diff --git a/model_zoo/research/cv/centernet/default_config.yaml b/model_zoo/research/cv/centernet/default_config.yaml new file mode 100644 index 00000000000..f2804def23a --- /dev/null +++ b/model_zoo/research/cv/centernet/default_config.yaml @@ -0,0 +1,176 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path" +device_target: "Ascend" +enable_profiling: False + +# ============================================================================== +# prepare *.mindrecord* data +coco_data_dir: "" +mindrecord_dir: "" # also used by train.py +mindrecord_prefix: "coco_hp.train.mind" +# train related +visual_image: "false" +save_result_dir: "" +device_id: 0 +device_num: 1 + +distribute: 'false' +need_profiler: "false" +profiler_path: "./profiler" +epoch_size: 1 +train_steps: -1 +enable_save_ckpt: "true" +do_shuffle: "true" +enable_data_sink: "true" +data_sink_steps: 1 +save_checkpoint_path: "" +load_checkpoint_path: "" +save_checkpoint_steps: 1000 +save_checkpoint_num: 1 +# test related +data_dir: "" +run_mode: "test" +enable_eval: "true" +# export related +export_load_ckpt: '' +export_format: '' +export_name: '' + +dataset_config: + num_classes: 1 + num_joints: 17 + max_objs: 32 + input_res: [512, 512] + output_res: [128, 128] + rand_crop: False + shift: 0.1 + scale: 0.4 + aug_rot: 0.0 + rotate: 0 + flip_prop: 0.5 + mean: np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32) + std: np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32) + flip_idx: [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + edges: [[0, 1], [0, 2], [1, 3], [2, 4], [4, 6], [3, 5], [5, 6], + [5, 7], [7, 9], [6, 8], [8, 10], [6, 12], [5, 11], [11, 12], + [12, 14], [14, 16], [11, 13], [13, 15]] + eig_val: np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32) + eig_vec: np.array([[-0.58752847, -0.69563484, 0.41340352], + [-0.5832747, 0.00994535, -0.81221408], + [-0.56089297, 0.71832671, 0.41158938]], dtype=np.float32) + categories: [{"supercategory": "person", + "id": 1, + "name": "person", + "keypoints": ["nose", "left_eye", "right_eye", "left_ear", "right_ear", + "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", + "left_wrist", "right_wrist", "left_hip", "right_hip", + "left_knee", "right_knee", "left_ankle", "right_ankle"], + "skeleton": [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], + [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], + [2, 4], [3, 5], [4, 6], [5, 7]]}] + +net_config: + down_ratio: 4 + last_level: 6 + final_kernel: 1 + stage_levels: [1, 1, 1, 2, 2, 1] + stage_channels: [16, 32, 64, 128, 256, 512] + head_conv: 256 + dense_hp: True + hm_hp: True + reg_hp_offset: True + reg_offset: True + hm_weight: 1 + off_weight: 1 + wh_weight: 0.1 + hp_weight: 1 + hm_hp_weight: 1 + mse_loss: False + reg_loss: 'l1' + +train_config: + batch_size: 32 + loss_scale_value: 1024 + optimizer: 'Adam' + lr_schedule: 'MultiDecay' + Adam: + weight_decay: 0.0 + decay_filter: "lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma')" + PolyDecay: + learning_rate: 0.00012 # 1.2e-4 + end_learning_rate: 0.0000005 # 5e-7 + power: 5.0 + eps: 0.0000001 # 1e-7 + warmup_steps: 2000 + MultiDecay: + learning_rate: 0.00012 # 1.2e-4 + eps: 0.0000001 # 1e-7 + warmup_steps: 2000 + multi_epochs: [270, 300] + factor: 10 + +eval_config: + soft_nms: True + keep_res: True + multi_scales: [1.0] + pad: 31 + K: 100 + score_thresh: 0.3 + +export_config: + input_res: dataset_config.input_res + ckpt_file: "./ckpt_file.ckpt" + export_format: "MINDIR" + export_name: "CenterNet_MultiPose" + +--- + +# Help description for each configuration +enable_modelarts: "Whether training on modelarts, default: False" +data_url: "Url for modelarts" +train_url: "Url for modelarts" +data_path: "The location of the input data." +output_path: "The location of the output file." +device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend." +enable_profiling: 'Whether enable profiling while training, default: False' + +distribute: "Run distribute, default is false." +need_profiler: "Profiling to parsing runtime info, default is false." +profiler_path: "The path to save profiling data" +epoch_size: "Epoch size, default is 1." +train_steps: "Training Steps, default is -1, i.e. run all steps according to epoch number." +device_id: "Device id, default is 0." +device_num: "Use device nums, default is 1." +enable_save_ckpt: "Enable save checkpoint, default is true." +do_shuffle: "Enable shuffle for dataset, default is true." +enable_data_sink: "Enable data sink, default is true." +data_sink_steps: "Sink steps for each epoch, default is 1." +save_checkpoint_path: "Save checkpoint path" +load_checkpoint_path: "Load checkpoint file path" +save_checkpoint_steps: "Save checkpoint steps, default is 1000." +save_checkpoint_num: "Save checkpoint numbers, default is 1." +mindrecord_dir: "Mindrecord dataset files directory" +mindrecord_prefix: "Prefix of MindRecord dataset filename." +visual_image: "Visulize the ground truth and predicted image" +save_result_dir: "The path to save the predict results" + +data_dir: "Dataset directory, the absolute image path is joined by the data_dir, and the relative path in anno_path" +run_mode: "test or validation, default is test." +enable_eval: "Whether evaluate accuracy after prediction" +--- + +device_target: ['Ascend', 'CPU'] +distribute: ["true", "false"] +need_profiler: ["true", "false"] +enable_save_ckpt: ["true", "false"] +do_shuffle: ["true", "false"] +enable_data_sink: ["true", "false"] +export_format: ["MINDIR"] \ No newline at end of file diff --git a/model_zoo/research/cv/centernet/eval.py b/model_zoo/research/cv/centernet/eval.py index 11eb970853c..023fc0a0048 100644 --- a/model_zoo/research/cv/centernet/eval.py +++ b/model_zoo/research/cv/centernet/eval.py @@ -20,7 +20,6 @@ import os import time import copy import json -import argparse import cv2 from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval @@ -31,53 +30,62 @@ import mindspore.log as logger from src import COCOHP, CenterNetMultiPoseEval from src import convert_eval_format, post_process, merge_outputs from src import visual_image -from src.config import dataset_config, net_config, eval_config +from src.model_utils.config import config, dataset_config, net_config, eval_config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.device_adapter import get_device_id _current_dir = os.path.dirname(os.path.realpath(__file__)) -parser = argparse.ArgumentParser(description='CenterNet evaluation') -parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'], - help='device where the code will be implemented. (Default: Ascend)') -parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") -parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") -parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, " - "the absolute image path is joined by the data_dir " - "and the relative path in anno_path") -parser.add_argument("--run_mode", type=str, default="test", help="test or validation, default is test.") -parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image") -parser.add_argument("--enable_eval", type=str, default="true", help="Whether evaluate accuracy after prediction") -parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results") -args_opt = parser.parse_args() +def modelarts_pre_process(): + '''modelarts pre process function.''' + try: + from nms import soft_nms_39 + print('soft_nms_39_attributes: {}'.format(soft_nms_39.__dir__())) + except ImportError: + print('NMS not installed! trying installing...\n') + cur_path = os.path.dirname(os.path.abspath(__file__)) + os.system('cd {}/CenterNet/src/lib/external/ && make && python setup.py install && cd - '.format(cur_path)) + try: + from nms import soft_nms_39 + print('soft_nms_39_attributes: {}'.format(soft_nms_39.__dir__())) + except ImportError: + print('Installing failed! check if the folder "./CenterNet" exists.') + else: + print('Install nms successfully') + config.data_dir = config.data_path + config.load_checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config.load_checkpoint_path) + +@moxing_wrapper(pre_process=modelarts_pre_process) def predict(): ''' Predict function ''' - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - if args_opt.device_target == "Ascend": - context.set_context(device_id=args_opt.device_id) + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + if config.device_target == "Ascend": + context.set_context(device_id=get_device_id()) enable_nms_fp16 = True else: enable_nms_fp16 = False - logger.info("Begin creating {} dataset".format(args_opt.run_mode)) - coco = COCOHP(dataset_config, run_mode=args_opt.run_mode, net_opt=net_config, - enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,) - coco.init(args_opt.data_dir, keep_res=eval_config.keep_res) + logger.info("Begin creating {} dataset".format(config.run_mode)) + coco = COCOHP(dataset_config, run_mode=config.run_mode, net_opt=net_config, + enable_visual_image=(config.visual_image == "true"), save_path=config.save_result_dir,) + coco.init(config.data_dir, keep_res=eval_config.keep_res) dataset = coco.create_eval_dataset() net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.K, enable_nms_fp16) net_for_eval.set_train(False) - param_dict = load_checkpoint(args_opt.load_checkpoint_path) + param_dict = load_checkpoint(config.load_checkpoint_path) load_param_into_net(net_for_eval, param_dict) # save results - save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode) + save_path = os.path.join(config.save_result_dir, config.run_mode) if not os.path.exists(save_path): os.makedirs(save_path) - if args_opt.visual_image == "true": + if config.visual_image == "true": save_pred_image_path = os.path.join(save_path, "pred_image") if not os.path.exists(save_pred_image_path): os.makedirs(save_pred_image_path) @@ -119,25 +127,25 @@ def predict(): pred_annos["images"].append(image_info) for image_anno in pred_json["annotations"]: pred_annos["annotations"].append(image_anno) - if args_opt.visual_image == "true": + if config.visual_image == "true": img_file = os.path.join(coco.image_path, gt_image_info[0]['file_name']) gt_image = cv2.imread(img_file) - if args_opt.run_mode != "test": + if config.run_mode != "test": annos = coco.coco.loadAnns(coco.anns[image_id]) visual_image(copy.deepcopy(gt_image), annos, save_gt_image_path) anno = copy.deepcopy(pred_json["annotations"]) visual_image(gt_image, anno, save_pred_image_path, score_threshold=eval_config.score_thresh) # save results - save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode) + save_path = os.path.join(config.save_result_dir, config.run_mode) if not os.path.exists(save_path): os.makedirs(save_path) - pred_anno_file = os.path.join(save_path, '{}_pred_result.json').format(args_opt.run_mode) + pred_anno_file = os.path.join(save_path, '{}_pred_result.json').format(config.run_mode) json.dump(pred_annos, open(pred_anno_file, 'w')) - pred_res_file = os.path.join(save_path, '{}_pred_eval.json').format(args_opt.run_mode) + pred_res_file = os.path.join(save_path, '{}_pred_eval.json').format(config.run_mode) json.dump(pred_annos["annotations"], open(pred_res_file, 'w')) - if args_opt.run_mode != "test" and args_opt.enable_eval: + if config.run_mode != "test" and config.enable_eval: run_eval(coco.annot_path, pred_res_file) diff --git a/model_zoo/research/cv/centernet/export.py b/model_zoo/research/cv/centernet/export.py index cf63d4973da..e33fb639d29 100644 --- a/model_zoo/research/cv/centernet/export.py +++ b/model_zoo/research/cv/centernet/export.py @@ -16,21 +16,26 @@ Export CenterNet mindir model. """ -import argparse +import os import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from src import CenterNetMultiPoseEval -from src.config import net_config, eval_config, export_config +from src.model_utils.config import config, net_config, eval_config, export_config +from src.model_utils.moxing_adapter import moxing_wrapper -parser = argparse.ArgumentParser(description='centernet export') -parser.add_argument("--device_id", type=int, default=0, help="Device id") -args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) +def modelarts_pre_process(): + '''modelarts pre process function.''' + export_config.ckpt_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), export_config.ckpt_file) + export_config.export_name = os.path.join(config.output_path, export_config.export_name) -if __name__ == '__main__': + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_export(): + '''export function''' + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=config.device_id) net = CenterNetMultiPoseEval(net_config, eval_config.K) net.set_train(False) @@ -42,3 +47,7 @@ if __name__ == '__main__': input_data = Tensor(np.random.uniform(-1.0, 1.0, size=input_shape).astype(np.float32)) export(net, input_data, file_name=export_config.export_name, file_format=export_config.export_format) + + +if __name__ == '__main__': + run_export() diff --git a/model_zoo/research/cv/centernet/scripts/run_standalone_eval_ascend.sh b/model_zoo/research/cv/centernet/scripts/run_standalone_eval_ascend.sh index 7b371f43532..3a38a3bd1b0 100644 --- a/model_zoo/research/cv/centernet/scripts/run_standalone_eval_ascend.sh +++ b/model_zoo/research/cv/centernet/scripts/run_standalone_eval_ascend.sh @@ -29,14 +29,20 @@ PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) CUR_DIR=`pwd` export GLOG_log_dir=${CUR_DIR}/ms_log export GLOG_logtostderr=0 +export DEVICE_ID=$DEVICE_ID # install nms module from third party if python -c "import nms" > /dev/null 2>&1 then echo "NMS module already exits, no need reinstall." else - echo "NMS module was not found, install it now..." - git clone https://github.com/xingyizhou/CenterNet.git + if [ -f './CenterNet' ] + then + echo "NMS module was not found, but has been downloaded" + else + echo "NMS module was not found, install it now..." + git clone https://github.com/xingyizhou/CenterNet.git + fi cd CenterNet/src/lib/external/ || exit make python setup.py install diff --git a/model_zoo/research/cv/centernet/scripts/run_standalone_eval_cpu.sh b/model_zoo/research/cv/centernet/scripts/run_standalone_eval_cpu.sh index ccbef46350a..b01a50e50d6 100644 --- a/model_zoo/research/cv/centernet/scripts/run_standalone_eval_cpu.sh +++ b/model_zoo/research/cv/centernet/scripts/run_standalone_eval_cpu.sh @@ -34,8 +34,13 @@ if python -c "import nms" > /dev/null 2>&1 then echo "NMS module already exits, no need reinstall." else - echo "NMS module was not found, install it now..." - git clone https://github.com/xingyizhou/CenterNet.git + if [ -f './CenterNet' ] + then + echo "NMS module was not found, but has been downloaded" + else + echo "NMS module was not found, install it now..." + git clone https://github.com/xingyizhou/CenterNet.git + fi cd CenterNet/src/lib/external/ || exit make python setup.py install diff --git a/model_zoo/research/cv/centernet/scripts/run_standalone_train_ascend.sh b/model_zoo/research/cv/centernet/scripts/run_standalone_train_ascend.sh index d67face326d..2e8efc704c6 100644 --- a/model_zoo/research/cv/centernet/scripts/run_standalone_train_ascend.sh +++ b/model_zoo/research/cv/centernet/scripts/run_standalone_train_ascend.sh @@ -35,6 +35,7 @@ PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) CUR_DIR=`pwd` export GLOG_log_dir=${CUR_DIR}/ms_log export GLOG_logtostderr=0 +export DEVICE_ID=$DEVICE_ID python ${PROJECT_DIR}/../train.py \ --distribute=false \ diff --git a/model_zoo/research/cv/centernet/src/centernet_pose.py b/model_zoo/research/cv/centernet/src/centernet_pose.py index 9cac9a8fa55..929f658e481 100644 --- a/model_zoo/research/cv/centernet/src/centernet_pose.py +++ b/model_zoo/research/cv/centernet/src/centernet_pose.py @@ -30,7 +30,7 @@ from .backbone_dla import DLASeg from .utils import Sigmoid, GradScale from .utils import FocalLoss, RegLoss, RegWeightedL1Loss from .decode import MultiPoseDecode -from .config import dataset_config as data_cfg +from .model_utils.config import dataset_config as data_cfg def _generate_feature(cin, cout, kernel_size, head_name, head_conv=0): diff --git a/model_zoo/research/cv/centernet/src/config.py b/model_zoo/research/cv/centernet/src/config.py deleted file mode 100644 index ff02cb4412b..00000000000 --- a/model_zoo/research/cv/centernet/src/config.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config setting, will be used in dataset.py, train.py, eval.py -""" - -import numpy as np -from easydict import EasyDict as edict - - -dataset_config = edict({ - 'num_classes': 1, - 'num_joints': 17, - 'max_objs': 32, - 'input_res': [512, 512], - 'output_res': [128, 128], - 'rand_crop': False, - 'shift': 0.1, - 'scale': 0.4, - 'aug_rot': 0.0, - 'rotate': 0, - 'flip_prop': 0.5, - 'mean': np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32), - 'std': np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32), - 'flip_idx': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]], - 'edges': [[0, 1], [0, 2], [1, 3], [2, 4], [4, 6], [3, 5], [5, 6], - [5, 7], [7, 9], [6, 8], [8, 10], [6, 12], [5, 11], [11, 12], - [12, 14], [14, 16], [11, 13], [13, 15]], - 'eig_val': np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32), - 'eig_vec': np.array([[-0.58752847, -0.69563484, 0.41340352], - [-0.5832747, 0.00994535, -0.81221408], - [-0.56089297, 0.71832671, 0.41158938]], dtype=np.float32), - 'categories': [{"supercategory": "person", - "id": 1, - "name": "person", - "keypoints": ["nose", "left_eye", "right_eye", "left_ear", "right_ear", - "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", - "left_wrist", "right_wrist", "left_hip", "right_hip", - "left_knee", "right_knee", "left_ankle", "right_ankle"], - "skeleton": [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], - [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], - [2, 4], [3, 5], [4, 6], [5, 7]]}], -}) - - -net_config = edict({ - 'down_ratio': 4, - 'last_level': 6, - 'final_kernel': 1, - 'stage_levels': [1, 1, 1, 2, 2, 1], - 'stage_channels': [16, 32, 64, 128, 256, 512], - 'head_conv': 256, - 'dense_hp': True, - 'hm_hp': True, - 'reg_hp_offset': True, - 'reg_offset': True, - 'hm_weight': 1, - 'off_weight': 1, - 'wh_weight': 0.1, - 'hp_weight': 1, - 'hm_hp_weight': 1, - 'mse_loss': False, - 'reg_loss': 'l1', -}) - - -train_config = edict({ - 'batch_size': 32, - 'loss_scale_value': 1024, - 'optimizer': 'Adam', - 'lr_schedule': 'MultiDecay', - 'Adam': edict({ - 'weight_decay': 0.0, - 'decay_filter': lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma'), - }), - 'PolyDecay': edict({ - 'learning_rate': 1.2e-4, - 'end_learning_rate': 5e-7, - 'power': 5.0, - 'eps': 1e-7, - 'warmup_steps': 2000, - }), - 'MultiDecay': edict({ - 'learning_rate': 1.2e-4, - 'eps': 1e-7, - 'warmup_steps': 2000, - 'multi_epochs': [270, 300], - 'factor': 10, - }) -}) - - -eval_config = edict({ - 'soft_nms': True, - 'keep_res': True, - 'multi_scales': [1.0], - 'pad': 31, - 'K': 100, - 'score_thresh': 0.3 -}) - - -export_config = edict({ - 'input_res': dataset_config.input_res, - 'ckpt_file': "./ckpt_file.ckpt", - 'export_format': "MINDIR", - 'export_name': "CenterNet_MultiPose", -}) diff --git a/model_zoo/research/cv/centernet/src/dataset.py b/model_zoo/research/cv/centernet/src/dataset.py index 26a5ac19a7d..df154bbbe8d 100644 --- a/model_zoo/research/cv/centernet/src/dataset.py +++ b/model_zoo/research/cv/centernet/src/dataset.py @@ -15,10 +15,9 @@ """ Data operations, will be used in train.py """ - +import sys import os import math -import argparse import cv2 import numpy as np import pycocotools.coco as coco @@ -26,9 +25,17 @@ import pycocotools.coco as coco import mindspore.dataset as ds from mindspore import log as logger from mindspore.mindrecord import FileWriter -from src.image import get_affine_transform, affine_transform -from src.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg -from src.visual import visual_image + +try: + from src.image import get_affine_transform, affine_transform + from src.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg + from src.visual import visual_image +except ImportError as import_error: + print('Import Error: {}, trying append path/centernet/src/../'.format(import_error)) + sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + from src.image import get_affine_transform, affine_transform + from src.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg + from src.visual import visual_image _current_dir = os.path.dirname(os.path.realpath(__file__)) cv2.setNumThreads(0) @@ -428,14 +435,8 @@ class COCOHP(ds.Dataset): if __name__ == '__main__': # Convert coco2017 dataset to mindrecord to improve performance on host - from src.config import dataset_config + from src.model_utils.config import config, dataset_config - parser = argparse.ArgumentParser(description='CenterNet MindRecord dataset') - parser.add_argument("--coco_data_dir", type=str, default="", help="Coco dataset directory.") - parser.add_argument("--mindrecord_dir", type=str, default="", help="MindRecord dataset dir.") - parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind", - help="Prefix of MindRecord dataset filename.") - args_opt = parser.parse_args() dsc = COCOHP(dataset_config, run_mode="train") - dsc.init(args_opt.coco_data_dir) - dsc.transfer_coco_to_mindrecord(args_opt.mindrecord_dir, args_opt.mindrecord_prefix, shard_num=8) + dsc.init(config.coco_data_dir) + dsc.transfer_coco_to_mindrecord(config.mindrecord_dir, config.mindrecord_prefix, shard_num=8) diff --git a/model_zoo/research/cv/centernet/src/model_utils/config.py b/model_zoo/research/cv/centernet/src/model_utils/config.py new file mode 100644 index 00000000000..33028f2e85a --- /dev/null +++ b/model_zoo/research/cv/centernet/src/model_utils/config.py @@ -0,0 +1,160 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse arguments""" + +import os +import ast +import argparse +from pprint import pprint, pformat +import yaml +import numpy as np + + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, str) and (v[:9] == 'np.array(' and v[-17:] == 'dtype=np.float32)'): + v = np.array(ast.literal_eval(v[9:v.rfind(']') + 1]), dtype=np.float32) + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, + help=help_description) + else: + parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + cfg_choices = {} + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + cfg_choices = {} + elif len(cfgs) == 3: + cfg, cfg_helper, cfg_choices = cfgs + else: + raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml") + print(cfg_helper) + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper, cfg_choices + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def extra_operations(cfg): + """ + Do extra work on Config object. + + Args: + cfg: Object after instantiation of class 'Config'. + """ + cfg.train_config.Adam.decay_filter = lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma') + cfg.export_config.input_res = cfg.dataset_config.input_res + if cfg.export_load_ckpt: + cfg.export_config.ckpt_file = cfg.export_load_ckpt + if cfg.export_name: + cfg.export_config.export_name = cfg.export_name + if cfg.export_format: + cfg.export_config.export_format = cfg.export_format + + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"), + help="Config file path") + path_args, _ = parser.parse_known_args() + default, helper, choices = parse_yaml(path_args.config_path) + pprint(default) + args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) + final_config = merge(args, default) + config_obj = Config(final_config) + extra_operations(config_obj) + return config_obj + + +config = get_config() +dataset_config = config.dataset_config +net_config = config.net_config +train_config = config.train_config +eval_config = config.eval_config +export_config = config.export_config + +if __name__ == '__main__': + print(config) diff --git a/model_zoo/research/cv/centernet/src/model_utils/device_adapter.py b/model_zoo/research/cv/centernet/src/model_utils/device_adapter.py new file mode 100644 index 00000000000..9c3d21d5e47 --- /dev/null +++ b/model_zoo/research/cv/centernet/src/model_utils/device_adapter.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Device adapter for ModelArts""" + +from src.model_utils.config import config + +if config.enable_modelarts: + from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + "get_device_id", "get_device_num", "get_rank_id", "get_job_id" +] diff --git a/model_zoo/research/cv/centernet/src/model_utils/local_adapter.py b/model_zoo/research/cv/centernet/src/model_utils/local_adapter.py new file mode 100644 index 00000000000..769fa6dc78e --- /dev/null +++ b/model_zoo/research/cv/centernet/src/model_utils/local_adapter.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Local adapter""" + +import os + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/model_zoo/research/cv/centernet/src/model_utils/moxing_adapter.py b/model_zoo/research/cv/centernet/src/model_utils/moxing_adapter.py new file mode 100644 index 00000000000..09cb0f0cf0f --- /dev/null +++ b/model_zoo/research/cv/centernet/src/model_utils/moxing_adapter.py @@ -0,0 +1,123 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Moxing adapter for ModelArts""" + +import os +import functools +from mindspore import context +from mindspore.profiler import Profiler +from src.model_utils.config import config + +_global_sync_count = 0 + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + # print("os.mknod({}) success".format(sync_lock)) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs. + """ + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + if config.enable_profiling: + profiler = Profiler() + + run_func(*args, **kwargs) + + if config.enable_profiling: + profiler.analyse() + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + return wrapped_func + return wrapper diff --git a/model_zoo/research/cv/centernet/src/post_process.py b/model_zoo/research/cv/centernet/src/post_process.py index e905d8322d8..0978379ea12 100644 --- a/model_zoo/research/cv/centernet/src/post_process.py +++ b/model_zoo/research/cv/centernet/src/post_process.py @@ -17,16 +17,10 @@ Post-process functions after decoding """ import numpy as np -from src.config import dataset_config as config +from src.model_utils.config import dataset_config as config from .image import get_affine_transform, affine_transform, transform_preds from .visual import coco_box_to_bbox -try: - from nms import soft_nms_39 -except ImportError: - print('NMS not installed! Do \n cd $CenterNet_ROOT/scripts/ \n' - 'and see run_standalone_eval.sh for more details to install it\n') - _NUM_JOINTS = config.num_joints @@ -48,6 +42,11 @@ def merge_outputs(detections, soft_nms=True): """merge detections together by nms""" results = np.concatenate([detection for detection in detections], axis=0).astype(np.float32) if soft_nms: + try: + from nms import soft_nms_39 + except ImportError: + print('NMS not installed! Do \n cd $CenterNet_ROOT/scripts/ \n' + 'and see run_standalone_eval.sh for more details to install it\n') soft_nms_39(results, Nt=0.5, threshold=0.01, method=2) results = results.tolist() return results diff --git a/model_zoo/research/cv/centernet/src/visual.py b/model_zoo/research/cv/centernet/src/visual.py index 92e4b0a7503..a983752d6bf 100644 --- a/model_zoo/research/cv/centernet/src/visual.py +++ b/model_zoo/research/cv/centernet/src/visual.py @@ -22,7 +22,7 @@ import random import cv2 import numpy as np import pycocotools.coco as COCO -from .config import dataset_config as data_cfg +from .model_utils.config import dataset_config as data_cfg from .image import get_affine_transform, affine_transform _NUM_JOINTS = data_cfg.num_joints diff --git a/model_zoo/research/cv/centernet/train.py b/model_zoo/research/cv/centernet/train.py index 11a4066e185..0fc0113943a 100644 --- a/model_zoo/research/cv/centernet/train.py +++ b/model_zoo/research/cv/centernet/train.py @@ -17,7 +17,6 @@ Train CenterNet and get network model files(.ckpt) """ import os -import argparse import mindspore.communication.management as D from mindspore.communication.management import get_rank from mindspore import context @@ -29,46 +28,17 @@ from mindspore.nn.optim import Adam from mindspore import log as logger from mindspore.common import set_seed from mindspore.profiler import Profiler + from src.dataset import COCOHP from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell from src import CenterNetWithoutLossScaleCell from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR -from src.config import dataset_config, net_config, train_config +from src.model_utils.config import config, dataset_config, net_config, train_config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num _current_dir = os.path.dirname(os.path.realpath(__file__)) -parser = argparse.ArgumentParser(description='CenterNet training') -parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'], - help='device where the code will be implemented. (Default: Ascend)') -parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], - help="Run distribute, default is false.") -parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"], - help="Profiling to parsing runtime info, default is false.") -parser.add_argument("--profiler_path", type=str, default=" ", help="The path to save profiling data") -parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") -parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1," - "i.e. run all steps according to epoch number.") -parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") -parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") -parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"], - help="Enable save checkpoint, default is true.") -parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"], - help="Enable shuffle for dataset, default is true.") -parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"], - help="Enable data sink, default is true.") -parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") -parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") -parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") -parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.") -parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") -parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset files directory") -parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind", - help="Prefix of MindRecord dataset filename.") -parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image") -parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results") - -args_opt = parser.parse_args() - def _set_parallel_all_reduce_split(): """set centernet all_reduce fusion split""" @@ -102,7 +72,7 @@ def _get_optimizer(network, dataset_size): lr_schedule = CenterNetPolynomialDecayLR(learning_rate=train_config.PolyDecay.learning_rate, end_learning_rate=train_config.PolyDecay.end_learning_rate, warmup_steps=train_config.PolyDecay.warmup_steps, - decay_steps=args_opt.train_steps, + decay_steps=config.train_steps, power=train_config.PolyDecay.power) optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.PolyDecay.eps, loss_scale=1.0) elif train_config.lr_schedule == "MultiDecay": @@ -110,7 +80,7 @@ def _get_optimizer(network, dataset_size): if not isinstance(multi_epochs, (list, tuple)): raise TypeError("multi_epochs must be list or tuple.") if not multi_epochs: - multi_epochs = [args_opt.epoch_size] + multi_epochs = [config.epoch_size] lr_schedule = CenterNetMultiEpochsDecayLR(learning_rate=train_config.MultiDecay.learning_rate, warmup_steps=train_config.MultiDecay.warmup_steps, multi_epochs=multi_epochs, @@ -126,83 +96,90 @@ def _get_optimizer(network, dataset_size): return optimizer +def modelarts_pre_process(): + '''modelarts pre process function.''' + config.mindrecord_dir = config.data_path + config.save_checkpoint_path = os.path.join(config.output_path, config.save_checkpoint_path) + + +@moxing_wrapper(pre_process=modelarts_pre_process) def train(): """training CenterNet""" - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) context.set_context(reserve_class_name_in_scope=False) context.set_context(save_graphs=False) - ckpt_save_dir = args_opt.save_checkpoint_path + ckpt_save_dir = config.save_checkpoint_path rank = 0 device_num = 1 num_workers = 8 - if args_opt.device_target == "Ascend": + if config.device_target == "Ascend": context.set_context(enable_auto_mixed_precision=False) - context.set_context(device_id=args_opt.device_id) - if args_opt.distribute == "true": + context.set_context(device_id=get_device_id()) + if config.distribute == "true": D.init() - device_num = args_opt.device_num - rank = args_opt.device_id % device_num - ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/' + device_num = get_device_num() + rank = get_rank_id() + ckpt_save_dir = config.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) _set_parallel_all_reduce_split() else: - args_opt.distribute = "false" - args_opt.need_profiler = "false" - args_opt.enable_data_sink = "false" + config.distribute = "false" + config.need_profiler = "false" + config.enable_data_sink = "false" # Start create dataset! # mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num. logger.info("Begin creating dataset for CenterNet") coco = COCOHP(dataset_config, run_mode="train", net_opt=net_config, - enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir) - dataset = coco.create_train_dataset(args_opt.mindrecord_dir, args_opt.mindrecord_prefix, + enable_visual_image=(config.visual_image == "true"), save_path=config.save_result_dir) + dataset = coco.create_train_dataset(config.mindrecord_dir, config.mindrecord_prefix, batch_size=train_config.batch_size, device_num=device_num, rank=rank, - num_parallel_workers=num_workers, do_shuffle=args_opt.do_shuffle == 'true') + num_parallel_workers=num_workers, do_shuffle=config.do_shuffle == 'true') dataset_size = dataset.get_dataset_size() logger.info("Create dataset done!") net_with_loss = CenterNetMultiPoseLossCell(net_config) - new_repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps - if args_opt.train_steps > 0: - new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) + new_repeat_count = config.epoch_size * dataset_size // config.data_sink_steps + if config.train_steps > 0: + new_repeat_count = min(new_repeat_count, config.train_steps // config.data_sink_steps) else: - args_opt.train_steps = args_opt.epoch_size * dataset_size - logger.info("train steps: {}".format(args_opt.train_steps)) + config.train_steps = config.epoch_size * dataset_size + logger.info("train steps: {}".format(config.train_steps)) optimizer = _get_optimizer(net_with_loss, dataset_size) - enable_static_time = args_opt.device_target == "CPU" - callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size, enable_static_time)] - if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0: - config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, - keep_checkpoint_max=args_opt.save_checkpoint_num) + enable_static_time = config.device_target == "CPU" + callback = [TimeMonitor(config.data_sink_steps), LossCallBack(dataset_size, enable_static_time)] + if config.enable_save_ckpt == "true" and get_device_id() % min(8, device_num) == 0: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, + keep_checkpoint_max=config.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet', directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) - if args_opt.load_checkpoint_path: - param_dict = load_checkpoint(args_opt.load_checkpoint_path) + if config.load_checkpoint_path: + param_dict = load_checkpoint(config.load_checkpoint_path) load_param_into_net(net_with_loss, param_dict) - if args_opt.device_target == "Ascend": + if config.device_target == "Ascend": net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer, sens=train_config.loss_scale_value) else: net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads) - model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), - sink_size=args_opt.data_sink_steps) + model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(config.enable_data_sink == "true"), + sink_size=config.data_sink_steps) if __name__ == '__main__': - if args_opt.need_profiler == "true": - profiler = Profiler(output_path=args_opt.profiler_path) + if config.need_profiler == "true": + profiler = Profiler(output_path=config.profiler_path) set_seed(0) train() - if args_opt.need_profiler == "true": + if config.need_profiler == "true": profiler.analyse()