forked from mindspore-Ecosystem/mindspore
!19134 centerface test
Merge pull request !19134 from huchunmei/center1
This commit is contained in:
commit
0aef1943d0
|
@ -175,7 +175,7 @@ step6: eval
|
|||
# cd ../dependency/evaluate;
|
||||
# python setup.py install;
|
||||
# cd -; #cd ../../scripts;
|
||||
sh eval_all.sh
|
||||
sh eval_all.sh [ground_truth_path]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
@ -192,6 +192,7 @@ sh eval_all.sh
|
|||
├── postprocess.py // 310infer postprocess scripts
|
||||
├── README.md // descriptions about CenterFace
|
||||
├── ascend310_infer // application for 310 inference
|
||||
├─default_config.yaml // Training parameter profile
|
||||
├── scripts
|
||||
│ ├──run_infer_310.sh // shell script for infer on ascend310
|
||||
│ ├──eval.sh // evaluate a single testing result
|
||||
|
@ -205,7 +206,6 @@ sh eval_all.sh
|
|||
│ ├──__init__.py
|
||||
│ ├──centerface.py // centerface networks, training entry
|
||||
│ ├──dataset.py // generate dataloader and data processing entry
|
||||
│ ├──config.py // centerface unique configs
|
||||
│ ├──losses.py // losses for centerface
|
||||
│ ├──lr_scheduler.py // learning rate scheduler
|
||||
│ ├──mobile_v2.py // modified mobilenet_v2 backbone
|
||||
|
@ -213,6 +213,11 @@ sh eval_all.sh
|
|||
│ ├──var_init.py // weight initialization
|
||||
│ ├──convert_weight_mobilenetv2.py // convert pretrained backbone to mindspore
|
||||
│ ├──convert_weight.py // CenterFace model convert to mindspore
|
||||
| └──model_utils
|
||||
| ├──config.py // Processing configuration parameters
|
||||
| ├──device_adapter.py // Get cloud ID
|
||||
| ├──local_adapter.py // Get local ID
|
||||
| └ ──moxing_adapter.py // Parameter processing
|
||||
└── dependency // third party codes: MIT License
|
||||
├──extd // training dependency: data augmentation
|
||||
│ ├──utils
|
||||
|
@ -451,7 +456,7 @@ cd ../../../scripts;
|
|||
```python
|
||||
# you need to change the parameter in eval.sh
|
||||
# default eval the ckpt saved in ./scripts/output/centerface/999
|
||||
sh eval.sh
|
||||
sh eval.sh [ground_truth_path]
|
||||
```
|
||||
|
||||
2. eval many testing output for user to choose the best one
|
||||
|
@ -459,7 +464,7 @@ cd ../../../scripts;
|
|||
```python
|
||||
# you need to change the parameter in eval_all.sh
|
||||
# default eval the ckpt saved in ./scripts/output/centerface/[89-140]
|
||||
sh eval_all.sh
|
||||
sh eval_all.sh [ground_truth_path]
|
||||
```
|
||||
|
||||
3. test+eval
|
||||
|
@ -475,7 +480,7 @@ sh test_and_eval.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [CKPT]
|
|||
you can see the MAP below by eval.sh
|
||||
|
||||
```log
|
||||
(ci3.7) [root@bms-aiserver scripts]# ./eval.sh
|
||||
(ci3.7) [root@bms-aiserver scripts]# ./eval.sh ./ground_truth_path
|
||||
start eval
|
||||
==================== Results = ==================== ./scripts/output/centerface/999
|
||||
Easy Val AP: 0.923914407045363
|
||||
|
@ -488,7 +493,7 @@ end eval
|
|||
you can see the MAP below by eval_all.sh
|
||||
|
||||
```log
|
||||
(ci3.7) [root@bms-aiserver scripts]# ./eval_all.sh
|
||||
(ci3.7) [root@bms-aiserver scripts]# ./eval_all.sh ./ground_truth_path
|
||||
==================== Results = ==================== ./scripts/output/centerface/89
|
||||
Easy Val AP: 0.8884892849068273
|
||||
Medium Val AP: 0.8928813452811216
|
||||
|
|
|
@ -0,0 +1,234 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_path: './checkpoint/'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
# Config setup
|
||||
flip_idx: [[0, 1], [3, 4]]
|
||||
default_resolution: [512, 512]
|
||||
heads: {'hm': 1, 'wh': 2, 'hm_offset': 2, 'landmarks': 5 * 2}
|
||||
head_conv: 64
|
||||
max_objs: 64
|
||||
|
||||
rand_crop: True
|
||||
scale: 0.4
|
||||
shift: 0.1
|
||||
aug_rot: 0
|
||||
color_aug: True
|
||||
flip: 0.5
|
||||
input_res: 512 #768 #800
|
||||
output_res: 128 #192 #200
|
||||
num_classes: 1
|
||||
num_joints: 5
|
||||
reg_offset: True
|
||||
hm_hp: True
|
||||
reg_hp_offset: True
|
||||
dense_hp: False
|
||||
hm_weight: 1.0
|
||||
wh_weight: 0.1
|
||||
off_weight: 1.0
|
||||
lm_weight: 0.1
|
||||
rotate: 0
|
||||
|
||||
# for test
|
||||
mean: [0.408, 0.447, 0.470]
|
||||
std: [0.289, 0.274, 0.278]
|
||||
test_scales: [0.999,]
|
||||
nms: 1
|
||||
flip_test: 0
|
||||
fix_res: True
|
||||
input_h: 832 #800
|
||||
input_w: 832 #800
|
||||
K: 200
|
||||
down_ratio: 4
|
||||
test_batch_size: 1
|
||||
|
||||
master_batch_size: 8
|
||||
num_workers: 8
|
||||
not_rand_crop: False
|
||||
no_color_aug: False
|
||||
|
||||
# ==============================================================================
|
||||
# train.py mindspore coco training
|
||||
# dataset related
|
||||
data_dir: '/cache/data/'
|
||||
annot_path: '/cache/data/annotations/train_wider_face.json'
|
||||
img_dir: '/cache/data/images/WIDER_train/images/'
|
||||
per_batch_size: 8
|
||||
|
||||
# network related
|
||||
pretrained_backbone: '/cache/data/centerface/mobilenet_v2.ckpt'
|
||||
resume: ''
|
||||
|
||||
# optimizer and lr related
|
||||
lr_scheduler: 'multistep'
|
||||
lr: 0.004 # 4e-3
|
||||
lr_epochs: '90,120'
|
||||
lr_gamma: 0.1
|
||||
eta_min: 0.
|
||||
t_max: 140
|
||||
max_epoch: 140
|
||||
warmup_epochs: 0
|
||||
weight_decay: 0.0005
|
||||
momentum: 0.9
|
||||
optimizer: 'adam'
|
||||
|
||||
# loss related
|
||||
loss_scale: 1024
|
||||
label_smooth: 0
|
||||
label_smooth_factor: 0.1
|
||||
|
||||
# logging related
|
||||
log_interval: 100
|
||||
ckpt_path: './output'
|
||||
ckpt_interval: 0 # None
|
||||
|
||||
is_save_on_master: 1
|
||||
|
||||
# distributed related
|
||||
is_distributed: 1
|
||||
rank: 0
|
||||
group_size: 1
|
||||
|
||||
# roma obs
|
||||
#train_url: ""
|
||||
|
||||
# profiler init, can open when you debug. if train, donot open, since it cost memory and disk space
|
||||
need_profiler: 0
|
||||
|
||||
# reset default config
|
||||
training_shape: ""
|
||||
resize_rate: 0 # None
|
||||
|
||||
# test.py mindspore coco training'
|
||||
test_model: ''
|
||||
ground_truth_mat: ''
|
||||
save_dir: ''
|
||||
ground_truth_path: ''
|
||||
eval: 0
|
||||
eval_script_path: ''
|
||||
ckpt_name: ""
|
||||
device_num: 1
|
||||
steps_per_epoch: 198
|
||||
start: 0
|
||||
end: 18
|
||||
|
||||
# export.py centerface export'
|
||||
device_id: 0
|
||||
batch_size: 1
|
||||
ckpt_file: ''
|
||||
file_name: "centerface"
|
||||
file_format: 'AIR'
|
||||
|
||||
# centerface preprocess"
|
||||
dataset_path: ''
|
||||
preprocess_path: ''
|
||||
|
||||
# postprocess / centerface calcul AP
|
||||
result_path: ''
|
||||
label_file: ''
|
||||
meta_file: ''
|
||||
save_path: ''
|
||||
|
||||
# ==============================================================================
|
||||
# src/convert_weight.py
|
||||
ckpt_fn: '/model_path/centerface.ckpt'
|
||||
pt_fn: '/model_path/centerface.pth'
|
||||
out_fn: '/model_path/centerface_out.ckpt'
|
||||
pt2ckpt: 1
|
||||
|
||||
# src/convert_weight_mobilenetv2.py
|
||||
ckpt_fn_v2: '/model_path/mobilenet_v2_key.ckpt'
|
||||
pt_fn_v2: '/model_path/mobilenet_v2-b0353104.pth'
|
||||
out_ckpt_fn: '/model_path/mobilenet_v2-b0353104.ckpt'
|
||||
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
|
||||
data_dir: 'train data dir'
|
||||
annot_path: 'train data annotation path'
|
||||
img_dir: 'train data img dir'
|
||||
per_batch_size: 'batch size for per gpu'
|
||||
pretrained_backbone: 'model_path, local pretrained backbone model to load'
|
||||
resume: 'path of pretrained centerface_model'
|
||||
lr_scheduler: 'lr-scheduler, option type: exponential, cosine_annealing'
|
||||
lr: 'learning rate of the training'
|
||||
lr_epochs: 'epoch of lr changing'
|
||||
lr_gamma: 'decrease lr by a factor of exponential lr_scheduler'
|
||||
eta_min: 'eta_min in cosine_annealing scheduler'
|
||||
t_max: 'T-max in cosine_annealing scheduler'
|
||||
max_epoch: 'max epoch num to train the model'
|
||||
warmup_epochs: 'warmup epoch'
|
||||
weight_decay: 'weight decay'
|
||||
momentum: 'momentum'
|
||||
optimizer: 'optimizer type, default: adam'
|
||||
loss_scale: 'static loss scale'
|
||||
label_smooth: 'whether to use label smooth in CE'
|
||||
label_smooth_factor: 'smooth strength of original one-hot'
|
||||
log_interval: 'logging interval'
|
||||
ckpt_path: 'checkpoint save location'
|
||||
ckpt_interval: 'ckpt_interval'
|
||||
is_save_on_master: 'save ckpt on master or all rank'
|
||||
is_distributed: 'if multi device'
|
||||
rank: 'local rank of distributed'
|
||||
group_size: 'world size of distributed'
|
||||
need_profiler: 'whether use profiler'
|
||||
training_shape: 'fix training shape'
|
||||
resize_rate: 'resize rate for multi-scale training'
|
||||
test_model: 'test model dir'
|
||||
ground_truth_mat: 'ground_truth, mat type'
|
||||
save_dir: 'save_path for evaluate'
|
||||
ground_truth_path: 'ground_truth path, contain all mat file'
|
||||
eval: 'if do eval after test'
|
||||
eval_script_path: 'evaluate script path'
|
||||
ckpt_name: 'input model name'
|
||||
steps_per_epoch: 'steps for each epoch'
|
||||
start: 'start loop number, used to calculate first epoch number'
|
||||
end: 'end loop number, used to calculate last epoch number'
|
||||
batch_size: "batch size"
|
||||
ckpt_file: "Checkpoint file path."
|
||||
|
||||
ckpt_fn: 'ckpt for user to get cell/module name'
|
||||
pt_fn: 'checkpoint filename to convert'
|
||||
out_fn: 'convert output ckpt/pth path'
|
||||
pt2ckpt: '1 : pt2ckpt; 0 : ckpt2pt'
|
||||
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
|
||||
run_distribute: 'Run distribute, default is false.'
|
||||
do_train: 'Do train or not, default is true.'
|
||||
do_eval: 'Do eval or not, default is false.'
|
||||
dataset: 'Dataset, default is coco.'
|
||||
pre_trained: 'Pretrain file path.'
|
||||
device_id: 'Device id, default is 0.'
|
||||
device_num: 'Use device nums, default is 1.'
|
||||
rank_id: 'Rank id, default is 0.'
|
||||
file_name: "output file name."
|
||||
file_format: 'file format'
|
||||
img_path: "image file path."
|
||||
result_path: "result file path."
|
||||
|
||||
dataset_path: "dataset path."
|
||||
preprocess_path: "preprocess path."
|
||||
label_file: "label file"
|
||||
meta_file: "label file"
|
||||
save_path: "label file"
|
||||
|
||||
---
|
||||
platform: ['Ascend', 'GPU', 'CPU']
|
||||
file_format: ["AIR", "MINDIR"]
|
||||
freeze_layer: ["", "none", "backbone"]
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
|
@ -21,27 +20,22 @@ from mindspore import context, Tensor
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.centerface import CenterfaceMobilev2, CenterFaceWithNms
|
||||
from src.config import ConfigCenterface
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(description='centerface export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="centerface", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
def modelarts_process():
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = ConfigCenterface()
|
||||
@moxing_wrapper(pre_process=modelarts_process)
|
||||
def export_centerface():
|
||||
net = CenterfaceMobilev2()
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
|
||||
|
@ -55,5 +49,8 @@ if __name__ == '__main__':
|
|||
net = CenterFaceWithNms(net)
|
||||
net.set_train(False)
|
||||
|
||||
input_data = Tensor(np.zeros([args.batch_size, 3, config.input_h, config.input_w]), mindspore.float32)
|
||||
export(net, input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
input_data = Tensor(np.zeros([config.batch_size, 3, config.input_h, config.input_w]), mindspore.float32)
|
||||
export(net, input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
if __name__ == '__main__':
|
||||
export_centerface()
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.config import ConfigCenterface
|
||||
from src.model_utils.config import config
|
||||
|
||||
from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
|
||||
from dependency.evaluate.eval import evaluation
|
||||
|
||||
|
@ -36,16 +36,8 @@ dct_map = {'16': '16--Award_Ceremony', '26': '26--Soldier_Drilling', '29': '29--
|
|||
'14': '14--Traffic', '41': '41--Swimming', '46': '46--Jockey', '10': '10--People_Marching',
|
||||
'54': '54--Rescue', '57': '57--Angler', '31': '31--Waiter_Waitress', '27': '27--Spa', '21': '21--Festival'}
|
||||
|
||||
parser = argparse.ArgumentParser(description='centerface calcul AP')
|
||||
parser.add_argument("--result_path", type=str, required=True, default='', help="result file path")
|
||||
parser.add_argument("--label_file", type=str, required=True, default='', help="label file")
|
||||
parser.add_argument("--meta_file", type=str, required=True, default='', help="label file")
|
||||
parser.add_argument("--save_path", type=str, required=True, default='', help="label file")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def cal_acc(result_path, label_file, meta_file, save_path):
|
||||
config = ConfigCenterface()
|
||||
detector = CenterFaceDetector(config, None)
|
||||
if not os.path.exists(save_path):
|
||||
for im_dir in dct_map.values():
|
||||
|
@ -87,4 +79,4 @@ def cal_acc(result_path, label_file, meta_file, save_path):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cal_acc(args.result_path, args.label_file, args.meta_file, args.save_path)
|
||||
cal_acc(config.result_path, config.label_file, config.meta_file, config.save_path)
|
||||
|
|
|
@ -14,21 +14,14 @@
|
|||
# ============================================================================
|
||||
"""pre process for 310 inference"""
|
||||
import os
|
||||
import argparse
|
||||
import shutil
|
||||
import cv2
|
||||
import numpy as np
|
||||
from src.config import ConfigCenterface
|
||||
from src.model_utils.config import config
|
||||
from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
|
||||
|
||||
parser = argparse.ArgumentParser(description="centerface preprocess")
|
||||
parser.add_argument("--dataset_path", type=str, required=True, help="dataset path.")
|
||||
parser.add_argument("--preprocess_path", type=str, required=True, help="preprocess path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def preprocess(dataset_path, preprocess_path):
|
||||
config = ConfigCenterface()
|
||||
event_list = os.listdir(dataset_path)
|
||||
input_path = os.path.join(preprocess_path, "input")
|
||||
meta_path = os.path.join(preprocess_path, "meta/meta")
|
||||
|
@ -65,4 +58,4 @@ def preprocess(dataset_path, preprocess_path):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
preprocess(args.dataset_path, args.preprocess_path)
|
||||
preprocess(config.dataset_path, config.preprocess_path)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
root=$PWD
|
||||
save_path=$root/output/centerface/
|
||||
ground_truth_path=$root/dataset/centerface/ground_truth
|
||||
ground_truth_path=$1
|
||||
echo "start eval"
|
||||
python ../dependency/evaluate/eval.py --pred=$save_path --gt=$ground_truth_path
|
||||
echo "end eval"
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
|
||||
root=$PWD
|
||||
save_path=$root/output/centerface/
|
||||
ground_truth_path=$root/dataset/centerface/ground_truth
|
||||
ground_truth_path=$1
|
||||
#for i in $(seq start_epoch end_epoch+1)
|
||||
for i in $(seq 89 200)
|
||||
do
|
||||
python ../dependency/evaluate/eval.py --pred=$save_path$i --gt=$ground_truth_path &
|
||||
python ../dependency/evaluate/eval.py --pred=$save_path$i --gt=$ground_truth_path >> log_eval_all.txt 2>&1 &
|
||||
sleep 10
|
||||
done
|
||||
wait
|
||||
|
|
|
@ -55,27 +55,28 @@ device_id=0
|
|||
ckpt="0-125_24750.ckpt" # the model saved for epoch=125
|
||||
ground_truth_path=$root/dataset/centerface/ground_truth
|
||||
|
||||
if [ $# == 1 ]
|
||||
if [ $# -ge 1 ]
|
||||
then
|
||||
model_path=$(get_real_path $1)
|
||||
if [ ! -f $model_path ]
|
||||
# if [ ! -f $model_path ]
|
||||
if [ ! -d $model_path ]
|
||||
then
|
||||
echo "error: model_path=$model_path is not a file"
|
||||
echo "error: model_path=$model_path is not a dir"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 2 ]
|
||||
if [ $# -ge 2 ]
|
||||
then
|
||||
dataset_path=$(get_real_path $2)
|
||||
if [ ! -f $dataset_path ]
|
||||
if [ ! -d $dataset_path ]
|
||||
then
|
||||
echo "error: dataset_path=$dataset_path is not a file"
|
||||
echo "error: dataset_path=$dataset_path is not a dir"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
if [ $# -ge 3 ]
|
||||
then
|
||||
ground_truth_mat=$(get_real_path $3)
|
||||
if [ ! -f $ground_truth_mat ]
|
||||
|
@ -85,27 +86,27 @@ then
|
|||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 4 ]
|
||||
if [ $# -ge 4 ]
|
||||
then
|
||||
save_path=$(get_real_path $4)
|
||||
if [ ! -f $save_path ]
|
||||
if [ ! -d $save_path ]
|
||||
then
|
||||
echo "error: save_path=$save_path is not a file"
|
||||
echo "error: save_path=$save_path is not a dir"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 5 ]
|
||||
if [ $# -ge 5 ]
|
||||
then
|
||||
device_id=$5
|
||||
fi
|
||||
|
||||
if [ $# == 6 ]
|
||||
if [ $# -ge 6 ]
|
||||
then
|
||||
ckpt=$6
|
||||
fi
|
||||
|
||||
if [ $# == 7 ]
|
||||
if [ $# -ge 7 ]
|
||||
then
|
||||
ground_truth_path=$(get_real_path $7)
|
||||
if [ ! -f $ground_truth_path ]
|
||||
|
|
|
@ -63,27 +63,28 @@ steps_per_epoch=198 #198 for 8P; 1583 for 1p
|
|||
start=11 # start epoch number = start * device_num + min(device_phy_id) + 1
|
||||
end=18 # end epoch number = end * device_num + max(device_phy_id) + 1
|
||||
|
||||
if [ $# == 1 ]
|
||||
if [ $# -ge 1 ]
|
||||
then
|
||||
model_path=$(get_real_path $1)
|
||||
if [ ! -f $model_path ]
|
||||
# if [ ! -f $model_path ]
|
||||
if [ ! -d $model_path ]
|
||||
then
|
||||
echo "error: model_path=$model_path is not a file"
|
||||
echo "error: model_path=$model_path is not a dir"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 2 ]
|
||||
if [ $# -ge 2 ]
|
||||
then
|
||||
dataset_path=$(get_real_path $2)
|
||||
if [ ! -f $dataset_path ]
|
||||
if [ ! -d $dataset_path ]
|
||||
then
|
||||
echo "error: dataset_path=$dataset_path is not a file"
|
||||
echo "error: dataset_path=$dataset_path is not a dir"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
if [ $# -ge 3 ]
|
||||
then
|
||||
ground_truth_mat=$(get_real_path $3)
|
||||
if [ ! -f $ground_truth_mat ]
|
||||
|
@ -93,27 +94,27 @@ then
|
|||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 4 ]
|
||||
if [ $# -ge 4 ]
|
||||
then
|
||||
save_path=$(get_real_path $4)
|
||||
if [ ! -f $save_path ]
|
||||
if [ ! -d $save_path ]
|
||||
then
|
||||
echo "error: save_path=$save_path is not a file"
|
||||
echo "error: save_path=$save_path is not a dir"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 5 ]
|
||||
if [ $# -ge 5 ]
|
||||
then
|
||||
device_num=$5
|
||||
fi
|
||||
|
||||
if [ $# == 6 ]
|
||||
if [ $# -ge 6 ]
|
||||
then
|
||||
steps_per_epoch=$6
|
||||
fi
|
||||
|
||||
if [ $# == 7 ]
|
||||
if [ $# -ge 7 ]
|
||||
then
|
||||
start=$7
|
||||
fi
|
||||
|
|
|
@ -51,7 +51,7 @@ annot_path=$dataset_path/annotations/train.json
|
|||
img_dir=$dataset_path/images/train/images
|
||||
rank_table=$root/rank_table_8p.json
|
||||
|
||||
if [ $# == 1 ]
|
||||
if [ $# -ge 1 ]
|
||||
then
|
||||
rank_table=$(get_real_path $1)
|
||||
if [ ! -f $rank_table ]
|
||||
|
@ -61,7 +61,7 @@ then
|
|||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 2 ]
|
||||
if [ $# -ge 2 ]
|
||||
then
|
||||
pretrained_backbone=$(get_real_path $2)
|
||||
if [ ! -f $pretrained_backbone ]
|
||||
|
@ -71,17 +71,17 @@ then
|
|||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
if [ $# -ge 3 ]
|
||||
then
|
||||
dataset_path=$(get_real_path $3)
|
||||
if [ ! -f $dataset_path ]
|
||||
if [ ! -d $dataset_path ]
|
||||
then
|
||||
echo "error: dataset_path=$dataset_path is not a file"
|
||||
echo "error: dataset_path=$dataset_path is not a dir"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 4 ]
|
||||
if [ $# -ge 4 ]
|
||||
then
|
||||
annot_path=$(get_real_path $4)
|
||||
if [ ! -f $annot_path ]
|
||||
|
@ -91,12 +91,12 @@ then
|
|||
fi
|
||||
fi
|
||||
|
||||
if [ $# == 5 ]
|
||||
if [ $# -ge 5 ]
|
||||
then
|
||||
img_dir=$(get_real_path $5)
|
||||
if [ ! -f $img_dir ]
|
||||
if [ ! -d $img_dir ]
|
||||
then
|
||||
echo "error: img_dir=$img_dir is not a file"
|
||||
echo "error: img_dir=$img_dir is not a dir"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
|
|
@ -50,38 +50,28 @@ annot_path=$dataset_path/annotations/train.json
|
|||
img_dir=$dataset_path/images/train/images
|
||||
use_device_id=0
|
||||
|
||||
if [ $# == 1 ]
|
||||
if [ $# -ge 1 ]
|
||||
then
|
||||
use_device_id=$1
|
||||
fi
|
||||
|
||||
if [ $# == 2 ]
|
||||
if [ $# -ge 2 ]
|
||||
then
|
||||
use_device_id=$1
|
||||
pretrained_backbone=$(get_real_path $2)
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
if [ $# -ge 3 ]
|
||||
then
|
||||
use_device_id=$1
|
||||
pretrained_backbone=$(get_real_path $2)
|
||||
dataset_path=$(get_real_path $3)
|
||||
fi
|
||||
|
||||
if [ $# == 4 ]
|
||||
if [ $# -ge 4 ]
|
||||
then
|
||||
use_device_id=$1
|
||||
pretrained_backbone=$(get_real_path $2)
|
||||
dataset_path=$(get_real_path $3)
|
||||
annot_path=$(get_real_path $4)
|
||||
fi
|
||||
|
||||
if [ $# == 5 ]
|
||||
if [ $# -ge 5 ]
|
||||
then
|
||||
use_device_id=$1
|
||||
pretrained_backbone=$(get_real_path $2)
|
||||
dataset_path=$(get_real_path $3)
|
||||
annot_path=$(get_real_path $4)
|
||||
img_dir=$(get_real_path $5)
|
||||
fi
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# ============================================================================
|
||||
"""centerface networks"""
|
||||
|
||||
from src.config import ConfigCenterface
|
||||
# from src.config import ConfigCenterface
|
||||
from src.model_utils.config import config
|
||||
from src.mobile_v2 import mobilenet_v2
|
||||
from src.losses import FocalLoss, SmoothL1LossNew, SmoothL1LossNewCMask
|
||||
|
||||
|
@ -132,7 +133,7 @@ class CenterfaceMobilev2(nn.Cell):
|
|||
|
||||
def __init__(self):
|
||||
super(CenterfaceMobilev2, self).__init__()
|
||||
self.config = ConfigCenterface()
|
||||
self.config = config
|
||||
|
||||
self.base = mobilenet_v2()
|
||||
channels = self.base.feat_channel
|
||||
|
@ -199,7 +200,7 @@ class CenterFaceWithLossCell(nn.Cell):
|
|||
def __init__(self, network):
|
||||
super(CenterFaceWithLossCell, self).__init__()
|
||||
self.centerface_network = network
|
||||
self.config = ConfigCenterface()
|
||||
self.config = config
|
||||
self.loss = CenterFaceLoss(self.config.wh_weight, self.config.reg_offset, self.config.off_weight,
|
||||
self.config.hm_weight, self.config.lm_weight)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
@ -295,7 +296,7 @@ class CenterFaceWithNms(nn.Cell):
|
|||
def __init__(self, network):
|
||||
super(CenterFaceWithNms, self).__init__()
|
||||
self.centerface_network = network
|
||||
self.config = ConfigCenterface()
|
||||
self.config = config
|
||||
# two type of maxpool self.maxpool2d = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
self.maxpool2d = P.MaxPoolWithArgmax(kernel_size=3, strides=1, pad_mode='same')
|
||||
self.topk = P.TopK(sorted=True)
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""centerface unique configs"""
|
||||
|
||||
class ConfigCenterface():
|
||||
"""
|
||||
Config setup
|
||||
"""
|
||||
flip_idx = [[0, 1], [3, 4]]
|
||||
default_resolution = [512, 512]
|
||||
heads = {'hm': 1, 'wh': 2, 'hm_offset': 2, 'landmarks': 5 * 2}
|
||||
head_conv = 64
|
||||
max_objs = 64
|
||||
|
||||
rand_crop = True
|
||||
scale = 0.4
|
||||
shift = 0.1
|
||||
aug_rot = 0
|
||||
color_aug = True
|
||||
flip = 0.5
|
||||
input_res = 512 #768 #800
|
||||
output_res = 128 #192 #200
|
||||
num_classes = 1
|
||||
num_joints = 5
|
||||
reg_offset = True
|
||||
hm_hp = True
|
||||
reg_hp_offset = True
|
||||
dense_hp = False
|
||||
hm_weight = 1.0
|
||||
wh_weight = 0.1
|
||||
off_weight = 1.0
|
||||
lm_weight = 0.1
|
||||
rotate = 0
|
||||
|
||||
# for test
|
||||
mean = [0.408, 0.447, 0.470]
|
||||
std = [0.289, 0.274, 0.278]
|
||||
test_scales = [0.999,]
|
||||
nms = 1
|
||||
flip_test = 0
|
||||
fix_res = True
|
||||
input_h = 832 #800
|
||||
input_w = 832 #800
|
||||
K = 200
|
||||
down_ratio = 4
|
||||
test_batch_size = 1
|
||||
|
||||
master_batch_size = 8
|
||||
num_workers = 8
|
||||
not_rand_crop = False
|
||||
no_color_aug = False
|
|
@ -16,20 +16,11 @@
|
|||
Centerface model transform
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from mindspore.train.serialization import load_checkpoint, save_checkpoint
|
||||
from mindspore import Tensor
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--ckpt_fn', type=str, default='/model_path/centerface.ckpt',
|
||||
help='ckpt for user to get cell/module name')
|
||||
parser.add_argument('--pt_fn', type=str, default='/model_path/centerface.pth', help='checkpoint filename to convert')
|
||||
parser.add_argument('--out_fn', type=str, default='/model_path/centerface_out.ckpt',
|
||||
help='convert output ckpt/pth path')
|
||||
parser.add_argument('--pt2ckpt', type=int, default=1, help='1 : pt2ckpt; 0 : ckpt2pt')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def load_model(model_path):
|
||||
"""
|
||||
|
@ -160,10 +151,10 @@ def ckpt_to_pt(pt, ckpt, out_path):
|
|||
return state_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.pt2ckpt == 1:
|
||||
pt_to_ckpt(args.pt_fn, args.ckpt_fn, args.out_fn)
|
||||
elif args.pt2ckpt == 0:
|
||||
ckpt_to_pt(args.pt_fn, args.ckpt_fn, args.out_fn)
|
||||
if config.pt2ckpt == 1:
|
||||
pt_to_ckpt(config.pt_fn, config.ckpt_fn, config.out_fn)
|
||||
elif config.pt2ckpt == 0:
|
||||
ckpt_to_pt(config.pt_fn, config.ckpt_fn, config.out_fn)
|
||||
else:
|
||||
# user defined functions
|
||||
pass
|
||||
|
|
|
@ -16,20 +16,14 @@
|
|||
Mobilenet model transform: torch => mindspore
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from mindspore.train.serialization import load_checkpoint, save_checkpoint
|
||||
from mindspore import Tensor
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--ckpt_fn', type=str, default='/model_path/mobilenet_v2_key.ckpt',
|
||||
help='ckpt for user to get cell/module name')
|
||||
parser.add_argument('--pt_fn', type=str, default='/model_path/mobilenet_v2-b0353104.pth',
|
||||
help='checkpoint filename to convert')
|
||||
parser.add_argument('--out_ckpt_fn', type=str, default='/model_path/mobilenet_v2-b0353104.ckpt',
|
||||
help='convert output ckpt path')
|
||||
|
||||
args = parser.parse_args()
|
||||
config.ckpt_fn = config.ckpt_fn_v2
|
||||
config.pt_fn = config.pt_fn_v2
|
||||
|
||||
def load_model(model_path):
|
||||
"""
|
||||
|
@ -129,4 +123,4 @@ def pt_to_ckpt(pt, ckpt, out_ckpt):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# beta <=> bias, gamma <=> weight
|
||||
pt_to_ckpt(args.pt_fn, args.ckpt_fn, args.out_ckpt_fn)
|
||||
pt_to_ckpt(config.pt_fn, config.ckpt_fn, config.out_ckpt_fn)
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -91,11 +91,11 @@ def load_backbone(net, ckpt_path, args):
|
|||
else:
|
||||
not_found_param.append(mobilev2_beta)
|
||||
|
||||
args.logger.info('================found_param {}========='.format(len(find_param)))
|
||||
args.logger.info(find_param)
|
||||
args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
|
||||
args.logger.info(not_found_param)
|
||||
args.logger.info('=====load {} successfully ====='.format(ckpt_path))
|
||||
print('================found_param {}========='.format(len(find_param)))
|
||||
print(find_param)
|
||||
print('================not_found_param {}========='.format(len(not_found_param)))
|
||||
print(not_found_param)
|
||||
print('=====load {} successfully ====='.format(ckpt_path))
|
||||
|
||||
return net
|
||||
|
||||
|
|
|
@ -17,67 +17,54 @@ Test centerface example
|
|||
"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import datetime
|
||||
import scipy.io as sio
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.utils import get_logger
|
||||
from src.var_init import default_recurisive_init
|
||||
from src.centerface import CenterfaceMobilev2, CenterFaceWithNms
|
||||
from src.config import ConfigCenterface
|
||||
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
|
||||
from dependency.evaluate.eval import evaluation
|
||||
|
||||
dev_id = int(os.getenv('DEVICE_ID'))
|
||||
dev_id = get_device_id()
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
|
||||
device_target="Ascend", save_graphs=False, device_id=dev_id)
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore coco training')
|
||||
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
|
||||
parser.add_argument('--test_model', type=str, default='', help='test model dir')
|
||||
parser.add_argument('--ground_truth_mat', type=str, default='', help='ground_truth, mat type')
|
||||
parser.add_argument('--save_dir', type=str, default='', help='save_path for evaluate')
|
||||
parser.add_argument('--ground_truth_path', type=str, default='', help='ground_truth path, contain all mat file')
|
||||
parser.add_argument('--eval', type=int, default=0, help='if do eval after test')
|
||||
parser.add_argument('--eval_script_path', type=str, default='', help='evaluate script path')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
parser.add_argument('--ckpt_name', type=str, default="", help='input model name')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='device num for testing')
|
||||
parser.add_argument('--steps_per_epoch', type=int, default=198, help='steps for each epoch')
|
||||
parser.add_argument('--start', type=int, default=0, help='start loop number, used to calculate first epoch number')
|
||||
parser.add_argument('--end', type=int, default=18, help='end loop number, used to calculate last epoch number')
|
||||
def modelarts_process():
|
||||
config.data_dir = config.data_path
|
||||
config.save_dir = config.output_path
|
||||
config.ckpt_path = os.path.join(config.output_path, config.ckpt_path)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
args.logger.save_args(args)
|
||||
@moxing_wrapper(pre_process=modelarts_process)
|
||||
def test_centerface():
|
||||
"""" test_centerface """
|
||||
config.outputs_dir = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
if args.ckpt_name != "":
|
||||
args.start = 0
|
||||
args.end = 1
|
||||
if config.ckpt_name != "":
|
||||
config.start = 0
|
||||
config.end = 1
|
||||
|
||||
for loop in range(args.start, args.end, 1):
|
||||
for loop in range(config.start, config.end, 1):
|
||||
network = CenterfaceMobilev2()
|
||||
default_recurisive_init(network)
|
||||
|
||||
if args.ckpt_name == "":
|
||||
ckpt_num = loop * args.device_num + args.rank + 1
|
||||
ckpt_name = "0-" + str(ckpt_num) + "_" + str(args.steps_per_epoch * ckpt_num) + ".ckpt"
|
||||
if config.ckpt_name == "":
|
||||
ckpt_num = loop * config.device_num + config.rank + 1
|
||||
ckpt_name = "0-" + str(ckpt_num) + "_" + str(config.steps_per_epoch * ckpt_num) + ".ckpt"
|
||||
else:
|
||||
ckpt_name = args.ckpt_name
|
||||
ckpt_name = config.ckpt_name
|
||||
|
||||
test_model = args.test_model + ckpt_name
|
||||
test_model = config.test_model + ckpt_name
|
||||
if not test_model:
|
||||
args.logger.info('load_model {} none'.format(test_model))
|
||||
print('load_model {} none'.format(test_model))
|
||||
continue
|
||||
|
||||
if os.path.isfile(test_model):
|
||||
|
@ -92,30 +79,28 @@ if __name__ == "__main__":
|
|||
param_dict_new[key] = values
|
||||
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load_model {} success'.format(test_model))
|
||||
print('load_model {} success'.format(test_model))
|
||||
else:
|
||||
args.logger.info('{} not exists or not a pre-trained file'.format(test_model))
|
||||
print('{} not exists or not a pre-trained file'.format(test_model))
|
||||
continue
|
||||
|
||||
train_network_type_nms = 1 # default with num
|
||||
if train_network_type_nms:
|
||||
network = CenterFaceWithNms(network)
|
||||
args.logger.info('train network type with nms')
|
||||
print('train network type with nms')
|
||||
network.set_train(False)
|
||||
args.logger.info('finish get network')
|
||||
|
||||
config = ConfigCenterface()
|
||||
print('finish get network')
|
||||
|
||||
# test network -----------
|
||||
start = time.time()
|
||||
|
||||
ground_truth_mat = sio.loadmat(args.ground_truth_mat)
|
||||
ground_truth_mat = sio.loadmat(config.ground_truth_mat)
|
||||
event_list = ground_truth_mat['event_list']
|
||||
file_list = ground_truth_mat['file_list']
|
||||
if args.ckpt_name == "":
|
||||
save_path = args.save_dir + str(ckpt_num) + '/'
|
||||
if config.ckpt_name == "":
|
||||
save_path = config.save_dir + str(ckpt_num) + '/'
|
||||
else:
|
||||
save_path = args.save_dir+ '/'
|
||||
save_path = config.save_dir+ '/'
|
||||
detector = CenterFaceDetector(config, network)
|
||||
|
||||
for index, event in enumerate(event_list):
|
||||
|
@ -123,12 +108,12 @@ if __name__ == "__main__":
|
|||
im_dir = event[0][0]
|
||||
if not os.path.exists(save_path + im_dir):
|
||||
os.makedirs(save_path + im_dir)
|
||||
args.logger.info('save_path + im_dir={}'.format(save_path + im_dir))
|
||||
print('save_path + im_dir={}'.format(save_path + im_dir))
|
||||
for num, file in enumerate(file_list_item):
|
||||
im_name = file[0][0]
|
||||
zip_name = '%s/%s.jpg' % (im_dir, im_name)
|
||||
img_path = os.path.join(args.data_dir, zip_name)
|
||||
args.logger.info('img_path={}'.format(img_path))
|
||||
img_path = os.path.join(config.data_dir, zip_name)
|
||||
print('img_path={}'.format(img_path))
|
||||
|
||||
dets = detector.run(img_path)['results']
|
||||
|
||||
|
@ -139,22 +124,25 @@ if __name__ == "__main__":
|
|||
x1, y1, x2, y2, s = b[0], b[1], b[2], b[3], b[4]
|
||||
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(x1, y1, (x2 - x1 + 1), (y2 - y1 + 1), s))
|
||||
f.close()
|
||||
args.logger.info('event:{}, num:{}'.format(index + 1, num + 1))
|
||||
print('event:{}, num:{}'.format(index + 1, num + 1))
|
||||
|
||||
end = time.time()
|
||||
args.logger.info("============num {} time {}".format(num, (end-start)*1000))
|
||||
print("============num {} time {}".format(num, (end-start)*1000))
|
||||
start = end
|
||||
|
||||
if args.eval:
|
||||
args.logger.info('==========start eval===============')
|
||||
args.logger.info("test output path = {}".format(save_path))
|
||||
if config.eval:
|
||||
print('==========start eval===============')
|
||||
print("test output path = {}".format(save_path))
|
||||
if os.path.isdir(save_path):
|
||||
evaluation(save_path, args.ground_truth_path)
|
||||
evaluation(save_path, config.ground_truth_path)
|
||||
else:
|
||||
args.logger.info('no test output path')
|
||||
args.logger.info('==========end eval===============')
|
||||
print('no test output path')
|
||||
print('==========end eval===============')
|
||||
|
||||
if args.ckpt_name != "":
|
||||
if config.ckpt_name != "":
|
||||
break
|
||||
|
||||
args.logger.info('==========end testing===============')
|
||||
print('==========end testing===============')
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_centerface()
|
||||
|
|
|
@ -18,7 +18,6 @@ Train centerface and get network model files(.ckpt)
|
|||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
|
@ -35,7 +34,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
from mindspore.profiler.profiling import Profiler
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.utils import get_logger
|
||||
from src.utils import AverageMeter
|
||||
from src.lr_scheduler import warmup_step_lr
|
||||
from src.lr_scheduler import warmup_cosine_annealing_lr, \
|
||||
|
@ -44,78 +42,23 @@ from src.lr_scheduler import MultiStepLR
|
|||
from src.var_init import default_recurisive_init
|
||||
from src.centerface import CenterfaceMobilev2
|
||||
from src.utils import load_backbone, get_param_groups
|
||||
from src.config import ConfigCenterface
|
||||
|
||||
from src.centerface import CenterFaceWithLossCell, TrainingWrapper
|
||||
from src.dataset import GetDataLoader
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
|
||||
set_seed(1)
|
||||
dev_id = int(os.getenv('DEVICE_ID'))
|
||||
dev_id = get_device_id()
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
|
||||
device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False)
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore coco training')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
|
||||
parser.add_argument('--annot_path', type=str, default='', help='train data annotation path')
|
||||
parser.add_argument('--img_dir', type=str, default='', help='train data img dir')
|
||||
parser.add_argument('--per_batch_size', default=8, type=int, help='batch size for per gpu')
|
||||
|
||||
# network related
|
||||
parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone'
|
||||
' model to load')
|
||||
parser.add_argument('--resume', default='', type=str, help='path of pretrained centerface_model')
|
||||
|
||||
# optimizer and lr related
|
||||
parser.add_argument('--lr_scheduler', default='multistep', type=str,
|
||||
help='lr-scheduler, option type: exponential, cosine_annealing')
|
||||
parser.add_argument('--lr', default=4e-3, type=float, help='learning rate of the training')
|
||||
parser.add_argument('--lr_epochs', type=str, default='90,120', help='epoch of lr changing')
|
||||
parser.add_argument('--lr_gamma', type=float, default=0.1,
|
||||
help='decrease lr by a factor of exponential lr_scheduler')
|
||||
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
|
||||
parser.add_argument('--t_max', type=int, default=140, help='T-max in cosine_annealing scheduler')
|
||||
parser.add_argument('--max_epoch', type=int, default=140, help='max epoch num to train the model')
|
||||
parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
|
||||
parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--optimizer', default='adam', type=str,
|
||||
help='optimizer type, default: adam')
|
||||
|
||||
# loss related
|
||||
parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale')
|
||||
parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE')
|
||||
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval')
|
||||
|
||||
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
|
||||
# roma obs
|
||||
parser.add_argument('--train_url', type=str, default="", help='train url')
|
||||
|
||||
# profiler init, can open when you debug. if train, donot open, since it cost memory and disk space
|
||||
parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler')
|
||||
|
||||
# reset default config
|
||||
parser.add_argument('--training_shape', type=str, default="", help='fix training shape')
|
||||
parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
|
||||
args.t_max = args.max_epoch
|
||||
|
||||
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
||||
if config.lr_scheduler == 'cosine_annealing' and config.max_epoch > config.t_max:
|
||||
config.t_max = config.max_epoch
|
||||
|
||||
config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
|
||||
|
||||
def convert_training_shape(args_):
|
||||
"""
|
||||
|
@ -135,34 +78,40 @@ class InternalCallbackParam(dict):
|
|||
self[para_name] = para_value
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
config.ckpt_path = os.path.join(config.output_path, config.ckpt_path)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train_centerface():
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_centerface()
|
||||
print('\ntrain.py config:\n', config)
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if config.is_distributed:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
config.rank = get_rank()
|
||||
config.group_size = get_group_size()
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 0
|
||||
if config.is_save_on_master:
|
||||
if config.rank == 0:
|
||||
config.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 1
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
args.logger.save_args(args)
|
||||
config.outputs_dir = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
if args.need_profiler:
|
||||
profiler = Profiler(output_path=args.outputs_dir)
|
||||
if config.need_profiler:
|
||||
profiler = Profiler(output_path=config.outputs_dir)
|
||||
|
||||
loss_meter = AverageMeter('loss')
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
if args.is_distributed:
|
||||
if config.is_distributed:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
degree = get_group_size()
|
||||
else:
|
||||
|
@ -176,14 +125,14 @@ if __name__ == "__main__":
|
|||
# init, to avoid overflow, some std of weight should be enough small
|
||||
default_recurisive_init(network)
|
||||
|
||||
if args.pretrained_backbone:
|
||||
network = load_backbone(network, args.pretrained_backbone, args)
|
||||
args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
|
||||
if config.pretrained_backbone:
|
||||
network = load_backbone(network, config.pretrained_backbone, config)
|
||||
print('load pre-trained backbone {} into network'.format(config.pretrained_backbone))
|
||||
else:
|
||||
args.logger.info('Not load pre-trained backbone, please be careful')
|
||||
print('Not load pre-trained backbone, please be careful')
|
||||
|
||||
if os.path.isfile(args.resume):
|
||||
param_dict = load_checkpoint(args.resume)
|
||||
if os.path.isfile(config.resume):
|
||||
param_dict = load_checkpoint(config.resume)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
|
||||
|
@ -194,107 +143,94 @@ if __name__ == "__main__":
|
|||
param_dict_new[key] = values
|
||||
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load_model {} success'.format(args.resume))
|
||||
print('load_model {} success'.format(config.resume))
|
||||
else:
|
||||
args.logger.info('{} not set/exists or not a pre-trained file'.format(args.resume))
|
||||
print('{} not set/exists or not a pre-trained file'.format(config.resume))
|
||||
|
||||
network = CenterFaceWithLossCell(network)
|
||||
args.logger.info('finish get network')
|
||||
print('finish get network')
|
||||
|
||||
config = ConfigCenterface()
|
||||
config.data_dir = args.data_dir
|
||||
config.annot_path = args.annot_path
|
||||
config.img_dir = args.img_dir
|
||||
|
||||
config.label_smooth = args.label_smooth
|
||||
config.label_smooth_factor = args.label_smooth_factor
|
||||
# -------------reset config-----------------
|
||||
if args.training_shape:
|
||||
config.multi_scale = [convert_training_shape(args)]
|
||||
|
||||
if args.resize_rate:
|
||||
config.resize_rate = args.resize_rate
|
||||
if config.training_shape:
|
||||
config.multi_scale = [convert_training_shape(config)]
|
||||
|
||||
# data loader
|
||||
data_loader, args.steps_per_epoch = GetDataLoader(per_batch_size=args.per_batch_size,
|
||||
max_epoch=args.max_epoch,
|
||||
rank=args.rank,
|
||||
group_size=args.group_size,
|
||||
config=config,
|
||||
split='train')
|
||||
args.steps_per_epoch = args.steps_per_epoch // args.max_epoch
|
||||
args.logger.info('Finish loading dataset')
|
||||
data_loader, config.steps_per_epoch = GetDataLoader(per_batch_size=config.per_batch_size, \
|
||||
max_epoch=config.max_epoch, rank=config.rank, group_size=config.group_size, \
|
||||
config=config, split='train')
|
||||
config.steps_per_epoch = config.steps_per_epoch // config.max_epoch
|
||||
print('Finish loading dataset')
|
||||
|
||||
if not args.ckpt_interval:
|
||||
args.ckpt_interval = args.steps_per_epoch
|
||||
if not config.ckpt_interval:
|
||||
config.ckpt_interval = config.steps_per_epoch
|
||||
|
||||
# lr scheduler
|
||||
if args.lr_scheduler == 'multistep':
|
||||
lr_fun = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch, args.max_epoch,
|
||||
args.warmup_epochs)
|
||||
if config.lr_scheduler == 'multistep':
|
||||
lr_fun = MultiStepLR(config.lr, config.lr_epochs, config.lr_gamma, config.steps_per_epoch, config.max_epoch,
|
||||
config.warmup_epochs)
|
||||
lr = lr_fun.get_lr()
|
||||
elif args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma
|
||||
elif config.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(config.lr,
|
||||
config.lr_epochs,
|
||||
config.steps_per_epoch,
|
||||
config.warmup_epochs,
|
||||
config.max_epoch,
|
||||
gamma=config.lr_gamma
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_V2':
|
||||
lr = warmup_cosine_annealing_lr_v2(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_sample':
|
||||
lr = warmup_cosine_annealing_lr_sample(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
elif config.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(config.lr,
|
||||
config.steps_per_epoch,
|
||||
config.warmup_epochs,
|
||||
config.max_epoch,
|
||||
config.t_max,
|
||||
config.eta_min)
|
||||
elif config.lr_scheduler == 'cosine_annealing_V2':
|
||||
lr = warmup_cosine_annealing_lr_v2(config.lr,
|
||||
config.steps_per_epoch,
|
||||
config.warmup_epochs,
|
||||
config.max_epoch,
|
||||
config.t_max,
|
||||
config.eta_min)
|
||||
elif config.lr_scheduler == 'cosine_annealing_sample':
|
||||
lr = warmup_cosine_annealing_lr_sample(config.lr,
|
||||
config.steps_per_epoch,
|
||||
config.warmup_epochs,
|
||||
config.max_epoch,
|
||||
config.t_max,
|
||||
config.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
raise NotImplementedError(config.lr_scheduler)
|
||||
|
||||
if args.optimizer == "adam":
|
||||
if config.optimizer == "adam":
|
||||
opt = Adam(params=get_param_groups(network),
|
||||
learning_rate=Tensor(lr),
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
args.logger.info("use adam optimizer")
|
||||
elif args.optimizer == "sgd":
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale)
|
||||
print("use adam optimizer")
|
||||
elif config.optimizer == "sgd":
|
||||
opt = SGD(params=get_param_groups(network),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale)
|
||||
else:
|
||||
opt = Momentum(params=get_param_groups(network),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale)
|
||||
|
||||
network = TrainingWrapper(network, opt, sens=args.loss_scale)
|
||||
network = TrainingWrapper(network, opt, sens=config.loss_scale)
|
||||
network.set_train()
|
||||
|
||||
if args.rank_save_ckpt_flag:
|
||||
if config.rank_save_ckpt_flag:
|
||||
# checkpoint save
|
||||
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
|
||||
ckpt_max_num = config.max_epoch * config.steps_per_epoch // config.ckpt_interval
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_interval,
|
||||
keep_checkpoint_max=ckpt_max_num)
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=args.outputs_dir,
|
||||
prefix='{}'.format(args.rank))
|
||||
directory=config.outputs_dir,
|
||||
prefix='{}'.format(config.rank))
|
||||
cb_params = InternalCallbackParam()
|
||||
cb_params.train_network = network
|
||||
cb_params.epoch_num = ckpt_max_num
|
||||
|
@ -302,14 +238,14 @@ if __name__ == "__main__":
|
|||
run_context = RunContext(cb_params)
|
||||
ckpt_cb.begin(run_context)
|
||||
|
||||
args.logger.info('args.steps_per_epoch = {} args.ckpt_interval ={}'.format(args.steps_per_epoch,
|
||||
args.ckpt_interval))
|
||||
print('config.steps_per_epoch = {} config.ckpt_interval ={}'.format(config.steps_per_epoch, \
|
||||
config.ckpt_interval))
|
||||
|
||||
t_end = time.time()
|
||||
|
||||
for i_all, batch_load in enumerate(data_loader):
|
||||
i = i_all % args.steps_per_epoch
|
||||
epoch = i_all // args.steps_per_epoch + 1
|
||||
i = i_all % config.steps_per_epoch
|
||||
epoch = i_all // config.steps_per_epoch + 1
|
||||
images, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks = batch_load
|
||||
|
||||
images = Tensor(images)
|
||||
|
@ -327,33 +263,28 @@ if __name__ == "__main__":
|
|||
overflow = np.all(overflow.asnumpy())
|
||||
loss = loss.asnumpy()
|
||||
loss_meter.update(loss)
|
||||
args.logger.info('epoch:{}, iter:{}, avg_loss:{}, loss:{}, overflow:{}, loss_scale:{}'.format(epoch,
|
||||
i,
|
||||
loss_meter,
|
||||
loss,
|
||||
overflow,
|
||||
scaling.asnumpy()
|
||||
))
|
||||
print('epoch:{}, iter:{}, avg_loss:{}, loss:{}, overflow:{}, loss_scale:{}'.format( \
|
||||
epoch, i, loss_meter, loss, overflow, scaling.asnumpy()))
|
||||
|
||||
if args.rank_save_ckpt_flag:
|
||||
if config.rank_save_ckpt_flag:
|
||||
# ckpt progress
|
||||
cb_params.cur_epoch_num = epoch
|
||||
cb_params.cur_step_num = i + 1 + (epoch-1)*args.steps_per_epoch
|
||||
cb_params.batch_num = i + 2 + (epoch-1)*args.steps_per_epoch
|
||||
cb_params.cur_step_num = i + 1 + (epoch-1)*config.steps_per_epoch
|
||||
cb_params.batch_num = i + 2 + (epoch-1)*config.steps_per_epoch
|
||||
ckpt_cb.step_end(run_context)
|
||||
|
||||
if (i_all+1) % args.steps_per_epoch == 0:
|
||||
if (i_all+1) % config.steps_per_epoch == 0:
|
||||
time_used = time.time() - t_end
|
||||
fps = args.per_batch_size * args.steps_per_epoch * args.group_size / time_used
|
||||
if args.rank == 0:
|
||||
args.logger.info(
|
||||
fps = config.per_batch_size * config.steps_per_epoch * config.group_size / time_used
|
||||
if config.rank == 0:
|
||||
print(
|
||||
'epoch[{}], {}, {:.2f} imgs/sec, lr:{}'
|
||||
.format(epoch, loss_meter, fps, lr[i + (epoch-1)*args.steps_per_epoch])
|
||||
.format(epoch, loss_meter, fps, lr[i + (epoch-1)*config.steps_per_epoch])
|
||||
)
|
||||
t_end = time.time()
|
||||
loss_meter.reset()
|
||||
|
||||
if args.need_profiler:
|
||||
if config.need_profiler:
|
||||
profiler.analyse()
|
||||
|
||||
args.logger.info('==========end training===============')
|
||||
print('==========end training===============')
|
||||
|
|
Loading…
Reference in New Issue