!19134 centerface test

Merge pull request !19134 from huchunmei/center1
This commit is contained in:
i-robot 2021-07-02 09:36:50 +00:00 committed by Gitee
commit 0aef1943d0
23 changed files with 802 additions and 435 deletions

View File

@ -175,7 +175,7 @@ step6: eval
# cd ../dependency/evaluate;
# python setup.py install;
# cd -; #cd ../../scripts;
sh eval_all.sh
sh eval_all.sh [ground_truth_path]
```
# [Script Description](#contents)
@ -192,6 +192,7 @@ sh eval_all.sh
├── postprocess.py // 310infer postprocess scripts
├── README.md // descriptions about CenterFace
├── ascend310_infer // application for 310 inference
├─default_config.yaml // Training parameter profile
├── scripts
│ ├──run_infer_310.sh // shell script for infer on ascend310
│ ├──eval.sh // evaluate a single testing result
@ -205,7 +206,6 @@ sh eval_all.sh
│ ├──__init__.py
│ ├──centerface.py // centerface networks, training entry
│ ├──dataset.py // generate dataloader and data processing entry
│ ├──config.py // centerface unique configs
│ ├──losses.py // losses for centerface
│ ├──lr_scheduler.py // learning rate scheduler
│ ├──mobile_v2.py // modified mobilenet_v2 backbone
@ -213,6 +213,11 @@ sh eval_all.sh
│ ├──var_init.py // weight initialization
│ ├──convert_weight_mobilenetv2.py // convert pretrained backbone to mindspore
│ ├──convert_weight.py // CenterFace model convert to mindspore
| └──model_utils
| ├──config.py // Processing configuration parameters
| ├──device_adapter.py // Get cloud ID
| ├──local_adapter.py // Get local ID
| └ ──moxing_adapter.py // Parameter processing
└── dependency // third party codes: MIT License
├──extd // training dependency: data augmentation
│ ├──utils
@ -451,7 +456,7 @@ cd ../../../scripts;
```python
# you need to change the parameter in eval.sh
# default eval the ckpt saved in ./scripts/output/centerface/999
sh eval.sh
sh eval.sh [ground_truth_path]
```
2. eval many testing output for user to choose the best one
@ -459,7 +464,7 @@ cd ../../../scripts;
```python
# you need to change the parameter in eval_all.sh
# default eval the ckpt saved in ./scripts/output/centerface/[89-140]
sh eval_all.sh
sh eval_all.sh [ground_truth_path]
```
3. test+eval
@ -475,7 +480,7 @@ sh test_and_eval.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [CKPT]
you can see the MAP below by eval.sh
```log
(ci3.7) [root@bms-aiserver scripts]# ./eval.sh
(ci3.7) [root@bms-aiserver scripts]# ./eval.sh ./ground_truth_path
start eval
==================== Results = ==================== ./scripts/output/centerface/999
Easy Val AP: 0.923914407045363
@ -488,7 +493,7 @@ end eval
you can see the MAP below by eval_all.sh
```log
(ci3.7) [root@bms-aiserver scripts]# ./eval_all.sh
(ci3.7) [root@bms-aiserver scripts]# ./eval_all.sh ./ground_truth_path
==================== Results = ==================== ./scripts/output/centerface/89
Easy Val AP: 0.8884892849068273
Medium Val AP: 0.8928813452811216

View File

@ -0,0 +1,234 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
checkpoint_path: './checkpoint/'
device_target: Ascend
enable_profiling: False
# ==============================================================================
# Config setup
flip_idx: [[0, 1], [3, 4]]
default_resolution: [512, 512]
heads: {'hm': 1, 'wh': 2, 'hm_offset': 2, 'landmarks': 5 * 2}
head_conv: 64
max_objs: 64
rand_crop: True
scale: 0.4
shift: 0.1
aug_rot: 0
color_aug: True
flip: 0.5
input_res: 512 #768 #800
output_res: 128 #192 #200
num_classes: 1
num_joints: 5
reg_offset: True
hm_hp: True
reg_hp_offset: True
dense_hp: False
hm_weight: 1.0
wh_weight: 0.1
off_weight: 1.0
lm_weight: 0.1
rotate: 0
# for test
mean: [0.408, 0.447, 0.470]
std: [0.289, 0.274, 0.278]
test_scales: [0.999,]
nms: 1
flip_test: 0
fix_res: True
input_h: 832 #800
input_w: 832 #800
K: 200
down_ratio: 4
test_batch_size: 1
master_batch_size: 8
num_workers: 8
not_rand_crop: False
no_color_aug: False
# ==============================================================================
# train.py mindspore coco training
# dataset related
data_dir: '/cache/data/'
annot_path: '/cache/data/annotations/train_wider_face.json'
img_dir: '/cache/data/images/WIDER_train/images/'
per_batch_size: 8
# network related
pretrained_backbone: '/cache/data/centerface/mobilenet_v2.ckpt'
resume: ''
# optimizer and lr related
lr_scheduler: 'multistep'
lr: 0.004 # 4e-3
lr_epochs: '90,120'
lr_gamma: 0.1
eta_min: 0.
t_max: 140
max_epoch: 140
warmup_epochs: 0
weight_decay: 0.0005
momentum: 0.9
optimizer: 'adam'
# loss related
loss_scale: 1024
label_smooth: 0
label_smooth_factor: 0.1
# logging related
log_interval: 100
ckpt_path: './output'
ckpt_interval: 0 # None
is_save_on_master: 1
# distributed related
is_distributed: 1
rank: 0
group_size: 1
# roma obs
#train_url: ""
# profiler init, can open when you debug. if train, donot open, since it cost memory and disk space
need_profiler: 0
# reset default config
training_shape: ""
resize_rate: 0 # None
# test.py mindspore coco training'
test_model: ''
ground_truth_mat: ''
save_dir: ''
ground_truth_path: ''
eval: 0
eval_script_path: ''
ckpt_name: ""
device_num: 1
steps_per_epoch: 198
start: 0
end: 18
# export.py centerface export'
device_id: 0
batch_size: 1
ckpt_file: ''
file_name: "centerface"
file_format: 'AIR'
# centerface preprocess"
dataset_path: ''
preprocess_path: ''
# postprocess / centerface calcul AP
result_path: ''
label_file: ''
meta_file: ''
save_path: ''
# ==============================================================================
# src/convert_weight.py
ckpt_fn: '/model_path/centerface.ckpt'
pt_fn: '/model_path/centerface.pth'
out_fn: '/model_path/centerface_out.ckpt'
pt2ckpt: 1
# src/convert_weight_mobilenetv2.py
ckpt_fn_v2: '/model_path/mobilenet_v2_key.ckpt'
pt_fn_v2: '/model_path/mobilenet_v2-b0353104.pth'
out_ckpt_fn: '/model_path/mobilenet_v2-b0353104.ckpt'
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
data_dir: 'train data dir'
annot_path: 'train data annotation path'
img_dir: 'train data img dir'
per_batch_size: 'batch size for per gpu'
pretrained_backbone: 'model_path, local pretrained backbone model to load'
resume: 'path of pretrained centerface_model'
lr_scheduler: 'lr-scheduler, option type: exponential, cosine_annealing'
lr: 'learning rate of the training'
lr_epochs: 'epoch of lr changing'
lr_gamma: 'decrease lr by a factor of exponential lr_scheduler'
eta_min: 'eta_min in cosine_annealing scheduler'
t_max: 'T-max in cosine_annealing scheduler'
max_epoch: 'max epoch num to train the model'
warmup_epochs: 'warmup epoch'
weight_decay: 'weight decay'
momentum: 'momentum'
optimizer: 'optimizer type, default: adam'
loss_scale: 'static loss scale'
label_smooth: 'whether to use label smooth in CE'
label_smooth_factor: 'smooth strength of original one-hot'
log_interval: 'logging interval'
ckpt_path: 'checkpoint save location'
ckpt_interval: 'ckpt_interval'
is_save_on_master: 'save ckpt on master or all rank'
is_distributed: 'if multi device'
rank: 'local rank of distributed'
group_size: 'world size of distributed'
need_profiler: 'whether use profiler'
training_shape: 'fix training shape'
resize_rate: 'resize rate for multi-scale training'
test_model: 'test model dir'
ground_truth_mat: 'ground_truth, mat type'
save_dir: 'save_path for evaluate'
ground_truth_path: 'ground_truth path, contain all mat file'
eval: 'if do eval after test'
eval_script_path: 'evaluate script path'
ckpt_name: 'input model name'
steps_per_epoch: 'steps for each epoch'
start: 'start loop number, used to calculate first epoch number'
end: 'end loop number, used to calculate last epoch number'
batch_size: "batch size"
ckpt_file: "Checkpoint file path."
ckpt_fn: 'ckpt for user to get cell/module name'
pt_fn: 'checkpoint filename to convert'
out_fn: 'convert output ckpt/pth path'
pt2ckpt: '1 : pt2ckpt; 0 : ckpt2pt'
enable_profiling: 'Whether enable profiling while training, default: False'
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
run_distribute: 'Run distribute, default is false.'
do_train: 'Do train or not, default is true.'
do_eval: 'Do eval or not, default is false.'
dataset: 'Dataset, default is coco.'
pre_trained: 'Pretrain file path.'
device_id: 'Device id, default is 0.'
device_num: 'Use device nums, default is 1.'
rank_id: 'Rank id, default is 0.'
file_name: "output file name."
file_format: 'file format'
img_path: "image file path."
result_path: "result file path."
dataset_path: "dataset path."
preprocess_path: "preprocess path."
label_file: "label file"
meta_file: "label file"
save_path: "label file"
---
platform: ['Ascend', 'GPU', 'CPU']
file_format: ["AIR", "MINDIR"]
freeze_layer: ["", "none", "backbone"]

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""export checkpoint file into air, onnx, mindir models"""
import argparse
import numpy as np
import mindspore
@ -21,27 +20,22 @@ from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.centerface import CenterfaceMobilev2, CenterFaceWithNms
from src.config import ConfigCenterface
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
from src.model_utils.moxing_adapter import moxing_wrapper
parser = argparse.ArgumentParser(description='centerface export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="centerface", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
def modelarts_process():
pass
if __name__ == '__main__':
config = ConfigCenterface()
@moxing_wrapper(pre_process=modelarts_process)
def export_centerface():
net = CenterfaceMobilev2()
param_dict = load_checkpoint(args.ckpt_file)
param_dict = load_checkpoint(config.ckpt_file)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
@ -55,5 +49,8 @@ if __name__ == '__main__':
net = CenterFaceWithNms(net)
net.set_train(False)
input_data = Tensor(np.zeros([args.batch_size, 3, config.input_h, config.input_w]), mindspore.float32)
export(net, input_data, file_name=args.file_name, file_format=args.file_format)
input_data = Tensor(np.zeros([config.batch_size, 3, config.input_h, config.input_w]), mindspore.float32)
export(net, input_data, file_name=config.file_name, file_format=config.file_format)
if __name__ == '__main__':
export_centerface()

View File

@ -14,9 +14,9 @@
# ============================================================================
"""post process for 310 inference"""
import os
import argparse
import numpy as np
from src.config import ConfigCenterface
from src.model_utils.config import config
from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
from dependency.evaluate.eval import evaluation
@ -36,16 +36,8 @@ dct_map = {'16': '16--Award_Ceremony', '26': '26--Soldier_Drilling', '29': '29--
'14': '14--Traffic', '41': '41--Swimming', '46': '46--Jockey', '10': '10--People_Marching',
'54': '54--Rescue', '57': '57--Angler', '31': '31--Waiter_Waitress', '27': '27--Spa', '21': '21--Festival'}
parser = argparse.ArgumentParser(description='centerface calcul AP')
parser.add_argument("--result_path", type=str, required=True, default='', help="result file path")
parser.add_argument("--label_file", type=str, required=True, default='', help="label file")
parser.add_argument("--meta_file", type=str, required=True, default='', help="label file")
parser.add_argument("--save_path", type=str, required=True, default='', help="label file")
args = parser.parse_args()
def cal_acc(result_path, label_file, meta_file, save_path):
config = ConfigCenterface()
detector = CenterFaceDetector(config, None)
if not os.path.exists(save_path):
for im_dir in dct_map.values():
@ -87,4 +79,4 @@ def cal_acc(result_path, label_file, meta_file, save_path):
if __name__ == '__main__':
cal_acc(args.result_path, args.label_file, args.meta_file, args.save_path)
cal_acc(config.result_path, config.label_file, config.meta_file, config.save_path)

View File

@ -14,21 +14,14 @@
# ============================================================================
"""pre process for 310 inference"""
import os
import argparse
import shutil
import cv2
import numpy as np
from src.config import ConfigCenterface
from src.model_utils.config import config
from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
parser = argparse.ArgumentParser(description="centerface preprocess")
parser.add_argument("--dataset_path", type=str, required=True, help="dataset path.")
parser.add_argument("--preprocess_path", type=str, required=True, help="preprocess path.")
args = parser.parse_args()
def preprocess(dataset_path, preprocess_path):
config = ConfigCenterface()
event_list = os.listdir(dataset_path)
input_path = os.path.join(preprocess_path, "input")
meta_path = os.path.join(preprocess_path, "meta/meta")
@ -65,4 +58,4 @@ def preprocess(dataset_path, preprocess_path):
if __name__ == '__main__':
preprocess(args.dataset_path, args.preprocess_path)
preprocess(config.dataset_path, config.preprocess_path)

View File

@ -16,7 +16,7 @@
root=$PWD
save_path=$root/output/centerface/
ground_truth_path=$root/dataset/centerface/ground_truth
ground_truth_path=$1
echo "start eval"
python ../dependency/evaluate/eval.py --pred=$save_path --gt=$ground_truth_path
echo "end eval"

View File

@ -16,11 +16,11 @@
root=$PWD
save_path=$root/output/centerface/
ground_truth_path=$root/dataset/centerface/ground_truth
ground_truth_path=$1
#for i in $(seq start_epoch end_epoch+1)
for i in $(seq 89 200)
do
python ../dependency/evaluate/eval.py --pred=$save_path$i --gt=$ground_truth_path &
python ../dependency/evaluate/eval.py --pred=$save_path$i --gt=$ground_truth_path >> log_eval_all.txt 2>&1 &
sleep 10
done
wait

View File

@ -55,27 +55,28 @@ device_id=0
ckpt="0-125_24750.ckpt" # the model saved for epoch=125
ground_truth_path=$root/dataset/centerface/ground_truth
if [ $# == 1 ]
if [ $# -ge 1 ]
then
model_path=$(get_real_path $1)
if [ ! -f $model_path ]
# if [ ! -f $model_path ]
if [ ! -d $model_path ]
then
echo "error: model_path=$model_path is not a file"
echo "error: model_path=$model_path is not a dir"
exit 1
fi
fi
if [ $# == 2 ]
if [ $# -ge 2 ]
then
dataset_path=$(get_real_path $2)
if [ ! -f $dataset_path ]
if [ ! -d $dataset_path ]
then
echo "error: dataset_path=$dataset_path is not a file"
echo "error: dataset_path=$dataset_path is not a dir"
exit 1
fi
fi
if [ $# == 3 ]
if [ $# -ge 3 ]
then
ground_truth_mat=$(get_real_path $3)
if [ ! -f $ground_truth_mat ]
@ -85,27 +86,27 @@ then
fi
fi
if [ $# == 4 ]
if [ $# -ge 4 ]
then
save_path=$(get_real_path $4)
if [ ! -f $save_path ]
if [ ! -d $save_path ]
then
echo "error: save_path=$save_path is not a file"
echo "error: save_path=$save_path is not a dir"
exit 1
fi
fi
if [ $# == 5 ]
if [ $# -ge 5 ]
then
device_id=$5
fi
if [ $# == 6 ]
if [ $# -ge 6 ]
then
ckpt=$6
fi
if [ $# == 7 ]
if [ $# -ge 7 ]
then
ground_truth_path=$(get_real_path $7)
if [ ! -f $ground_truth_path ]

View File

@ -63,27 +63,28 @@ steps_per_epoch=198 #198 for 8P; 1583 for 1p
start=11 # start epoch number = start * device_num + min(device_phy_id) + 1
end=18 # end epoch number = end * device_num + max(device_phy_id) + 1
if [ $# == 1 ]
if [ $# -ge 1 ]
then
model_path=$(get_real_path $1)
if [ ! -f $model_path ]
# if [ ! -f $model_path ]
if [ ! -d $model_path ]
then
echo "error: model_path=$model_path is not a file"
echo "error: model_path=$model_path is not a dir"
exit 1
fi
fi
if [ $# == 2 ]
if [ $# -ge 2 ]
then
dataset_path=$(get_real_path $2)
if [ ! -f $dataset_path ]
if [ ! -d $dataset_path ]
then
echo "error: dataset_path=$dataset_path is not a file"
echo "error: dataset_path=$dataset_path is not a dir"
exit 1
fi
fi
if [ $# == 3 ]
if [ $# -ge 3 ]
then
ground_truth_mat=$(get_real_path $3)
if [ ! -f $ground_truth_mat ]
@ -93,27 +94,27 @@ then
fi
fi
if [ $# == 4 ]
if [ $# -ge 4 ]
then
save_path=$(get_real_path $4)
if [ ! -f $save_path ]
if [ ! -d $save_path ]
then
echo "error: save_path=$save_path is not a file"
echo "error: save_path=$save_path is not a dir"
exit 1
fi
fi
if [ $# == 5 ]
if [ $# -ge 5 ]
then
device_num=$5
fi
if [ $# == 6 ]
if [ $# -ge 6 ]
then
steps_per_epoch=$6
fi
if [ $# == 7 ]
if [ $# -ge 7 ]
then
start=$7
fi

View File

@ -51,7 +51,7 @@ annot_path=$dataset_path/annotations/train.json
img_dir=$dataset_path/images/train/images
rank_table=$root/rank_table_8p.json
if [ $# == 1 ]
if [ $# -ge 1 ]
then
rank_table=$(get_real_path $1)
if [ ! -f $rank_table ]
@ -61,7 +61,7 @@ then
fi
fi
if [ $# == 2 ]
if [ $# -ge 2 ]
then
pretrained_backbone=$(get_real_path $2)
if [ ! -f $pretrained_backbone ]
@ -71,17 +71,17 @@ then
fi
fi
if [ $# == 3 ]
if [ $# -ge 3 ]
then
dataset_path=$(get_real_path $3)
if [ ! -f $dataset_path ]
if [ ! -d $dataset_path ]
then
echo "error: dataset_path=$dataset_path is not a file"
echo "error: dataset_path=$dataset_path is not a dir"
exit 1
fi
fi
if [ $# == 4 ]
if [ $# -ge 4 ]
then
annot_path=$(get_real_path $4)
if [ ! -f $annot_path ]
@ -91,12 +91,12 @@ then
fi
fi
if [ $# == 5 ]
if [ $# -ge 5 ]
then
img_dir=$(get_real_path $5)
if [ ! -f $img_dir ]
if [ ! -d $img_dir ]
then
echo "error: img_dir=$img_dir is not a file"
echo "error: img_dir=$img_dir is not a dir"
exit 1
fi
fi

View File

@ -50,38 +50,28 @@ annot_path=$dataset_path/annotations/train.json
img_dir=$dataset_path/images/train/images
use_device_id=0
if [ $# == 1 ]
if [ $# -ge 1 ]
then
use_device_id=$1
fi
if [ $# == 2 ]
if [ $# -ge 2 ]
then
use_device_id=$1
pretrained_backbone=$(get_real_path $2)
fi
if [ $# == 3 ]
if [ $# -ge 3 ]
then
use_device_id=$1
pretrained_backbone=$(get_real_path $2)
dataset_path=$(get_real_path $3)
fi
if [ $# == 4 ]
if [ $# -ge 4 ]
then
use_device_id=$1
pretrained_backbone=$(get_real_path $2)
dataset_path=$(get_real_path $3)
annot_path=$(get_real_path $4)
fi
if [ $# == 5 ]
if [ $# -ge 5 ]
then
use_device_id=$1
pretrained_backbone=$(get_real_path $2)
dataset_path=$(get_real_path $3)
annot_path=$(get_real_path $4)
img_dir=$(get_real_path $5)
fi

View File

@ -14,7 +14,8 @@
# ============================================================================
"""centerface networks"""
from src.config import ConfigCenterface
# from src.config import ConfigCenterface
from src.model_utils.config import config
from src.mobile_v2 import mobilenet_v2
from src.losses import FocalLoss, SmoothL1LossNew, SmoothL1LossNewCMask
@ -132,7 +133,7 @@ class CenterfaceMobilev2(nn.Cell):
def __init__(self):
super(CenterfaceMobilev2, self).__init__()
self.config = ConfigCenterface()
self.config = config
self.base = mobilenet_v2()
channels = self.base.feat_channel
@ -199,7 +200,7 @@ class CenterFaceWithLossCell(nn.Cell):
def __init__(self, network):
super(CenterFaceWithLossCell, self).__init__()
self.centerface_network = network
self.config = ConfigCenterface()
self.config = config
self.loss = CenterFaceLoss(self.config.wh_weight, self.config.reg_offset, self.config.off_weight,
self.config.hm_weight, self.config.lm_weight)
self.reduce_sum = P.ReduceSum()
@ -295,7 +296,7 @@ class CenterFaceWithNms(nn.Cell):
def __init__(self, network):
super(CenterFaceWithNms, self).__init__()
self.centerface_network = network
self.config = ConfigCenterface()
self.config = config
# two type of maxpool self.maxpool2d = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='same')
self.maxpool2d = P.MaxPoolWithArgmax(kernel_size=3, strides=1, pad_mode='same')
self.topk = P.TopK(sorted=True)

View File

@ -1,63 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""centerface unique configs"""
class ConfigCenterface():
"""
Config setup
"""
flip_idx = [[0, 1], [3, 4]]
default_resolution = [512, 512]
heads = {'hm': 1, 'wh': 2, 'hm_offset': 2, 'landmarks': 5 * 2}
head_conv = 64
max_objs = 64
rand_crop = True
scale = 0.4
shift = 0.1
aug_rot = 0
color_aug = True
flip = 0.5
input_res = 512 #768 #800
output_res = 128 #192 #200
num_classes = 1
num_joints = 5
reg_offset = True
hm_hp = True
reg_hp_offset = True
dense_hp = False
hm_weight = 1.0
wh_weight = 0.1
off_weight = 1.0
lm_weight = 0.1
rotate = 0
# for test
mean = [0.408, 0.447, 0.470]
std = [0.289, 0.274, 0.278]
test_scales = [0.999,]
nms = 1
flip_test = 0
fix_res = True
input_h = 832 #800
input_w = 832 #800
K = 200
down_ratio = 4
test_batch_size = 1
master_batch_size = 8
num_workers = 8
not_rand_crop = False
no_color_aug = False

View File

@ -16,20 +16,11 @@
Centerface model transform
"""
import os
import argparse
import torch
from mindspore.train.serialization import load_checkpoint, save_checkpoint
from mindspore import Tensor
from src.model_utils.config import config
parser = argparse.ArgumentParser(description='')
parser.add_argument('--ckpt_fn', type=str, default='/model_path/centerface.ckpt',
help='ckpt for user to get cell/module name')
parser.add_argument('--pt_fn', type=str, default='/model_path/centerface.pth', help='checkpoint filename to convert')
parser.add_argument('--out_fn', type=str, default='/model_path/centerface_out.ckpt',
help='convert output ckpt/pth path')
parser.add_argument('--pt2ckpt', type=int, default=1, help='1 : pt2ckpt; 0 : ckpt2pt')
args = parser.parse_args()
def load_model(model_path):
"""
@ -160,10 +151,10 @@ def ckpt_to_pt(pt, ckpt, out_path):
return state_dict
if __name__ == "__main__":
if args.pt2ckpt == 1:
pt_to_ckpt(args.pt_fn, args.ckpt_fn, args.out_fn)
elif args.pt2ckpt == 0:
ckpt_to_pt(args.pt_fn, args.ckpt_fn, args.out_fn)
if config.pt2ckpt == 1:
pt_to_ckpt(config.pt_fn, config.ckpt_fn, config.out_fn)
elif config.pt2ckpt == 0:
ckpt_to_pt(config.pt_fn, config.ckpt_fn, config.out_fn)
else:
# user defined functions
pass

View File

@ -16,20 +16,14 @@
Mobilenet model transform: torch => mindspore
"""
import os
import argparse
import torch
from mindspore.train.serialization import load_checkpoint, save_checkpoint
from mindspore import Tensor
from src.model_utils.config import config
parser = argparse.ArgumentParser(description='')
parser.add_argument('--ckpt_fn', type=str, default='/model_path/mobilenet_v2_key.ckpt',
help='ckpt for user to get cell/module name')
parser.add_argument('--pt_fn', type=str, default='/model_path/mobilenet_v2-b0353104.pth',
help='checkpoint filename to convert')
parser.add_argument('--out_ckpt_fn', type=str, default='/model_path/mobilenet_v2-b0353104.ckpt',
help='convert output ckpt path')
args = parser.parse_args()
config.ckpt_fn = config.ckpt_fn_v2
config.pt_fn = config.pt_fn_v2
def load_model(model_path):
"""
@ -129,4 +123,4 @@ def pt_to_ckpt(pt, ckpt, out_ckpt):
if __name__ == "__main__":
# beta <=> bias, gamma <=> weight
pt_to_ckpt(args.pt_fn, args.ckpt_fn, args.out_ckpt_fn)
pt_to_ckpt(config.pt_fn, config.ckpt_fn, config.out_ckpt_fn)

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,122 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -91,11 +91,11 @@ def load_backbone(net, ckpt_path, args):
else:
not_found_param.append(mobilev2_beta)
args.logger.info('================found_param {}========='.format(len(find_param)))
args.logger.info(find_param)
args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
args.logger.info(not_found_param)
args.logger.info('=====load {} successfully ====='.format(ckpt_path))
print('================found_param {}========='.format(len(find_param)))
print(find_param)
print('================not_found_param {}========='.format(len(not_found_param)))
print(not_found_param)
print('=====load {} successfully ====='.format(ckpt_path))
return net

View File

@ -17,67 +17,54 @@ Test centerface example
"""
import os
import time
import argparse
import datetime
import scipy.io as sio
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.utils import get_logger
from src.var_init import default_recurisive_init
from src.centerface import CenterfaceMobilev2, CenterFaceWithNms
from src.config import ConfigCenterface
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
from dependency.evaluate.eval import evaluation
dev_id = int(os.getenv('DEVICE_ID'))
dev_id = get_device_id()
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
device_target="Ascend", save_graphs=False, device_id=dev_id)
parser = argparse.ArgumentParser('mindspore coco training')
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--test_model', type=str, default='', help='test model dir')
parser.add_argument('--ground_truth_mat', type=str, default='', help='ground_truth, mat type')
parser.add_argument('--save_dir', type=str, default='', help='save_path for evaluate')
parser.add_argument('--ground_truth_path', type=str, default='', help='ground_truth path, contain all mat file')
parser.add_argument('--eval', type=int, default=0, help='if do eval after test')
parser.add_argument('--eval_script_path', type=str, default='', help='evaluate script path')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_name', type=str, default="", help='input model name')
parser.add_argument('--device_num', type=int, default=1, help='device num for testing')
parser.add_argument('--steps_per_epoch', type=int, default=198, help='steps for each epoch')
parser.add_argument('--start', type=int, default=0, help='start loop number, used to calculate first epoch number')
parser.add_argument('--end', type=int, default=18, help='end loop number, used to calculate last epoch number')
def modelarts_process():
config.data_dir = config.data_path
config.save_dir = config.output_path
config.ckpt_path = os.path.join(config.output_path, config.ckpt_path)
args, _ = parser.parse_known_args()
if __name__ == "__main__":
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
@moxing_wrapper(pre_process=modelarts_process)
def test_centerface():
"""" test_centerface """
config.outputs_dir = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if args.ckpt_name != "":
args.start = 0
args.end = 1
if config.ckpt_name != "":
config.start = 0
config.end = 1
for loop in range(args.start, args.end, 1):
for loop in range(config.start, config.end, 1):
network = CenterfaceMobilev2()
default_recurisive_init(network)
if args.ckpt_name == "":
ckpt_num = loop * args.device_num + args.rank + 1
ckpt_name = "0-" + str(ckpt_num) + "_" + str(args.steps_per_epoch * ckpt_num) + ".ckpt"
if config.ckpt_name == "":
ckpt_num = loop * config.device_num + config.rank + 1
ckpt_name = "0-" + str(ckpt_num) + "_" + str(config.steps_per_epoch * ckpt_num) + ".ckpt"
else:
ckpt_name = args.ckpt_name
ckpt_name = config.ckpt_name
test_model = args.test_model + ckpt_name
test_model = config.test_model + ckpt_name
if not test_model:
args.logger.info('load_model {} none'.format(test_model))
print('load_model {} none'.format(test_model))
continue
if os.path.isfile(test_model):
@ -92,30 +79,28 @@ if __name__ == "__main__":
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(test_model))
print('load_model {} success'.format(test_model))
else:
args.logger.info('{} not exists or not a pre-trained file'.format(test_model))
print('{} not exists or not a pre-trained file'.format(test_model))
continue
train_network_type_nms = 1 # default with num
if train_network_type_nms:
network = CenterFaceWithNms(network)
args.logger.info('train network type with nms')
print('train network type with nms')
network.set_train(False)
args.logger.info('finish get network')
config = ConfigCenterface()
print('finish get network')
# test network -----------
start = time.time()
ground_truth_mat = sio.loadmat(args.ground_truth_mat)
ground_truth_mat = sio.loadmat(config.ground_truth_mat)
event_list = ground_truth_mat['event_list']
file_list = ground_truth_mat['file_list']
if args.ckpt_name == "":
save_path = args.save_dir + str(ckpt_num) + '/'
if config.ckpt_name == "":
save_path = config.save_dir + str(ckpt_num) + '/'
else:
save_path = args.save_dir+ '/'
save_path = config.save_dir+ '/'
detector = CenterFaceDetector(config, network)
for index, event in enumerate(event_list):
@ -123,12 +108,12 @@ if __name__ == "__main__":
im_dir = event[0][0]
if not os.path.exists(save_path + im_dir):
os.makedirs(save_path + im_dir)
args.logger.info('save_path + im_dir={}'.format(save_path + im_dir))
print('save_path + im_dir={}'.format(save_path + im_dir))
for num, file in enumerate(file_list_item):
im_name = file[0][0]
zip_name = '%s/%s.jpg' % (im_dir, im_name)
img_path = os.path.join(args.data_dir, zip_name)
args.logger.info('img_path={}'.format(img_path))
img_path = os.path.join(config.data_dir, zip_name)
print('img_path={}'.format(img_path))
dets = detector.run(img_path)['results']
@ -139,22 +124,25 @@ if __name__ == "__main__":
x1, y1, x2, y2, s = b[0], b[1], b[2], b[3], b[4]
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(x1, y1, (x2 - x1 + 1), (y2 - y1 + 1), s))
f.close()
args.logger.info('event:{}, num:{}'.format(index + 1, num + 1))
print('event:{}, num:{}'.format(index + 1, num + 1))
end = time.time()
args.logger.info("============num {} time {}".format(num, (end-start)*1000))
print("============num {} time {}".format(num, (end-start)*1000))
start = end
if args.eval:
args.logger.info('==========start eval===============')
args.logger.info("test output path = {}".format(save_path))
if config.eval:
print('==========start eval===============')
print("test output path = {}".format(save_path))
if os.path.isdir(save_path):
evaluation(save_path, args.ground_truth_path)
evaluation(save_path, config.ground_truth_path)
else:
args.logger.info('no test output path')
args.logger.info('==========end eval===============')
print('no test output path')
print('==========end eval===============')
if args.ckpt_name != "":
if config.ckpt_name != "":
break
args.logger.info('==========end testing===============')
print('==========end testing===============')
if __name__ == "__main__":
test_centerface()

View File

@ -18,7 +18,6 @@ Train centerface and get network model files(.ckpt)
import os
import time
import argparse
import datetime
import numpy as np
@ -35,7 +34,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.profiler.profiling import Profiler
from mindspore.common import set_seed
from src.utils import get_logger
from src.utils import AverageMeter
from src.lr_scheduler import warmup_step_lr
from src.lr_scheduler import warmup_cosine_annealing_lr, \
@ -44,78 +42,23 @@ from src.lr_scheduler import MultiStepLR
from src.var_init import default_recurisive_init
from src.centerface import CenterfaceMobilev2
from src.utils import load_backbone, get_param_groups
from src.config import ConfigCenterface
from src.centerface import CenterFaceWithLossCell, TrainingWrapper
from src.dataset import GetDataLoader
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
set_seed(1)
dev_id = int(os.getenv('DEVICE_ID'))
dev_id = get_device_id()
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False)
parser = argparse.ArgumentParser('mindspore coco training')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--annot_path', type=str, default='', help='train data annotation path')
parser.add_argument('--img_dir', type=str, default='', help='train data img dir')
parser.add_argument('--per_batch_size', default=8, type=int, help='batch size for per gpu')
# network related
parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone'
' model to load')
parser.add_argument('--resume', default='', type=str, help='path of pretrained centerface_model')
# optimizer and lr related
parser.add_argument('--lr_scheduler', default='multistep', type=str,
help='lr-scheduler, option type: exponential, cosine_annealing')
parser.add_argument('--lr', default=4e-3, type=float, help='learning rate of the training')
parser.add_argument('--lr_epochs', type=str, default='90,120', help='epoch of lr changing')
parser.add_argument('--lr_gamma', type=float, default=0.1,
help='decrease lr by a factor of exponential lr_scheduler')
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
parser.add_argument('--t_max', type=int, default=140, help='T-max in cosine_annealing scheduler')
parser.add_argument('--max_epoch', type=int, default=140, help='max epoch num to train the model')
parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--optimizer', default='adam', type=str,
help='optimizer type, default: adam')
# loss related
parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale')
parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE')
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot')
# logging related
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
# distributed related
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
# profiler init, can open when you debug. if train, donot open, since it cost memory and disk space
parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler')
# reset default config
parser.add_argument('--training_shape', type=str, default="", help='fix training shape')
parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training')
args, _ = parser.parse_known_args()
if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
args.t_max = args.max_epoch
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
if config.lr_scheduler == 'cosine_annealing' and config.max_epoch > config.t_max:
config.t_max = config.max_epoch
config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
def convert_training_shape(args_):
"""
@ -135,34 +78,40 @@ class InternalCallbackParam(dict):
self[para_name] = para_value
def modelarts_pre_process():
config.ckpt_path = os.path.join(config.output_path, config.ckpt_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_centerface():
pass
if __name__ == "__main__":
train_centerface()
print('\ntrain.py config:\n', config)
# init distributed
if args.is_distributed:
if config.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
config.rank = get_rank()
config.group_size = get_group_size()
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
config.rank_save_ckpt_flag = 0
if config.is_save_on_master:
if config.rank == 0:
config.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
config.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
config.outputs_dir = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if args.need_profiler:
profiler = Profiler(output_path=args.outputs_dir)
if config.need_profiler:
profiler = Profiler(output_path=config.outputs_dir)
loss_meter = AverageMeter('loss')
context.reset_auto_parallel_context()
if args.is_distributed:
if config.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
degree = get_group_size()
else:
@ -176,14 +125,14 @@ if __name__ == "__main__":
# init, to avoid overflow, some std of weight should be enough small
default_recurisive_init(network)
if args.pretrained_backbone:
network = load_backbone(network, args.pretrained_backbone, args)
args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
if config.pretrained_backbone:
network = load_backbone(network, config.pretrained_backbone, config)
print('load pre-trained backbone {} into network'.format(config.pretrained_backbone))
else:
args.logger.info('Not load pre-trained backbone, please be careful')
print('Not load pre-trained backbone, please be careful')
if os.path.isfile(args.resume):
param_dict = load_checkpoint(args.resume)
if os.path.isfile(config.resume):
param_dict = load_checkpoint(config.resume)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
@ -194,107 +143,94 @@ if __name__ == "__main__":
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(args.resume))
print('load_model {} success'.format(config.resume))
else:
args.logger.info('{} not set/exists or not a pre-trained file'.format(args.resume))
print('{} not set/exists or not a pre-trained file'.format(config.resume))
network = CenterFaceWithLossCell(network)
args.logger.info('finish get network')
print('finish get network')
config = ConfigCenterface()
config.data_dir = args.data_dir
config.annot_path = args.annot_path
config.img_dir = args.img_dir
config.label_smooth = args.label_smooth
config.label_smooth_factor = args.label_smooth_factor
# -------------reset config-----------------
if args.training_shape:
config.multi_scale = [convert_training_shape(args)]
if args.resize_rate:
config.resize_rate = args.resize_rate
if config.training_shape:
config.multi_scale = [convert_training_shape(config)]
# data loader
data_loader, args.steps_per_epoch = GetDataLoader(per_batch_size=args.per_batch_size,
max_epoch=args.max_epoch,
rank=args.rank,
group_size=args.group_size,
config=config,
split='train')
args.steps_per_epoch = args.steps_per_epoch // args.max_epoch
args.logger.info('Finish loading dataset')
data_loader, config.steps_per_epoch = GetDataLoader(per_batch_size=config.per_batch_size, \
max_epoch=config.max_epoch, rank=config.rank, group_size=config.group_size, \
config=config, split='train')
config.steps_per_epoch = config.steps_per_epoch // config.max_epoch
print('Finish loading dataset')
if not args.ckpt_interval:
args.ckpt_interval = args.steps_per_epoch
if not config.ckpt_interval:
config.ckpt_interval = config.steps_per_epoch
# lr scheduler
if args.lr_scheduler == 'multistep':
lr_fun = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch, args.max_epoch,
args.warmup_epochs)
if config.lr_scheduler == 'multistep':
lr_fun = MultiStepLR(config.lr, config.lr_epochs, config.lr_gamma, config.steps_per_epoch, config.max_epoch,
config.warmup_epochs)
lr = lr_fun.get_lr()
elif args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma
elif config.lr_scheduler == 'exponential':
lr = warmup_step_lr(config.lr,
config.lr_epochs,
config.steps_per_epoch,
config.warmup_epochs,
config.max_epoch,
gamma=config.lr_gamma
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_V2':
lr = warmup_cosine_annealing_lr_v2(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_sample':
lr = warmup_cosine_annealing_lr_sample(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.t_max,
args.eta_min)
elif config.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(config.lr,
config.steps_per_epoch,
config.warmup_epochs,
config.max_epoch,
config.t_max,
config.eta_min)
elif config.lr_scheduler == 'cosine_annealing_V2':
lr = warmup_cosine_annealing_lr_v2(config.lr,
config.steps_per_epoch,
config.warmup_epochs,
config.max_epoch,
config.t_max,
config.eta_min)
elif config.lr_scheduler == 'cosine_annealing_sample':
lr = warmup_cosine_annealing_lr_sample(config.lr,
config.steps_per_epoch,
config.warmup_epochs,
config.max_epoch,
config.t_max,
config.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
raise NotImplementedError(config.lr_scheduler)
if args.optimizer == "adam":
if config.optimizer == "adam":
opt = Adam(params=get_param_groups(network),
learning_rate=Tensor(lr),
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
args.logger.info("use adam optimizer")
elif args.optimizer == "sgd":
weight_decay=config.weight_decay,
loss_scale=config.loss_scale)
print("use adam optimizer")
elif config.optimizer == "sgd":
opt = SGD(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
momentum=config.momentum,
weight_decay=config.weight_decay,
loss_scale=config.loss_scale)
else:
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
momentum=config.momentum,
weight_decay=config.weight_decay,
loss_scale=config.loss_scale)
network = TrainingWrapper(network, opt, sens=args.loss_scale)
network = TrainingWrapper(network, opt, sens=config.loss_scale)
network.set_train()
if args.rank_save_ckpt_flag:
if config.rank_save_ckpt_flag:
# checkpoint save
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
ckpt_max_num = config.max_epoch * config.steps_per_epoch // config.ckpt_interval
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_interval,
keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=args.outputs_dir,
prefix='{}'.format(args.rank))
directory=config.outputs_dir,
prefix='{}'.format(config.rank))
cb_params = InternalCallbackParam()
cb_params.train_network = network
cb_params.epoch_num = ckpt_max_num
@ -302,14 +238,14 @@ if __name__ == "__main__":
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
args.logger.info('args.steps_per_epoch = {} args.ckpt_interval ={}'.format(args.steps_per_epoch,
args.ckpt_interval))
print('config.steps_per_epoch = {} config.ckpt_interval ={}'.format(config.steps_per_epoch, \
config.ckpt_interval))
t_end = time.time()
for i_all, batch_load in enumerate(data_loader):
i = i_all % args.steps_per_epoch
epoch = i_all // args.steps_per_epoch + 1
i = i_all % config.steps_per_epoch
epoch = i_all // config.steps_per_epoch + 1
images, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks = batch_load
images = Tensor(images)
@ -327,33 +263,28 @@ if __name__ == "__main__":
overflow = np.all(overflow.asnumpy())
loss = loss.asnumpy()
loss_meter.update(loss)
args.logger.info('epoch:{}, iter:{}, avg_loss:{}, loss:{}, overflow:{}, loss_scale:{}'.format(epoch,
i,
loss_meter,
loss,
overflow,
scaling.asnumpy()
))
print('epoch:{}, iter:{}, avg_loss:{}, loss:{}, overflow:{}, loss_scale:{}'.format( \
epoch, i, loss_meter, loss, overflow, scaling.asnumpy()))
if args.rank_save_ckpt_flag:
if config.rank_save_ckpt_flag:
# ckpt progress
cb_params.cur_epoch_num = epoch
cb_params.cur_step_num = i + 1 + (epoch-1)*args.steps_per_epoch
cb_params.batch_num = i + 2 + (epoch-1)*args.steps_per_epoch
cb_params.cur_step_num = i + 1 + (epoch-1)*config.steps_per_epoch
cb_params.batch_num = i + 2 + (epoch-1)*config.steps_per_epoch
ckpt_cb.step_end(run_context)
if (i_all+1) % args.steps_per_epoch == 0:
if (i_all+1) % config.steps_per_epoch == 0:
time_used = time.time() - t_end
fps = args.per_batch_size * args.steps_per_epoch * args.group_size / time_used
if args.rank == 0:
args.logger.info(
fps = config.per_batch_size * config.steps_per_epoch * config.group_size / time_used
if config.rank == 0:
print(
'epoch[{}], {}, {:.2f} imgs/sec, lr:{}'
.format(epoch, loss_meter, fps, lr[i + (epoch-1)*args.steps_per_epoch])
.format(epoch, loss_meter, fps, lr[i + (epoch-1)*config.steps_per_epoch])
)
t_end = time.time()
loss_meter.reset()
if args.need_profiler:
if config.need_profiler:
profiler.analyse()
args.logger.info('==========end training===============')
print('==========end training===============')