!10818 Modify CenterNet scripts for cpu adaption and user friendliness
From: @shibeiji Reviewed-by: @linqingke,@c_34,@linqingke Signed-off-by: @wuxuejian
This commit is contained in:
commit
720baf670e
|
@ -119,19 +119,27 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
Note: 1.the first run of training will generate the mindrecord file, which will take a long time.
|
||||
2.MINDRECORD_DATASET_PATH is the mindrecord dataset directory.
|
||||
3.LOAD_CHECKPOINT_PATH is the pretrained checkpoint file directory, if no just set ""
|
||||
4.RUN_MODE support validation and testing, set to be "val"/"test"
|
||||
|
||||
```shell
|
||||
# create dataset in mindrecord format
|
||||
bash scripts/convert_dataset_to_mindrecord.sh
|
||||
|
||||
# standalone training
|
||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [EPOCH_SIZE]
|
||||
# standalone training on Ascend
|
||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH]
|
||||
|
||||
# distributed training
|
||||
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE]
|
||||
# standalone training on CPU
|
||||
bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH]
|
||||
|
||||
# eval
|
||||
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID]
|
||||
# distributed training on Ascend
|
||||
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] [RANK_TABLE_FILE]
|
||||
|
||||
# eval on Ascend
|
||||
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH]
|
||||
|
||||
# eval on CPU
|
||||
bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
@ -153,9 +161,11 @@ bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID]
|
|||
│ │ ├──get_distribute_pretrain_cmd.py // script for distributed pretraining
|
||||
│ │ ├──README.md
|
||||
│ ├──convert_dataset_to_mindrecord.sh // shell script for converting coco type dataset to mindrecord
|
||||
│ ├──run_standalone_train_ascend.sh // shell script for standalone pretrain on ascend
|
||||
│ ├──run_distributed_train_ascend.sh // shell script for distributed pretrain on ascend
|
||||
│ ├──run_standalone_train_ascend.sh // shell script for standalone training on ascend
|
||||
│ ├──run_distributed_train_ascend.sh // shell script for distributed training on ascend
|
||||
│ ├──run_standalone_eval_ascend.sh // shell script for standalone evaluation on ascend
|
||||
│ ├──run_standalone_train_cpu.sh // shell script for standalone training on cpu
|
||||
│ ├──run_standalone_eval_cpu.sh // shell script for standalone evaluation on cpu
|
||||
└── src
|
||||
├──__init__.py
|
||||
├──centernet_pose.py // centernet networks, training entry
|
||||
|
@ -259,7 +269,6 @@ config for training.
|
|||
|
||||
```text
|
||||
config for evaluation.
|
||||
flip_test whether to use flip test: True | False, default is False
|
||||
soft_nms nms after decode: True | False, default is True
|
||||
keep_res keep original or fix resolution: True | False, default is False
|
||||
multi_scales use multi-scales of image: List, default is [1.0]
|
||||
|
@ -350,12 +359,12 @@ bash scripts/convert_dataset_to_mindrecord.sh
|
|||
|
||||
The command above will run in the background, after converting mindrecord files will be located in path specified by yourself.
|
||||
|
||||
### Training
|
||||
### Standalone Training
|
||||
|
||||
#### Running on Ascend
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_pretrain_ascend.sh 0 1
|
||||
bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows:
|
||||
|
@ -368,12 +377,31 @@ epoch: 349.0, current epoch percent: 1.00, step: 87500, outputs are (Tensor(shap
|
|||
...
|
||||
```
|
||||
|
||||
#### Running on CPU
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows (rusume from pretrained checkpoint and batch_size was set to be 8):
|
||||
|
||||
```text
|
||||
# grep "epoch" training_log.txt
|
||||
...
|
||||
epoch: 0.0, current epoch percent: 0.00, step: 1, time of per steps: 66.693 s, outputs are 3.645
|
||||
epoch: 0.0, current epoch percent: 0.00, step: 2, time of per steps: 46.594 s, outputs are 4.862
|
||||
epoch: 0.0, current epoch percent: 0.00, step: 3, time of per steps: 44.718 s, outputs are 3.927
|
||||
epoch: 0.0, current epoch percent: 0.00, step: 4, time of per steps: 45.113 s, outputs are 3.910
|
||||
epoch: 0.0, current epoch percent: 0.00, step: 5, time of per steps: 45.213 s, outputs are 3.749
|
||||
...
|
||||
```
|
||||
|
||||
### Distributed Training
|
||||
|
||||
#### Running on Ascend
|
||||
|
||||
```bash
|
||||
bash scripts/run_distributed_pretrain_ascend.sh /path/coco2017 /path/mindrecord /path/hccl.json
|
||||
bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows:
|
||||
|
@ -394,7 +422,11 @@ epoch: 0.0, current epoch percent: 0.002, step: 200, outputs are (Tensor(shape=[
|
|||
|
||||
```bash
|
||||
# Evaluation base on validation dataset will be done automatically, while for test or test-dev dataset, the accuracy should be upload to the CodaLab official website(https://competitions.codalab.org).
|
||||
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID]
|
||||
# On Ascend
|
||||
bash scripts/run_standalone_eval_ascend.sh device_id val(or test) /path/coco_dataset /path/load_ckpt
|
||||
|
||||
# On CPU
|
||||
bash scripts/run_standalone_eval_cpu.sh val(or test) /path/coco_dataset /path/load_ckpt
|
||||
```
|
||||
|
||||
you can see the MAP result below as below:
|
||||
|
@ -439,7 +471,7 @@ python export.py [DEVICE_ID]
|
|||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
### Training Performance On Ascend
|
||||
|
||||
CenterNet on 11.8K images(The annotation and data format must be the same as coco)
|
||||
|
||||
|
@ -460,7 +492,7 @@ CenterNet on 11.8K images(The annotation and data format must be the same as coc
|
|||
| Checkpoint | 242M (.ckpt file) |
|
||||
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet> |
|
||||
|
||||
### Inference Performance
|
||||
### Inference Performance On Ascend
|
||||
|
||||
CenterNet on validation(5K images) and test-dev(40K images)
|
||||
|
||||
|
|
|
@ -36,6 +36,8 @@ from src.config import dataset_config, net_config, eval_config
|
|||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
parser = argparse.ArgumentParser(description='CenterNet evaluation')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
|
||||
|
@ -52,15 +54,20 @@ def predict():
|
|||
'''
|
||||
Predict function
|
||||
'''
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
enable_nms_fp16 = True
|
||||
else:
|
||||
enable_nms_fp16 = False
|
||||
|
||||
logger.info("Begin creating {} dataset".format(args_opt.run_mode))
|
||||
coco = COCOHP(dataset_config, run_mode=args_opt.run_mode, net_opt=net_config,
|
||||
enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,)
|
||||
coco.init(args_opt.data_dir, keep_res=eval_config.keep_res, flip_test=eval_config.flip_test)
|
||||
coco.init(args_opt.data_dir, keep_res=eval_config.keep_res)
|
||||
dataset = coco.create_eval_dataset()
|
||||
|
||||
net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
|
||||
net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.K, enable_nms_fp16)
|
||||
net_for_eval.set_train(False)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
||||
|
@ -103,9 +110,7 @@ def predict():
|
|||
print("Image {}/{} id: {} cost time {} ms".format(index, total_nums, image_id, (end - start) * 1000.))
|
||||
|
||||
# post-process
|
||||
soft_nms = eval_config.soft_nms or len(eval_config.multi_scales) > 0
|
||||
detections = merge_outputs(detections, soft_nms)
|
||||
|
||||
detections = merge_outputs(detections, eval_config.soft_nms)
|
||||
# get prediction result
|
||||
pred_json = convert_eval_format(detections, image_id)
|
||||
gt_image_info = coco.coco.loadImgs([image_id])
|
||||
|
|
|
@ -31,7 +31,7 @@ args = parser.parse_args()
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
|
||||
net = CenterNetMultiPoseEval(net_config, eval_config.K)
|
||||
net.set_train(False)
|
||||
|
||||
param_dict = load_checkpoint(export_config.ckpt_file)
|
||||
|
|
|
@ -39,8 +39,7 @@ def parse_args():
|
|||
parser.add_argument("--hyper_parameter_config_dir", type=str, default="",
|
||||
help="Hyper Parameter config path, it is better to use absolute path")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset directory")
|
||||
parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind",
|
||||
help="Prefix of MindRecord dataset filename.")
|
||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--hccl_config_dir", type=str, default="",
|
||||
help="Hccl config path, it is better to use absolute path")
|
||||
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
|
||||
|
@ -72,7 +71,7 @@ def distribute_train():
|
|||
|
||||
run_script = args.run_script_dir
|
||||
mindrecord_dir = args.mindrecord_dir
|
||||
mindrecord_prefix = args.mindrecord_prefix
|
||||
load_checkpoint_path = args.load_checkpoint_path
|
||||
cf = configparser.ConfigParser()
|
||||
cf.read(args.hyper_parameter_config_dir)
|
||||
cfg = dict(cf.items("config"))
|
||||
|
@ -151,7 +150,7 @@ def distribute_train():
|
|||
" 'device_num' or 'mindrecord_dir'! ")
|
||||
run_cmd += opt
|
||||
run_cmd += " --mindrecord_dir=" + mindrecord_dir
|
||||
run_cmd += " --mindrecord_prefix=" + mindrecord_prefix
|
||||
run_cmd += " --load_checkpoint_path=" + load_checkpoint_path
|
||||
run_cmd += ' --device_id=' + str(device_id) + ' --device_num=' \
|
||||
+ str(rank_size) + ' >./training_log.txt 2>&1 &'
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ enable_save_ckpt=true
|
|||
do_shuffle=true
|
||||
enable_data_sink=true
|
||||
data_sink_steps=50
|
||||
load_checkpoint_path=""
|
||||
save_checkpoint_path=./
|
||||
save_checkpoint_steps=3000
|
||||
save_checkpoint_num=1
|
||||
|
|
|
@ -14,21 +14,26 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_distributed_train_ascend.sh DATA_DIR MINDRECORD_DIR RANK_TABLE_FILE"
|
||||
echo "for example: bash run_distributed_train_ascend.sh /path/dataset /path/mindrecord /path/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================"
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH RANK_TABLE_FILE"
|
||||
echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json"
|
||||
echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset \"\" /path/hccl.json"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "For hyper parameter, please note that you should customize the scripts:
|
||||
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
|
||||
echo "=============================================================================================================="
|
||||
echo "================================================================================================================"
|
||||
CUR_DIR=`pwd`
|
||||
MINDRECORD_DIR=$1
|
||||
LOAD_CHECKPOINT_PATH=$2
|
||||
HCCL_RANK_FILE=$3
|
||||
|
||||
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
|
||||
--run_script_dir=${CUR_DIR}/train.py \
|
||||
--hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \
|
||||
--mindrecord_dir=$1 \
|
||||
--hccl_config_dir=$2 \
|
||||
--mindrecord_dir=$MINDRECORD_DIR \
|
||||
--load_checkpoint_path=$LOAD_CHECKPOINT_PATH \
|
||||
--hccl_config_dir=$HCCL_RANK_FILE \
|
||||
--hccl_time_out=1200 \
|
||||
--cmd_file=distributed_cmd.sh
|
||||
|
||||
|
|
|
@ -16,11 +16,14 @@
|
|||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_eval_ascend.sh DEVICE_ID"
|
||||
echo "for example: bash run_standalone_eval_ascend.sh 0"
|
||||
echo "bash run_standalone_eval_ascend.sh DEVICE_ID RUN_MODE DATA_DIR LOAD_CHECKPOINT_PATH"
|
||||
echo "for example of validation: bash run_standalone_eval_ascend.sh 0 val /path/coco_dataset /path/load_ckpt"
|
||||
echo "for example of test: bash run_standalone_eval_ascend.sh 0 test /path/coco_dataset /path/load_ckpt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
RUN_MODE=$2
|
||||
DATA_DIR=$3
|
||||
LOAD_CHECKPOINT_PATH=$4
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
|
@ -42,10 +45,11 @@ else
|
|||
fi
|
||||
|
||||
python ${PROJECT_DIR}/../eval.py \
|
||||
--device_target=Ascend \
|
||||
--device_id=$DEVICE_ID \
|
||||
--load_checkpoint_path="" \
|
||||
--data_dir="" \
|
||||
--load_checkpoint_path=$LOAD_CHECKPOINT_PATH \
|
||||
--data_dir=$DATA_DIR \
|
||||
--run_mode=$RUN_MODE \
|
||||
--visual_image=true \
|
||||
--enable_eval=true \
|
||||
--save_result_dir="" \
|
||||
--run_mode=val > eval_log.txt 2>&1 &
|
||||
--save_result_dir=./ > eval_log.txt 2>&1 &
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_eval_cpu.sh RUN_MODE DATA_DIR LOAD_CHECKPOINT_PATH"
|
||||
echo "for example of validation: bash run_standalone_eval_cpu.sh val /path/coco_dataset /path/load_ckpt"
|
||||
echo "for example of test: bash run_standalone_eval_cpu.sh test /path/coco_dataset /path/load_ckpt"
|
||||
echo "=============================================================================================================="
|
||||
RUN_MODE=$1
|
||||
DATA_DIR=$2
|
||||
LOAD_CHECKPOINT_PATH=$3
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
# install nms module from third party
|
||||
if python -c "import nms" > /dev/null 2>&1
|
||||
then
|
||||
echo "NMS module already exits, no need reinstall."
|
||||
else
|
||||
echo "NMS module was not found, install it now..."
|
||||
git clone https://github.com/xingyizhou/CenterNet.git
|
||||
cd CenterNet/src/lib/external/
|
||||
make
|
||||
python setup.py install
|
||||
cd -
|
||||
rm -rf CenterNet
|
||||
fi
|
||||
|
||||
python ${PROJECT_DIR}/../eval.py \
|
||||
--device_target=CPU \
|
||||
--load_checkpoint_path=$LOAD_CHECKPOINT_PATH \
|
||||
--data_dir=$DATA_DIR \
|
||||
--run_mode=$RUN_MODE \
|
||||
--visual_image=true \
|
||||
--enable_eval=true \
|
||||
--save_result_dir=./ > eval_log.txt 2>&1 &
|
|
@ -16,12 +16,14 @@
|
|||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_pretrain_ascend.sh DEVICE_ID EPOCH_SIZE"
|
||||
echo "for example: bash run_standalone_pretrain_ascend.sh 0 350"
|
||||
echo "bash run_standalone_train_ascend.sh DEVICE_ID MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
|
||||
echo "for example: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset /path/load_ckpt"
|
||||
echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset \"\" "
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
MINDRECORD_DIR=$2
|
||||
LOAD_CHECKPOINT_PATH=$3
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
@ -33,16 +35,16 @@ python ${PROJECT_DIR}/../train.py \
|
|||
--distribute=false \
|
||||
--need_profiler=false \
|
||||
--profiler_path=./profiler \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--enable_save_ckpt=true \
|
||||
--do_shuffle=true \
|
||||
--enable_data_sink=true \
|
||||
--data_sink_steps=50 \
|
||||
--load_checkpoint_path="" \
|
||||
--epoch_size=350 \
|
||||
--load_checkpoint_path=$LOAD_CHECKPOINT_PATH \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--mindrecord_dir="" \
|
||||
--mindrecord_dir=$MINDRECORD_DIR \
|
||||
--mindrecord_prefix="coco_hp.train.mind" \
|
||||
--visual_image=false \
|
||||
--save_result_dir="" > training_log.txt 2>&1 &
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_train_cpu.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
|
||||
echo "for example: bash run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt"
|
||||
echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset \"\" "
|
||||
echo "=============================================================================================================="
|
||||
|
||||
MINDRECORD_DIR=$1
|
||||
LOAD_CHECKPOINT_PATH=$2
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
python ${PROJECT_DIR}/../train.py \
|
||||
--device_target=CPU \
|
||||
--enable_save_ckpt=true \
|
||||
--do_shuffle=true \
|
||||
--epoch_size=1 \
|
||||
--load_checkpoint_path=$LOAD_CHECKPOINT_PATH \
|
||||
--save_checkpoint_steps=1000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--mindrecord_dir=$MINDRECORD_DIR \
|
||||
--mindrecord_prefix="coco_hp.train.mind" \
|
||||
--visual_image=false \
|
||||
--save_result_dir="" > training_log.txt 2>&1 &
|
|
@ -15,7 +15,7 @@
|
|||
"""CenterNet Init."""
|
||||
|
||||
from .centernet_pose import GatherMultiPoseFeatureCell, CenterNetMultiPoseLossCell, \
|
||||
CenterNetWithLossScaleCell, CenterNetMultiPoseEval
|
||||
CenterNetWithLossScaleCell, CenterNetMultiPoseEval, CenterNetWithoutLossScaleCell
|
||||
from .dataset import COCOHP
|
||||
from .visual import visual_allimages, visual_image
|
||||
from .decode import MultiPoseDecode
|
||||
|
@ -23,6 +23,7 @@ from .post_process import convert_eval_format, to_float, resize_detection, post_
|
|||
|
||||
__all__ = [
|
||||
"GatherMultiPoseFeatureCell", "CenterNetMultiPoseLossCell", "CenterNetWithLossScaleCell", \
|
||||
"CenterNetMultiPoseEval", "COCOHP", "visual_allimages", "visual_image", "MultiPoseDecode", \
|
||||
"convert_eval_format", "to_float", "resize_detection", "post_process", "merge_outputs"
|
||||
"CenterNetMultiPoseEval", "CenterNetWithoutLossScaleCell", "COCOHP", "visual_allimages", \
|
||||
"visual_image", "MultiPoseDecode", "convert_eval_format", "to_float", "resize_detection", \
|
||||
"post_process", "merge_outputs"
|
||||
]
|
||||
|
|
|
@ -197,6 +197,46 @@ class CenterNetMultiPoseLossCell(nn.Cell):
|
|||
return total_loss
|
||||
|
||||
|
||||
class CenterNetWithoutLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of centernet training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
|
||||
Returns:
|
||||
Tuple of Tensors, the loss, overflow flag and scaling sens of the network.
|
||||
"""
|
||||
def __init__(self, network, optimizer):
|
||||
super(CenterNetWithoutLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.image = ImagePreProcess()
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=False)
|
||||
|
||||
@ops.add_flags(has_effect=True)
|
||||
def construct(self, image, hm, reg_mask, ind, wh, kps, kps_mask, reg,
|
||||
hm_hp, hp_offset, hp_ind, hp_mask):
|
||||
"""Defines the computation performed."""
|
||||
image = self.image(image)
|
||||
weights = self.weights
|
||||
loss = self.network(image, hm, reg_mask, ind, wh, kps, kps_mask, reg,
|
||||
hm_hp, hp_offset, hp_ind, hp_mask)
|
||||
|
||||
grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, kps,
|
||||
kps_mask, reg, hm_hp, hp_offset,
|
||||
hp_ind, hp_mask)
|
||||
succ = self.optimizer(grads)
|
||||
ret = loss
|
||||
return ops.depend(ret, succ)
|
||||
|
||||
|
||||
class CenterNetWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of centernet training.
|
||||
|
@ -279,17 +319,16 @@ class CenterNetMultiPoseEval(nn.Cell):
|
|||
|
||||
Args:
|
||||
net_config: The config info of CenterNet network.
|
||||
flip_test(bool): Flip data augmentation or not. Default: False.
|
||||
K(number): Max number of output objects. Default: 100.
|
||||
enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor, detection of images(bboxes, score, keypoints and category id of each objects)
|
||||
"""
|
||||
def __init__(self, net_config, flip_test=False, K=100):
|
||||
def __init__(self, net_config, K=100, enable_nms_fp16=True):
|
||||
super(CenterNetMultiPoseEval, self).__init__()
|
||||
self.network = GatherMultiPoseFeatureCell(net_config)
|
||||
self.decode = MultiPoseDecode(net_config, flip_test, K)
|
||||
self.flip_test = flip_test
|
||||
self.decode = MultiPoseDecode(net_config, K, enable_nms_fp16)
|
||||
self.shape = ops.Shape()
|
||||
self.reshape = ops.Reshape()
|
||||
|
||||
|
|
|
@ -104,8 +104,7 @@ train_config = edict({
|
|||
|
||||
|
||||
eval_config = edict({
|
||||
'flip_test': False,
|
||||
'soft_nms': False,
|
||||
'soft_nms': True,
|
||||
'keep_res': True,
|
||||
'multi_scales': [1.0],
|
||||
'pad': 31,
|
||||
|
|
|
@ -17,7 +17,6 @@ Data operations, will be used in train.py
|
|||
"""
|
||||
|
||||
import os
|
||||
import copy
|
||||
import math
|
||||
import argparse
|
||||
import cv2
|
||||
|
@ -66,7 +65,7 @@ class COCOHP(ds.Dataset):
|
|||
if not os.path.exists(self.save_path):
|
||||
os.makedirs(self.save_path)
|
||||
|
||||
def init(self, data_dir, keep_res=False, flip_test=False):
|
||||
def init(self, data_dir, keep_res=False):
|
||||
"""initailize additional info"""
|
||||
logger.info('Initializing coco 2017 {} data.'.format(self.run_mode))
|
||||
if not os.path.isdir(data_dir):
|
||||
|
@ -94,7 +93,6 @@ class COCOHP(ds.Dataset):
|
|||
self.images = image_ids
|
||||
self.num_samples = len(self.images)
|
||||
self.keep_res = keep_res
|
||||
self.flip_test = flip_test
|
||||
if self.run_mode != "train":
|
||||
self.pad = 31
|
||||
logger.info('Loaded {} {} samples'.format(self.run_mode, self.num_samples))
|
||||
|
@ -167,7 +165,7 @@ class COCOHP(ds.Dataset):
|
|||
ret = (img, image_id)
|
||||
return ret
|
||||
|
||||
def pre_process_for_test(self, image, img_id, scale, meta=None):
|
||||
def pre_process_for_test(self, image, img_id, scale):
|
||||
"""image pre-process for evaluation"""
|
||||
b, h, w, ch = image.shape
|
||||
assert b == 1, "only single image was supported here"
|
||||
|
@ -191,17 +189,8 @@ class COCOHP(ds.Dataset):
|
|||
flags=cv2.INTER_LINEAR)
|
||||
inp_img = (inp_image.astype(np.float32) / 255. - self.data_opt.mean) / self.data_opt.std
|
||||
|
||||
h, w, ch = inp_img.shape
|
||||
images = copy.deepcopy(inp_img)
|
||||
if self.flip_test:
|
||||
flip_image = inp_img[:, ::-1, :]
|
||||
inp_img = inp_img.reshape((1, h, w, ch))
|
||||
flip_image = flip_image.reshape((1, h, w, ch))
|
||||
# (2, h, w, c)
|
||||
images = np.concatenate((inp_img, flip_image), axis=0)
|
||||
else:
|
||||
images = images.reshape((1, h, w, ch))
|
||||
images = images.transpose(0, 3, 1, 2)
|
||||
eval_image = inp_img.reshape((1,) + inp_img.shape)
|
||||
eval_image = eval_image.transpose(0, 3, 1, 2)
|
||||
|
||||
meta = {'c': c, 's': s,
|
||||
'out_height': inp_height // self.net_opt.down_ratio,
|
||||
|
@ -244,7 +233,7 @@ class COCOHP(ds.Dataset):
|
|||
image_name = "gt_" + self.run_mode + "_image_" + str(img_id) + "_scale_" + str(scale) + ".png"
|
||||
cv2.imwrite("{}/{}".format(self.save_path, image_name), inp_image)
|
||||
|
||||
return images, meta
|
||||
return eval_image, meta
|
||||
|
||||
def preprocess_fn(self, img, num_objects, keypoints, bboxes, category_id):
|
||||
"""image pre-process and augmentation"""
|
||||
|
|
|
@ -30,25 +30,32 @@ class NMS(nn.Cell):
|
|||
|
||||
Args:
|
||||
kernel(int): Maxpooling kernel size. Default: 3.
|
||||
enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor, heatmap after non-maximum suppression.
|
||||
"""
|
||||
def __init__(self, kernel=3):
|
||||
def __init__(self, kernel=3, enable_nms_fp16=True):
|
||||
super(NMS, self).__init__()
|
||||
self.pad = (kernel - 1) // 2
|
||||
self.cast = ops.Cast()
|
||||
self.dtype = ops.DType()
|
||||
self.equal = ops.Equal()
|
||||
self.max_pool = nn.MaxPool2d(kernel, stride=1, pad_mode="same")
|
||||
self.enable_fp16 = enable_nms_fp16
|
||||
|
||||
def construct(self, heat):
|
||||
"""Non-maximum suppression"""
|
||||
dtype = self.dtype(heat)
|
||||
if self.enable_fp16:
|
||||
heat = self.cast(heat, mstype.float16)
|
||||
heat_max = self.max_pool(heat)
|
||||
keep = self.equal(heat, heat_max)
|
||||
keep = self.cast(keep, dtype)
|
||||
heat = self.cast(heat, dtype)
|
||||
else:
|
||||
heat_max = self.max_pool(heat)
|
||||
keep = self.equal(heat, heat_max)
|
||||
heat = heat * keep
|
||||
return heat
|
||||
|
||||
|
@ -127,17 +134,23 @@ class GatherFeatureByInd(nn.Cell):
|
|||
"""
|
||||
Gather features by index
|
||||
|
||||
Args: None
|
||||
Args:
|
||||
enable_cpu_gather (bool): Use cpu operator GatherD to gather feature or not, adaption for CPU. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, enable_cpu_gatherd=True):
|
||||
super(GatherFeatureByInd, self).__init__()
|
||||
self.tile = ops.Tile()
|
||||
self.shape = ops.Shape()
|
||||
self.concat = ops.Concat(axis=1)
|
||||
self.reshape = ops.Reshape()
|
||||
self.enable_cpu_gatherd = enable_cpu_gatherd
|
||||
if self.enable_cpu_gatherd:
|
||||
self.gather_nd = ops.GatherD()
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
else:
|
||||
self.gather_nd = ops.GatherNd()
|
||||
|
||||
def construct(self, feat, ind):
|
||||
|
@ -147,15 +160,21 @@ class GatherFeatureByInd(nn.Cell):
|
|||
b, J, K = self.shape(ind)
|
||||
feat = self.reshape(feat, (b, J, K, -1))
|
||||
_, _, _, N = self.shape(feat)
|
||||
if self.enable_cpu_gatherd:
|
||||
# (b, J, K, N)
|
||||
index = self.expand_dims(ind, -1)
|
||||
index = self.tile(index, (1, 1, 1, N))
|
||||
feat = self.gather_nd(feat, 2, index)
|
||||
else:
|
||||
ind = self.reshape(ind, (-1, 1))
|
||||
ind_b = nn.Range(0, b * J, 1)()
|
||||
ind_b = self.reshape(ind_b, (-1, 1))
|
||||
ind_b = self.tile(ind_b, (1, K))
|
||||
ind_b = self.reshape(ind_b, (-1, 1))
|
||||
index = self.concat((ind_b, ind))
|
||||
# (b, N, 2)
|
||||
# (b*J, K, 2)
|
||||
index = self.reshape(index, (-1, K, 2))
|
||||
# (b, N, c)
|
||||
# (b*J, K)
|
||||
feat = self.reshape(feat, (-1, K, N))
|
||||
feat = self.gather_nd(feat, index)
|
||||
feat = self.reshape(feat, (b, J, K, -1))
|
||||
|
@ -285,17 +304,16 @@ class MultiPoseDecode(nn.Cell):
|
|||
|
||||
Args:
|
||||
net_config(edict): config info for CenterNet network.
|
||||
flip_test(bool): flip test of not. Default: False.
|
||||
K(int): maximum objects number. Default: 100.
|
||||
enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor, multi-objects detections.
|
||||
"""
|
||||
def __init__(self, net_config, flip_test=False, K=100):
|
||||
def __init__(self, net_config, K=100, enable_nms_fp16=True):
|
||||
super(MultiPoseDecode, self).__init__()
|
||||
self.K = K
|
||||
self.flip_test = flip_test
|
||||
self.nms = NMS()
|
||||
self.nms = NMS(enable_nms_fp16=enable_nms_fp16)
|
||||
self.shape = ops.Shape()
|
||||
self.gather_topk = GatherTopK()
|
||||
self.gather_topk_channel = GatherTopKChannel()
|
||||
|
@ -336,8 +354,6 @@ class MultiPoseDecode(nn.Cell):
|
|||
def construct(self, feature):
|
||||
"""gather detections"""
|
||||
heat = feature[0]
|
||||
if self.flip_test:
|
||||
heat = self.flip_tensor(heat)
|
||||
K = self.K
|
||||
b, _, _, _ = self.shape(heat)
|
||||
heat = self.nms(heat)
|
||||
|
@ -346,8 +362,6 @@ class MultiPoseDecode(nn.Cell):
|
|||
xs = self.reshape(xs, (b, K, 1))
|
||||
|
||||
kps = feature[1]
|
||||
if self.flip_test:
|
||||
kps = self.flip_lr_off(kps)
|
||||
num_joints = self.shape(kps)[1] / 2
|
||||
# (b, K, num_joints*2)
|
||||
kps = self.trans_gather_feature(kps, inds)
|
||||
|
@ -365,15 +379,11 @@ class MultiPoseDecode(nn.Cell):
|
|||
kps = self.reshape(kps, (b, K, num_joints * 2))
|
||||
|
||||
wh = feature[2]
|
||||
if self.flip_test:
|
||||
wh = self.flip_tensor(wh)
|
||||
wh = self.trans_gather_feature(wh, inds)
|
||||
ws, hs = self.half(wh)
|
||||
|
||||
if self.reg_offset:
|
||||
reg = feature[self.reg_ind]
|
||||
if self.flip_test:
|
||||
reg, _ = self.half_first(reg)
|
||||
reg = self.trans_gather_feature(reg, inds)
|
||||
reg = self.reshape(reg, (b, K, 2))
|
||||
reg_w, reg_h = self.half(reg)
|
||||
|
@ -387,16 +397,12 @@ class MultiPoseDecode(nn.Cell):
|
|||
|
||||
if self.hm_hp:
|
||||
hm_hp = feature[self.hm_hp_ind]
|
||||
if self.flip_test:
|
||||
hm_hp = self.flip_lr(hm_hp)
|
||||
hm_hp = self.nms(hm_hp)
|
||||
# (b, num_joints, K)
|
||||
hm_score, hm_inds, hm_ys, hm_xs = self.gather_topk_channel(hm_hp, K=K)
|
||||
|
||||
if self.reg_hp_offset:
|
||||
hp_offset = feature[self.reg_hp_ind]
|
||||
if self.flip_test:
|
||||
hp_offset, _ = self.half_first(hp_offset)
|
||||
hp_offset = self.trans_gather_feature(hp_offset, self.reshape(hm_inds, (b, -1)))
|
||||
hp_offset = self.reshape(hp_offset, (b, num_joints, K, 2))
|
||||
hp_ws, hp_hs = self.half(hp_offset)
|
||||
|
|
|
@ -17,6 +17,7 @@ Functional Cells to be used.
|
|||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
@ -119,21 +120,34 @@ class GatherFeature(nn.Cell):
|
|||
"""
|
||||
Gather feature at specified position
|
||||
|
||||
Args: None
|
||||
Args:
|
||||
enable_cpu_gather (bool): Use cpu operator GatherD to gather feature or not, adaption for CPU. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor, feature at spectified position
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, enable_cpu_gather=True):
|
||||
super(GatherFeature, self).__init__()
|
||||
self.tile = ops.Tile()
|
||||
self.shape = ops.Shape()
|
||||
self.concat = ops.Concat(axis=1)
|
||||
self.reshape = ops.Reshape()
|
||||
self.gather_nd = ops.GatherNd()
|
||||
self.enable_cpu_gather = enable_cpu_gather
|
||||
if self.enable_cpu_gather:
|
||||
self.gather_nd = ops.GatherD()
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
else:
|
||||
self.gather_nd = ops.GatherND()
|
||||
|
||||
def construct(self, feat, ind):
|
||||
"""gather by specified index"""
|
||||
if self.enable_cpu_gather:
|
||||
_, _, c = self.shape(feat)
|
||||
# (b, N, c)
|
||||
index = self.expand_dims(ind, -1)
|
||||
index = self.tile(index, (1, 1, c))
|
||||
feat = self.gather_nd(feat, 1, index)
|
||||
else:
|
||||
# (b, N)->(b*N, 1)
|
||||
b, N = self.shape(ind)
|
||||
ind = self.reshape(ind, (-1, 1))
|
||||
|
@ -477,11 +491,19 @@ class LossCallBack(Callback):
|
|||
|
||||
Args:
|
||||
dataset_size (int): Dataset size. Default: -1.
|
||||
enable_static_time (bool): enable static time cost, adaption for CPU. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_size=-1):
|
||||
def __init__(self, dataset_size=-1, enable_static_time=False):
|
||||
super(LossCallBack, self).__init__()
|
||||
self._dataset_size = dataset_size
|
||||
self._enable_static_time = enable_static_time
|
||||
|
||||
def step_begin(self, run_context):
|
||||
"""
|
||||
Get begining time of each step
|
||||
"""
|
||||
self._begin_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
|
@ -493,11 +515,19 @@ class LossCallBack(Callback):
|
|||
if percent == 0:
|
||||
percent = 1
|
||||
epoch_num -= 1
|
||||
if self._enable_static_time:
|
||||
cur_time = time.time()
|
||||
time_per_step = cur_time - self._begin_time
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, time per step: {} s, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, "%.3f" % time_per_step,
|
||||
str(cb_params.net_outputs)), flush=True)
|
||||
else:
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)), flush=True)
|
||||
else:
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
str(cb_params.net_outputs)), flush=True)
|
||||
|
||||
|
||||
class CenterNetPolynomialDecayLR(LearningRateSchedule):
|
||||
|
|
|
@ -31,12 +31,15 @@ from mindspore.common import set_seed
|
|||
from mindspore.profiler import Profiler
|
||||
from src.dataset import COCOHP
|
||||
from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell
|
||||
from src import CenterNetWithoutLossScaleCell
|
||||
from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
|
||||
from src.config import dataset_config, net_config, train_config
|
||||
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
parser = argparse.ArgumentParser(description='CenterNet training')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
|
||||
help="Run distribute, default is false.")
|
||||
parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"],
|
||||
|
@ -125,12 +128,17 @@ def _get_optimizer(network, dataset_size):
|
|||
|
||||
def train():
|
||||
"""training CenterNet"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_auto_mixed_precision=False)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
rank = 0
|
||||
device_num = 1
|
||||
num_workers = 8
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(enable_auto_mixed_precision=False)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
device_num = args_opt.device_num
|
||||
|
@ -142,9 +150,10 @@ def train():
|
|||
device_num=device_num)
|
||||
_set_parallel_all_reduce_split()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
num_workers = 8
|
||||
args_opt.distribute = "false"
|
||||
args_opt.need_profiler = "false"
|
||||
args_opt.enable_data_sink = "false"
|
||||
|
||||
# Start create dataset!
|
||||
# mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
|
||||
logger.info("Begin creating dataset for CenterNet")
|
||||
|
@ -167,7 +176,8 @@ def train():
|
|||
|
||||
optimizer = _get_optimizer(net_with_loss, dataset_size)
|
||||
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size)]
|
||||
enable_static_time = args_opt.device_target == "CPU"
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size, enable_static_time)]
|
||||
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
|
@ -178,12 +188,13 @@ def train():
|
|||
if args_opt.load_checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
||||
load_param_into_net(net_with_loss, param_dict)
|
||||
|
||||
if args_opt.device_target == "Ascend":
|
||||
net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
sens=train_config.loss_scale_value)
|
||||
else:
|
||||
net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss, optimizer=optimizer)
|
||||
|
||||
model = Model(net_with_grads)
|
||||
|
||||
model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
|
Loading…
Reference in New Issue