!18706 centernet can been used on ModelArts

Merge pull request !18706 from 郑彬/master
This commit is contained in:
i-robot 2021-06-23 06:46:02 +00:00 committed by Gitee
commit 18eed2fb33
17 changed files with 798 additions and 300 deletions

View File

@ -97,7 +97,7 @@ Dataset used: [COCO2017](https://cocodataset.org/)
pip install mmcv==0.2.14
```
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
And change the COCO_ROOT and other settings you need in `default_config.yaml`. The directory structure is as follows:
```path
.
@ -115,12 +115,15 @@ Dataset used: [COCO2017](https://cocodataset.org/)
# [Quick Start](#contents)
- running on local
After installing MindSpore via the official website, you can start training and evaluation as follows:
Note: 1.the first run of training will generate the mindrecord file, which will take a long time.
2.MINDRECORD_DATASET_PATH is the mindrecord dataset directory.
3.LOAD_CHECKPOINT_PATH is the pretrained checkpoint file directory, if no just set ""
4.RUN_MODE support validation and testing, set to be "val"/"test"
3.For `train.py`, LOAD_CHECKPOINT_PATH is the pretrained checkpoint file directory, if no just set "".
4.For `eval.py`, LOAD_CHECKPOINT_PATH is the checkpoint to be evaluated.
5.RUN_MODE support validation and testing, set to be "val"/"test"
```shell
# create dataset in mindrecord format
@ -142,6 +145,61 @@ bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LO
bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH]
```
- running on ModelArts
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
- Training with single cards on ModelArts
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/centernet" on the website UI interface.
# (4) Set the startup file to /{path}/centernet/train.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/centernet/default_config.yaml.
# 1. Set ”enable_modelarts: True“
# 2. Set “epoch_size: 350”
# 3. Set “distribute: 'true'”
# 4. Set “data_sink_steps: 50”
# 5. Set “save_checkpoint_path: ./checkpoints”
# b. adding on the website UI interface.
# 1. Add ”enable_modelarts=True“
# 2. Add “epoch_size=350”
# 3. Add “distribute=true”
# 4. Add “data_sink_steps=50”
# 5. Add “save_checkpoint_path=./checkpoints”
# (6) Upload the mindrecdrd dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Under the item "resource pool selection", select the specification of single cards.
# (10) Create your job.
```
- evaluating with single card on ModelArts
```python
# (1) Upload the code folder 'centernet' to S3 bucket.
# (2) Git clone https://github.com/xingyizhou/CenterNet.git on local, and put the folder 'CenterNet' under the folder 'centernet' on s3 bucket.
# (3) Click to "create training task" on the website UI interface.
# (4) Set the code directory to "/{path}/centernet" on the website UI interface.
# (5) Set the startup file to /{path}/centernet/eval.py" on the website UI interface.
# (6) Perform a or b.
# a. setting parameters in /{path}/centernet/default_config.yaml.
# 1. Set ”enable_modelarts: True“
# 2. Set “run_mode: 'val'”
# 3. Set “load_checkpoint_path: ./{path}/*.ckpt”('load_checkpoint_path' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.)
# b. adding on the website UI interface.
# 1. Add ”enable_modelarts=True“
# 2. Add “run_mode=val”
# 3. Add “load_checkpoint_path=./{path}/*.ckpt”('load_checkpoint_path' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.)
# (7) Upload the dataset(not mindrecord format) to S3 bucket.
# (8) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
# (9) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (10) Under the item "resource pool selection", select the specification of a single card.
# (11) Create your job.
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
@ -154,12 +212,13 @@ bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_P
├── eval.py // testing and evaluation outputs
├── export.py // convert mindspore model to air model
├── README.md // descriptions about CenterNet
├── default_config.yaml // parameter configuration
├── scripts
│ ├──ascend_distributed_launcher
│ │ ├──__init__.py
│ │ ├──hyper_parameter_config.ini // hyper parameter for distributed training
│ │ ├──get_distribute_train_cmd.py // script for distributed training
│ │ ──README.md
│ │ ──README.md
│ ├──convert_dataset_to_mindrecord.sh // shell script for converting coco type dataset to mindrecord
│ ├──run_standalone_train_ascend.sh // shell script for standalone training on ascend
│ ├──run_distributed_train_ascend.sh // shell script for distributed training on ascend
@ -167,10 +226,14 @@ bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_P
│ ├──run_standalone_train_cpu.sh // shell script for standalone training on cpu
│ ├──run_standalone_eval_cpu.sh // shell script for standalone evaluation on cpu
└── src
├──model_utils
│ ├──config.py // parsing parameter configuration file of "*.yaml"
│ ├──device_adapter.py // local or ModelArts training
│ ├──local_adapter.py // get related environment variables on local
│ └──moxing_adapter.py // get related environment variables abd transfer data on ModelArts
├──__init__.py
├──centernet_pose.py // centernet networks, training entry
├──dataset.py // generate dataloader and data processing entry
├──config.py // centernet unique configs
├──dcn_v2.py // deformable convolution operator v2
├──decode.py // decode the head features
├──backbone_dla.py // deep layer aggregation backbone
@ -255,34 +318,34 @@ options:
### Options and Parameters
Parameters for training and evaluation can be set in file `config.py` and `finetune_eval_config.py` respectively.
Parameters for training and evaluation can be set in file `default_config.yaml`.
#### Options
```text
config for training.
batch_size batch size of input dataset: N, default is 32
loss_scale_value initial value of loss scale: N, default is 1024
optimizer optimizer used in the network: Adam, default is Adam
lr_schedule schedules to get the learning rate
```python
train_config:
batch_size: 32 // batch size of input dataset: N, default is 32
loss_scale_value: 1024 // initial value of loss scale: N, default is 1024
optimizer: 'Adam' // optimizer used in the network: Adam, default is Adam
lr_schedule: 'MultiDecay' // schedules to get the learning rate
```
```text
config for evaluation.
soft_nms nms after decode: True | False, default is True
keep_res keep original or fix resolution: True | False, default is False
multi_scales use multi-scales of image: List, default is [1.0]
pad pad size when keep original resolution, default is 31
K number of bboxes to be computed by TopK, default is 100
score_thresh threshold of score when visualize image and annotation info
eval_config:
soft_nms: True // nms after decode: True | False, default is True
keep_res: c // keep original or fix resolution: True | False, default is False
multi_scales: [1.0] // use multi-scales of image: List, default is [1.0]
pad: 31 // pad size when keep original resolution, default is 31
K: 100 // number of bboxes to be computed by TopK, default is 100
score_thresh: 0.3 // threshold of score when visualize image and annotation info
```
```text
config for export.
input_res input resolution of the model air, default is [512, 512]
ckpt_file checkpoint file, default is "./ckkt_file.ckpt"
export_format the exported format of model air, default is MINDIR
export_name the exported file name, default is "CentNet_MultiPose"
input_res: dataset_config.input_res // input resolution of the model air, default is [512, 512]
ckpt_file: "./ckpt_file.ckpt" // checkpoint file, default is "./ckkt_file.ckpt"
export_format: "MINDIR" // the exported format of model air, default is MINDIR
export_name: "CenterNet_MultiPose" // the exported file name, default is "CentNet_MultiPose"
```
#### Parameters
@ -462,8 +525,35 @@ overall performance on coco2017 test-dev dataset
If you want to infer the network on Ascend 310, you should convert the model to AIR:
- Export on local
```python
python export.py [DEVICE_ID]
python export.py --device_id [DEVICE_ID] --export_format MINDIR --export_load_ckpt [CKPT_FILE__PATH] --export_name [EXPORT_FILE_NAME]
```
- Export on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start as follows)
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/centernet" on the website UI interface.
# (4) Set the startup file to /{path}/centernet/export.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/centernet/default_config.yaml.
# 1. Set ”enable_modelarts: True“
# 2. Set “export_load_ckpt: ./{path}/*.ckpt”('export_load_ckpt' indicates the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.)
# 3. Set ”export_name: centernet“
# 4. Set ”export_formatMINDIR“
# b. adding on the website UI interface.
# 1. Add ”enable_modelarts=True“
# 2. Add “export_load_ckpt=./{path}/*.ckpt”('export_load_ckpt' indicates the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.)
# 3. Add ”export_name=centernet“
# 4. Add ”export_format=MINDIR“
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (This step is useless, but necessary.).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Under the item "resource pool selection", select the specification of a single card.
# (10) Create your job.
# You will see centernet.mindir under {Output file path}.
```
# [Model Description](#contents)

View File

@ -0,0 +1,176 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
enable_profiling: False
# ==============================================================================
# prepare *.mindrecord* data
coco_data_dir: ""
mindrecord_dir: "" # also used by train.py
mindrecord_prefix: "coco_hp.train.mind"
# train related
visual_image: "false"
save_result_dir: ""
device_id: 0
device_num: 1
distribute: 'false'
need_profiler: "false"
profiler_path: "./profiler"
epoch_size: 1
train_steps: -1
enable_save_ckpt: "true"
do_shuffle: "true"
enable_data_sink: "true"
data_sink_steps: 1
save_checkpoint_path: ""
load_checkpoint_path: ""
save_checkpoint_steps: 1000
save_checkpoint_num: 1
# test related
data_dir: ""
run_mode: "test"
enable_eval: "true"
# export related
export_load_ckpt: ''
export_format: ''
export_name: ''
dataset_config:
num_classes: 1
num_joints: 17
max_objs: 32
input_res: [512, 512]
output_res: [128, 128]
rand_crop: False
shift: 0.1
scale: 0.4
aug_rot: 0.0
rotate: 0
flip_prop: 0.5
mean: np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32)
std: np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32)
flip_idx: [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
edges: [[0, 1], [0, 2], [1, 3], [2, 4], [4, 6], [3, 5], [5, 6],
[5, 7], [7, 9], [6, 8], [8, 10], [6, 12], [5, 11], [11, 12],
[12, 14], [14, 16], [11, 13], [13, 15]]
eig_val: np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32)
eig_vec: np.array([[-0.58752847, -0.69563484, 0.41340352],
[-0.5832747, 0.00994535, -0.81221408],
[-0.56089297, 0.71832671, 0.41158938]], dtype=np.float32)
categories: [{"supercategory": "person",
"id": 1,
"name": "person",
"keypoints": ["nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"],
"skeleton": [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13],
[6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3],
[2, 4], [3, 5], [4, 6], [5, 7]]}]
net_config:
down_ratio: 4
last_level: 6
final_kernel: 1
stage_levels: [1, 1, 1, 2, 2, 1]
stage_channels: [16, 32, 64, 128, 256, 512]
head_conv: 256
dense_hp: True
hm_hp: True
reg_hp_offset: True
reg_offset: True
hm_weight: 1
off_weight: 1
wh_weight: 0.1
hp_weight: 1
hm_hp_weight: 1
mse_loss: False
reg_loss: 'l1'
train_config:
batch_size: 32
loss_scale_value: 1024
optimizer: 'Adam'
lr_schedule: 'MultiDecay'
Adam:
weight_decay: 0.0
decay_filter: "lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma')"
PolyDecay:
learning_rate: 0.00012 # 1.2e-4
end_learning_rate: 0.0000005 # 5e-7
power: 5.0
eps: 0.0000001 # 1e-7
warmup_steps: 2000
MultiDecay:
learning_rate: 0.00012 # 1.2e-4
eps: 0.0000001 # 1e-7
warmup_steps: 2000
multi_epochs: [270, 300]
factor: 10
eval_config:
soft_nms: True
keep_res: True
multi_scales: [1.0]
pad: 31
K: 100
score_thresh: 0.3
export_config:
input_res: dataset_config.input_res
ckpt_file: "./ckpt_file.ckpt"
export_format: "MINDIR"
export_name: "CenterNet_MultiPose"
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
enable_profiling: 'Whether enable profiling while training, default: False'
distribute: "Run distribute, default is false."
need_profiler: "Profiling to parsing runtime info, default is false."
profiler_path: "The path to save profiling data"
epoch_size: "Epoch size, default is 1."
train_steps: "Training Steps, default is -1, i.e. run all steps according to epoch number."
device_id: "Device id, default is 0."
device_num: "Use device nums, default is 1."
enable_save_ckpt: "Enable save checkpoint, default is true."
do_shuffle: "Enable shuffle for dataset, default is true."
enable_data_sink: "Enable data sink, default is true."
data_sink_steps: "Sink steps for each epoch, default is 1."
save_checkpoint_path: "Save checkpoint path"
load_checkpoint_path: "Load checkpoint file path"
save_checkpoint_steps: "Save checkpoint steps, default is 1000."
save_checkpoint_num: "Save checkpoint numbers, default is 1."
mindrecord_dir: "Mindrecord dataset files directory"
mindrecord_prefix: "Prefix of MindRecord dataset filename."
visual_image: "Visulize the ground truth and predicted image"
save_result_dir: "The path to save the predict results"
data_dir: "Dataset directory, the absolute image path is joined by the data_dir, and the relative path in anno_path"
run_mode: "test or validation, default is test."
enable_eval: "Whether evaluate accuracy after prediction"
---
device_target: ['Ascend', 'CPU']
distribute: ["true", "false"]
need_profiler: ["true", "false"]
enable_save_ckpt: ["true", "false"]
do_shuffle: ["true", "false"]
enable_data_sink: ["true", "false"]
export_format: ["MINDIR"]

