forked from mindspore-Ecosystem/mindspore
modify dataset creating procedure for CenterNet
This commit is contained in:
parent
cfc9c740b1
commit
8fc4705d67
|
@ -117,18 +117,21 @@ Dataset used: [COCO2017](<https://cocodataset.org/>)
|
|||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
Note: 1.the first run will generate the mindrecord file, which will take a long time.
|
||||
2.VALIDATION_JSON_FILE is ground truth label file. CHECKPOINT_PATH is a checkpoint file after training.
|
||||
Note: 1.the first run of training will generate the mindrecord file, which will take a long time.
|
||||
2.MINDRECORD_DATASET_PATH is the mindrecord dataset directory.
|
||||
|
||||
```shell
|
||||
# create dataset in mindrecord format
|
||||
bash scripts/convert_dataset_to_mindrecord.sh
|
||||
|
||||
# standalone training
|
||||
bash run_standalone_train_ascend.sh [DEVICE_ID] [EPOCH_SIZE]
|
||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [EPOCH_SIZE]
|
||||
|
||||
# distributed training
|
||||
bash run_distributed_train_ascend.sh [COCO_DATASET_PATH] [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE]
|
||||
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE]
|
||||
|
||||
# eval
|
||||
bash run_standalone_eval_ascend.sh [DEVICE_ID]
|
||||
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
@ -149,6 +152,7 @@ bash run_standalone_eval_ascend.sh [DEVICE_ID]
|
|||
│ │ ├──hyper_parameter_config.ini // hyper parameter for distributed pretraining
|
||||
│ │ ├──get_distribute_pretrain_cmd.py // script for distributed pretraining
|
||||
│ │ ├──README.md
|
||||
│ ├──convert_dataset_to_mindrecord.sh // shell script for converting coco type dataset to mindrecord
|
||||
│ ├──run_standalone_train_ascend.sh // shell script for standalone pretrain on ascend
|
||||
│ ├──run_distributed_train_ascend.sh // shell script for distributed pretrain on ascend
|
||||
│ ├──run_standalone_eval_ascend.sh // shell script for standalone evaluation on ascend
|
||||
|
@ -168,6 +172,19 @@ bash run_standalone_eval_ascend.sh [DEVICE_ID]
|
|||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
### Create MindRecord type dataset
|
||||
|
||||
```text
|
||||
usage: dataset.py [--coco_data_dir COCO_DATA_DIR]
|
||||
[--mindrecord_dir MINDRECORD_DIR]
|
||||
[--mindrecord_prefix MINDRECORD_PREFIX]
|
||||
|
||||
options:
|
||||
--coco_data_dir path to coco dataset directory: PATH, default is ""
|
||||
--mindrecord_dir path to mindrecord dataset directory: PATH, default is ""
|
||||
--mindrecord_prefix prefix of MindRecord dataset filename: STR, default is "coco_hp.train.mind"
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
```text
|
||||
|
@ -180,7 +197,8 @@ usage: train.py [--device_target DEVICE_TARGET] [--distribute DISTRIBUTE]
|
|||
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
|
||||
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
|
||||
[--save_checkpoint_steps N] [--save_checkpoint_num N]
|
||||
[--data_dir DATA_DIR] [--mindrecord_dir MINDRECORD_DIR]
|
||||
[--mindrecord_dir MINDRECORD_DIR]
|
||||
[--mindrecord_prefix MINDRECORD_PREFIX]
|
||||
[--visual_image VISUAL_IMAGE] [--save_result_dir SAVE_RESULT_DIR]
|
||||
|
||||
options:
|
||||
|
@ -201,8 +219,8 @@ options:
|
|||
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
|
||||
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
|
||||
--save_checkpoint_num number for saving checkpoint files: N, default is 1
|
||||
--data_dir path to original dataset directory: PATH, default is ""
|
||||
--mindrecord_dir path to mindrecord dataset directory: PATH, default is ""
|
||||
--mindrecord_prefix prefix of MindRecord dataset filename: STR, default is "coco_hp.train.mind"
|
||||
--visual_image whether visualize the image and annotation info: "true" | "false", default is "false"
|
||||
--save_result_dir path to save the visualization results: PATH, default is ""
|
||||
```
|
||||
|
@ -214,7 +232,7 @@ usage: eval.py [--device_target DEVICE_TARGET] [--device_id N]
|
|||
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
|
||||
[--data_dir DATA_DIR] [--run_mode RUN_MODE]
|
||||
[--visual_image VISUAL_IMAGE]
|
||||
[enable_eval ENABLE_EVAL] [--save_result_dir SAVE_RESULT_DIR]
|
||||
[--enable_eval ENABLE_EVAL] [--save_result_dir SAVE_RESULT_DIR]
|
||||
options:
|
||||
--device_target device where the code will be implemented: "Ascend" | "CPU", default is "Ascend"
|
||||
--device_id device id to run task, default is 0
|
||||
|
@ -324,6 +342,14 @@ Parameters for optimizer and learning rate:
|
|||
|
||||
## [Training Process](#contents)
|
||||
|
||||
Before your first training, convert coco type dataset to mindrecord files is needed to improve performance on host.
|
||||
|
||||
```bash
|
||||
bash scripts/convert_dataset_to_mindrecord.sh
|
||||
```
|
||||
|
||||
The command above will run in the background, after converting mindrecord files will be located in path specified by yourself.
|
||||
|
||||
### Training
|
||||
|
||||
#### Running on Ascend
|
||||
|
|
|
@ -55,9 +55,9 @@ def predict():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
|
||||
logger.info("Begin creating {} dataset".format(args_opt.run_mode))
|
||||
coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode=args_opt.run_mode)
|
||||
coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,
|
||||
keep_res=eval_config.keep_res, flip_test=eval_config.flip_test)
|
||||
coco = COCOHP(dataset_config, run_mode=args_opt.run_mode, net_opt=net_config,
|
||||
enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,)
|
||||
coco.init(args_opt.data_dir, keep_res=eval_config.keep_res, flip_test=eval_config.flip_test)
|
||||
dataset = coco.create_eval_dataset()
|
||||
|
||||
net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
|
||||
|
|
|
@ -38,12 +38,9 @@ def parse_args():
|
|||
help="Run script path, it is better to use absolute path")
|
||||
parser.add_argument("--hyper_parameter_config_dir", type=str, default="",
|
||||
help="Hyper Parameter config path, it is better to use absolute path")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
|
||||
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by "
|
||||
"data_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
|
||||
"rather than data_dir and anno_path. Default is ./Mindrecord_train")
|
||||
parser.add_argument("--data_dir", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset directory")
|
||||
parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind",
|
||||
help="Prefix of MindRecord dataset filename.")
|
||||
parser.add_argument("--hccl_config_dir", type=str, default="",
|
||||
help="Hccl config path, it is better to use absolute path")
|
||||
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
|
||||
|
@ -74,8 +71,8 @@ def distribute_train():
|
|||
args = parse_args()
|
||||
|
||||
run_script = args.run_script_dir
|
||||
data_dir = args.data_dir
|
||||
mindrecord_dir = args.mindrecord_dir
|
||||
mindrecord_prefix = args.mindrecord_prefix
|
||||
cf = configparser.ConfigParser()
|
||||
cf.read(args.hyper_parameter_config_dir)
|
||||
cfg = dict(cf.items("config"))
|
||||
|
@ -142,7 +139,6 @@ def distribute_train():
|
|||
|
||||
print("core_nums:", cmdopt)
|
||||
print("epoch_size:", str(cfg['epoch_size']))
|
||||
print("data_dir:", data_dir)
|
||||
print("mindrecord_dir:", mindrecord_dir)
|
||||
print("log_file_dir: " + cur_dir + "/LOG" + str(device_id) + "/training_log.txt")
|
||||
|
||||
|
@ -150,12 +146,12 @@ def distribute_train():
|
|||
|
||||
run_cmd = 'taskset -c ' + cmdopt + ' nohup python ' + run_script + " "
|
||||
opt = " ".join(["--" + key + "=" + str(cfg[key]) for key in cfg.keys()])
|
||||
if ('device_id' in opt) or ('device_num' in opt) or ('data_dir' in opt):
|
||||
if ('device_id' in opt) or ('device_num' in opt) or ('mindrecord_dir' in opt):
|
||||
raise ValueError("hyper_parameter_config.ini can not setting 'device_id',"
|
||||
" 'device_num' or 'data_dir'! ")
|
||||
" 'device_num' or 'mindrecord_dir'! ")
|
||||
run_cmd += opt
|
||||
run_cmd += " --data_dir=" + data_dir
|
||||
run_cmd += " --mindrecord_dir=" + mindrecord_dir
|
||||
run_cmd += " --mindrecord_prefix=" + mindrecord_prefix
|
||||
run_cmd += ' --device_id=' + str(device_id) + ' --device_num=' \
|
||||
+ str(rank_size) + ' >./training_log.txt 2>&1 &'
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ load_checkpoint_path=""
|
|||
save_checkpoint_path=./
|
||||
save_checkpoint_steps=3000
|
||||
save_checkpoint_num=1
|
||||
mindrecord_prefix="coco_hp.train.mind"
|
||||
need_profiler=false
|
||||
profiler_path=./profiler
|
||||
visual_image=false
|
|
@ -0,0 +1,28 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash convert_dataset_to_mindrecord.sh"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
export GLOG_v=1
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
||||
python ${PROJECT_DIR}/../src/dataset.py \
|
||||
--coco_data_dir="" \
|
||||
--mindrecord_dir="" \
|
||||
--mindrecord_prefix="coco_hp.train.mind" > create_dataset.log 2>&1 &
|
|
@ -27,9 +27,8 @@ CUR_DIR=`pwd`
|
|||
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
|
||||
--run_script_dir=${CUR_DIR}/train.py \
|
||||
--hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \
|
||||
--data_dir=$1 \
|
||||
--mindrecord_dir=$2 \
|
||||
--hccl_config_dir=$3 \
|
||||
--mindrecord_dir=$1 \
|
||||
--hccl_config_dir=$2 \
|
||||
--hccl_time_out=1200 \
|
||||
--cmd_file=distributed_cmd.sh
|
||||
|
||||
|
|
|
@ -48,4 +48,4 @@ python ${PROJECT_DIR}/../eval.py \
|
|||
--visual_image=true \
|
||||
--enable_eval=true \
|
||||
--save_result_dir="" \
|
||||
--run_mode=val > log.txt 2>&1 &
|
||||
--run_mode=val > eval_log.txt 2>&1 &
|
|
@ -42,7 +42,7 @@ python ${PROJECT_DIR}/../train.py \
|
|||
--load_checkpoint_path="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir="" \
|
||||
--mindrecord_dir="" \
|
||||
--mindrecord_prefix="coco_hp.train.mind" \
|
||||
--visual_image=false \
|
||||
--save_result_dir=""> log.txt 2>&1 &
|
||||
--save_result_dir="" > training_log.txt 2>&1 &
|
|
@ -19,6 +19,7 @@ Data operations, will be used in train.py
|
|||
import os
|
||||
import copy
|
||||
import math
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pycocotools.coco as coco
|
||||
|
@ -26,10 +27,9 @@ import pycocotools.coco as coco
|
|||
import mindspore.dataset.engine.datasets as de
|
||||
from mindspore import log as logger
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from .image import color_aug
|
||||
from .image import get_affine_transform, affine_transform
|
||||
from .image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg
|
||||
from .visual import visual_image
|
||||
from src.image import color_aug, get_affine_transform, affine_transform
|
||||
from src.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian, draw_dense_reg
|
||||
from src.visual import visual_image
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
|
@ -47,40 +47,39 @@ class COCOHP(de.Dataset):
|
|||
Returns:
|
||||
Prepocessed training or testing dataset for CenterNet network.
|
||||
"""
|
||||
def __init__(self, data_dir, data_opt, net_opt, run_mode):
|
||||
def __init__(self, data_opt, run_mode="train", net_opt=None, enable_visual_image=False, save_path=None):
|
||||
super(COCOHP, self).__init__()
|
||||
if not os.path.isdir(data_dir):
|
||||
raise RuntimeError("Invalid dataset path")
|
||||
assert run_mode in ["train", "test", "val"], "only train/test/val mode are supported"
|
||||
self.run_mode = run_mode
|
||||
|
||||
if self.run_mode != "test":
|
||||
self.annot_path = os.path.join(data_dir, 'annotations',
|
||||
'person_keypoints_{}2017.json').format(self.run_mode)
|
||||
else:
|
||||
self.annot_path = os.path.join(data_dir, 'annotations', 'image_info_test-dev2017.json')
|
||||
self.image_path = os.path.join(data_dir, '{}2017').format(self.run_mode)
|
||||
|
||||
self._data_rng = np.random.RandomState(123)
|
||||
self.data_opt = data_opt
|
||||
self.data_opt.mean = self.data_opt.mean.reshape(1, 1, 3)
|
||||
self.data_opt.std = self.data_opt.std.reshape(1, 1, 3)
|
||||
self.net_opt = net_opt
|
||||
self.coco = coco.COCO(self.annot_path)
|
||||
|
||||
|
||||
def init(self, enable_visual_image=False, save_path=None, keep_res=False, flip_test=False):
|
||||
"""initailize additional info"""
|
||||
logger.info('Initializing coco 2017 {} data.'.format(self.run_mode))
|
||||
logger.info('Image path: {}'.format(self.image_path))
|
||||
logger.info('Annotations: {}'.format(self.annot_path))
|
||||
assert run_mode in ["train", "test", "val"], "only train/test/val mode are supported"
|
||||
self.run_mode = run_mode
|
||||
|
||||
if net_opt is not None:
|
||||
self.net_opt = net_opt
|
||||
self.enable_visual_image = enable_visual_image
|
||||
if self.enable_visual_image:
|
||||
self.save_path = os.path.join(save_path, self.run_mode, "input_image")
|
||||
if not os.path.exists(self.save_path):
|
||||
os.makedirs(self.save_path)
|
||||
|
||||
|
||||
def init(self, data_dir, keep_res=False, flip_test=False):
|
||||
"""initailize additional info"""
|
||||
logger.info('Initializing coco 2017 {} data.'.format(self.run_mode))
|
||||
if not os.path.isdir(data_dir):
|
||||
raise RuntimeError("Invalid dataset path")
|
||||
if self.run_mode != "test":
|
||||
self.annot_path = os.path.join(data_dir, 'annotations',
|
||||
'person_keypoints_{}2017.json').format(self.run_mode)
|
||||
else:
|
||||
self.annot_path = os.path.join(data_dir, 'annotations', 'image_info_test-dev2017.json')
|
||||
self.image_path = os.path.join(data_dir, '{}2017').format(self.run_mode)
|
||||
logger.info('Image path: {}'.format(self.image_path))
|
||||
logger.info('Annotations: {}'.format(self.annot_path))
|
||||
|
||||
self.coco = coco.COCO(self.annot_path)
|
||||
image_ids = self.coco.getImgIds()
|
||||
if self.run_mode != "test":
|
||||
self.images = []
|
||||
|
@ -102,8 +101,15 @@ class COCOHP(de.Dataset):
|
|||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def transfer_coco_to_mindrecord(self, mindrecord_dir, file_name, shard_num=1):
|
||||
def transfer_coco_to_mindrecord(self, mindrecord_dir, file_name="coco_hp.train.mind", shard_num=1):
|
||||
"""Create MindRecord file by image_dir and anno_path."""
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if os.path.isdir(self.image_path) and os.path.exists(self.annot_path):
|
||||
logger.info("Create MindRecord based on COCO_HP dataset")
|
||||
else:
|
||||
raise ValueError('data_dir {} or anno_path {} does not exist'.format(self.image_path, self.annot_path))
|
||||
|
||||
mindrecord_path = os.path.join(mindrecord_dir, file_name)
|
||||
writer = FileWriter(mindrecord_path, shard_num)
|
||||
centernet_json = {
|
||||
|
@ -139,6 +145,7 @@ class COCOHP(de.Dataset):
|
|||
"category_id": np.array(category_id, np.int32)}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
logger.info("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
|
||||
|
||||
def _coco_box_to_bbox(self, box):
|
||||
|
@ -393,19 +400,11 @@ class COCOHP(de.Dataset):
|
|||
return ret
|
||||
|
||||
|
||||
def create_train_dataset(self, mindrecord_dir, prefix, batch_size=1,
|
||||
def create_train_dataset(self, mindrecord_dir, prefix="coco_hp.train.mind", batch_size=1,
|
||||
device_num=1, rank=0, num_parallel_workers=1, do_shuffle=True):
|
||||
"""create train dataset based on mindrecord file"""
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if os.path.isdir(self.image_path) and os.path.exists(self.annot_path):
|
||||
logger.info("Create MindRecord based on COCO_HP dataset")
|
||||
self.transfer_coco_to_mindrecord(mindrecord_dir, prefix, shard_num=8)
|
||||
logger.info("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
raise ValueError('data_dir {} or anno_path {} does not exist'.format(self.image_path, self.annot_path))
|
||||
else:
|
||||
logger.info("MindRecord dataset already exists, dir: {}".format(mindrecord_dir))
|
||||
raise ValueError('MindRecord data_dir {} does not exist'.format(mindrecord_dir))
|
||||
|
||||
files = os.listdir(mindrecord_dir)
|
||||
data_files = []
|
||||
|
@ -415,7 +414,6 @@ class COCOHP(de.Dataset):
|
|||
if not data_files:
|
||||
raise ValueError('data_dir {} have no data files'.format(mindrecord_dir))
|
||||
|
||||
|
||||
columns = ["image", "num_objects", "keypoints", "bbox", "category_id"]
|
||||
ds = de.MindDataset(data_files,
|
||||
columns_list=columns,
|
||||
|
@ -447,3 +445,17 @@ class COCOHP(de.Dataset):
|
|||
ds = de.GeneratorDataset(generator, column, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True, num_parallel_workers=8)
|
||||
return ds
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Convert coco2017 dataset to mindrecord to improve performance on host
|
||||
from src.config import dataset_config
|
||||
parser = argparse.ArgumentParser(description='CenterNet MindRecord dataset')
|
||||
parser.add_argument("--coco_data_dir", type=str, default="", help="Coco dataset directory.")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="", help="MindRecord dataset dir.")
|
||||
parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind",
|
||||
help="Prefix of MindRecord dataset filename.")
|
||||
args_opt = parser.parse_args()
|
||||
dsc = COCOHP(dataset_config, run_mode="train")
|
||||
dsc.init(args_opt.coco_data_dir)
|
||||
dsc.transfer_coco_to_mindrecord(args_opt.mindrecord_dir, args_opt.mindrecord_prefix, shard_num=8)
|
||||
|
|
|
@ -20,7 +20,6 @@ import math
|
|||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
|
@ -494,11 +493,11 @@ class LossCallBack(Callback):
|
|||
if percent == 0:
|
||||
percent = 1
|
||||
epoch_num -= 1
|
||||
logger.info("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
|
||||
else:
|
||||
logger.info("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
|
||||
|
||||
class CenterNetPolynomialDecayLR(LearningRateSchedule):
|
||||
|
|
|
@ -58,13 +58,9 @@ parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save c
|
|||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
|
||||
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="",
|
||||
help="Mindrecord files directory. If is empty, mindrecord format files will be generated"
|
||||
"based on the original dataset and annotation information. If mindrecord_dir isn't empty,"
|
||||
"mindrecord_dir will be used inplace of data_dir and anno_path.")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
|
||||
"the absolute image path is joined by the data_dir "
|
||||
"and the relative path in anno_path")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset files directory")
|
||||
parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind",
|
||||
help="Prefix of MindRecord dataset filename.")
|
||||
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
|
||||
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
|
||||
|
||||
|
@ -148,16 +144,15 @@ def train():
|
|||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
num_workers = device_num * 8
|
||||
num_workers = 8
|
||||
# Start create dataset!
|
||||
# mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
|
||||
logger.info("Begin creating dataset for CenterNet")
|
||||
prefix = "coco_hp.train.mind"
|
||||
coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode="train")
|
||||
coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir)
|
||||
dataset = coco.create_train_dataset(args_opt.mindrecord_dir, prefix, batch_size=train_config.batch_size,
|
||||
device_num=device_num, rank=rank, num_parallel_workers=num_workers,
|
||||
do_shuffle=args_opt.do_shuffle == 'true')
|
||||
coco = COCOHP(dataset_config, run_mode="train", net_opt=net_config,
|
||||
enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir)
|
||||
dataset = coco.create_train_dataset(args_opt.mindrecord_dir, args_opt.mindrecord_prefix,
|
||||
batch_size=train_config.batch_size, device_num=device_num, rank=rank,
|
||||
num_parallel_workers=num_workers, do_shuffle=args_opt.do_shuffle == 'true')
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
logger.info("Create dataset done!")
|
||||
|
||||
|
|
Loading…
Reference in New Issue