merge openpose
This commit is contained in:
parent
a12be8b065
commit
bc2e6b6c47
|
@ -56,12 +56,12 @@ In the currently provided training script, the coco2017 data set is used as an e
|
|||
|
||||
```python
|
||||
├── dataset
|
||||
├── annotation
|
||||
├── annotations
|
||||
├─person_keypoints_train2017.json
|
||||
└─person_keypoints_val2017.json
|
||||
├─ignore_mask_train
|
||||
├─ignore_mask_val
|
||||
├─tran2017
|
||||
├─ignore_mask_train2017
|
||||
├─ignore_mask_val2017
|
||||
├─train2017
|
||||
└─val2017
|
||||
```
|
||||
|
||||
|
@ -90,15 +90,15 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
```python
|
||||
# run training example
|
||||
python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > train.log 2>&1 &
|
||||
python train.py --imgpath_train ./train2017 --jsonpath_train ./person_keypoints_train2017.json --maskpath_train ./ignore_mask_train2017 > train.log 2>&1 &
|
||||
|
||||
# run distributed training example
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE]
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [IMGPATH_TRAIN] [JSONPATH_TRAIN] [MASKPATH_TRAIN]
|
||||
|
||||
# run evaluation example
|
||||
python eval.py --model_path path_to_eval_model.ckpt --imgpath_val ./dataset/val2017 --ann ./dataset/annotations/person_keypoints_val2017.json > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_eval_ascend.sh
|
||||
bash scripts/run_eval_ascend.sh [MODEL_PATH] [IMGPATH_VAL] [ANN]
|
||||
```
|
||||
|
||||
[RANK_TABLE_FILE] is the path of the multi-card information configuration table in the environment. The configuration table can be automatically generated by the tool [hccl_tool](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
@ -108,32 +108,38 @@ After installing MindSpore via the official website, you can start training and
|
|||
## [Script and Sample Code](#contents)
|
||||
|
||||
```python
|
||||
├── ModelZoo_openpose_MS_MIT
|
||||
├── openpose
|
||||
├── README.md // descriptions about openpose
|
||||
├── scripts
|
||||
│ ├──run_standalone_train.sh // shell script for distributed on Ascend
|
||||
│ ├──run_distribute_train.sh // shell script for distributed on Ascend with 8p
|
||||
│ ├──run_eval_ascend.sh // shell script for evaluation on Ascend
|
||||
├── src
|
||||
│ ├── model_utils
|
||||
│ ├── config.py # Parameter config
|
||||
│ ├── moxing_adapter.py # modelarts device configuration
|
||||
│ └── device_adapter.py # Device Config
|
||||
│ └── local_adapter.py # local device config
|
||||
│ ├──openposenet.py // Openpose architecture
|
||||
│ ├──loss.py // Loss function
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──dataset.py // Data preprocessing
|
||||
│ ├──utils.py // Utils
|
||||
│ ├──gen_ignore_mask.py // Generating mask data script
|
||||
├── export.py // model conversion script
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── eval.py // evaluation script
|
||||
├── mindspore_hub_config.py // hub config file
|
||||
├── default_config.yaml // config file
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
Parameters for both training and evaluation can be set in default_config.yaml
|
||||
|
||||
- config for openpose
|
||||
|
||||
```python
|
||||
'data_dir': 'path to dataset' # absolute full path to the train and evaluation datasets
|
||||
```default_config.yaml
|
||||
'imgpath_train': 'path to dataset' # absolute full path to the train and evaluation datasets
|
||||
'vgg_path': 'path to vgg model' # absolute full path to vgg19 model
|
||||
'save_model_path': 'path of saving models' # absolute full path to output models
|
||||
'load_pretrain': 'False' # whether training based on the pre-trained model
|
||||
|
@ -150,7 +156,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
'ckpt_interval': 5000 # the interval of saving a output model
|
||||
```
|
||||
|
||||
For more configuration details, please refer the script `config.py`.
|
||||
For more configuration details, please refer the script `default_config.yaml`.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
|
@ -159,7 +165,7 @@ For more configuration details, please refer the script `config.py`.
|
|||
- running on Ascend
|
||||
|
||||
```python
|
||||
python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > train.log 2>&1 &
|
||||
python train.py --imgpath_train ./train2017 --jsonpath_train ./person_keypoints_train2017.json --maskpath_train ./ignore_mask_train2017 > train.log 2>&1 &
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
|
@ -168,13 +174,70 @@ For more configuration details, please refer the script `config.py`.
|
|||
|
||||
```python
|
||||
# grep "epoch " train.log
|
||||
epoch[0], iter[0], loss[0.29211228793809957], 0.13 imgs/sec, vgglr=0.0,baselr=2.499999936844688e-05,stagelr=9.999999747378752e-05
|
||||
epoch[0], iter[100], loss[0.060355084178521694], 24.92 imgs/sec, vgglr=0.0,baselr=2.499999936844688e-05,stagelr=9.999999747378752e-05
|
||||
epoch[0], iter[200], loss[0.026628130997662272], 26.20 imgs/sec, vgglr=0.0,baselr=2.499999936844688e-05,stagelr=9.999999747378752e-05
|
||||
epoch[0], iter[23], mean loss is 0.292112287
|
||||
epoch[0], iter[123], mean loss is 0.060355084
|
||||
epoch[0], iter[223], mean loss is 0.026628130
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the directory of config.py: 'save_model_path'.
|
||||
The model checkpoint will be saved in the directory of default_config.yaml: 'save_model_path'.
|
||||
|
||||
- running on ModelArts
|
||||
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```python
|
||||
# Example of using distributed training dpn on modelarts :
|
||||
# Data set storage method
|
||||
# ├── openpose_dataset
|
||||
# ├── annotations
|
||||
# ├─person_keypoints_train2017.json
|
||||
# └─person_keypoints_val2017.json
|
||||
# ├─ignore_mask_train2017
|
||||
# ├─ignore_mask_val2017
|
||||
# ├─train2017
|
||||
# └─val2017
|
||||
# └─checkpoint
|
||||
# └─pre_trained
|
||||
#
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
|
||||
# a. set "enable_modelarts=True"
|
||||
# set "vgg_path=/cache/data/pre_trained/vgg19-0-97_5004.ckpt"
|
||||
# set "maskpath_train=/cache/data/ignore_mask_train2017"
|
||||
# set "jsonpath_train=/cache/data/annotations/person_keypoints_train2017"
|
||||
# set "save_model_path=/cache/train/checkpoint"
|
||||
# set "imgpath_train=/cache/data/train2017"
|
||||
#
|
||||
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/openpose"。
|
||||
# (4) Set the model's startup file on the modelarts interface "train.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../openpose_dataset"(choices openpose_dataset Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (6) start trainning the model。
|
||||
|
||||
# Example of using model inference on modelarts
|
||||
# (1) Place the trained model to the corresponding position of the bucket。
|
||||
# (2) chocie a or b。
|
||||
# a.set "enable_modelarts=True"
|
||||
# set "ann=/cache/data/annotations/person_keypoints_val2017"
|
||||
# set "output_img_path=/cache/data/output_imgs/"
|
||||
# set "imgpath_val=/cache/data/val2017"
|
||||
# set "model_path=/cache/data/checkpoint/0-80_663.ckpt"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) Set the code path on the modelarts interface "/path/openpose"。
|
||||
# (5) Set the model's startup file on the modelarts interface "eval.py" 。
|
||||
# (6) Set the data path of the model on the modelarts interface ".../openpose_dataset"(openpose_dataset Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (7) Start model inference。
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
|
@ -187,7 +250,7 @@ For more configuration details, please refer the script `config.py`.
|
|||
```python
|
||||
python eval.py --model_path path_to_eval_model.ckpt --imgpath_val ./dataset/val2017 --ann ./dataset/annotations/person_keypoints_val2017.json > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_eval_ascend.sh
|
||||
bash scripts/run_eval_ascend.sh [MODEL_PATH] [IMGPATH_VAL] [ANN]
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
|
@ -199,6 +262,27 @@ For more configuration details, please refer the script `config.py`.
|
|||
|
||||
```
|
||||
|
||||
- Export MindIR on Modelarts
|
||||
|
||||
```Modelarts
|
||||
Export MindIR example on ModelArts
|
||||
Data storage method is the same as training
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters)。
|
||||
# a. set "enable_modelarts=True"
|
||||
# set "file_name=/cache/train/openpose"
|
||||
# set "file_format=MINDIR"
|
||||
# set "ckpt_file=/cache/data/checkpoint file name"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
# (2)Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/openpose"。
|
||||
# (4) Set the model's startup file on the modelarts interface "export.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../openpose_dataset/checkpoint"(choices openpose_dataset/checkpoint Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
checkpoint_path: "./checkpoint/"
|
||||
checkpoint_file: "./checkpoint/.ckpt"
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
imgpath_train: ""
|
||||
jsonpath_train: ""
|
||||
maskpath_train: ""
|
||||
save_model_path: "./checkpoint/"
|
||||
load_pretrain: False
|
||||
pretrained_model_path: ""
|
||||
|
||||
# train type
|
||||
train_type: "fix_loss_scale"
|
||||
train_type_NP: "clip_grad"
|
||||
|
||||
# vgg bn
|
||||
vgg_with_bn: False
|
||||
vgg_path: ""
|
||||
|
||||
#if clip_grad
|
||||
GRADIENT_CLIP_TYPE: 1
|
||||
GRADIENT_CLIP_VALUE: 10.0
|
||||
|
||||
# optimizer and lr
|
||||
optimizer: "Adam"
|
||||
optimizer_NP: "Momentum"
|
||||
group_params: True
|
||||
group_params_NP: False
|
||||
lr: 1e-4
|
||||
lr_type: "default" # chose in [default, cosine]
|
||||
lr_gamma: 0.1
|
||||
lr_steps: "100000,200000,250000"
|
||||
lr_steps_NP: "250000,300000"
|
||||
warmup_epoch: 5
|
||||
max_epoch_train: 60
|
||||
max_epoch_train_NP: 80
|
||||
loss_scale: 16384
|
||||
|
||||
|
||||
# default param
|
||||
batch_size: 10
|
||||
min_keypoints: 5
|
||||
min_area: 1024
|
||||
insize: 368
|
||||
downscale: 8
|
||||
paf_sigma: 8
|
||||
heatmap_sigma: 7
|
||||
keep_checkpoint_max: 5
|
||||
log_interval: 100
|
||||
ckpt_interval: 5304
|
||||
min_box_size: 64
|
||||
max_box_size: 512
|
||||
min_scale: 0.5
|
||||
max_scale: 2.0
|
||||
max_rotate_degree: 40
|
||||
center_perterb_max: 40
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
is_distributed: 0
|
||||
eva_num: 100
|
||||
model_path: ""
|
||||
imgpath_val: ""
|
||||
ann: ""
|
||||
output_img_path: "./output_imgs/"
|
||||
|
||||
|
||||
# inference params
|
||||
inference_img_size: 368
|
||||
inference_scales: [0.5, 1, 1.5, 2]
|
||||
heatmap_size: 320
|
||||
gaussian_sigma: 2.5
|
||||
ksize: 17
|
||||
n_integ_points: 10
|
||||
n_integ_points_thresh: 8
|
||||
heatmap_peak_thresh: 0.05
|
||||
inner_product_thresh: 0.05
|
||||
limb_length_ratio: 1.0
|
||||
length_penalty_value: 1
|
||||
n_subset_limbs_thresh: 3
|
||||
subset_score_thresh: 0.2
|
||||
|
||||
# face params
|
||||
face_inference_img_size: 368
|
||||
face_heatmap_peak_thresh: 0.1
|
||||
face_crop_scale: 1.5
|
||||
face_line_indices: [
|
||||
[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], # 轮廓
|
||||
[17, 18], [18, 19], [19, 20], [20, 21],
|
||||
[22, 23], [23, 24], [24, 25], [25, 26],
|
||||
[27, 28], [28, 29], [29, 30],
|
||||
[31, 32], [32, 33], [33, 34], [34, 35],
|
||||
[36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
|
||||
[42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
|
||||
[48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], # 唇外廓
|
||||
[60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], [66, 67], [67, 60]
|
||||
]
|
||||
|
||||
# hand params
|
||||
hand_inference_img_size: 368
|
||||
hand_heatmap_peak_thresh: 0.1
|
||||
fingers_indices: [
|
||||
[[0, 1], [1, 2], [2, 3], [3, 4]],
|
||||
[[0, 5], [5, 6], [6, 7], [7, 8]],
|
||||
[[0, 9], [9, 10], [10, 11], [11, 12]],
|
||||
[[0, 13], [13, 14], [14, 15], [15, 16]],
|
||||
[[0, 17], [17, 18], [18, 19], [19, 20]],
|
||||
]
|
||||
|
||||
# ======================================================================================
|
||||
#export options
|
||||
device_id: 0
|
||||
export_batch_size: 1
|
||||
ckpt_file: ""
|
||||
file_name: "openpose"
|
||||
file_format: "MINDIR"
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
is_distributed: "Run distribute, default is false."
|
||||
device_id: "device id"
|
||||
export_batch_size: "batch size"
|
||||
file_name: "output file name"
|
||||
file_format: "file format choices[AIR, MINDIR, ONNX]"
|
||||
ckpt_file: "Checkpoint file path."
|
||||
train_dir: "train data dir"
|
||||
train_ann: "train annotations json"
|
||||
model_path: "path of testing model"
|
||||
imgpath_val: "path of testing imgs"
|
||||
ann: "path of annotations"
|
||||
output_img_path: "path of testing imgs"
|
|
@ -12,9 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import warnings
|
||||
import sys
|
||||
import numpy as np
|
||||
|
@ -23,34 +23,23 @@ from scipy.ndimage.filters import gaussian_filter
|
|||
from tqdm import tqdm
|
||||
from pycocotools.coco import COCO as LoadAnn
|
||||
from pycocotools.cocoeval import COCOeval as MapEval
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.config import params, JointType
|
||||
from src.openposenet import OpenPoseNet
|
||||
from src.dataset import valdata
|
||||
from src.model_utils.config import config, JointType
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
devid = get_device_id()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
device_target=config.device_target, save_graphs=False, device_id=devid)
|
||||
show_gt = 0
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore openpose_net test')
|
||||
parser.add_argument('--model_path', type=str, default='./0-33_170000.ckpt', help='path of testing model')
|
||||
parser.add_argument('--imgpath_val', type=str, default='./dataset/coco/val2017', help='path of testing imgs')
|
||||
parser.add_argument('--ann', type=str, default='./dataset/coco/annotations/person_keypoints_val2017.json',
|
||||
help='path of annotations')
|
||||
parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs')
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True):
|
||||
class NullWriter():
|
||||
|
@ -94,6 +83,7 @@ def load_model(test_net, model_path):
|
|||
|
||||
load_param_into_net(test_net, param_dict_new)
|
||||
|
||||
|
||||
def preprocess(img):
|
||||
x_data = img.astype('f')
|
||||
x_data /= 255
|
||||
|
@ -101,6 +91,7 @@ def preprocess(img):
|
|||
x_data = x_data.transpose(2, 0, 1)[None]
|
||||
return x_data
|
||||
|
||||
|
||||
def getImgsPath(img_dir_path):
|
||||
filepaths = []
|
||||
dirpaths = []
|
||||
|
@ -115,6 +106,7 @@ def getImgsPath(img_dir_path):
|
|||
dirpaths.append(dir_path)
|
||||
return filepaths
|
||||
|
||||
|
||||
def compute_optimal_size(orig_img, img_size, stride=8):
|
||||
orig_img_h, orig_img_w, _ = orig_img.shape
|
||||
aspect = orig_img_h / orig_img_w
|
||||
|
@ -132,6 +124,7 @@ def compute_optimal_size(orig_img, img_size, stride=8):
|
|||
img_h += stride - surplus
|
||||
return (img_w, img_h)
|
||||
|
||||
|
||||
def compute_peaks_from_heatmaps(heatmaps):
|
||||
|
||||
heatmaps = heatmaps[:-1]
|
||||
|
@ -139,7 +132,7 @@ def compute_peaks_from_heatmaps(heatmaps):
|
|||
all_peaks = []
|
||||
peak_counter = 0
|
||||
for i, heatmap in enumerate(heatmaps):
|
||||
heatmap = gaussian_filter(heatmap, sigma=params['gaussian_sigma'])
|
||||
heatmap = gaussian_filter(heatmap, sigma=config.gaussian_sigma)
|
||||
|
||||
map_left = np.zeros(heatmap.shape)
|
||||
map_right = np.zeros(heatmap.shape)
|
||||
|
@ -152,7 +145,7 @@ def compute_peaks_from_heatmaps(heatmaps):
|
|||
map_bottom[:, :-1] = heatmap[:, 1:]
|
||||
|
||||
peaks_binary = np.logical_and.reduce((
|
||||
heatmap > params['heatmap_peak_thresh'],
|
||||
heatmap > config.heatmap_peak_thresh,
|
||||
heatmap > map_left,
|
||||
heatmap > map_right,
|
||||
heatmap > map_top,
|
||||
|
@ -172,6 +165,7 @@ def compute_peaks_from_heatmaps(heatmaps):
|
|||
|
||||
return all_peaks
|
||||
|
||||
|
||||
def compute_candidate_connections(paf, cand_a, cand_b, img_len, params_):
|
||||
candidate_connections = []
|
||||
for joint_a in cand_a:
|
||||
|
@ -180,28 +174,29 @@ def compute_candidate_connections(paf, cand_a, cand_b, img_len, params_):
|
|||
norm = np.linalg.norm(vector)
|
||||
if norm == 0:
|
||||
continue
|
||||
ys = np.linspace(joint_a[1], joint_b[1], num=params_['n_integ_points'])
|
||||
xs = np.linspace(joint_a[0], joint_b[0], num=params_['n_integ_points'])
|
||||
ys = np.linspace(joint_a[1], joint_b[1], num=params_.n_integ_points)
|
||||
xs = np.linspace(joint_a[0], joint_b[0], num=params_.n_integ_points)
|
||||
integ_points = np.stack([ys, xs]).T.round().astype('i')
|
||||
|
||||
paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]])
|
||||
unit_vector = vector / norm
|
||||
inner_products = np.dot(paf_in_edge, unit_vector)
|
||||
integ_value = inner_products.sum() / len(inner_products)
|
||||
integ_value_with_dist_prior = integ_value + min(params_['limb_length_ratio'] * img_len / norm -
|
||||
params_['length_penalty_value'], 0)
|
||||
n_valid_points = sum(inner_products > params_['inner_product_thresh'])
|
||||
if n_valid_points > params_['n_integ_points_thresh'] and integ_value_with_dist_prior > 0:
|
||||
integ_value_with_dist_prior = integ_value + min(params_.limb_length_ratio * img_len / norm -
|
||||
params_.length_penalty_value, 0)
|
||||
n_valid_points = sum(inner_products > params_.inner_product_thresh)
|
||||
if n_valid_points > params_.n_integ_points_thresh and integ_value_with_dist_prior > 0:
|
||||
candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior])
|
||||
candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True)
|
||||
return candidate_connections
|
||||
|
||||
|
||||
def compute_connections(pafs, all_peaks, img_len, params_):
|
||||
all_connections = []
|
||||
for i in range(len(params_['limbs_point'])):
|
||||
for i in range(len(params_.limbs_point)):
|
||||
paf_index = [i * 2, i * 2 + 1]
|
||||
paf = pafs[paf_index] # shape: (2, 320, 320)
|
||||
limb_point = params_['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>]
|
||||
limb_point = params_.limbs_point[i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>]
|
||||
cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:]
|
||||
cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:]
|
||||
|
||||
|
@ -224,7 +219,7 @@ def grouping_key_points(all_connections, candidate_peaks, params_):
|
|||
subsets = -1 * np.ones((0, 20))
|
||||
|
||||
for l, connections in enumerate(all_connections):
|
||||
joint_a, joint_b = params_['limbs_point'][l]
|
||||
joint_a, joint_b = params_.limbs_point[l]
|
||||
for ind_a, ind_b, score in connections[:, :3]:
|
||||
ind_a, ind_b = int(ind_a), int(ind_b)
|
||||
joint_found_cnt = 0
|
||||
|
@ -284,11 +279,12 @@ def grouping_key_points(all_connections, candidate_peaks, params_):
|
|||
pass
|
||||
|
||||
# delete low score subsets
|
||||
keep = np.logical_and(subsets[:, -1] >= params_['n_subset_limbs_thresh'],
|
||||
subsets[:, -2] / subsets[:, -1] >= params_['subset_score_thresh'])
|
||||
keep = np.logical_and(subsets[:, -1] >= params_.n_subset_limbs_thresh,
|
||||
subsets[:, -2] / subsets[:, -1] >= params_.subset_score_thresh)
|
||||
subsets = subsets[keep]
|
||||
return subsets
|
||||
|
||||
|
||||
def subsets_to_pose_array(subsets, all_peaks):
|
||||
person_pose_array = []
|
||||
for subset in subsets:
|
||||
|
@ -308,8 +304,8 @@ def detect(img, network):
|
|||
orig_img = img.copy()
|
||||
orig_img_h, orig_img_w, _ = orig_img.shape
|
||||
|
||||
input_w, input_h = compute_optimal_size(orig_img, params['inference_img_size']) # 368
|
||||
map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size'])
|
||||
input_w, input_h = compute_optimal_size(orig_img, config.inference_img_size) # 368
|
||||
map_w, map_h = compute_optimal_size(orig_img, config.inference_img_size)
|
||||
|
||||
resized_image = cv2.resize(orig_img, (input_w, input_h))
|
||||
x_data = preprocess(resized_image)
|
||||
|
@ -338,8 +334,8 @@ def detect(img, network):
|
|||
all_peaks = compute_peaks_from_heatmaps(heatmaps)
|
||||
if all_peaks.shape[0] == 0:
|
||||
return np.empty((0, len(JointType), 3)), np.empty(0)
|
||||
all_connections = compute_connections(pafs, all_peaks, map_w, params)
|
||||
subsets = grouping_key_points(all_connections, all_peaks, params)
|
||||
all_connections = compute_connections(pafs, all_peaks, map_w, config)
|
||||
subsets = grouping_key_points(all_connections, all_peaks, config)
|
||||
all_peaks[:, 1] *= orig_img_w / map_w
|
||||
all_peaks[:, 2] *= orig_img_h / map_h
|
||||
poses = subsets_to_pose_array(subsets, all_peaks)
|
||||
|
@ -369,7 +365,7 @@ def draw_person_pose(orig_img, poses):
|
|||
|
||||
# limbs
|
||||
for pose in poses.round().astype('i'):
|
||||
for i, (limb, color) in enumerate(zip(params['limbs_point'], limb_colors)):
|
||||
for i, (limb, color) in enumerate(zip(config.limbs_point, limb_colors)):
|
||||
if i not in (9, 13): # don't show ear-shoulder connection
|
||||
limb_ind = np.array(limb)
|
||||
if np.all(pose[limb_ind][:, 2] != 0):
|
||||
|
@ -383,6 +379,7 @@ def draw_person_pose(orig_img, poses):
|
|||
cv2.circle(canvas, (x, y), 3, color, -1)
|
||||
return canvas
|
||||
|
||||
|
||||
def depreprocess(img):
|
||||
x_data = img[0]
|
||||
x_data += 0.5
|
||||
|
@ -391,19 +388,24 @@ def depreprocess(img):
|
|||
x_data = x_data.transpose(1, 2, 0)
|
||||
return x_data
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=None)
|
||||
def val():
|
||||
if args.is_distributed:
|
||||
config.rank = get_rank_id()
|
||||
config.group_size = get_device_num()
|
||||
|
||||
if config.is_distributed:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
if not os.path.exists(args.output_path):
|
||||
os.mkdir(args.output_path)
|
||||
network = OpenPoseNet(vgg_with_bn=params['vgg_with_bn'])
|
||||
config.rank = get_rank_id()
|
||||
config.group_size = get_device_num()
|
||||
if not os.path.exists(config.output_img_path):
|
||||
os.mkdir(config.output_img_path)
|
||||
network = OpenPoseNet(vgg_with_bn=config.vgg_with_bn)
|
||||
network.set_train(False)
|
||||
load_model(network, args.model_path)
|
||||
load_model(network, config.model_path)
|
||||
|
||||
print("load models right")
|
||||
dataset = valdata(args.ann, args.imgpath_val, args.rank, args.group_size, mode='val')
|
||||
dataset = valdata(config.ann, config.imgpath_val, config.rank, config.group_size, mode='val')
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
de_dataset = dataset.create_tuple_iterator()
|
||||
|
||||
|
@ -431,14 +433,15 @@ def val():
|
|||
print("Predict poses size is zero.", flush=True)
|
||||
img = draw_person_pose(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), poses)
|
||||
|
||||
save_path = os.path.join(args.output_path, str(img_id)+".png")
|
||||
save_path = os.path.join(config.output_img_path, str(img_id)+".png")
|
||||
cv2.imwrite(save_path, img)
|
||||
|
||||
result_json = 'eval_result.json'
|
||||
with open(os.path.join(args.output_path, result_json), 'w') as fid:
|
||||
with open(os.path.join(config.output_img_path, result_json), 'w') as fid:
|
||||
json.dump(kpt_json, fid)
|
||||
res = evaluate_mAP(os.path.join(args.output_path, result_json), ann_file=args.ann)
|
||||
res = evaluate_mAP(os.path.join(config.output_img_path, result_json), ann_file=config.ann)
|
||||
print('result: ', res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
val()
|
||||
|
|
|
@ -14,34 +14,34 @@
|
|||
# ============================================================================
|
||||
"""export"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.openposenet import OpenPoseNet
|
||||
from src.config import params
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(description="openpose export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="openpose", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=config.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=None)
|
||||
def model_export():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
# define net
|
||||
net = OpenPoseNet()
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
inputs = np.ones([args.batch_size, 3, params["insize"], params["insize"]]).astype(np.float32)
|
||||
export(net, Tensor(inputs), file_name=args.file_name, file_format=args.file_format)
|
||||
inputs = np.ones([config.batch_size, 3, config.insize, config.insize]).astype(np.float32)
|
||||
export(net, Tensor(inputs), file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_export()
|
||||
|
|
|
@ -13,10 +13,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE]"
|
||||
echo "Usage: sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [IAMGEPATH_TRAIN] [JSONPATH_TRAIN] [MASKPATH_TRAIN]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -47,15 +46,16 @@ do
|
|||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ./*.py ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp -r ./scripts ./train_parallel$i
|
||||
cp ./*yaml ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--train_dir train2017 \
|
||||
--group_size 8 \
|
||||
--train_ann person_keypoints_train2017.json > log.txt 2>&1 &
|
||||
--imgpath_train=$2 \
|
||||
--jsonpath_train=$3 \
|
||||
--maskpath_train=$4 > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
||||
|
|
|
@ -14,9 +14,17 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh scripts/run_eval_ascend.sh [MODEL_PATH] [IMPATH_VAL] [ANN]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=0
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
python eval.py \
|
||||
--model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-80_663.ckpt \
|
||||
--imgpath_val ./dataset/val2017 \
|
||||
--ann ./dataset/annotations/person_keypoints_val2017.json \
|
||||
--model_path=$1 \
|
||||
--imgpath_val=$2 \
|
||||
--ann=$3 \
|
||||
> eval.log 2>&1 &
|
||||
|
|
|
@ -14,6 +14,20 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh scripts/run_standalone_train.sh [IAMGEPATH_TRAIN] [JSONPATH_TRAIN] [MASKPATH_TRAIN]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=0
|
||||
cd ..
|
||||
python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > scripts/train.log 2>&1 &
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
rm -rf train
|
||||
mkdir train
|
||||
cp -r ./src ./train
|
||||
cp -r ./scripts ./train
|
||||
cp ./*.py ./train
|
||||
cp ./*yaml ./train
|
||||
cd ./train
|
||||
python train.py --imgpath_train=$1 --jsonpath_train=$2 --maskpath_train=$3 > train.log 2>&1 &
|
||||
|
|
|
@ -1,191 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from enum import IntEnum
|
||||
|
||||
class JointType(IntEnum):
|
||||
Nose = 0
|
||||
|
||||
Neck = 1
|
||||
|
||||
RightShoulder = 2
|
||||
|
||||
RightElbow = 3
|
||||
|
||||
RightHand = 4
|
||||
|
||||
LeftShoulder = 5
|
||||
|
||||
LeftElbow = 6
|
||||
|
||||
LeftHand = 7
|
||||
|
||||
RightWaist = 8
|
||||
|
||||
RightKnee = 9
|
||||
|
||||
RightFoot = 10
|
||||
|
||||
LeftWaist = 11
|
||||
|
||||
LeftKnee = 12
|
||||
|
||||
LeftFoot = 13
|
||||
|
||||
RightEye = 14
|
||||
|
||||
LeftEye = 15
|
||||
|
||||
RightEar = 16
|
||||
|
||||
LeftEar = 17
|
||||
|
||||
params = {
|
||||
# paths
|
||||
'data_dir': './dataset',
|
||||
'save_model_path': './checkpoints/',
|
||||
'load_pretrain': False,
|
||||
'pretrained_model_path': "",
|
||||
|
||||
# train type
|
||||
'train_type': 'fix_loss_scale', # chose in ['clip_grad', 'fix_loss_scale']
|
||||
'train_type_NP': 'clip_grad',
|
||||
|
||||
# vgg bn
|
||||
'vgg_with_bn': False,
|
||||
'vgg_path': './vgg_model/vgg19-0-97_5004.ckpt',
|
||||
|
||||
# if clip_grad
|
||||
'GRADIENT_CLIP_TYPE': 1,
|
||||
'GRADIENT_CLIP_VALUE': 10.0,
|
||||
|
||||
# optimizer and lr
|
||||
'optimizer': "Adam", # chose in ['Momentum', 'Adam']
|
||||
'optimizer_NP': "Momentum",
|
||||
'group_params': True,
|
||||
'group_params_NP': False,
|
||||
'lr': 1e-4,
|
||||
'lr_type': 'default', # chose in ["default", "cosine"]
|
||||
'lr_gamma': 0.1, # if default
|
||||
'lr_steps': '100000,200000,250000', # if default
|
||||
'lr_steps_NP': '250000,300000', # if default
|
||||
'warmup_epoch': 5, # if cosine
|
||||
'max_epoch_train': 60,
|
||||
'max_epoch_train_NP': 80,
|
||||
|
||||
'loss_scale': 16384,
|
||||
|
||||
# default param
|
||||
'batch_size': 10,
|
||||
'min_keypoints': 5,
|
||||
'min_area': 32 * 32,
|
||||
'insize': 368,
|
||||
'downscale': 8,
|
||||
'paf_sigma': 8,
|
||||
'heatmap_sigma': 7,
|
||||
'eva_num': 100,
|
||||
'keep_checkpoint_max': 1,
|
||||
'log_interval': 100,
|
||||
'ckpt_interval': 5304,
|
||||
|
||||
'min_box_size': 64,
|
||||
'max_box_size': 512,
|
||||
'min_scale': 0.5,
|
||||
'max_scale': 2.0,
|
||||
'max_rotate_degree': 40,
|
||||
'center_perterb_max': 40,
|
||||
|
||||
# inference params
|
||||
'inference_img_size': 368,
|
||||
'inference_scales': [0.5, 1, 1.5, 2],
|
||||
# 'inference_scales': [1.0],
|
||||
'heatmap_size': 320,
|
||||
'gaussian_sigma': 2.5,
|
||||
'ksize': 17,
|
||||
'n_integ_points': 10,
|
||||
'n_integ_points_thresh': 8,
|
||||
'heatmap_peak_thresh': 0.05,
|
||||
'inner_product_thresh': 0.05,
|
||||
'limb_length_ratio': 1.0,
|
||||
'length_penalty_value': 1,
|
||||
'n_subset_limbs_thresh': 3,
|
||||
'subset_score_thresh': 0.2,
|
||||
'limbs_point': [
|
||||
[JointType.Neck, JointType.RightWaist],
|
||||
[JointType.RightWaist, JointType.RightKnee],
|
||||
[JointType.RightKnee, JointType.RightFoot],
|
||||
[JointType.Neck, JointType.LeftWaist],
|
||||
[JointType.LeftWaist, JointType.LeftKnee],
|
||||
[JointType.LeftKnee, JointType.LeftFoot],
|
||||
[JointType.Neck, JointType.RightShoulder],
|
||||
[JointType.RightShoulder, JointType.RightElbow],
|
||||
[JointType.RightElbow, JointType.RightHand],
|
||||
[JointType.RightShoulder, JointType.RightEar],
|
||||
[JointType.Neck, JointType.LeftShoulder],
|
||||
[JointType.LeftShoulder, JointType.LeftElbow],
|
||||
[JointType.LeftElbow, JointType.LeftHand],
|
||||
[JointType.LeftShoulder, JointType.LeftEar],
|
||||
[JointType.Neck, JointType.Nose],
|
||||
[JointType.Nose, JointType.RightEye],
|
||||
[JointType.Nose, JointType.LeftEye],
|
||||
[JointType.RightEye, JointType.RightEar],
|
||||
[JointType.LeftEye, JointType.LeftEar]
|
||||
],
|
||||
'joint_indices': [
|
||||
JointType.Nose,
|
||||
JointType.LeftEye,
|
||||
JointType.RightEye,
|
||||
JointType.LeftEar,
|
||||
JointType.RightEar,
|
||||
JointType.LeftShoulder,
|
||||
JointType.RightShoulder,
|
||||
JointType.LeftElbow,
|
||||
JointType.RightElbow,
|
||||
JointType.LeftHand,
|
||||
JointType.RightHand,
|
||||
JointType.LeftWaist,
|
||||
JointType.RightWaist,
|
||||
JointType.LeftKnee,
|
||||
JointType.RightKnee,
|
||||
JointType.LeftFoot,
|
||||
JointType.RightFoot
|
||||
],
|
||||
|
||||
# face params
|
||||
'face_inference_img_size': 368,
|
||||
'face_heatmap_peak_thresh': 0.1,
|
||||
'face_crop_scale': 1.5,
|
||||
'face_line_indices': [
|
||||
[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], # 轮廓
|
||||
[17, 18], [18, 19], [19, 20], [20, 21],
|
||||
[22, 23], [23, 24], [24, 25], [25, 26],
|
||||
[27, 28], [28, 29], [29, 30],
|
||||
[31, 32], [32, 33], [33, 34], [34, 35],
|
||||
[36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
|
||||
[42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
|
||||
[48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], # 唇外廓
|
||||
[60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], [66, 67], [67, 60]
|
||||
],
|
||||
|
||||
# hand params
|
||||
'hand_inference_img_size': 368,
|
||||
'hand_heatmap_peak_thresh': 0.1,
|
||||
'fingers_indices': [
|
||||
[[0, 1], [1, 2], [2, 3], [3, 4]],
|
||||
[[0, 5], [5, 6], [6, 7], [7, 8]],
|
||||
[[0, 9], [9, 10], [10, 11], [11, 12]],
|
||||
[[0, 13], [13, 14], [14, 15], [15, 16]],
|
||||
[[0, 17], [17, 18], [18, 19], [19, 20]],
|
||||
],
|
||||
}
|
|
@ -18,10 +18,10 @@ import random
|
|||
import numpy as np
|
||||
import cv2
|
||||
from pycocotools.coco import COCO as ReadJson
|
||||
|
||||
import mindspore.dataset as de
|
||||
from src.model_utils.config import config, JointType
|
||||
|
||||
|
||||
from src.config import JointType, params
|
||||
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
|
@ -60,8 +60,8 @@ class txtdataset():
|
|||
valid_annotations_for_img = []
|
||||
for annotation in annotations_for_img:
|
||||
# if too few keypoints or too small
|
||||
if annotation['num_keypoints'] >= params['min_keypoints'] and \
|
||||
annotation['area'] > params['min_area']:
|
||||
if annotation['num_keypoints'] >= config.min_keypoints and \
|
||||
annotation['area'] > config.min_area:
|
||||
person_cnt += 1
|
||||
valid_annotations_for_img.append(annotation)
|
||||
|
||||
|
@ -129,11 +129,11 @@ class txtdataset():
|
|||
joint_bboxes = self.get_pose_bboxes(poses)
|
||||
bbox_sizes = ((joint_bboxes[:, 2:] - joint_bboxes[:, :2] + 1) ** 2).sum(axis=1) ** 0.5
|
||||
|
||||
min_scale = params['min_box_size'] / bbox_sizes.min()
|
||||
max_scale = params['max_box_size'] / bbox_sizes.max()
|
||||
min_scale = config.min_box_size / bbox_sizes.min()
|
||||
max_scale = config.max_box_size / bbox_sizes.max()
|
||||
|
||||
min_scale = min(max(min_scale, params['min_scale']), 1)
|
||||
max_scale = min(max(max_scale, 1), params['max_scale'])
|
||||
min_scale = min(max(min_scale, config.min_scale), 1)
|
||||
max_scale = min(max(max_scale, 1), config.max_scale)
|
||||
|
||||
scale = float((max_scale - min_scale) * random.random() + min_scale)
|
||||
shape = (round(w * scale), round(h * scale))
|
||||
|
@ -143,7 +143,7 @@ class txtdataset():
|
|||
|
||||
def random_rotate_img(self, img, mask, poses):
|
||||
h, w, _ = img.shape
|
||||
degree = np.random.randn() / 3 * params['max_rotate_degree']
|
||||
degree = np.random.randn() / 3 * config.max_rotate_degree
|
||||
rad = degree * math.pi / 180
|
||||
center = (w / 2, h / 2)
|
||||
R = cv2.getRotationMatrix2D(center, degree, 1)
|
||||
|
@ -169,7 +169,7 @@ class txtdataset():
|
|||
bbox_center = bbox[:2] + (bbox[2:] - bbox[:2]) / 2
|
||||
|
||||
r_xy = np.random.rand(2)
|
||||
perturb = ((r_xy - 0.5) * 2 * params['center_perterb_max'])
|
||||
perturb = ((r_xy - 0.5) * 2 * config.center_perterb_max)
|
||||
center = (bbox_center + perturb + 0.5).astype('i')
|
||||
|
||||
crop_img = np.zeros((insize, insize, 3), 'uint8') + 127.5
|
||||
|
@ -329,7 +329,7 @@ class txtdataset():
|
|||
def generate_pafs(self, img, poses, paf_sigma):
|
||||
pafs = np.zeros((0,) + img.shape[:-1])
|
||||
|
||||
for limb in params['limbs_point']:
|
||||
for limb in config.limbs_point:
|
||||
paf = np.zeros((2,) + img.shape[:-1])
|
||||
paf_flags = np.zeros(paf.shape) # for constant paf
|
||||
|
||||
|
@ -376,7 +376,7 @@ class txtdataset():
|
|||
resize_shape = (img.shape[0]//8, img.shape[1]//8, 3)
|
||||
pafs = np.zeros((0,) + resize_shape[:-1])
|
||||
|
||||
for limb in params['limbs_point']:
|
||||
for limb in config.limbs_point:
|
||||
paf = np.zeros((2,) + resize_shape[:-1])
|
||||
paf_flags = np.zeros(paf.shape) # for constant paf
|
||||
|
||||
|
@ -410,7 +410,7 @@ class txtdataset():
|
|||
valid_annotations_for_img = []
|
||||
for annotation in annotations_for_img:
|
||||
# if too few keypoints or too small
|
||||
if annotation['num_keypoints'] >= params['min_keypoints'] and annotation['area'] > params['min_area']:
|
||||
if annotation['num_keypoints'] >= config.min_keypoints and annotation['area'] > config.min_area:
|
||||
person_cnt += 1
|
||||
valid_annotations_for_img.append(annotation)
|
||||
|
||||
|
@ -440,7 +440,7 @@ class txtdataset():
|
|||
pose = np.zeros((1, len(JointType), 3), dtype=np.int32)
|
||||
|
||||
# convert poses position
|
||||
for i, joint_index in enumerate(params['joint_indices']):
|
||||
for i, joint_index in enumerate(config.joint_indices):
|
||||
pose[0][joint_index] = ann_pose[i]
|
||||
|
||||
# compute neck position
|
||||
|
@ -470,9 +470,9 @@ class txtdataset():
|
|||
resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses,
|
||||
shape=(self.insize, self.insize))
|
||||
|
||||
resized_heatmaps = self.generate_heatmaps_fast(resized_img, resized_poses, params['heatmap_sigma'])
|
||||
resized_heatmaps = self.generate_heatmaps_fast(resized_img, resized_poses, config.heatmap_sigma)
|
||||
|
||||
resized_pafs = self.generate_pafs_fast(resized_img, resized_poses, params['paf_sigma'])
|
||||
resized_pafs = self.generate_pafs_fast(resized_img, resized_poses, config.paf_sigma)
|
||||
|
||||
ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool')
|
||||
resized_ignore_mask = self.resize_output(ignore_mask)
|
||||
|
@ -540,10 +540,11 @@ class DistributedSampler():
|
|||
def __len__(self):
|
||||
return self.num_samplers
|
||||
|
||||
|
||||
def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''):
|
||||
#cv2.setNumThreads(0)
|
||||
val = ReadJson(jsonpath)
|
||||
dataset = txtdataset(val, imgpath, maskpath, params['insize'], mode=mode)
|
||||
dataset = txtdataset(val, imgpath, maskpath, config.insize, mode=mode)
|
||||
sampler = DistributedSampler(dataset, rank, group_size)
|
||||
ds = de.GeneratorDataset(dataset, ['img', 'img_id'], num_parallel_workers=8, sampler=sampler)
|
||||
ds = ds.repeat(1)
|
||||
|
@ -554,7 +555,7 @@ def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mo
|
|||
multiprocessing=True, num_worker=20):
|
||||
|
||||
train = ReadJson(jsonpath)
|
||||
dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode)
|
||||
dataset = txtdataset(train, imgpath, maskpath, config.insize, mode=mode)
|
||||
if group_size == 1:
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
|
||||
shuffle=shuffle,
|
||||
|
|
|
@ -22,8 +22,7 @@ from mindspore.context import ParallelMode, get_auto_parallel_context
|
|||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
from src.config import params
|
||||
from src.model_utils.config import config
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
time_stamp_init = False
|
||||
|
@ -32,8 +31,8 @@ grad_scale = C.MultitypeFuncGraph("grad_scale")
|
|||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
GRADIENT_CLIP_TYPE = params['GRADIENT_CLIP_TYPE']
|
||||
GRADIENT_CLIP_VALUE = params['GRADIENT_CLIP_VALUE']
|
||||
GRADIENT_CLIP_TYPE = config.GRADIENT_CLIP_TYPE
|
||||
GRADIENT_CLIP_VALUE = config.GRADIENT_CLIP_VALUE
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
from enum import IntEnum
|
||||
import yaml
|
||||
|
||||
|
||||
global_yaml = '../../default_config.yaml'
|
||||
|
||||
|
||||
class JointType(IntEnum):
|
||||
Nose = 0
|
||||
|
||||
Neck = 1
|
||||
|
||||
RightShoulder = 2
|
||||
|
||||
RightElbow = 3
|
||||
|
||||
RightHand = 4
|
||||
|
||||
LeftShoulder = 5
|
||||
|
||||
LeftElbow = 6
|
||||
|
||||
LeftHand = 7
|
||||
|
||||
RightWaist = 8
|
||||
|
||||
RightKnee = 9
|
||||
|
||||
RightFoot = 10
|
||||
|
||||
LeftWaist = 11
|
||||
|
||||
LeftKnee = 12
|
||||
|
||||
LeftFoot = 13
|
||||
|
||||
RightEye = 14
|
||||
|
||||
LeftEye = 15
|
||||
|
||||
RightEar = 16
|
||||
|
||||
LeftEar = 17
|
||||
|
||||
|
||||
limbs_point = [
|
||||
[JointType.Neck, JointType.RightWaist],
|
||||
[JointType.RightWaist, JointType.RightKnee],
|
||||
[JointType.RightKnee, JointType.RightFoot],
|
||||
[JointType.Neck, JointType.LeftWaist],
|
||||
[JointType.LeftWaist, JointType.LeftKnee],
|
||||
[JointType.LeftKnee, JointType.LeftFoot],
|
||||
[JointType.Neck, JointType.RightShoulder],
|
||||
[JointType.RightShoulder, JointType.RightElbow],
|
||||
[JointType.RightElbow, JointType.RightHand],
|
||||
[JointType.RightShoulder, JointType.RightEar],
|
||||
[JointType.Neck, JointType.LeftShoulder],
|
||||
[JointType.LeftShoulder, JointType.LeftElbow],
|
||||
[JointType.LeftElbow, JointType.LeftHand],
|
||||
[JointType.LeftShoulder, JointType.LeftEar],
|
||||
[JointType.Neck, JointType.Nose],
|
||||
[JointType.Nose, JointType.RightEye],
|
||||
[JointType.Nose, JointType.LeftEye],
|
||||
[JointType.RightEye, JointType.RightEar],
|
||||
[JointType.LeftEye, JointType.LeftEar]
|
||||
]
|
||||
|
||||
|
||||
joint_indices = [
|
||||
JointType.Nose,
|
||||
JointType.LeftEye,
|
||||
JointType.RightEye,
|
||||
JointType.LeftEar,
|
||||
JointType.RightEar,
|
||||
JointType.LeftShoulder,
|
||||
JointType.RightShoulder,
|
||||
JointType.LeftElbow,
|
||||
JointType.RightElbow,
|
||||
JointType.LeftHand,
|
||||
JointType.RightHand,
|
||||
JointType.LeftWaist,
|
||||
JointType.RightWaist,
|
||||
JointType.LeftKnee,
|
||||
JointType.RightKnee,
|
||||
JointType.LeftFoot,
|
||||
JointType.RightFoot
|
||||
]
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml
|
||||
|
||||
Args:
|
||||
parser: Parent parser
|
||||
cfg: Base configuration
|
||||
helper: Helper description
|
||||
cfg_path: Path to the default yaml config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config
|
||||
"""
|
||||
with open(yaml_path, 'r', encoding='utf-8') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError('Failed to parse yaml')
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments
|
||||
|
||||
Args:
|
||||
args: command line arguments
|
||||
cfg: Base configuration
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='default name', add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, global_yaml),
|
||||
help='Config file path')
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
|
||||
configs = Config(final_config)
|
||||
configs.limbs_point = limbs_point
|
||||
configs.joint_indices = joint_indices
|
||||
pprint(configs)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return 'Local Job'
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_syn_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local
|
||||
Uploca data from local directory to remote obs in contrast
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_syn_count
|
||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
|
||||
_global_syn_count += 1
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print('from path: ', from_path)
|
||||
print('to path: ', to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print('===finished data synchronization===')
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print('===save flag===')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
print('Finish sync data from {} to {}'.format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print('Dataset downloaded: ', os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
if not os.path.exists(config.load_path):
|
||||
# os.makedirs(config.load_path)
|
||||
print('=' * 20 + 'makedirs')
|
||||
if os.path.isdir(config.load_path):
|
||||
print('=' * 20 + 'makedirs success')
|
||||
else:
|
||||
print('=' * 20 + 'makedirs fail')
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print('Preload downloaded: ', os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print('Workspace downloaded: ', os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print('Start to copy output directory')
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -13,86 +13,84 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from ast import literal_eval as liter
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.nn.optim import Adam, Momentum
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.openposenet import OpenPoseNet
|
||||
from src.loss import openpose_loss, BuildTrainNetwork, TrainOneStepWithClipGradientCell
|
||||
from src.config import params
|
||||
from src.utils import get_lr, load_model, MyLossMonitor
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_rank_id, get_device_num
|
||||
|
||||
|
||||
mindspore.common.seed.set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore openpose training')
|
||||
parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir')
|
||||
parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json',
|
||||
help='train annotations json')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
args, _ = parser.parse_known_args()
|
||||
args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann)
|
||||
args.imgpath_train = os.path.join(params['data_dir'], args.train_dir)
|
||||
args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train')
|
||||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
"""Train function."""
|
||||
config.lr = liter(config.lr)
|
||||
config.outputs_dir = config.save_model_path
|
||||
device_num = get_device_num()
|
||||
|
||||
args.outputs_dir = params['save_model_path']
|
||||
|
||||
if args.group_size > 1:
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_{}/".format(str(get_rank())))
|
||||
args.rank = get_rank()
|
||||
config.rank = get_rank_id()
|
||||
config.outputs_dir = os.path.join(config.outputs_dir, "ckpt_{}/".format(config.rank))
|
||||
else:
|
||||
args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_0/")
|
||||
args.rank = 0
|
||||
config.outputs_dir = os.path.join(config.outputs_dir, "ckpt_0/")
|
||||
config.rank = 0
|
||||
|
||||
if args.group_size > 1:
|
||||
args.max_epoch = params["max_epoch_train_NP"]
|
||||
args.loss_scale = params['loss_scale'] / 2
|
||||
args.lr_steps = list(map(int, params["lr_steps_NP"].split(',')))
|
||||
params['train_type'] = params['train_type_NP']
|
||||
params['optimizer'] = params['optimizer_NP']
|
||||
params['group_params'] = params['group_params_NP']
|
||||
if device_num > 1:
|
||||
config.max_epoch = config.max_epoch_train_NP
|
||||
config.loss_scale = config.loss_scale / 2
|
||||
config.lr_steps = list(map(int, config.lr_steps_NP.split(',')))
|
||||
config.train_type = config.train_type_NP
|
||||
config.optimizer = config.optimizer_NP
|
||||
config.group_params = config.group_params_NP
|
||||
else:
|
||||
args.max_epoch = params["max_epoch_train"]
|
||||
args.loss_scale = params['loss_scale']
|
||||
args.lr_steps = list(map(int, params["lr_steps"].split(',')))
|
||||
config.max_epoch = config.max_epoch_train
|
||||
config.loss_scale = config.loss_scale
|
||||
config.lr_steps = list(map(int, config.lr_steps.split(',')))
|
||||
|
||||
# create network
|
||||
print('start create network')
|
||||
criterion = openpose_loss()
|
||||
criterion.add_flags_recursive(fp32=True)
|
||||
network = OpenPoseNet(vggpath=params['vgg_path'], vgg_with_bn=params['vgg_with_bn'])
|
||||
if params["load_pretrain"]:
|
||||
print("load pretrain model:", params["pretrained_model_path"])
|
||||
load_model(network, params["pretrained_model_path"])
|
||||
network = OpenPoseNet(vggpath=config.vgg_path, vgg_with_bn=config.vgg_with_bn)
|
||||
if config.load_pretrain:
|
||||
print("load pretrain model:", config.pretrained_model_path)
|
||||
load_model(network, config.pretrained_model_path)
|
||||
train_net = BuildTrainNetwork(network, criterion)
|
||||
|
||||
# create dataset
|
||||
if os.path.exists(args.jsonpath_train) and os.path.exists(args.imgpath_train) \
|
||||
and os.path.exists(args.maskpath_train):
|
||||
if os.path.exists(config.jsonpath_train) and os.path.exists(config.imgpath_train) \
|
||||
and os.path.exists(config.maskpath_train):
|
||||
print('start create dataset')
|
||||
else:
|
||||
print('Error: wrong data path')
|
||||
return 0
|
||||
|
||||
num_worker = 20 if args.group_size > 1 else 48
|
||||
de_dataset_train = create_dataset(args.jsonpath_train, args.imgpath_train, args.maskpath_train,
|
||||
batch_size=params['batch_size'],
|
||||
rank=args.rank,
|
||||
group_size=args.group_size,
|
||||
num_worker = 20 if device_num > 1 else 48
|
||||
de_dataset_train = create_dataset(config.jsonpath_train, config.imgpath_train, config.maskpath_train,
|
||||
batch_size=config.batch_size,
|
||||
rank=config.rank,
|
||||
group_size=device_num,
|
||||
num_worker=num_worker,
|
||||
multiprocessing=True,
|
||||
shuffle=True,
|
||||
|
@ -101,17 +99,17 @@ def train():
|
|||
print("steps_per_epoch: ", steps_per_epoch)
|
||||
|
||||
# lr scheduler
|
||||
lr_stage, lr_base, lr_vgg = get_lr(params['lr'] * args.group_size,
|
||||
params['lr_gamma'],
|
||||
lr_stage, lr_base, lr_vgg = get_lr(config.lr * device_num,
|
||||
config.lr_gamma,
|
||||
steps_per_epoch,
|
||||
args.max_epoch,
|
||||
args.lr_steps,
|
||||
args.group_size,
|
||||
lr_type=params['lr_type'],
|
||||
warmup_epoch=params['warmup_epoch'])
|
||||
config.max_epoch,
|
||||
config.lr_steps,
|
||||
device_num,
|
||||
lr_type=config.lr_type,
|
||||
warmup_epoch=config.warmup_epoch)
|
||||
|
||||
# optimizer
|
||||
if params['group_params']:
|
||||
if config.group_params:
|
||||
vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params()))
|
||||
base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params()))
|
||||
stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params()))
|
||||
|
@ -120,47 +118,47 @@ def train():
|
|||
{'params': base_params, 'lr': lr_base},
|
||||
{'params': stages_params, 'lr': lr_stage}]
|
||||
|
||||
if params['optimizer'] == "Momentum":
|
||||
if config.optimizer == "Momentum":
|
||||
opt = Momentum(group_params, learning_rate=lr_stage, momentum=0.9)
|
||||
elif params['optimizer'] == "Adam":
|
||||
elif config.optimizer == "Adam":
|
||||
opt = Adam(group_params)
|
||||
else:
|
||||
raise ValueError("optimizer not support.")
|
||||
else:
|
||||
if params['optimizer'] == "Momentum":
|
||||
if config.optimizer == "Momentum":
|
||||
opt = Momentum(train_net.trainable_params(), learning_rate=lr_stage, momentum=0.9)
|
||||
elif params['optimizer'] == "Adam":
|
||||
elif config.optimizer == "Adam":
|
||||
opt = Adam(train_net.trainable_params(), learning_rate=lr_stage)
|
||||
else:
|
||||
raise ValueError("optimizer not support.")
|
||||
|
||||
# callback
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=params['ckpt_interval'],
|
||||
keep_checkpoint_max=params["keep_checkpoint_max"])
|
||||
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(args.rank), directory=args.outputs_dir, config=config_ck)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.ckpt_interval,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(config.rank), directory=config.outputs_dir, config=config_ck)
|
||||
time_cb = TimeMonitor(data_size=de_dataset_train.get_dataset_size())
|
||||
if args.rank == 0:
|
||||
if config.rank == 0:
|
||||
callback_list = [MyLossMonitor(), time_cb, ckpoint_cb]
|
||||
else:
|
||||
callback_list = [MyLossMonitor(), time_cb]
|
||||
|
||||
# train
|
||||
if params['train_type'] == 'clip_grad':
|
||||
train_net = TrainOneStepWithClipGradientCell(train_net, opt, sens=args.loss_scale)
|
||||
if config.train_type == 'clip_grad':
|
||||
train_net = TrainOneStepWithClipGradientCell(train_net, opt, sens=config.loss_scale)
|
||||
train_net.set_train()
|
||||
model = Model(train_net)
|
||||
elif params['train_type'] == 'fix_loss_scale':
|
||||
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
||||
elif config.train_type == 'fix_loss_scale':
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
train_net.set_train()
|
||||
model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager)
|
||||
else:
|
||||
raise ValueError("Type {} is not support.".format(params['train_type']))
|
||||
raise ValueError("Type {} is not support.".format(config.train_type))
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(args.max_epoch, de_dataset_train, callbacks=callback_list,
|
||||
model.train(config.max_epoch, de_dataset_train, callbacks=callback_list,
|
||||
dataset_sink_mode=False)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mindspore.common.seed.set_seed(1)
|
||||
train()
|
||||
|
|
Loading…
Reference in New Issue