View File

@ -20,7 +20,6 @@ import os
import time
import copy
import json
import argparse
import cv2
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
@ -31,53 +30,62 @@ import mindspore.log as logger
from src import COCOHP, CenterNetMultiPoseEval
from src import convert_eval_format, post_process, merge_outputs
from src import visual_image
from src.config import dataset_config, net_config, eval_config
from src.model_utils.config import config, dataset_config, net_config, eval_config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
_current_dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description='CenterNet evaluation')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
"the absolute image path is joined by the data_dir "
"and the relative path in anno_path")
parser.add_argument("--run_mode", type=str, default="test", help="test or validation, default is test.")
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
parser.add_argument("--enable_eval", type=str, default="true", help="Whether evaluate accuracy after prediction")
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
args_opt = parser.parse_args()
def modelarts_pre_process():
'''modelarts pre process function.'''
try:
from nms import soft_nms_39
print('soft_nms_39_attributes: {}'.format(soft_nms_39.__dir__()))
except ImportError:
print('NMS not installed! trying installing...\n')
cur_path = os.path.dirname(os.path.abspath(__file__))
os.system('cd {}/CenterNet/src/lib/external/ && make && python setup.py install && cd - '.format(cur_path))
try:
from nms import soft_nms_39
print('soft_nms_39_attributes: {}'.format(soft_nms_39.__dir__()))
except ImportError:
print('Installing failed! check if the folder "./CenterNet" exists.')
else:
print('Install nms successfully')
config.data_dir = config.data_path
config.load_checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config.load_checkpoint_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def predict():
'''
Predict function
'''
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
if args_opt.device_target == "Ascend":
context.set_context(device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
enable_nms_fp16 = True
else:
enable_nms_fp16 = False
logger.info("Begin creating {} dataset".format(args_opt.run_mode))
coco = COCOHP(dataset_config, run_mode=args_opt.run_mode, net_opt=net_config,
enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,)
coco.init(args_opt.data_dir, keep_res=eval_config.keep_res)
logger.info("Begin creating {} dataset".format(config.run_mode))
coco = COCOHP(dataset_config, run_mode=config.run_mode, net_opt=net_config,
enable_visual_image=(config.visual_image == "true"), save_path=config.save_result_dir,)
coco.init(config.data_dir, keep_res=eval_config.keep_res)
dataset = coco.create_eval_dataset()
net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.K, enable_nms_fp16)
net_for_eval.set_train(False)
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
param_dict = load_checkpoint(config.load_checkpoint_path)
load_param_into_net(net_for_eval, param_dict)
# save results
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
save_path = os.path.join(config.save_result_dir, config.run_mode)
if not os.path.exists(save_path):
os.makedirs(save_path)
if args_opt.visual_image == "true":
if config.visual_image == "true":
save_pred_image_path = os.path.join(save_path, "pred_image")
if not os.path.exists(save_pred_image_path):
os.makedirs(save_pred_image_path)
@ -119,25 +127,25 @@ def predict():
pred_annos["images"].append(image_info)
for image_anno in pred_json["annotations"]:
pred_annos["annotations"].append(image_anno)
if args_opt.visual_image == "true":
if config.visual_image == "true":
img_file = os.path.join(coco.image_path, gt_image_info[0]['file_name'])
gt_image = cv2.imread(img_file)
if args_opt.run_mode != "test":
if config.run_mode != "test":
annos = coco.coco.loadAnns(coco.anns[image_id])
visual_image(copy.deepcopy(gt_image), annos, save_gt_image_path)
anno = copy.deepcopy(pred_json["annotations"])
visual_image(gt_image, anno, save_pred_image_path, score_threshold=eval_config.score_thresh)
# save results
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
save_path = os.path.join(config.save_result_dir, config.run_mode)
if not os.path.exists(save_path):
os.makedirs(save_path)
pred_anno_file = os.path.join(save_path, '{}_pred_result.json').format(args_opt.run_mode)
pred_anno_file = os.path.join(save_path, '{}_pred_result.json').format(config.run_mode)
json.dump(pred_annos, open(pred_anno_file, 'w'))
pred_res_file = os.path.join(save_path, '{}_pred_eval.json').format(args_opt.run_mode)
pred_res_file = os.path.join(save_path, '{}_pred_eval.json').format(config.run_mode)
json.dump(pred_annos["annotations"], open(pred_res_file, 'w'))
if args_opt.run_mode != "test" and args_opt.enable_eval:
if config.run_mode != "test" and config.enable_eval:
run_eval(coco.annot_path, pred_res_file)

