forked from mindspore-Ecosystem/mindspore
Add openpose net to modelzoo
This commit is contained in:
parent
63fcdb44b5
commit
a79416cfba
|
@ -0,0 +1,225 @@
|
|||
# Contents
|
||||
|
||||
- [Openpose Description](#googlenet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
|
||||
# [Openpose Description](#contents)
|
||||
|
||||
Openpose network proposes a bottom-up human attitude estimation algorithm using Part Affinity Fields (PAFs). Instead of a top-down algorithm: Detect people first and then return key-points and skeleton. The advantage of openpose is that the computing time does not increase significantly as the number of people in the image increases.However,the top-down algorithm is based on the detection result, and the runtimes grow linearly with the number of people.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1611.08050): Zhe Cao,Tomas Simon,Shih-En Wei,Yaser Sheikh,"Realtime Multi-Person 2D Pose Estimation using Part Affinity Fields",The IEEE Conference on Computer Vision and Pattern Recongnition(CVPR),2017
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
In first step the image is passed through baseline CNN network to extract the feature maps of the input In the paper. In this paper thee authors used first 10 layers of VGG-19 network.
|
||||
The feature map is then process in a multi-stage CNN pipeline to generate the Part Confidence Maps and Part Affinity Field.
|
||||
In the last step, the Confidence Maps and Part Affinity Fields that are generated above are processed by a greedy bipartite matching algorithm to obtain the poses for each person in the image.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Prepare datasets, including training sets, verification sets, and annotations.The training set and validation set samples are located in the "dataset" directory, The available datasets include coco2014,coco2017 datasets.
|
||||
In the currently provided training script, the coco2017 data set is used as an example to perform data preprocessing during the training process. If users use data sets in other formats, please modify the data set loading and preprocessing methods
|
||||
|
||||
- Download data from coco2017 data official website and unzip.
|
||||
|
||||
````bash
|
||||
wget http://images.cocodataset.org/zips/train2017.zip
|
||||
wget http://images.cocodataset.org/zips/val2017.zip
|
||||
wget http://images.cocodataset.org/annotations/annotations2017.zip
|
||||
````
|
||||
|
||||
- Create the mask dataset.
|
||||
|
||||
Run python gen_ignore_mask.py
|
||||
|
||||
````python
|
||||
python gen_ignore_mask.py --train_ann ../dataset/annotations/person_keypoints_train2017.json --val_ann ../dataset/annotations/person_keypoints_val2017.json --train_dir train2017 --val_dir val2017
|
||||
````
|
||||
|
||||
- The dataset folder is generated in the root directory and contains the following files:
|
||||
|
||||
```python
|
||||
├── dataset
|
||||
├── annotation
|
||||
├─person_keypoints_train2017.json
|
||||
└─person_keypoints_val2017.json
|
||||
├─ignore_mask_train
|
||||
├─ignore_mask_val
|
||||
├─tran2017
|
||||
└─val2017
|
||||
```
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware (Ascend)
|
||||
- Prepare hardware environment with Ascend. If you want to try, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- Download the VGG19 model of the MindSpore version:
|
||||
- [vgg19-0-97_5004.ckpt](http://10.154.33.38:51203/tutorials/image_classification.html)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```python
|
||||
# run training example
|
||||
python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > train.log 2>&1 &
|
||||
|
||||
# run distributed training example
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE]
|
||||
|
||||
# run evaluation example
|
||||
python eval.py --model_path path_to_eval_model.ckpt --imgpath_val ./dataset/val2017 --ann ./dataset/annotations/person_keypoints_val2017.json > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_eval_ascend.sh
|
||||
```
|
||||
|
||||
[RANK_TABLE_FILE] is the path of the multi-card information configuration table in the environment. The configuration table can be automatically generated by the tool [hccl_tool](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```python
|
||||
├── ModelZoo_openpose_MS_MIT
|
||||
├── README.md // descriptions about openpose
|
||||
├── scripts
|
||||
│ ├──run_standalone_train.sh // shell script for distributed on Ascend
|
||||
│ ├──run_distribute_train.sh // shell script for distributed on Ascend with 8p
|
||||
│ ├──run_eval_ascend.sh // shell script for evaluation on Ascend
|
||||
├── src
|
||||
│ ├──openposenet.py // Openpose architecture
|
||||
│ ├──loss.py // Loss function
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──dataset.py // Data preprocessing
|
||||
│ ├──utils.py // Utils
|
||||
│ ├──gen_ignore_mask.py // Generating mask data script
|
||||
├── export.py // model conversion script
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
|
||||
- config for openpose
|
||||
|
||||
```python
|
||||
'data_dir': 'path to dataset' # absolute full path to the train and evaluation datasets
|
||||
'vgg_path': 'path to vgg model' # absolute full path to vgg19 model
|
||||
'save_model_path': 'path of saving models' # absolute full path to output models
|
||||
'load_pretrain': 'False' # whether training based on the pre-trained model
|
||||
'pretrained_model_path':'' # load pre-trained model path
|
||||
'lr': 1e-4 # initial learning rate
|
||||
'batch_size': 10 # training batch size
|
||||
'lr_gamma': 0.1 # lr scale when reach lr_steps
|
||||
'lr_steps': '100000,200000,250000' # the steps when lr * lr_gamma
|
||||
'loss scale': 16386 # the loss scale of mixed precision
|
||||
'max_epoch_train': 60 # total training epochs
|
||||
'insize': 368 # image size used as input to the model
|
||||
'keep_checkpoint_max': 5 # only keep the last keep_checkpoint_max checkpoint
|
||||
'log_interval': 100 # the interval of print a log
|
||||
'ckpt_interval': 5000 # the interval of saving a output model
|
||||
```
|
||||
|
||||
For more configuration details, please refer the script `config.py`.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```python
|
||||
python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > train.log 2>&1 &
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
|
||||
|
||||
```python
|
||||
# grep "epoch " train.log
|
||||
epoch[0], iter[0], loss[0.29211228793809957], 0.13 imgs/sec, vgglr=0.0,baselr=2.499999936844688e-05,stagelr=9.999999747378752e-05
|
||||
epoch[0], iter[100], loss[0.060355084178521694], 24.92 imgs/sec, vgglr=0.0,baselr=2.499999936844688e-05,stagelr=9.999999747378752e-05
|
||||
epoch[0], iter[200], loss[0.026628130997662272], 26.20 imgs/sec, vgglr=0.0,baselr=2.499999936844688e-05,stagelr=9.999999747378752e-05
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the directory of config.py: 'save_model_path'.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
- running on Ascend
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/openpose/outputs/\*time*\/0-6_30000.ckpt".
|
||||
|
||||
```python
|
||||
python eval.py --model_path path_to_eval_model.ckpt --imgpath_val ./dataset/val2017 --ann ./dataset/annotations/person_keypoints_val2017.json > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_eval_ascend.sh
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```python
|
||||
# grep "AP" eval.log
|
||||
|
||||
{'AP': 0.40030956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242}
|
||||
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend
|
||||
| -------------------------- | -----------------------------------------------------------
|
||||
| Model Version | openpose
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G
|
||||
| uploaded Date | 10/20/2020 (month/day/year)
|
||||
| MindSpore Version | 1.0.1-alpha
|
||||
| Training Parameters | epoch = 60, steps = 30k, batch_size = 10, lr = 0.0001
|
||||
| Optimizer | Adam
|
||||
| Loss Function | MSE
|
||||
| outputs | pose
|
||||
| Speed | 1pc: 29imgs/s
|
||||
| Total time | 1pc: 30h
|
||||
| Checkpoint for Fine tuning | 602.33M (.ckpt file)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,458 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import warnings
|
||||
import sys
|
||||
import cv2
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
from pycocotools.coco import COCO as LoadAnn
|
||||
from pycocotools.cocoeval import COCOeval as MapEval
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.dataset import valdata
|
||||
from src.openposenet import OpenPoseNet
|
||||
from src.config import params, JointType
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
show_gt = 0
|
||||
|
||||
def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True):
|
||||
class NullWriter():
|
||||
def write(self, arg):
|
||||
pass
|
||||
if silence:
|
||||
nullwrite = NullWriter()
|
||||
oldstdout = sys.stdout
|
||||
sys.stdout = nullwrite # disable output
|
||||
|
||||
Gt = LoadAnn(ann_file)
|
||||
Dt = Gt.loadRes(res_file)
|
||||
|
||||
Eval = MapEval(Gt, Dt, ann_type)
|
||||
Eval.evaluate()
|
||||
Eval.accumulate()
|
||||
Eval.summarize()
|
||||
|
||||
if silence:
|
||||
sys.stdout = oldstdout # enable output
|
||||
|
||||
stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)',
|
||||
'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
|
||||
info_str = {}
|
||||
for ind, name in enumerate(stats_names):
|
||||
info_str[name] = Eval.stats[ind]
|
||||
|
||||
return info_str
|
||||
|
||||
def parse_args():
|
||||
"""Parse arguments."""
|
||||
parser = argparse.ArgumentParser('mindspore openpose_net test')
|
||||
|
||||
parser.add_argument('--model_path', type=str, default='./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt',
|
||||
help='path of testing model')
|
||||
parser.add_argument('--imgpath_val', type=str, default='/data0/zhy/dataset/coco/val2017',
|
||||
help='path of testing imgs')
|
||||
parser.add_argument('--ann', type=str, default='/data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json',
|
||||
help='path of annotations')
|
||||
parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs')
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
args, _ = parser.parse_known_args()
|
||||
return args
|
||||
|
||||
def load_model(test_net, model_path):
|
||||
assert os.path.exists(model_path)
|
||||
param_dict = load_checkpoint(model_path)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
|
||||
if key.startswith('moment'):
|
||||
continue
|
||||
elif key.startswith('network'):
|
||||
param_dict_new[key[8:]] = values
|
||||
# else:
|
||||
# param_dict_new[key] = values
|
||||
load_param_into_net(test_net, param_dict_new)
|
||||
|
||||
def preprocess(img):
|
||||
x_data = img.astype('f')
|
||||
x_data /= 255
|
||||
x_data -= 0.5
|
||||
x_data = x_data.transpose(2, 0, 1)[None]
|
||||
return x_data
|
||||
|
||||
def getImgsPath(img_dir_path):
|
||||
filepaths = []
|
||||
dirpaths = []
|
||||
pathName = img_dir_path
|
||||
|
||||
for root, dirs, files in os.walk(pathName):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
filepaths.append(file_path)
|
||||
for d in dirs:
|
||||
dir_path = os.path.join(root, d)
|
||||
dirpaths.append(dir_path)
|
||||
return filepaths
|
||||
|
||||
def compute_optimal_size(orig_img, img_size, stride=8):
|
||||
orig_img_h, orig_img_w, _ = orig_img.shape
|
||||
aspect = orig_img_h / orig_img_w
|
||||
if orig_img_h < orig_img_w:
|
||||
img_h = img_size
|
||||
img_w = np.round(img_size / aspect).astype(int)
|
||||
surplus = img_w % stride
|
||||
if surplus != 0:
|
||||
img_w += stride - surplus
|
||||
else:
|
||||
img_w = img_size
|
||||
img_h = np.round(img_size * aspect).astype(int)
|
||||
surplus = img_h % stride
|
||||
if surplus != 0:
|
||||
img_h += stride - surplus
|
||||
return (img_w, img_h)
|
||||
|
||||
def compute_peaks_from_heatmaps(heatmaps):
|
||||
|
||||
heatmaps = heatmaps[:-1]
|
||||
|
||||
all_peaks = []
|
||||
peak_counter = 0
|
||||
for i, heatmap in enumerate(heatmaps):
|
||||
heatmap = gaussian_filter(heatmap, sigma=params['gaussian_sigma'])
|
||||
|
||||
map_left = np.zeros(heatmap.shape)
|
||||
map_right = np.zeros(heatmap.shape)
|
||||
map_top = np.zeros(heatmap.shape)
|
||||
map_bottom = np.zeros(heatmap.shape)
|
||||
|
||||
map_left[1:, :] = heatmap[:-1, :]
|
||||
map_right[:-1, :] = heatmap[1:, :]
|
||||
map_top[:, 1:] = heatmap[:, :-1]
|
||||
map_bottom[:, :-1] = heatmap[:, 1:]
|
||||
|
||||
peaks_binary = np.logical_and.reduce((
|
||||
heatmap > params['heatmap_peak_thresh'],
|
||||
heatmap > map_left,
|
||||
heatmap > map_right,
|
||||
heatmap > map_top,
|
||||
heatmap > map_bottom,
|
||||
))
|
||||
|
||||
peaks = zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
|
||||
|
||||
peaks_with_score = [(i,) + peak_pos + (heatmap[peak_pos[1], peak_pos[0]],) for peak_pos in peaks]
|
||||
|
||||
peaks_id = range(peak_counter, peak_counter + len(peaks_with_score))
|
||||
peaks_with_score_and_id = [peaks_with_score[i] + (peaks_id[i],) for i in range(len(peaks_id))]
|
||||
|
||||
peak_counter += len(peaks_with_score_and_id)
|
||||
all_peaks.append(peaks_with_score_and_id)
|
||||
all_peaks = np.array([peak for peaks_each_category in all_peaks for peak in peaks_each_category])
|
||||
|
||||
return all_peaks
|
||||
|
||||
def compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg):
|
||||
candidate_connections = []
|
||||
for joint_a in cand_a:
|
||||
for joint_b in cand_b:
|
||||
vector = joint_b[:2] - joint_a[:2]
|
||||
norm = np.linalg.norm(vector)
|
||||
if norm == 0:
|
||||
continue
|
||||
ys = np.linspace(joint_a[1], joint_b[1], num=cfg['n_integ_points'])
|
||||
xs = np.linspace(joint_a[0], joint_b[0], num=cfg['n_integ_points'])
|
||||
integ_points = np.stack([ys, xs]).T.round().astype('i')
|
||||
|
||||
paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]])
|
||||
unit_vector = vector / norm
|
||||
inner_products = np.dot(paf_in_edge, unit_vector)
|
||||
integ_value = inner_products.sum() / len(inner_products)
|
||||
integ_value_with_dist_prior = integ_value + min(cfg['limb_length_ratio'] * img_len / norm -
|
||||
cfg['length_penalty_value'], 0)
|
||||
n_valid_points = sum(inner_products > cfg['inner_product_thresh'])
|
||||
if n_valid_points > cfg['n_integ_points_thresh'] and integ_value_with_dist_prior > 0:
|
||||
candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior])
|
||||
candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True)
|
||||
return candidate_connections
|
||||
|
||||
def compute_connections(pafs, all_peaks, img_len, cfg):
|
||||
all_connections = []
|
||||
for i in range(len(cfg['limbs_point'])):
|
||||
paf_index = [i * 2, i * 2 + 1]
|
||||
paf = pafs[paf_index] # shape: (2, 320, 320)
|
||||
limb_point = cfg['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>]
|
||||
cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:]
|
||||
cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:]
|
||||
|
||||
if cand_a and cand_b:
|
||||
candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg)
|
||||
|
||||
connections = np.zeros((0, 3))
|
||||
|
||||
for index_a, index_b, score in candidate_connections:
|
||||
if index_a not in connections[:, 0] and index_b not in connections[:, 1]:
|
||||
connections = np.vstack([connections, [index_a, index_b, score]])
|
||||
if len(connections) >= min(len(cand_a), len(cand_b)):
|
||||
break
|
||||
all_connections.append(connections)
|
||||
else:
|
||||
all_connections.append(np.zeros((0, 3)))
|
||||
return all_connections
|
||||
|
||||
def grouping_key_points(all_connections, candidate_peaks, cfg):
|
||||
subsets = -1 * np.ones((0, 20))
|
||||
|
||||
for l, connections in enumerate(all_connections):
|
||||
joint_a, joint_b = cfg['limbs_point'][l]
|
||||
for ind_a, ind_b, score in connections[:, :3]:
|
||||
ind_a, ind_b = int(ind_a), int(ind_b)
|
||||
joint_found_cnt = 0
|
||||
joint_found_subset_index = [-1, -1]
|
||||
for subset_ind, subset in enumerate(subsets):
|
||||
|
||||
if subset[joint_a] == ind_a or subset[joint_b] == ind_b:
|
||||
joint_found_subset_index[joint_found_cnt] = subset_ind
|
||||
joint_found_cnt += 1
|
||||
|
||||
if joint_found_cnt == 1:
|
||||
|
||||
found_subset = subsets[joint_found_subset_index[0]]
|
||||
if found_subset[joint_b] != ind_b:
|
||||
found_subset[joint_b] = ind_b
|
||||
found_subset[-1] += 1 # increment joint count
|
||||
found_subset[-2] += candidate_peaks[ind_b, 3] + score
|
||||
|
||||
elif joint_found_cnt == 2:
|
||||
|
||||
found_subset_1 = subsets[joint_found_subset_index[0]]
|
||||
found_subset_2 = subsets[joint_found_subset_index[1]]
|
||||
|
||||
membership = ((found_subset_1 >= 0).astype(int) + (found_subset_2 >= 0).astype(int))[:-2]
|
||||
if not np.any(membership == 2): # merge two subsets when no duplication
|
||||
found_subset_1[:-2] += found_subset_2[:-2] + 1 # default is -1
|
||||
found_subset_1[-2:] += found_subset_2[-2:]
|
||||
found_subset_1[-2] += score
|
||||
subsets = np.delete(subsets, joint_found_subset_index[1], axis=0)
|
||||
else:
|
||||
if found_subset_1[joint_a] == -1:
|
||||
found_subset_1[joint_a] = ind_a
|
||||
found_subset_1[-1] += 1
|
||||
found_subset_1[-2] += candidate_peaks[ind_a, 3] + score
|
||||
elif found_subset_1[joint_b] == -1:
|
||||
found_subset_1[joint_b] = ind_b
|
||||
found_subset_1[-1] += 1
|
||||
found_subset_1[-2] += candidate_peaks[ind_b, 3] + score
|
||||
if found_subset_2[joint_a] == -1:
|
||||
found_subset_2[joint_a] = ind_a
|
||||
found_subset_2[-1] += 1
|
||||
found_subset_2[-2] += candidate_peaks[ind_a, 3] + score
|
||||
elif found_subset_2[joint_b] == -1:
|
||||
found_subset_2[joint_b] = ind_b
|
||||
found_subset_2[-1] += 1
|
||||
found_subset_2[-2] += candidate_peaks[ind_b, 3] + score
|
||||
|
||||
elif joint_found_cnt == 0 and l != 9 and l != 13:
|
||||
row = -1 * np.ones(20)
|
||||
row[joint_a] = ind_a
|
||||
row[joint_b] = ind_b
|
||||
row[-1] = 2
|
||||
row[-2] = sum(candidate_peaks[[ind_a, ind_b], 3]) + score
|
||||
subsets = np.vstack([subsets, row])
|
||||
elif joint_found_cnt >= 3:
|
||||
pass
|
||||
|
||||
# delete low score subsets
|
||||
keep = np.logical_and(subsets[:, -1] >= cfg['n_subset_limbs_thresh'],
|
||||
subsets[:, -2] / subsets[:, -1] >= cfg['subset_score_thresh'])
|
||||
# cfg['n_subset_limbs_thresh'] = 3
|
||||
# cfg['subset_score_thresh'] = 0.2
|
||||
subsets = subsets[keep]
|
||||
return subsets
|
||||
|
||||
def subsets_to_pose_array(subsets, all_peaks):
|
||||
person_pose_array = []
|
||||
for subset in subsets:
|
||||
joints = []
|
||||
for joint_index in subset[:18].astype('i'):
|
||||
if joint_index >= 0:
|
||||
joint = all_peaks[joint_index][1:3].tolist()
|
||||
joint.append(2)
|
||||
joints.append(joint)
|
||||
else:
|
||||
joints.append([0, 0, 0])
|
||||
person_pose_array.append(np.array(joints))
|
||||
person_pose_array = np.array(person_pose_array)
|
||||
return person_pose_array
|
||||
|
||||
def detect(img, network):
|
||||
orig_img = img.copy()
|
||||
orig_img_h, orig_img_w, _ = orig_img.shape
|
||||
|
||||
input_w, input_h = compute_optimal_size(orig_img, params['inference_img_size']) # 368
|
||||
# map_w, map_h = compute_optimal_size(orig_img, params['heatmap_size']) # 320
|
||||
map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size'])
|
||||
|
||||
print("image size is: ", input_w, input_h)
|
||||
|
||||
resized_image = cv2.resize(orig_img, (input_w, input_h))
|
||||
x_data = preprocess(resized_image)
|
||||
x_data = Tensor(x_data, mstype.float32)
|
||||
x_data.requires_grad = False
|
||||
|
||||
logit_pafs, logit_heatmap = network(x_data)
|
||||
|
||||
logit_pafs = logit_pafs[-1].asnumpy()[0]
|
||||
logit_heatmap = logit_heatmap[-1].asnumpy()[0]
|
||||
|
||||
pafs = np.zeros((logit_pafs.shape[0], map_h, map_w))
|
||||
for i in range(logit_pafs.shape[0]):
|
||||
pafs[i] = cv2.resize(logit_pafs[i], (map_w, map_h))
|
||||
if show_gt:
|
||||
save_path = "./test_output/" + str(i) + "pafs.png"
|
||||
cv2.imwrite(save_path, pafs[i]*255)
|
||||
|
||||
heatmaps = np.zeros((logit_heatmap.shape[0], map_h, map_w))
|
||||
for i in range(logit_heatmap.shape[0]):
|
||||
heatmaps[i] = cv2.resize(logit_heatmap[i], (map_w, map_h))
|
||||
if show_gt:
|
||||
save_path = "./test_output/" + str(i) + "heatmap.png"
|
||||
cv2.imwrite(save_path, heatmaps[i]*255)
|
||||
|
||||
all_peaks = compute_peaks_from_heatmaps(heatmaps)
|
||||
if not all_peaks:
|
||||
return np.empty((0, len(JointType), 3)), np.empty(0)
|
||||
all_connections = compute_connections(pafs, all_peaks, map_w, params)
|
||||
subsets = grouping_key_points(all_connections, all_peaks, params)
|
||||
all_peaks[:, 1] *= orig_img_w / map_w
|
||||
all_peaks[:, 2] *= orig_img_h / map_h
|
||||
poses = subsets_to_pose_array(subsets, all_peaks)
|
||||
scores = subsets[:, -2]
|
||||
|
||||
return poses, scores
|
||||
|
||||
def draw_person_pose(orig_img, poses):
|
||||
orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
|
||||
if not poses:
|
||||
return orig_img
|
||||
|
||||
limb_colors = [
|
||||
[0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255],
|
||||
[0, 85, 255], [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0.],
|
||||
[255, 0, 85], [170, 255, 0], [85, 255, 0], [170, 0, 255.], [0, 0, 255],
|
||||
[0, 0, 255], [255, 0, 255], [170, 0, 255], [255, 0, 170],
|
||||
]
|
||||
|
||||
joint_colors = [
|
||||
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
|
||||
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
|
||||
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
|
||||
[255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
||||
|
||||
canvas = orig_img.copy()
|
||||
|
||||
# limbs
|
||||
for pose in poses.round().astype('i'):
|
||||
for i, (limb, color) in enumerate(zip(params['limbs_point'], limb_colors)):
|
||||
if i not in (9, 13): # don't show ear-shoulder connection
|
||||
limb_ind = np.array(limb)
|
||||
if np.all(pose[limb_ind][:, 2] != 0):
|
||||
joint1, joint2 = pose[limb_ind][:, :2]
|
||||
cv2.line(canvas, tuple(joint1), tuple(joint2), color, 2)
|
||||
|
||||
# joints
|
||||
for pose in poses.round().astype('i'):
|
||||
for i, ((x, y, v), color) in enumerate(zip(pose, joint_colors)):
|
||||
if v != 0:
|
||||
cv2.circle(canvas, (x, y), 3, color, -1)
|
||||
return canvas
|
||||
|
||||
def depreprocess(img):
|
||||
# x_data = img.astype('f')
|
||||
x_data = img[0]
|
||||
x_data += 0.5
|
||||
x_data *= 255
|
||||
x_data = x_data.astype('uint8')
|
||||
x_data = x_data.transpose(1, 2, 0)
|
||||
return x_data
|
||||
|
||||
def _eval():
|
||||
args = parse_args()
|
||||
if args.is_distributed:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
if not os.path.exists(args.output_path):
|
||||
os.mkdir(args.output_path)
|
||||
network = OpenPoseNet()
|
||||
network.set_train(False)
|
||||
load_model(network, args.model_path)
|
||||
|
||||
print("load models right")
|
||||
dataset = valdata(args.ann, args.imgpath_val, args.rank, args.group_size, mode='val')
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
de_dataset = dataset.create_tuple_iterator()
|
||||
|
||||
print("eval dataset size: ", dataset_size)
|
||||
kpt_json = []
|
||||
for _, (img, img_id) in tqdm(enumerate(de_dataset), total=dataset_size):
|
||||
img = img.asnumpy()
|
||||
img_id = int((img_id.asnumpy())[0])
|
||||
poses, scores = detect(img, network)
|
||||
|
||||
if poses:
|
||||
#print("got poses")
|
||||
for index, pose in enumerate(poses):
|
||||
data = dict()
|
||||
|
||||
pose = pose[[0, 15, 14, 17, 16, 5, 2, 6, 3, 7, 4, 11, 8, 12, 9, 13, 10, 1], :].round().astype('i')
|
||||
|
||||
keypoints = pose.reshape(-1).tolist()
|
||||
keypoints = keypoints[:-3]
|
||||
data['image_id'] = img_id
|
||||
data['score'] = scores[index]
|
||||
data['category_id'] = 1
|
||||
data['keypoints'] = keypoints
|
||||
kpt_json.append(data)
|
||||
else:
|
||||
print("Predict poses size is zero.", flush=True)
|
||||
img = draw_person_pose(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), poses)
|
||||
|
||||
#print('Saving result into',str(img_id)+'.png...')
|
||||
save_path = os.path.join(args.output_path, str(img_id)+".png")
|
||||
cv2.imwrite(save_path, img)
|
||||
|
||||
result_json = 'eval_result.json'
|
||||
with open(os.path.join(args.output_path, result_json), 'w') as fid:
|
||||
json.dump(kpt_json, fid)
|
||||
res = evaluate_mAP(os.path.join(args.output_path, result_json), ann_file=args.ann)
|
||||
print('result: ', res)
|
||||
|
||||
if __name__ == "__main__":
|
||||
_eval()
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.openposenet import OpenPoseNet
|
||||
|
||||
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
# define net
|
||||
net = OpenPoseNet()
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
inputs = np.random.uniform(0.0, 1.0, size=[1, 3, 368, 368]).astype(np.float32)
|
||||
export(net, Tensor(inputs), file_name="openpose.air", file_format='AIR')
|
|
@ -0,0 +1,61 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
RANK_TABLE_FILE=$(get_real_path $1)
|
||||
|
||||
echo $RANK_TABLE_FILE
|
||||
|
||||
if [ ! -f $RANK_TABLE_FILE ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_FILE
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--train_dir train2017 \
|
||||
--group_size 8 \
|
||||
--train_ann person_keypoints_train2017.json > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
python eval.py \
|
||||
--model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt \
|
||||
--imgpath_val /data0/zhy/dataset/coco/val2017 \
|
||||
--ann /data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json \
|
||||
> eval.log 2>&1 &
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
cd ..
|
||||
python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > scripts/train.log 2>&1 &
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from enum import IntEnum
|
||||
|
||||
class JointType(IntEnum):
|
||||
Nose = 0
|
||||
|
||||
Neck = 1
|
||||
|
||||
RightShoulder = 2
|
||||
|
||||
RightElbow = 3
|
||||
|
||||
RightHand = 4
|
||||
|
||||
LeftShoulder = 5
|
||||
|
||||
LeftElbow = 6
|
||||
|
||||
LeftHand = 7
|
||||
|
||||
RightWaist = 8
|
||||
|
||||
RightKnee = 9
|
||||
|
||||
RightFoot = 10
|
||||
|
||||
LeftWaist = 11
|
||||
|
||||
LeftKnee = 12
|
||||
|
||||
LeftFoot = 13
|
||||
|
||||
RightEye = 14
|
||||
|
||||
LeftEye = 15
|
||||
|
||||
RightEar = 16
|
||||
|
||||
LeftEar = 17
|
||||
|
||||
params = {
|
||||
# paths
|
||||
'data_dir': '/data0/zhy/dataset/coco',
|
||||
'vgg_path': '/data0/zhy/dataset/coco/vgg19-0-97_5004.ckpt',
|
||||
'save_model_path': './checkpoints/',
|
||||
'load_pretrain': False,
|
||||
'pretrained_model_path': "",
|
||||
# training params
|
||||
'batch_size': 10,
|
||||
|
||||
'lr': 1e-4,
|
||||
'lr_gamma': 0.1,
|
||||
'lr_steps': '100000,200000,250000',
|
||||
'lr_steps_NP': '250000',
|
||||
|
||||
'loss_scale': 16386,
|
||||
'max_epoch_train': 60,
|
||||
'min_keypoints': 5,
|
||||
'min_area': 32 * 32,
|
||||
'insize': 368,
|
||||
'downscale': 8,
|
||||
'paf_sigma': 8,
|
||||
'heatmap_sigma': 7,
|
||||
'eva_num': 100,
|
||||
'keep_checkpoint_max': 5,
|
||||
'log_interval': 100,
|
||||
'ckpt_interval': 663, # 5000,
|
||||
|
||||
'min_box_size': 64,
|
||||
'max_box_size': 512,
|
||||
'min_scale': 0.5,
|
||||
'max_scale': 2.0,
|
||||
'max_rotate_degree': 40,
|
||||
'center_perterb_max': 40,
|
||||
|
||||
# inference params
|
||||
'inference_img_size': 368,
|
||||
'inference_scales': [0.5, 1, 1.5, 2],
|
||||
# 'inference_scales': [1.0],
|
||||
'heatmap_size': 320,
|
||||
'gaussian_sigma': 2.5,
|
||||
'ksize': 17,
|
||||
'n_integ_points': 10,
|
||||
'n_integ_points_thresh': 8,
|
||||
'heatmap_peak_thresh': 0.05,
|
||||
'inner_product_thresh': 0.05,
|
||||
'limb_length_ratio': 1.0,
|
||||
'length_penalty_value': 1,
|
||||
'n_subset_limbs_thresh': 3,
|
||||
'subset_score_thresh': 0.2,
|
||||
'limbs_point': [
|
||||
[JointType.Neck, JointType.RightWaist],
|
||||
[JointType.RightWaist, JointType.RightKnee],
|
||||
[JointType.RightKnee, JointType.RightFoot],
|
||||
[JointType.Neck, JointType.LeftWaist],
|
||||
[JointType.LeftWaist, JointType.LeftKnee],
|
||||
[JointType.LeftKnee, JointType.LeftFoot],
|
||||
[JointType.Neck, JointType.RightShoulder],
|
||||
[JointType.RightShoulder, JointType.RightElbow],
|
||||
[JointType.RightElbow, JointType.RightHand],
|
||||
[JointType.RightShoulder, JointType.RightEar],
|
||||
[JointType.Neck, JointType.LeftShoulder],
|
||||
[JointType.LeftShoulder, JointType.LeftElbow],
|
||||
[JointType.LeftElbow, JointType.LeftHand],
|
||||
[JointType.LeftShoulder, JointType.LeftEar],
|
||||
[JointType.Neck, JointType.Nose],
|
||||
[JointType.Nose, JointType.RightEye],
|
||||
[JointType.Nose, JointType.LeftEye],
|
||||
[JointType.RightEye, JointType.RightEar],
|
||||
[JointType.LeftEye, JointType.LeftEar]
|
||||
],
|
||||
'joint_indices': [
|
||||
JointType.Nose,
|
||||
JointType.LeftEye,
|
||||
JointType.RightEye,
|
||||
JointType.LeftEar,
|
||||
JointType.RightEar,
|
||||
JointType.LeftShoulder,
|
||||
JointType.RightShoulder,
|
||||
JointType.LeftElbow,
|
||||
JointType.RightElbow,
|
||||
JointType.LeftHand,
|
||||
JointType.RightHand,
|
||||
JointType.LeftWaist,
|
||||
JointType.RightWaist,
|
||||
JointType.LeftKnee,
|
||||
JointType.RightKnee,
|
||||
JointType.LeftFoot,
|
||||
JointType.RightFoot
|
||||
],
|
||||
|
||||
# face params
|
||||
'face_inference_img_size': 368,
|
||||
'face_heatmap_peak_thresh': 0.1,
|
||||
'face_crop_scale': 1.5,
|
||||
'face_line_indices': [
|
||||
[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], # 轮廓
|
||||
[17, 18], [18, 19], [19, 20], [20, 21],
|
||||
[22, 23], [23, 24], [24, 25], [25, 26],
|
||||
[27, 28], [28, 29], [29, 30],
|
||||
[31, 32], [32, 33], [33, 34], [34, 35],
|
||||
[36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
|
||||
[42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
|
||||
[48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], # 唇外廓
|
||||
[60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], [66, 67], [67, 60]
|
||||
],
|
||||
|
||||
# hand params
|
||||
'hand_inference_img_size': 368,
|
||||
'hand_heatmap_peak_thresh': 0.1,
|
||||
'fingers_indices': [
|
||||
[[0, 1], [1, 2], [2, 3], [3, 4]],
|
||||
[[0, 5], [5, 6], [6, 7], [7, 8]],
|
||||
[[0, 9], [9, 10], [10, 11], [11, 12]],
|
||||
[[0, 13], [13, 14], [14, 15], [15, 16]],
|
||||
[[0, 17], [17, 18], [18, 19], [19, 20]],
|
||||
],
|
||||
}
|
|
@ -0,0 +1,511 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
from pycocotools.coco import COCO as ReadJson
|
||||
import mindspore.dataset as de
|
||||
|
||||
from src.config import JointType, params
|
||||
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
class txtdataset():
|
||||
def __init__(self, train, imgpath, maskpath, insize, mode='train', n_samples=None):
|
||||
self.train = train
|
||||
self.mode = mode
|
||||
self.imgpath = imgpath
|
||||
self.maskpath = maskpath
|
||||
self.insize = insize
|
||||
self.maxtime = 0
|
||||
self.catIds = train.getCatIds(catNms=['person'])
|
||||
self.imgIds = sorted(train.getImgIds(catIds=self.catIds))
|
||||
if self.mode == 'train':
|
||||
self.clean_imgIds()
|
||||
if self.mode in ['val', 'eval'] and n_samples is not None:
|
||||
self.imgIds = random.sample(self.imgIds, n_samples)
|
||||
print('{} images: {}'.format(mode, len(self)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgIds)
|
||||
|
||||
def clean_imgIds(self):
|
||||
print("cleaning imgids")
|
||||
|
||||
for img_id in self.imgIds.copy():
|
||||
annotations = None
|
||||
anno_ids = self.train.getAnnIds(imgIds=[img_id], iscrowd=None)
|
||||
|
||||
# annotation for that image
|
||||
if anno_ids:
|
||||
annotations_for_img = self.train.loadAnns(anno_ids)
|
||||
|
||||
person_cnt = 0
|
||||
valid_annotations_for_img = []
|
||||
for annotation in annotations_for_img:
|
||||
# if too few keypoints or too small
|
||||
if annotation['num_keypoints'] >= params['min_keypoints'] and \
|
||||
annotation['area'] > params['min_area']:
|
||||
person_cnt += 1
|
||||
valid_annotations_for_img.append(annotation)
|
||||
|
||||
# if person annotation
|
||||
if person_cnt > 0:
|
||||
annotations = valid_annotations_for_img
|
||||
if annotations is None:
|
||||
#print(img_id,'is removed')
|
||||
self.imgIds.remove(img_id)
|
||||
|
||||
def overlay_paf(self, img, paf):
|
||||
hue = ((np.arctan2(paf[1], paf[0]) / np.pi) / -2 + 0.5)
|
||||
saturation = np.sqrt(paf[0] ** 2 + paf[1] ** 2)
|
||||
saturation[saturation > 1.0] = 1.0
|
||||
value = saturation.copy()
|
||||
hsv_paf = np.vstack((hue[np.newaxis], saturation[np.newaxis], value[np.newaxis])).transpose(1, 2, 0)
|
||||
rgb_paf = cv2.cvtColor((hsv_paf * 255).astype(np.uint8), cv2.COLOR_HSV2BGR)
|
||||
img = cv2.addWeighted(img, 0.6, rgb_paf, 0.4, 0)
|
||||
return img
|
||||
|
||||
def overlay_pafs(self, img, pafs):
|
||||
mix_paf = np.zeros((2,) + img.shape[:-1])
|
||||
paf_flags = np.zeros(mix_paf.shape) # for constant paf
|
||||
|
||||
for paf in pafs.reshape((int(pafs.shape[0]/2), 2,) + pafs.shape[1:]):
|
||||
paf_flags = paf != 0
|
||||
paf_flags += np.broadcast_to(paf_flags[0] | paf_flags[1], paf.shape)
|
||||
mix_paf += paf
|
||||
|
||||
mix_paf[paf_flags > 0] /= paf_flags[paf_flags > 0]
|
||||
img = self.overlay_paf(img, mix_paf)
|
||||
return img
|
||||
|
||||
def overlay_heatmap(self, img, heatmap):
|
||||
rgb_heatmap = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
||||
img = cv2.addWeighted(img, 0.6, rgb_heatmap, 0.4, 0)
|
||||
return img
|
||||
|
||||
def overlay_ignore_mask(self, img, ignore_mask):
|
||||
img = img * np.repeat((ignore_mask == 0).astype(np.uint8)[:, :, None], 3, axis=2)
|
||||
return img
|
||||
|
||||
# -------------------- augment code --------------------------------
|
||||
def get_pose_bboxes(self, poses):
|
||||
pose_bboxes = []
|
||||
for pose in poses:
|
||||
x1 = pose[pose[:, 2] > 0][:, 0].min()
|
||||
y1 = pose[pose[:, 2] > 0][:, 1].min()
|
||||
x2 = pose[pose[:, 2] > 0][:, 0].max()
|
||||
y2 = pose[pose[:, 2] > 0][:, 1].max()
|
||||
pose_bboxes.append([x1, y1, x2, y2])
|
||||
pose_bboxes = np.array(pose_bboxes)
|
||||
return pose_bboxes
|
||||
|
||||
def resize_data(self, img, ignore_mask, poses, shape):
|
||||
"""resize img, mask and annotations"""
|
||||
img_h, img_w, _ = img.shape
|
||||
|
||||
resized_img = cv2.resize(img, shape)
|
||||
ignore_mask = cv2.resize(ignore_mask.astype(np.uint8), shape).astype('bool')
|
||||
poses[:, :, :2] = (poses[:, :, :2] * np.array(shape) / np.array((img_w, img_h)))
|
||||
return resized_img, ignore_mask, poses
|
||||
|
||||
def random_resize_img(self, img, ignore_mask, poses):
|
||||
h, w, _ = img.shape
|
||||
joint_bboxes = self.get_pose_bboxes(poses)
|
||||
bbox_sizes = ((joint_bboxes[:, 2:] - joint_bboxes[:, :2] + 1) ** 2).sum(axis=1) ** 0.5
|
||||
|
||||
min_scale = params['min_box_size'] / bbox_sizes.min()
|
||||
max_scale = params['max_box_size'] / bbox_sizes.max()
|
||||
|
||||
min_scale = min(max(min_scale, params['min_scale']), 1)
|
||||
max_scale = min(max(max_scale, 1), params['max_scale'])
|
||||
|
||||
scale = float((max_scale - min_scale) * random.random() + min_scale)
|
||||
#scale = random.random()*1.5+0.5
|
||||
shape = (round(w * scale), round(h * scale))
|
||||
|
||||
resized_img, resized_mask, resized_poses = self.resize_data(img, ignore_mask, poses, shape)
|
||||
return resized_img, resized_mask, resized_poses
|
||||
|
||||
def random_rotate_img(self, img, mask, poses):
|
||||
h, w, _ = img.shape
|
||||
# degree = (random.random() - 0.5) * 2 * params['max_rotate_degree']
|
||||
degree = np.random.randn() / 3 * params['max_rotate_degree']
|
||||
rad = degree * math.pi / 180
|
||||
center = (w / 2, h / 2)
|
||||
R = cv2.getRotationMatrix2D(center, degree, 1)
|
||||
bbox = (w * abs(math.cos(rad)) + h * abs(math.sin(rad)), w * abs(math.sin(rad)) + h * abs(math.cos(rad)))
|
||||
R[0, 2] += bbox[0] / 2 - center[0]
|
||||
R[1, 2] += bbox[1] / 2 - center[1]
|
||||
rotate_img = cv2.warpAffine(img, R, (int(bbox[0]+0.5), int(bbox[1]+0.5)), flags=cv2.INTER_CUBIC,
|
||||
borderMode=cv2.BORDER_CONSTANT, borderValue=[127.5, 127.5, 127.5])
|
||||
rotate_mask = cv2.warpAffine(mask.astype('uint8')*255, R, (int(bbox[0]+0.5), int(bbox[1]+0.5))) > 0
|
||||
|
||||
tmp_poses = np.ones_like(poses)
|
||||
tmp_poses[:, :, :2] = poses[:, :, :2].copy()
|
||||
tmp_rotate_poses = np.dot(tmp_poses, R.T) # apply rotation matrix to the poses
|
||||
rotate_poses = poses.copy() # to keep visibility flag
|
||||
rotate_poses[:, :, :2] = tmp_rotate_poses
|
||||
return rotate_img, rotate_mask, rotate_poses
|
||||
|
||||
def random_crop_img(self, img, ignore_mask, poses):
|
||||
h, w, _ = img.shape
|
||||
insize = self.insize
|
||||
joint_bboxes = self.get_pose_bboxes(poses)
|
||||
bbox = random.choice(joint_bboxes) # select a bbox randomly
|
||||
bbox_center = bbox[:2] + (bbox[2:] - bbox[:2]) / 2
|
||||
|
||||
r_xy = np.random.rand(2)
|
||||
perturb = ((r_xy - 0.5) * 2 * params['center_perterb_max'])
|
||||
center = (bbox_center + perturb + 0.5).astype('i')
|
||||
|
||||
crop_img = np.zeros((insize, insize, 3), 'uint8') + 127.5
|
||||
crop_mask = np.zeros((insize, insize), 'bool')
|
||||
|
||||
offset = (center - (insize - 1) / 2 + 0.5).astype('i')
|
||||
offset_ = (center + (insize - 1) / 2 - (w - 1, h - 1) + 0.5).astype('i')
|
||||
|
||||
x1, y1 = (center - (insize-1)/2 + 0.5).astype('i')
|
||||
x2, y2 = (center + (insize-1)/2 + 0.5).astype('i')
|
||||
|
||||
x1 = max(x1, 0)
|
||||
y1 = max(y1, 0)
|
||||
x2 = min(x2, w-1)
|
||||
y2 = min(y2, h-1)
|
||||
|
||||
x_from = -offset[0] if offset[0] < 0 else 0
|
||||
y_from = -offset[1] if offset[1] < 0 else 0
|
||||
x_to = insize - offset_[0] - 1 if offset_[0] >= 0 else insize - 1
|
||||
y_to = insize - offset_[1] - 1 if offset_[1] >= 0 else insize - 1
|
||||
|
||||
crop_img[y_from:y_to+1, x_from:x_to+1] = img[y1:y2+1, x1:x2+1].copy()
|
||||
crop_mask[y_from:y_to+1, x_from:x_to+1] = ignore_mask[y1:y2+1, x1:x2+1].copy()
|
||||
|
||||
poses[:, :, :2] -= offset
|
||||
return crop_img.astype('uint8'), crop_mask, poses
|
||||
|
||||
def distort_color(self, img):
|
||||
img_max = np.broadcast_to(np.array(255, dtype=np.uint8), img.shape[:-1])
|
||||
img_min = np.zeros(img.shape[:-1], dtype=np.uint8)
|
||||
|
||||
hsv_img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2HSV).astype(np.int32)
|
||||
hsv_img[:, :, 0] = np.maximum(np.minimum(hsv_img[:, :, 0] - 10 + np.random.randint(20 + 1), img_max), img_min) # hue
|
||||
hsv_img[:, :, 1] = np.maximum(np.minimum(hsv_img[:, :, 1] - 40 + np.random.randint(80 + 1), img_max), img_min) # saturation
|
||||
hsv_img[:, :, 2] = np.maximum(np.minimum(hsv_img[:, :, 2] - 30 + np.random.randint(60 + 1), img_max), img_min) # value
|
||||
hsv_img = hsv_img.astype(np.uint8)
|
||||
|
||||
distorted_img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
|
||||
return distorted_img
|
||||
|
||||
def flip_img(self, img, mask, poses):
|
||||
flipped_img = cv2.flip(img, 1)
|
||||
flipped_mask = cv2.flip(mask.astype(np.uint8), 1).astype('bool')
|
||||
poses[:, :, 0] = img.shape[1] - 1 - poses[:, :, 0]
|
||||
|
||||
def swap_joints(poses, joint_type_1, joint_type_2):
|
||||
tmp = poses[:, joint_type_1].copy()
|
||||
poses[:, joint_type_1] = poses[:, joint_type_2]
|
||||
poses[:, joint_type_2] = tmp
|
||||
|
||||
swap_joints(poses, JointType.LeftEye, JointType.RightEye)
|
||||
swap_joints(poses, JointType.LeftEar, JointType.RightEar)
|
||||
swap_joints(poses, JointType.LeftShoulder, JointType.RightShoulder)
|
||||
swap_joints(poses, JointType.LeftElbow, JointType.RightElbow)
|
||||
swap_joints(poses, JointType.LeftHand, JointType.RightHand)
|
||||
swap_joints(poses, JointType.LeftWaist, JointType.RightWaist)
|
||||
swap_joints(poses, JointType.LeftKnee, JointType.RightKnee)
|
||||
swap_joints(poses, JointType.LeftFoot, JointType.RightFoot)
|
||||
return flipped_img, flipped_mask, poses
|
||||
|
||||
def augment_data(self, img, ignore_mask, poses):
|
||||
aug_img = img.copy()
|
||||
aug_img, ignore_mask, poses = self.random_resize_img(aug_img, ignore_mask, poses)
|
||||
aug_img, ignore_mask, poses = self.random_rotate_img(aug_img, ignore_mask, poses)
|
||||
aug_img, ignore_mask, poses = self.random_crop_img(aug_img, ignore_mask, poses)
|
||||
if np.random.randint(2):
|
||||
aug_img = self.distort_color(aug_img)
|
||||
if np.random.randint(2):
|
||||
aug_img, ignore_mask, poses = self.flip_img(aug_img, ignore_mask, poses)
|
||||
|
||||
return aug_img, ignore_mask, poses
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# return shape: (height, width)
|
||||
def generate_gaussian_heatmap(self, shape, joint, sigma):
|
||||
x, y = joint
|
||||
grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
|
||||
grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
|
||||
grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2
|
||||
gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2)
|
||||
return gaussian_heatmap
|
||||
|
||||
def generate_heatmaps(self, img, poses, heatmap_sigma):
|
||||
heatmaps = np.zeros((0,) + img.shape[:-1])
|
||||
sum_heatmap = np.zeros(img.shape[:-1])
|
||||
for joint_index in range(len(JointType)):
|
||||
heatmap = np.zeros(img.shape[:-1])
|
||||
for pose in poses:
|
||||
if pose[joint_index, 2] > 0:
|
||||
jointmap = self.generate_gaussian_heatmap(img.shape[:-1], pose[joint_index][:2], heatmap_sigma)
|
||||
heatmap[jointmap > heatmap] = jointmap[jointmap > heatmap]
|
||||
sum_heatmap[jointmap > sum_heatmap] = jointmap[jointmap > sum_heatmap]
|
||||
heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape)))
|
||||
bg_heatmap = 1 - sum_heatmap # background channel
|
||||
heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
|
||||
return heatmaps.astype('f')
|
||||
|
||||
# return shape: (2, height, width)
|
||||
def generate_constant_paf(self, shape, joint_from, joint_to, paf_width):
|
||||
if np.array_equal(joint_from, joint_to): # same joint
|
||||
return np.zeros((2,) + shape[:-1])
|
||||
|
||||
joint_distance = np.linalg.norm(joint_to - joint_from)
|
||||
unit_vector = (joint_to - joint_from) / joint_distance
|
||||
rad = np.pi / 2
|
||||
# [[0, 1], [-1, 0]]
|
||||
rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]])
|
||||
# [[u_y], [-u_x]]
|
||||
vertical_unit_vector = np.dot(rot_matrix, unit_vector)
|
||||
grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
|
||||
grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
|
||||
horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
|
||||
horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
|
||||
vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] * \
|
||||
(grid_y - joint_from[1])
|
||||
vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8
|
||||
paf_flag = horizontal_paf_flag & vertical_paf_flag
|
||||
constant_paf = np.stack((paf_flag, paf_flag)) *\
|
||||
np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1)
|
||||
|
||||
return constant_paf
|
||||
|
||||
def generate_pafs(self, img, poses, paf_sigma):
|
||||
pafs = np.zeros((0,) + img.shape[:-1])
|
||||
|
||||
for limb in params['limbs_point']:
|
||||
paf = np.zeros((2,) + img.shape[:-1])
|
||||
paf_flags = np.zeros(paf.shape) # for constant paf
|
||||
|
||||
for pose in poses:
|
||||
joint_from, joint_to = pose[limb]
|
||||
if joint_from[2] > 0 and joint_to[2] > 0:
|
||||
limb_paf = self.generate_constant_paf(img.shape, joint_from[:2], joint_to[:2], paf_sigma) # [2, 368, 368]
|
||||
limb_paf_flags = limb_paf != 0
|
||||
paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape)
|
||||
|
||||
paf += limb_paf
|
||||
|
||||
paf[paf_flags > 0] /= paf_flags[paf_flags > 0]
|
||||
pafs = np.vstack((pafs, paf))
|
||||
return pafs.astype('f')
|
||||
|
||||
def get_img_annotation(self, ind=None, img_id=None):
|
||||
annotations = None
|
||||
|
||||
if ind is not None:
|
||||
img_id = self.imgIds[ind]
|
||||
anno_ids = self.train.getAnnIds(imgIds=[img_id], iscrowd=None)
|
||||
|
||||
# annotation for that image
|
||||
if anno_ids:
|
||||
annotations_for_img = self.train.loadAnns(anno_ids)
|
||||
|
||||
person_cnt = 0
|
||||
valid_annotations_for_img = []
|
||||
for annotation in annotations_for_img:
|
||||
# if too few keypoints or too small
|
||||
if annotation['num_keypoints'] >= params['min_keypoints'] and annotation['area'] > params['min_area']:
|
||||
person_cnt += 1
|
||||
valid_annotations_for_img.append(annotation)
|
||||
|
||||
# if person annotation
|
||||
if person_cnt > 0:
|
||||
annotations = valid_annotations_for_img
|
||||
|
||||
img_path = os.path.join(self.imgpath, self.train.loadImgs([img_id])[0]['file_name'])
|
||||
mask_path = os.path.join(self.maskpath, '{:012d}.png'.format(img_id))
|
||||
img = cv2.imread(img_path)
|
||||
ignore_mask = cv2.imread(mask_path, 0)
|
||||
if ignore_mask is None:
|
||||
ignore_mask = np.zeros(img.shape[:2], np.float32)
|
||||
else:
|
||||
ignore_mask[ignore_mask == 255] = 1
|
||||
|
||||
if self.mode == 'eval':
|
||||
return img, img_id, annotations_for_img, ignore_mask
|
||||
|
||||
return img, img_id, annotations, ignore_mask.astype('f')
|
||||
|
||||
def parse_annotation(self, annotations):
|
||||
poses = np.zeros((0, len(JointType), 3), dtype=np.int32)
|
||||
|
||||
for ann in annotations:
|
||||
ann_pose = np.array(ann['keypoints']).reshape(-1, 3)
|
||||
pose = np.zeros((1, len(JointType), 3), dtype=np.int32)
|
||||
|
||||
# convert poses position
|
||||
for i, joint_index in enumerate(params['joint_indices']):
|
||||
pose[0][joint_index] = ann_pose[i]
|
||||
|
||||
# compute neck position
|
||||
if pose[0][JointType.LeftShoulder][2] > 0 and pose[0][JointType.RightShoulder][2] > 0:
|
||||
pose[0][JointType.Neck][0] = int((pose[0][JointType.LeftShoulder][0] +
|
||||
pose[0][JointType.RightShoulder][0]) / 2)
|
||||
pose[0][JointType.Neck][1] = int((pose[0][JointType.LeftShoulder][1] +
|
||||
pose[0][JointType.RightShoulder][1]) / 2)
|
||||
pose[0][JointType.Neck][2] = 2
|
||||
|
||||
poses = np.vstack((poses, pose))
|
||||
return poses
|
||||
|
||||
def resize_output(self, input_np, map_h=46, map_w=46):
|
||||
if len(input_np.shape) == 3:
|
||||
output = np.zeros((input_np.shape[0], map_h, map_w))
|
||||
for i in range(input_np.shape[0]):
|
||||
output[i] = cv2.resize(input_np[i], (map_w, map_h))
|
||||
return output.astype('f')
|
||||
|
||||
input_np = input_np.astype('f')
|
||||
output = cv2.resize(input_np, (map_h, map_w))
|
||||
return output
|
||||
|
||||
def generate_labels(self, img, poses, ignore_mask):
|
||||
img, ignore_mask, poses = self.augment_data(img, ignore_mask, poses)
|
||||
resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses,
|
||||
shape=(self.insize, self.insize))
|
||||
|
||||
heatmaps = self.generate_heatmaps(resized_img, resized_poses, params['heatmap_sigma'])
|
||||
pafs = self.generate_pafs(resized_img, resized_poses, params['paf_sigma']) # params['paf_sigma']: 8
|
||||
|
||||
ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool')
|
||||
resized_pafs = self.resize_output(pafs)
|
||||
resized_heatmaps = self.resize_output(heatmaps)
|
||||
resized_ignore_mask = self.resize_output(ignore_mask)
|
||||
|
||||
return resized_img, resized_pafs, resized_heatmaps, resized_ignore_mask
|
||||
|
||||
def preprocess(self, img):
|
||||
x_data = img.astype('f')
|
||||
x_data /= 255
|
||||
x_data -= 0.5
|
||||
x_data = x_data.transpose(2, 0, 1)
|
||||
return x_data
|
||||
|
||||
def __getitem__(self, i):
|
||||
img, img_id, annotations, ignore_mask = self.get_img_annotation(ind=i)
|
||||
|
||||
if self.mode in ['eval', 'val']:
|
||||
# don't need to make heatmaps/pafs
|
||||
return img, np.array([img_id])
|
||||
|
||||
# if no annotations are available
|
||||
while annotations is None:
|
||||
print("none annotations", img_id)
|
||||
img_id = self.imgIds[np.random.randint(len(self))]
|
||||
img, img_id, annotations, ignore_mask = self.get_img_annotation(img_id=img_id)
|
||||
|
||||
poses = self.parse_annotation(annotations)
|
||||
|
||||
# TEST
|
||||
# return img, poses, ignore_mask
|
||||
|
||||
resized_img, pafs, heatmaps, ignore_mask = self.generate_labels(img, poses, ignore_mask)
|
||||
resized_img = self.preprocess(resized_img)
|
||||
ignore_mask = 1. - ignore_mask
|
||||
|
||||
# # TEST
|
||||
# print("Shape: ", resized_img.dtype, " ", pafs.dtype, " ", heatmaps.dtype, " ", ignore_mask.dtype)
|
||||
|
||||
return resized_img, pafs, heatmaps, ignore_mask
|
||||
|
||||
|
||||
class DistributedSampler():
|
||||
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.group_size = group_size
|
||||
self.dataset_len = len(self.dataset)
|
||||
self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
|
||||
self.total_size = self.num_samplers * self.group_size
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.seed = (self.seed + 1) & 0xffffffff
|
||||
np.random.seed(self.seed)
|
||||
indices = np.random.permutation(self.dataset_len).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset_len)))
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = indices[self.rank::self.group_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samplers
|
||||
|
||||
|
||||
def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''):
|
||||
#cv2.setNumThreads(0)
|
||||
val = ReadJson(jsonpath)
|
||||
dataset = txtdataset(val, imgpath, maskpath, params['insize'], mode=mode)
|
||||
sampler = DistributedSampler(dataset, rank, group_size)
|
||||
ds = de.GeneratorDataset(dataset, ['img', 'img_id'], num_parallel_workers=8, sampler=sampler)
|
||||
ds = ds.repeat(1)
|
||||
return ds
|
||||
|
||||
|
||||
def openpose(jsonpath, imgpath, maskpath, per_batch_size, max_epoch, rank, group_size, mode='train'):
|
||||
train = ReadJson(jsonpath)
|
||||
num_parallel = 48
|
||||
if group_size > 1:
|
||||
num_parallel = 20
|
||||
dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode)
|
||||
sampler = DistributedSampler(dataset, rank, group_size)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
|
||||
num_parallel_workers=num_parallel, sampler=sampler, shuffle=True)
|
||||
de_dataset = de_dataset.project(columns=["image", "pafs", "heatmaps", "ignore_mask"])
|
||||
de_dataset = de_dataset.batch(batch_size=per_batch_size, drop_remainder=True, num_parallel_workers=num_parallel)
|
||||
steap_pre_epoch = de_dataset.get_dataset_size()
|
||||
de_dataset = de_dataset.repeat(max_epoch)
|
||||
|
||||
return de_dataset, steap_pre_epoch
|
||||
|
||||
|
||||
def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mode='train', repeat_num=1, shuffle=True,
|
||||
multiprocessing=True, num_worker=20):
|
||||
|
||||
train = ReadJson(jsonpath)
|
||||
dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode)
|
||||
if group_size == 1:
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
|
||||
shuffle=shuffle,
|
||||
num_parallel_workers=num_worker,
|
||||
python_multiprocessing=multiprocessing)
|
||||
else:
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
|
||||
shuffle=shuffle,
|
||||
num_parallel_workers=num_worker,
|
||||
python_multiprocessing=multiprocessing,
|
||||
num_shards=group_size,
|
||||
shard_id=rank)
|
||||
|
||||
de_dataset = de_dataset.batch(batch_size=batch_size, drop_remainder=True)
|
||||
de_dataset = de_dataset.repeat(repeat_num)
|
||||
|
||||
return de_dataset
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pycocotools.coco import COCO as ReadJson
|
||||
|
||||
from config import params
|
||||
|
||||
class DataLoader():
|
||||
def __init__(self, coco, dir_name, data_mode='train'):
|
||||
self.train = coco
|
||||
self.dir_name = dir_name
|
||||
assert data_mode in ['train', 'val'], 'Data loading mode is invalid.'
|
||||
self.mode = data_mode
|
||||
self.catIds = coco.getCatIds() # catNms=['person']
|
||||
self.imgIds = sorted(coco.getImgIds(catIds=self.catIds))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgIds)
|
||||
|
||||
def gen_masks(self, image, anns):
|
||||
_mask_all = np.zeros(image.shape[:2], 'bool')
|
||||
_mask_miss = np.zeros(image.shape[:2], 'bool')
|
||||
for ann in anns:
|
||||
mask = self.train.annToMask(ann).astype('bool')
|
||||
if ann['iscrowd'] == 1:
|
||||
intxn = _mask_all & mask
|
||||
_mask_miss = np.bitwise_or(_mask_miss.astype(int), np.subtract(mask, intxn, dtype=np.int32))
|
||||
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
|
||||
elif ann['num_keypoints'] < params['min_keypoints'] or ann['area'] <= params['min_area']:
|
||||
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
|
||||
_mask_miss = np.bitwise_or(_mask_miss.astype(int), mask.astype(int))
|
||||
else:
|
||||
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
|
||||
return _mask_all, _mask_miss
|
||||
|
||||
def dwaw_gen_masks(self, image, mask, color=(0, 0, 1)):
|
||||
bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
||||
mskd = image * bimsk.astype(np.int32)
|
||||
clmsk = np.ones(bimsk.shape) * bimsk
|
||||
for ind in range(3):
|
||||
clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255
|
||||
image = image + 0.7 * clmsk - 0.7 * mskd
|
||||
return image.astype(np.uint8)
|
||||
|
||||
def draw_masks_and_keypoints(self, image, anns):
|
||||
for ann in anns:
|
||||
# masks
|
||||
mask = self.train.annToMask(ann).astype(np.uint8)
|
||||
if ann['iscrowd'] == 1:
|
||||
color = (0, 0, 1)
|
||||
elif ann['num_keypoints'] == 0:
|
||||
color = (0, 1, 0)
|
||||
else:
|
||||
color = (1, 0, 0)
|
||||
bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
||||
mskd = image * bimsk.astype(np.int32)
|
||||
clmsk = np.ones(bimsk.shape) * bimsk
|
||||
for ind in range(3):
|
||||
clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255
|
||||
image = image + 0.7 * clmsk - 0.7 * mskd
|
||||
|
||||
# keypoints
|
||||
for x, y, v in np.array(ann['keypoints']).reshape(-1, 3):
|
||||
if v == 1:
|
||||
cv2.circle(image, (x, y), 3, (255, 255, 0), -1)
|
||||
elif v == 2:
|
||||
cv2.circle(image, (x, y), 3, (255, 0, 255), -1)
|
||||
return image.astype(np.uint8)
|
||||
|
||||
def get_img_annotation(self, ind=None, image_id=None):
|
||||
if ind is not None:
|
||||
image_id = self.imgIds[ind]
|
||||
|
||||
anno_ids = self.train.getAnnIds(imgIds=[image_id])
|
||||
_annotations = self.train.loadAnns(anno_ids)
|
||||
|
||||
img_file = os.path.join(params['data_dir'], self.dir_name, self.train.loadImgs([image_id])[0]['file_name'])
|
||||
_image = cv2.imread(img_file)
|
||||
return _image, _annotations, image_id
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--vis', action='store_true', help='visualize annotations and ignore masks')
|
||||
parser.add_argument('--train_ann', type=str, help='train annotations json')
|
||||
parser.add_argument('--val_ann', type=str, help='val annotations json')
|
||||
parser.add_argument('--train_dir', type=str, help='name of train dir')
|
||||
parser.add_argument('--val_dir', type=str, help='name of val dir')
|
||||
args = parser.parse_args()
|
||||
path_list = [args.train_ann, args.val_ann, args.train_dir, args.val_dir]
|
||||
for index, mode in enumerate(['train', 'val']):
|
||||
train = ReadJson(path_list[index])
|
||||
data_loader = DataLoader(train, path_list[index+2], mode=mode)
|
||||
|
||||
save_dir = os.path.join(params['data_dir'], 'ignore_mask_{}'.format(mode))
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
for i in tqdm(range(len(data_loader))):
|
||||
img, annotations, img_id = data_loader.get_img_annotation(ind=i)
|
||||
mask_all, mask_miss = data_loader.gen_masks(img, annotations)
|
||||
|
||||
if args.vis:
|
||||
ann_img = data_loader.draw_masks_and_keypoints(img, annotations)
|
||||
msk_img = data_loader.dwaw_gen_masks(img, mask_miss)
|
||||
cv2.imshow('image', np.hstack((ann_img, msk_img)))
|
||||
k = cv2.waitKey()
|
||||
if k == ord('q'):
|
||||
break
|
||||
elif k == ord('s'):
|
||||
cv2.imwrite('aaa.png', np.hstack((ann_img, msk_img)))
|
||||
|
||||
if np.any(mask_miss) and not args.vis:
|
||||
mask_miss = mask_miss.astype(np.uint8) * 255
|
||||
save_path = os.path.join(save_dir, '{:012d}.png'.format(img_id))
|
||||
cv2.imwrite(save_path, mask_miss)
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import time
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
@grad_scale.register("Tensor", "RowTensor")
|
||||
def tensor_grad_scale_row_tensor(scale, grad):
|
||||
return RowTensor(grad.indices,
|
||||
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
|
||||
grad.dense_shape)
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
class openpose_loss(_Loss):
|
||||
def __init__(self):
|
||||
super(openpose_loss, self).__init__()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.tile = P.Tile()
|
||||
self.mul = P.Mul()
|
||||
self.l2_loss = P.L2Loss()
|
||||
self.square = P.Square()
|
||||
self.reduceMean = P.ReduceMean()
|
||||
self.reduceSum = P.ReduceSum()
|
||||
self.print = P.Print()
|
||||
self.shape = P.Shape()
|
||||
self.maxoftensor = P.ArgMaxWithValue(-1)
|
||||
|
||||
def mean_square_error(self, map1, map2, mask=None):
|
||||
# print("mask", mask)
|
||||
# import pdb; pdb.set_trace()
|
||||
if mask is None:
|
||||
mse = self.reduceMean((map1 - map2) ** 2)
|
||||
return mse
|
||||
|
||||
squareMap = self.square(map1 - map2)
|
||||
squareMap_mask = self.mul(squareMap, mask)
|
||||
mse = self.reduceMean(squareMap_mask)
|
||||
return mse
|
||||
|
||||
def construct(self, logit_paf, logit_heatmap, gt_paf, gt_heatmap, ignore_mask):
|
||||
# Input
|
||||
# ignore_mask, make sure the ignore_mask the 0-1 array instead of the bool-false array
|
||||
heatmaps_loss = []
|
||||
pafs_loss = []
|
||||
total_loss = 0
|
||||
|
||||
paf_masks = self.tile(self.expand_dims(ignore_mask, 1), (1, self.shape(gt_paf)[1], 1, 1))
|
||||
heatmap_masks = self.tile(self.expand_dims(ignore_mask, 1), (1, self.shape(gt_heatmap)[1], 1, 1))
|
||||
|
||||
paf_masks = F.stop_gradient(paf_masks)
|
||||
heatmap_masks = F.stop_gradient(heatmap_masks)
|
||||
for logit_paf_t, logit_heatmap_t in zip(logit_paf, logit_heatmap):
|
||||
# TEST
|
||||
# tensor1 -- tuple
|
||||
# tensor1 = self.maxoftensor(logit_paf_t)[1]
|
||||
# tensor2 = self.maxoftensor(logit_heatmap_t)[1]
|
||||
# tensor3 = self.maxoftensor(tensor1)[1]
|
||||
# tensor4 = self.maxoftensor(tensor2)[1]
|
||||
# self.print("paf",tensor3)
|
||||
# self.print("heatmaps",tensor2)
|
||||
pafs_loss_t = self.mean_square_error(logit_paf_t, gt_paf, paf_masks)
|
||||
heatmaps_loss_t = self.mean_square_error(logit_heatmap_t, gt_heatmap, heatmap_masks)
|
||||
|
||||
total_loss += pafs_loss_t + heatmaps_loss_t
|
||||
heatmaps_loss.append(heatmaps_loss_t)
|
||||
pafs_loss.append(pafs_loss_t)
|
||||
|
||||
return total_loss, heatmaps_loss, pafs_loss
|
||||
|
||||
class Depend_network(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Depend_network, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, *args):
|
||||
loss, _, _ = self.network(*args) # loss, heatmaps_loss, pafs_loss
|
||||
return loss
|
||||
|
||||
class TrainingWrapper(nn.Cell):
|
||||
def __init__(self, network, optimizer, sens=1):
|
||||
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.depend_network = Depend_network(network)
|
||||
# self.weights = ms.ParameterTuple(network.trainable_params())
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.print = P.Print()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
#if mean.get_device_num_is_set():
|
||||
# if mean:
|
||||
#degree = context.get_auto_parallel_context("device_num")
|
||||
# else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, *args):
|
||||
weights = self.weights
|
||||
loss, heatmaps_loss, pafs_loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
#grads = self.grad(self.network, weights)(*args, sens)
|
||||
grads = self.grad(self.depend_network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
#return F.depend(loss, self.optimizer(grads))
|
||||
# for grad in grads:
|
||||
# self.print(grad)
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
return loss, heatmaps_loss, pafs_loss
|
||||
|
||||
class BuildTrainNetwork(nn.Cell):
|
||||
def __init__(self, network, criterion):
|
||||
super(BuildTrainNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.criterion = criterion
|
||||
|
||||
def construct(self, input_data, gt_paf, gt_heatmap, mask):
|
||||
logit_pafs, logit_heatmap = self.network(input_data)
|
||||
loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask)
|
||||
return loss
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.count = 0
|
||||
self.loss_sum = 0
|
||||
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = time.time()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
|
||||
self.count += 1
|
||||
self.loss_sum += float(loss)
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if self.count >= 1:
|
||||
global time_stamp_first
|
||||
time_stamp_current = time.time()
|
||||
|
||||
loss = self.loss_sum/self.count
|
||||
|
||||
loss_file = open("./loss.log", "a+")
|
||||
loss_file.write("%lu epoch: %s step: %s ,loss: %.5f" %
|
||||
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
self.count = 0
|
||||
self.loss_sum = 0
|
|
@ -0,0 +1,320 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn import Conv2d, ReLU
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
#selfCat = P.Concat(axis=1)
|
||||
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
loadvgg = 1
|
||||
|
||||
|
||||
class OpenPoseNet(nn.Cell):
|
||||
insize = 368
|
||||
def __init__(self, vggpath=''):
|
||||
super(OpenPoseNet, self).__init__()
|
||||
self.base = Base_model()
|
||||
self.stage_1 = Stage_1()
|
||||
self.stage_2 = Stage_x()
|
||||
self.stage_3 = Stage_x()
|
||||
self.stage_4 = Stage_x()
|
||||
self.stage_5 = Stage_x()
|
||||
self.stage_6 = Stage_x()
|
||||
self.shape = P.Shape()
|
||||
self.cat = P.Concat(axis=1)
|
||||
self.print = P.Print()
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, Conv2d):
|
||||
# init.constant_(m.bias, 0)
|
||||
if loadvgg and vggpath:
|
||||
param_dict = load_checkpoint(vggpath)
|
||||
param_dict_new = {}
|
||||
trans_name = 'base.vgg_base.'
|
||||
for key, values in param_dict.items():
|
||||
|
||||
#print('key:',key,self.shape(values))
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[trans_name+key[17:]] = values
|
||||
# else:
|
||||
# param_dict_new[key] = values
|
||||
#print(param_dict_new)
|
||||
load_param_into_net(self.base.vgg_base, param_dict_new)
|
||||
|
||||
def construct(self, x):
|
||||
heatmaps = []
|
||||
pafs = []
|
||||
feature_map = self.base(x)
|
||||
h1, h2 = self.stage_1(feature_map)
|
||||
pafs.append(h1)
|
||||
heatmaps.append(h2)
|
||||
h1, h2 = self.stage_2(self.cat((h1, h2, feature_map)))
|
||||
pafs.append(h1)
|
||||
heatmaps.append(h2)
|
||||
h1, h2 = self.stage_3(self.cat((h1, h2, feature_map)))
|
||||
pafs.append(h1)
|
||||
heatmaps.append(h2)
|
||||
h1, h2 = self.stage_4(self.cat((h1, h2, feature_map)))
|
||||
pafs.append(h1)
|
||||
heatmaps.append(h2)
|
||||
h1, h2 = self.stage_5(self.cat((h1, h2, feature_map)))
|
||||
pafs.append(h1)
|
||||
heatmaps.append(h2)
|
||||
h1, h2 = self.stage_6(self.cat((h1, h2, feature_map)))
|
||||
pafs.append(h1)
|
||||
heatmaps.append(h2)
|
||||
return pafs, heatmaps
|
||||
|
||||
class Vgg(nn.Cell):
|
||||
def __init__(self, cfg, batch_norm=False):
|
||||
# Important: When choose vgg, batch_size should <=64, otherwise will cause unknown error
|
||||
super(Vgg, self).__init__()
|
||||
self.layers = self._make_layer(cfg, batch_norm=batch_norm)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.layers(x)
|
||||
return x
|
||||
|
||||
def _make_layer(self, cfg, batch_norm=False):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for v in cfg:
|
||||
if v == 'M':
|
||||
# layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')]
|
||||
else:
|
||||
conv2d = Conv2d(in_channels=in_channels,
|
||||
out_channels=v,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
# padding=1,
|
||||
pad_mode='same',
|
||||
has_bias=True)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU()]
|
||||
in_channels = v
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
class VGG_Base(nn.Cell):
|
||||
def __init__(self):
|
||||
super(VGG_Base, self).__init__()
|
||||
self.conv1_1 = Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv1_2 = Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.conv2_1 = Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv2_2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.conv3_1 = Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv3_2 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv3_3 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv3_4 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.conv4_1 = Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv4_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.relu = ReLU()
|
||||
self.max_pooling_2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.conv1_1(x))
|
||||
x = self.relu(self.conv1_2(x))
|
||||
x = self.max_pooling_2d(x)
|
||||
x = self.relu(self.conv2_1(x))
|
||||
x = self.relu(self.conv2_2(x))
|
||||
x = self.max_pooling_2d(x)
|
||||
x = self.relu(self.conv3_1(x))
|
||||
x = self.relu(self.conv3_2(x))
|
||||
x = self.relu(self.conv3_3(x))
|
||||
x = self.relu(self.conv3_4(x))
|
||||
x = self.max_pooling_2d(x)
|
||||
x = self.relu(self.conv4_1(x))
|
||||
x = self.relu(self.conv4_2(x))
|
||||
return x
|
||||
|
||||
class VGG_Base_MS(nn.Cell):
|
||||
def __init__(self):
|
||||
super(VGG_Base_MS, self).__init__()
|
||||
self.Layer1_1 = Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.Layer1_2 = Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.Layer2_1 = Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.Layer2_2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.Layer3_1 = Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.Layer3_2 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.Layer3_3 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.Layer3_4 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.Layer4_1 = Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.Layer4_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.relu = ReLU()
|
||||
self.max_pooling_2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.Layer1_1(x))
|
||||
x = self.relu(self.Layer1_2(x))
|
||||
x = self.max_pooling_2d(x)
|
||||
x = self.relu(self.Layer2_1(x))
|
||||
x = self.relu(self.Layer2_2(x))
|
||||
x = self.max_pooling_2d(x)
|
||||
x = self.relu(self.Layer3_1(x))
|
||||
x = self.relu(self.Layer3_2(x))
|
||||
x = self.relu(self.Layer3_3(x))
|
||||
x = self.relu(self.Layer3_4(x))
|
||||
x = self.max_pooling_2d(x)
|
||||
x = self.relu(self.Layer4_1(x))
|
||||
x = self.relu(self.Layer4_2(x))
|
||||
return x
|
||||
|
||||
class Base_model(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Base_model, self).__init__()
|
||||
#cfgs_zh = {'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512]}
|
||||
cfgs_zh = {'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512]}
|
||||
#cfgs_zh = {'16': [64, 64,128, 128, 256, 256, 256, 512, 512, 512]}
|
||||
self.vgg_base = Vgg(cfgs_zh['19'], batch_norm=False)
|
||||
#self.vgg_base = VGG_Base()
|
||||
#self.vgg_base = VGG_Base_MS()
|
||||
|
||||
self.conv4_3_CPM = Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv4_4_CPM = Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.relu = ReLU()
|
||||
def construct(self, x):
|
||||
x = self.vgg_base(x)
|
||||
x = self.relu(self.conv4_3_CPM(x))
|
||||
x = self.relu(self.conv4_4_CPM(x))
|
||||
return x
|
||||
|
||||
class Stage_1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Stage_1, self).__init__()
|
||||
|
||||
self.conv1_CPM_L1 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv2_CPM_L1 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv3_CPM_L1 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv4_CPM_L1 = Conv2d(in_channels=128, out_channels=512, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv5_CPM_L1 = Conv2d(in_channels=512, out_channels=38, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.conv1_CPM_L2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv2_CPM_L2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv3_CPM_L2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv4_CPM_L2 = Conv2d(in_channels=128, out_channels=512, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv5_CPM_L2 = Conv2d(in_channels=512, out_channels=19, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.relu = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
h1 = self.relu(self.conv1_CPM_L1(x)) # branch1
|
||||
h1 = self.relu(self.conv2_CPM_L1(h1))
|
||||
h1 = self.relu(self.conv3_CPM_L1(h1))
|
||||
h1 = self.relu(self.conv4_CPM_L1(h1))
|
||||
h1 = self.conv5_CPM_L1(h1)
|
||||
|
||||
h2 = self.relu(self.conv1_CPM_L2(x)) # branch2
|
||||
h2 = self.relu(self.conv2_CPM_L2(h2))
|
||||
h2 = self.relu(self.conv3_CPM_L2(h2))
|
||||
h2 = self.relu(self.conv4_CPM_L2(h2))
|
||||
h2 = self.conv5_CPM_L2(h2)
|
||||
return h1, h2
|
||||
|
||||
class Stage_x(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Stage_x, self).__init__()
|
||||
self.conv1_L1 = Conv2d(in_channels=185, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.conv2_L1 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv3_L1 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv4_L1 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv5_L1 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv6_L1 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv7_L1 = Conv2d(in_channels=128, out_channels=38, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
|
||||
self.conv1_L2 = Conv2d(in_channels=185, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv2_L2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv3_L2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv4_L2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv5_L2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv6_L2 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.conv7_L2 = Conv2d(in_channels=128, out_channels=19, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=True)
|
||||
self.relu = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
h1 = self.relu(self.conv1_L1(x)) # branch1
|
||||
h1 = self.relu(self.conv2_L1(h1))
|
||||
h1 = self.relu(self.conv3_L1(h1))
|
||||
h1 = self.relu(self.conv4_L1(h1))
|
||||
h1 = self.relu(self.conv5_L1(h1))
|
||||
h1 = self.relu(self.conv6_L1(h1))
|
||||
h1 = self.conv7_L1(h1)
|
||||
h2 = self.relu(self.conv1_L2(x)) # branch2
|
||||
h2 = self.relu(self.conv2_L2(h2))
|
||||
h2 = self.relu(self.conv3_L2(h2))
|
||||
h2 = self.relu(self.conv4_L2(h2))
|
||||
h2 = self.relu(self.conv5_L2(h2))
|
||||
h2 = self.relu(self.conv6_L2(h2))
|
||||
h2 = self.conv7_L2(h2)
|
||||
return h1, h2
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
from src.config import params
|
||||
|
||||
class MyLossMonitor(LossMonitor):
|
||||
def __init__(self, per_print_times=1):
|
||||
super(MyLossMonitor, self).__init__()
|
||||
self._per_print_times = per_print_times
|
||||
self._start_time = time.time()
|
||||
self._loss_list = []
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
loss = loss[0]
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
# print("epoch: %s step: %s, loss is %s, step time: %.3f s." % (cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
# loss,
|
||||
# (time.time() - self._start_time)), flush=True)
|
||||
self._loss_list.append(loss)
|
||||
if cb_params.cur_step_num % 100 == 0:
|
||||
print("epoch: %s, steps: [%s] mean loss is: %s"%(cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
np.array(self._loss_list).mean()), flush=True)
|
||||
self._loss_list = []
|
||||
|
||||
self._start_time = time.time()
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse train arguments."""
|
||||
parser = argparse.ArgumentParser('mindspore openpose training')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir')
|
||||
parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json',
|
||||
help='train annotations json')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann)
|
||||
args.imgpath_train = os.path.join(params['data_dir'], args.train_dir)
|
||||
args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train')
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_lr(lr, lr_gamma, steps_per_epoch, max_epoch_train, lr_steps, group_size):
|
||||
lr_stage = np.array([lr] * steps_per_epoch * max_epoch_train).astype('f')
|
||||
for step in lr_steps:
|
||||
step //= group_size
|
||||
lr_stage[step:] *= lr_gamma
|
||||
|
||||
lr_base = lr_stage.copy()
|
||||
lr_base = lr_base / 4
|
||||
|
||||
lr_vgg = lr_base.copy()
|
||||
vgg_freeze_step = 2000
|
||||
lr_vgg[:vgg_freeze_step] = 0
|
||||
return lr_stage, lr_base, lr_vgg
|
||||
|
||||
# zhang add
|
||||
def adjust_learning_rate(init_lr, lr_gamma, steps_per_epoch, max_epoch_train, stepvalues):
|
||||
lr_stage = np.array([init_lr] * steps_per_epoch * max_epoch_train).astype('f')
|
||||
for epoch in stepvalues:
|
||||
lr_stage[epoch * steps_per_epoch:] *= lr_gamma
|
||||
|
||||
lr_base = lr_stage.copy()
|
||||
lr_base = lr_base / 4
|
||||
|
||||
lr_vgg = lr_base.copy()
|
||||
vgg_freeze_step = 2000
|
||||
lr_vgg[:vgg_freeze_step] = 0
|
||||
return lr_stage, lr_base, lr_vgg
|
||||
|
||||
|
||||
def load_model(test_net, model_path):
|
||||
if model_path:
|
||||
param_dict = load_checkpoint(model_path)
|
||||
# print(type(param_dict))
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
# print('key:', key)
|
||||
if key.startswith('moment'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
|
||||
# else:
|
||||
# param_dict_new[key] = values
|
||||
load_param_into_net(test_net, param_dict_new)
|
||||
|
||||
|
||||
class show_loss_list():
|
||||
def __init__(self, name):
|
||||
self.loss_list = np.zeros(6).astype('f')
|
||||
self.sums = 0
|
||||
self.name = name
|
||||
|
||||
def add(self, list_of_tensor):
|
||||
self.sums += 1
|
||||
for i, loss_tensor in enumerate(list_of_tensor):
|
||||
self.loss_list[i] += loss_tensor.asnumpy()
|
||||
|
||||
def show(self):
|
||||
print(self.name + ' stage_loss:', self.loss_list / (self.sums + 1e-8), flush=True)
|
||||
self.loss_list = np.zeros(6).astype('f')
|
||||
self.sums = 0
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
def __init__(self):
|
||||
self.loss = 0
|
||||
self.sum = 0
|
||||
|
||||
def add(self, tensor):
|
||||
self.sum += 1
|
||||
self.loss += tensor.asnumpy()
|
||||
|
||||
def meter(self):
|
||||
avergeLoss = self.loss / (self.sum + 1e-8)
|
||||
self.loss = 0
|
||||
self.sum = 0
|
||||
return avergeLoss
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.nn.optim import Adam
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.openposenet import OpenPoseNet
|
||||
from src.loss import openpose_loss, BuildTrainNetwork
|
||||
from src.config import params
|
||||
from src.utils import parse_args, get_lr, load_model, MyLossMonitor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
|
||||
args.outputs_dir = params['save_model_path']
|
||||
|
||||
if args.group_size > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_{}/".format(str(get_rank())))
|
||||
args.rank = get_rank()
|
||||
else:
|
||||
args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_0/")
|
||||
args.rank = 0
|
||||
|
||||
# with out loss_scale
|
||||
if args.group_size > 1:
|
||||
args.loss_scale = params['loss_scale'] / 2
|
||||
args.lr_steps = list(map(int, params["lr_steps_NP"].split(',')))
|
||||
else:
|
||||
args.loss_scale = params['loss_scale']
|
||||
args.lr_steps = list(map(int, params["lr_steps"].split(',')))
|
||||
|
||||
# create network
|
||||
print('start create network')
|
||||
criterion = openpose_loss()
|
||||
criterion.add_flags_recursive(fp32=True)
|
||||
network = OpenPoseNet(vggpath=params['vgg_path'])
|
||||
# network.add_flags_recursive(fp32=True)
|
||||
|
||||
if params["load_pretrain"]:
|
||||
print("load pretrain model:", params["pretrained_model_path"])
|
||||
load_model(network, params["pretrained_model_path"])
|
||||
train_net = BuildTrainNetwork(network, criterion)
|
||||
|
||||
# create dataset
|
||||
if os.path.exists(args.jsonpath_train) and os.path.exists(args.imgpath_train) \
|
||||
and os.path.exists(args.maskpath_train):
|
||||
print('start create dataset')
|
||||
else:
|
||||
print('Error: wrong data path')
|
||||
|
||||
|
||||
num_worker = 20 if args.group_size > 1 else 48
|
||||
de_dataset_train = create_dataset(args.jsonpath_train, args.imgpath_train, args.maskpath_train,
|
||||
batch_size=params['batch_size'],
|
||||
rank=args.rank,
|
||||
group_size=args.group_size,
|
||||
num_worker=num_worker,
|
||||
multiprocessing=True,
|
||||
shuffle=True,
|
||||
repeat_num=1)
|
||||
steps_per_epoch = de_dataset_train.get_dataset_size()
|
||||
print("steps_per_epoch: ", steps_per_epoch)
|
||||
|
||||
# lr scheduler
|
||||
lr_stage, lr_base, lr_vgg = get_lr(params['lr'] * args.group_size,
|
||||
params['lr_gamma'],
|
||||
steps_per_epoch,
|
||||
params["max_epoch_train"],
|
||||
args.lr_steps,
|
||||
args.group_size)
|
||||
vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params()))
|
||||
base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params()))
|
||||
stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params()))
|
||||
|
||||
group_params = [{'params': vgg19_base_params, 'lr': lr_vgg},
|
||||
{'params': base_params, 'lr': lr_base},
|
||||
{'params': stages_params, 'lr': lr_stage}]
|
||||
|
||||
opt = Adam(group_params, loss_scale=args.loss_scale)
|
||||
|
||||
train_net.set_train(True)
|
||||
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
||||
|
||||
model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager)
|
||||
|
||||
params['ckpt_interval'] = max(steps_per_epoch, params['ckpt_interval'])
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=params['ckpt_interval'],
|
||||
keep_checkpoint_max=params["keep_checkpoint_max"])
|
||||
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(args.rank), directory=args.outputs_dir, config=config_ck)
|
||||
time_cb = TimeMonitor(data_size=de_dataset_train.get_dataset_size())
|
||||
callback_list = [MyLossMonitor(), time_cb, ckpoint_cb]
|
||||
print("============== Starting Training ==============")
|
||||
model.train(params["max_epoch_train"], de_dataset_train, callbacks=callback_list,
|
||||
dataset_sink_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# mindspore.common.seed.set_seed(1)
|
||||
train()
|
Loading…
Reference in New Issue