View File

@ -16,21 +16,26 @@
Export CenterNet mindir model.
"""
import argparse
import os
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src import CenterNetMultiPoseEval
from src.config import net_config, eval_config, export_config
from src.model_utils.config import config, net_config, eval_config, export_config
from src.model_utils.moxing_adapter import moxing_wrapper
parser = argparse.ArgumentParser(description='centernet export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
def modelarts_pre_process():
'''modelarts pre process function.'''
export_config.ckpt_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), export_config.ckpt_file)
export_config.export_name = os.path.join(config.output_path, export_config.export_name)
if __name__ == '__main__':
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_export():
'''export function'''
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=config.device_id)
net = CenterNetMultiPoseEval(net_config, eval_config.K)
net.set_train(False)
@ -42,3 +47,7 @@ if __name__ == '__main__':
input_data = Tensor(np.random.uniform(-1.0, 1.0, size=input_shape).astype(np.float32))
export(net, input_data, file_name=export_config.export_name, file_format=export_config.export_format)
if __name__ == '__main__':
run_export()

View File

@ -29,14 +29,20 @@ PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
export DEVICE_ID=$DEVICE_ID
# install nms module from third party
if python -c "import nms" > /dev/null 2>&1
then
echo "NMS module already exits, no need reinstall."
else
if [ -f './CenterNet' ]
then
echo "NMS module was not found, but has been downloaded"
else
echo "NMS module was not found, install it now..."
git clone https://github.com/xingyizhou/CenterNet.git
fi
cd CenterNet/src/lib/external/ || exit
make
python setup.py install

View File

@ -33,9 +33,14 @@ export GLOG_logtostderr=0
if python -c "import nms" > /dev/null 2>&1
then
echo "NMS module already exits, no need reinstall."
else
if [ -f './CenterNet' ]
then
echo "NMS module was not found, but has been downloaded"
else
echo "NMS module was not found, install it now..."
git clone https://github.com/xingyizhou/CenterNet.git
fi
cd CenterNet/src/lib/external/ || exit
make
python setup.py install

View File

@ -35,6 +35,7 @@ PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
export DEVICE_ID=$DEVICE_ID
python ${PROJECT_DIR}/../train.py \
--distribute=false \

View File

@ -30,7 +30,7 @@ from .backbone_dla import DLASeg
from .utils import Sigmoid, GradScale
from .utils import FocalLoss, RegLoss, RegWeightedL1Loss
from .decode import MultiPoseDecode
from .config import dataset_config as data_cfg
from .model_utils.config import dataset_config as data_cfg
def _generate_feature(cin, cout, kernel_size, head_name, head_conv=0):

View File

@ -1,120 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in dataset.py, train.py eval.py
"""
import numpy as np
from easydict import EasyDict as edict
dataset_config = edict({
'num_classes': 1,
'num_joints': 17,
'max_objs': 32,
'input_res': [512, 512],
'output_res': [128, 128],
'rand_crop': False,
'shift': 0.1,
'scale': 0.4,
'aug_rot': 0.0,
'rotate': 0,
'flip_prop': 0.5,
'mean': np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32),
'std': np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32),
'flip_idx': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]],
'edges': [[0, 1], [0, 2], [1, 3], [2, 4], [4, 6], [3, 5], [5, 6],
[5, 7], [7, 9], [6, 8], [8, 10], [6, 12], [5, 11], [11, 12],
[12, 14], [14, 16], [11, 13], [13, 15]],
'eig_val': np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32),
'eig_vec': np.array([[-0.58752847, -0.69563484, 0.41340352],
[-0.5832747, 0.00994535, -0.81221408],
[-0.56089297, 0.71832671, 0.41158938]], dtype=np.float32),
'categories': [{"supercategory": "person",
"id": 1,
"name": "person",
"keypoints": ["nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"],
"skeleton": [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13],
[6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3],
[2, 4], [3, 5], [4, 6], [5, 7]]}],
})
net_config = edict({
'down_ratio': 4,
'last_level': 6,
'final_kernel': 1,
'stage_levels': [1, 1, 1, 2, 2, 1],
'stage_channels': [16, 32, 64, 128, 256, 512],
'head_conv': 256,
'dense_hp': True,
'hm_hp': True,
'reg_hp_offset': True,
'reg_offset': True,
'hm_weight': 1,
'off_weight': 1,
'wh_weight': 0.1,
'hp_weight': 1,
'hm_hp_weight': 1,
'mse_loss': False,
'reg_loss': 'l1',
})
train_config = edict({
'batch_size': 32,
'loss_scale_value': 1024,
'optimizer': 'Adam',
'lr_schedule': 'MultiDecay',
'Adam': edict({
'weight_decay': 0.0,
'decay_filter': lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma'),
}),
'PolyDecay': edict({
'learning_rate': 1.2e-4,
'end_learning_rate': 5e-7,
'power': 5.0,
'eps': 1e-7,
'warmup_steps': 2000,
}),
'MultiDecay': edict({
'learning_rate': 1.2e-4,
'eps': 1e-7,
'warmup_steps': 2000,
'multi_epochs': [270, 300],
'factor': 10,
})
})
eval_config = edict({
'soft_nms': True,
'keep_res': True,
'multi_scales': [1.0],
'pad': 31,
'K': 100,
'score_thresh': 0.3
})
export_config = edict({
'input_res': dataset_config.input_res,
'ckpt_file': "./ckpt_file.ckpt",
'export_format': "MINDIR",
'export_name': "CenterNet_MultiPose",
})

View File

@ -15,10 +15,9 @@
"""
Data operations, will be used in train.py
"""
import sys
import os
import math
import argparse
import cv2
import numpy as np
import pycocotools.coco as coco
@ -26,6 +25,14 @@ import pycocotools.coco as coco
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
try:
from src.image import get_affine_transform, affine_transform
from src.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg
from src.visual import visual_image
except ImportError as import_error:
print('Import Error: {}, trying append path/centernet/src/../'.format(import_error))
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from src.image import get_affine_transform, affine_transform
from src.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg
from src.visual import visual_image
@ -428,14 +435,8 @@ class COCOHP(ds.Dataset):
if __name__ == '__main__':
# Convert coco2017 dataset to mindrecord to improve performance on host
from src.config import dataset_config
from src.model_utils.config import config, dataset_config
parser = argparse.ArgumentParser(description='CenterNet MindRecord dataset')
parser.add_argument("--coco_data_dir", type=str, default="", help="Coco dataset directory.")
parser.add_argument("--mindrecord_dir", type=str, default="", help="MindRecord dataset dir.")
parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind",
help="Prefix of MindRecord dataset filename.")
args_opt = parser.parse_args()
dsc = COCOHP(dataset_config, run_mode="train")
dsc.init(args_opt.coco_data_dir)
dsc.transfer_coco_to_mindrecord(args_opt.mindrecord_dir, args_opt.mindrecord_prefix, shard_num=8)
dsc.init(config.coco_data_dir)
dsc.transfer_coco_to_mindrecord(config.mindrecord_dir, config.mindrecord_prefix, shard_num=8)

View File

@ -0,0 +1,160 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
import numpy as np
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, str) and (v[:9] == 'np.array(' and v[-17:] == 'dtype=np.float32)'):
v = np.array(ast.literal_eval(v[9:v.rfind(']') + 1]), dtype=np.float32)
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def extra_operations(cfg):
"""
Do extra work on Config object.
Args:
cfg: Object after instantiation of class 'Config'.
"""
cfg.train_config.Adam.decay_filter = lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma')
cfg.export_config.input_res = cfg.dataset_config.input_res
if cfg.export_load_ckpt:
cfg.export_config.ckpt_file = cfg.export_load_ckpt
if cfg.export_name:
cfg.export_config.export_name = cfg.export_name
if cfg.export_format:
cfg.export_config.export_format = cfg.export_format
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
config_obj = Config(final_config)
extra_operations(config_obj)
return config_obj
config = get_config()
dataset_config = config.dataset_config
net_config = config.net_config
train_config = config.train_config
eval_config = config.eval_config
export_config = config.export_config
if __name__ == '__main__':
print(config)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from src.model_utils.config import config
if config.enable_modelarts:
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from src.model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
# print("os.mknod({}) success".format(sync_lock))
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -17,16 +17,10 @@ Post-process functions after decoding
"""
import numpy as np
from src.config import dataset_config as config
from src.model_utils.config import dataset_config as config
from .image import get_affine_transform, affine_transform, transform_preds
from .visual import coco_box_to_bbox
try:
from nms import soft_nms_39
except ImportError:
print('NMS not installed! Do \n cd $CenterNet_ROOT/scripts/ \n'
'and see run_standalone_eval.sh for more details to install it\n')
_NUM_JOINTS = config.num_joints
@ -48,6 +42,11 @@ def merge_outputs(detections, soft_nms=True):
"""merge detections together by nms"""
results = np.concatenate([detection for detection in detections], axis=0).astype(np.float32)
if soft_nms:
try:
from nms import soft_nms_39
except ImportError:
print('NMS not installed! Do \n cd $CenterNet_ROOT/scripts/ \n'
'and see run_standalone_eval.sh for more details to install it\n')
soft_nms_39(results, Nt=0.5, threshold=0.01, method=2)
results = results.tolist()
return results

View File

@ -22,7 +22,7 @@ import random
import cv2
import numpy as np
import pycocotools.coco as COCO
from .config import dataset_config as data_cfg
from .model_utils.config import dataset_config as data_cfg
from .image import get_affine_transform, affine_transform
_NUM_JOINTS = data_cfg.num_joints

View File

@ -17,7 +17,6 @@ Train CenterNet and get network model files(.ckpt)
"""
import os
import argparse
import mindspore.communication.management as D
from mindspore.communication.management import get_rank
from mindspore import context
@ -29,46 +28,17 @@ from mindspore.nn.optim import Adam
from mindspore import log as logger
from mindspore.common import set_seed
from mindspore.profiler import Profiler
from src.dataset import COCOHP
from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell
from src import CenterNetWithoutLossScaleCell
from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
from src.config import dataset_config, net_config, train_config
from src.model_utils.config import config, dataset_config, net_config, train_config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
_current_dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description='CenterNet training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.")
parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"],
help="Profiling to parsing runtime info, default is false.")
parser.add_argument("--profiler_path", type=str, default=" ", help="The path to save profiling data")
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1,"
"i.e. run all steps according to epoch number.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
help="Enable save checkpoint, default is true.")
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset files directory")
parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind",
help="Prefix of MindRecord dataset filename.")
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
args_opt = parser.parse_args()
def _set_parallel_all_reduce_split():
"""set centernet all_reduce fusion split"""
@ -102,7 +72,7 @@ def _get_optimizer(network, dataset_size):
lr_schedule = CenterNetPolynomialDecayLR(learning_rate=train_config.PolyDecay.learning_rate,
end_learning_rate=train_config.PolyDecay.end_learning_rate,
warmup_steps=train_config.PolyDecay.warmup_steps,
decay_steps=args_opt.train_steps,
decay_steps=config.train_steps,
power=train_config.PolyDecay.power)
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.PolyDecay.eps, loss_scale=1.0)
elif train_config.lr_schedule == "MultiDecay":
@ -110,7 +80,7 @@ def _get_optimizer(network, dataset_size):
if not isinstance(multi_epochs, (list, tuple)):
raise TypeError("multi_epochs must be list or tuple.")
if not multi_epochs:
multi_epochs = [args_opt.epoch_size]
multi_epochs = [config.epoch_size]
lr_schedule = CenterNetMultiEpochsDecayLR(learning_rate=train_config.MultiDecay.learning_rate,
warmup_steps=train_config.MultiDecay.warmup_steps,
multi_epochs=multi_epochs,
@ -126,83 +96,90 @@ def _get_optimizer(network, dataset_size):
return optimizer
def modelarts_pre_process():
'''modelarts pre process function.'''
config.mindrecord_dir = config.data_path
config.save_checkpoint_path = os.path.join(config.output_path, config.save_checkpoint_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def train():
"""training CenterNet"""
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(save_graphs=False)
ckpt_save_dir = args_opt.save_checkpoint_path
ckpt_save_dir = config.save_checkpoint_path
rank = 0
device_num = 1
num_workers = 8
if args_opt.device_target == "Ascend":
if config.device_target == "Ascend":
context.set_context(enable_auto_mixed_precision=False)
context.set_context(device_id=args_opt.device_id)
if args_opt.distribute == "true":
context.set_context(device_id=get_device_id())
if config.distribute == "true":
D.init()
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
device_num = get_device_num()
rank = get_rank_id()
ckpt_save_dir = config.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
_set_parallel_all_reduce_split()
else:
args_opt.distribute = "false"
args_opt.need_profiler = "false"
args_opt.enable_data_sink = "false"
config.distribute = "false"
config.need_profiler = "false"
config.enable_data_sink = "false"
# Start create dataset!
# mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
logger.info("Begin creating dataset for CenterNet")
coco = COCOHP(dataset_config, run_mode="train", net_opt=net_config,
enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir)
dataset = coco.create_train_dataset(args_opt.mindrecord_dir, args_opt.mindrecord_prefix,
enable_visual_image=(config.visual_image == "true"), save_path=config.save_result_dir)
dataset = coco.create_train_dataset(config.mindrecord_dir, config.mindrecord_prefix,
batch_size=train_config.batch_size, device_num=device_num, rank=rank,
num_parallel_workers=num_workers, do_shuffle=args_opt.do_shuffle == 'true')
num_parallel_workers=num_workers, do_shuffle=config.do_shuffle == 'true')
dataset_size = dataset.get_dataset_size()
logger.info("Create dataset done!")
net_with_loss = CenterNetMultiPoseLossCell(net_config)
new_repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
new_repeat_count = config.epoch_size * dataset_size // config.data_sink_steps
if config.train_steps > 0:
new_repeat_count = min(new_repeat_count, config.train_steps // config.data_sink_steps)
else:
args_opt.train_steps = args_opt.epoch_size * dataset_size
logger.info("train steps: {}".format(args_opt.train_steps))
config.train_steps = config.epoch_size * dataset_size
logger.info("train steps: {}".format(config.train_steps))
optimizer = _get_optimizer(net_with_loss, dataset_size)
enable_static_time = args_opt.device_target == "CPU"
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size, enable_static_time)]
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
enable_static_time = config.device_target == "CPU"
callback = [TimeMonitor(config.data_sink_steps), LossCallBack(dataset_size, enable_static_time)]
if config.enable_save_ckpt == "true" and get_device_id() % min(8, device_num) == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet',
directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
callback.append(ckpoint_cb)
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
if config.load_checkpoint_path:
param_dict = load_checkpoint(config.load_checkpoint_path)
load_param_into_net(net_with_loss, param_dict)
if args_opt.device_target == "Ascend":
if config.device_target == "Ascend":
net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
sens=train_config.loss_scale_value)
else:
net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss, optimizer=optimizer)
model = Model(net_with_grads)
model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
sink_size=args_opt.data_sink_steps)
model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(config.enable_data_sink == "true"),
sink_size=config.data_sink_steps)
if __name__ == '__main__':
if args_opt.need_profiler == "true":
profiler = Profiler(output_path=args_opt.profiler_path)
if config.need_profiler == "true":
profiler = Profiler(output_path=config.profiler_path)
set_seed(0)
train()
if args_opt.need_profiler == "true":
if config.need_profiler == "true":
profiler.analyse()