add simplepose implementation

This commit is contained in:
rmdyh 2020-12-25 17:19:49 +08:00
parent ef0b483eb4
commit 15d8cccd7b
14 changed files with 1898 additions and 0 deletions

View File

@ -0,0 +1,340 @@
# Contents
- [Description](#description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Dataset Preparation](#dataset-preparation)
- [Model Checkpoints](#model-checkpoints)
- [Running](#running)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Description](#contents)
SimplePoseNet is a convolution-based neural network for the task of human pose estimation and tracking. It provides baseline methods that are surprisingly simple and effective, thus helpful for inspiring and evaluating new ideas for the field. State-of-the-art results are achieved on challenging benchmarks. More detail about this model can be found in:
B. Xiao, H. Wu, and Y. Wei, “Simple baselines for human pose estimation and tracking,” in Proc. Eur. Conf. Comput. Vis., 2018, pp. 472487.
This repository contains a Mindspore implementation of SimplePoseNet based upon Microsoft's original Pytorch implementation (<https://github.com/microsoft/human-pose-estimation.pytorch>). The training and validating scripts are also included, and the evaluation results are shown in the [Performance](#performance) section.
# [Model Architecture](#contents)
The overall network architecture of SimplePoseNet is shown below:
[Link](https://arxiv.org/pdf/1804.06208.pdf)
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: COCO2017
- Dataset size:
- Train: 19G, 118,287 images
- Test: 788MB, 5,000 images
- Data format: JPG images
- Note: Data will be processed in `src/dataset.py`
- Person detection result for validation: Detection result provided by author in the [repository](https://github.com/microsoft/human-pose-estimation.pytorch)
# [Features](#contents)
## [Mixed Precision](#contents)
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
To run the python scripts in the repository, you need to prepare the environment as follow:
- Hardware
- Prepare hardware environment with Ascend. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to [ascend@huawei.com](mailto:ascend@huawei.com). Once approved, you can get the resources.
- Python and dependencies
- python 3.7
- mindspore 1.0.1
- easydict 1.9
- opencv-python 4.3.0.36
- pycocotools 2.0
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Quick Start](#contents)
## [Dataset Preparation](#contents)
SimplePoseNet use COCO2017 dataset to train and validate in this repository. Download the dataset from [official website](https://cocodataset.org/). You can place the dataset anywhere and tell the scripts where it is by modifying the `DATASET.ROOT` setting in configuration file `src/config.py`. For more information about the configuration file, please refer to [Script Parameters](#script-parameters).
You also need the person detection result of COCO val2017 to reproduce the multi-person pose estimation results, as mentioned in [Dataset](#dataset). Please checkout the author's repository, download and extract them under `<ROOT>/experiments/`, and make them look like this:
```text
└─ <ROOT>
└─ experiments
└─ COCO_val2017_detections_AP_H_56_person.json
```
## [Model Checkpoints](#contents)
Before you start your training process, you need to obtain mindspore imagenet pretrained models. The model weight file can be obtained by running the Resnet training script in [official model zoo](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet). We also provide a pretrained model that can be used to train SimplePoseNet directly in [GoogleDrive](https://drive.google.com/file/d/1r3Hs0QNys0HyNtsQhSvx6IKdyRkC-3Hh/view?usp=sharing). The model file should be placed under `<ROOT>/models/` like this:
```text
└─ <ROOT>
└─ models
└─resnet50.ckpt
```
## [Running](#contents)
To train the model, run the shell script `scripts/train_standalone.sh` with the format below:
```shell
sh scripts/train_standalone.sh [device_id] [ckpt_path_to_save]
```
To validate the model, change the settings in `src/config.py` to the path of the model you want to validate. For example:
```python
config.TEST.MODEL_FILE='results/xxxx.ckpt'
```
Then, run the shell script `scripts/eval.sh` with the format below:
```shell
sh scripts/eval.sh [device_id]
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
The structure of the files in this repository is shown below.
```text
└─ mindspore-simpleposenet
├─ scripts
│ ├─ eval.sh // launch ascend standalone evaluation
│ ├─ train_distributed.sh // launch ascend distributed training
│ └─ train_standalone.sh // launch ascend standalone training
├─ src
│ ├─utils
│ │ ├─ transform.py // utils about image transformation
│ │ └─ nms.py // utils about nms
│ ├─evaluate
│ │ └─ coco_eval.py // evaluate result by coco
│ ├─ config.py // network and running config
│ ├─ dataset.py // dataset processor and provider
│ ├─ model.py // SimplePoseNet implementation
│ ├─ network_define.py // define loss
│ └─ predict.py // predict keypoints from heatmaps
├─ eval.py // evaluation script
├─ param_convert.py // model parameters convertion script
├─ train.py // training script
└─ README.md // descriptions about this repository
```
## [Script Parameters](#contents)
Configurations for both training and evaluation are set in `src/config.py`. All the settings are shown following.
- config for SimplePoseNet on COCO2017 dataset:
```python
# pose_resnet related params
POSE_RESNET.HEATMAP_SIZE = [48, 64] # heatmap size
POSE_RESNET.SIGMA = 2 # Gaussian hyperparameter in heatmap generation
POSE_RESNET.FINAL_CONV_KERNEL = 1 # final convolution kernel size
POSE_RESNET.DECONV_WITH_BIAS = False # deconvolution bias
POSE_RESNET.NUM_DECONV_LAYERS = 3 # the number of deconvolution layers
POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256] # the filter size of deconvolution layers
POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4] # kernel size of deconvolution layers
POSE_RESNET.NUM_LAYERS = 50 # number of layers(for resnet)
# common params for NETWORK
config.MODEL.NAME = 'pose_resnet' # model name
config.MODEL.INIT_WEIGHTS = True # init model weights by resnet
config.MODEL.PRETRAINED = './models/resnet50.ckpt' # pretrained model
config.MODEL.NUM_JOINTS = 17 # the number of keypoints
config.MODEL.IMAGE_SIZE = [192, 256] # image size
# dataset
config.DATASET.ROOT = '/data/coco2017/' # coco2017 dataset root
config.DATASET.TEST_SET = 'val2017' # folder name of test set
config.DATASET.TRAIN_SET = 'train2017' # folder name of train set
# data augmentation
config.DATASET.FLIP = True # random flip
config.DATASET.ROT_FACTOR = 40 # random rotation
config.DATASET.SCALE_FACTOR = 0.3 # random scale
# for train
config.TRAIN.BATCH_SIZE = 64 # batch size
config.TRAIN.BEGIN_EPOCH = 0 # begin epoch
config.TRAIN.END_EPOCH = 140 # end epoch
config.TRAIN.LR = 0.001 # initial learning rate
config.TRAIN.LR_FACTOR = 0.1 # learning rate reduce factor
config.TRAIN.LR_STEP = [90,120] # step to reduce lr
# test
config.TEST.BATCH_SIZE = 32 # batch size
config.TEST.FLIP_TEST = True # flip test
config.TEST.POST_PROCESS = True # post process
config.TEST.SHIFT_HEATMAP = True # shift heatmap
config.TEST.USE_GT_BBOX = False # use groundtruth bbox
config.TEST.MODEL_FILE = '' # model file to test
# detect bbox file
config.TEST.COCO_BBOX_FILE = 'experiments/COCO_val2017_detections_AP_H_56_person.json'
# nms
config.TEST.OKS_THRE = 0.9 # oks threshold
config.TEST.IN_VIS_THRE = 0.2 # visible threshold
config.TEST.BBOX_THRE = 1.0 # bbox threshold
config.TEST.IMAGE_THRE = 0.0 # image threshold
config.TEST.NMS_THRE = 1.0 # nms threshold
```
## [Training Process](#contents)
### [Training](#contents)
#### Running on Ascend
Run `scripts/train_standalone.sh` to train the model standalone. The usage of the script is:
```shell
sh scripts/train_standalone.sh [device_id] [ckpt_path_to_save]
```
For example, you can run the shell command below to launch the training procedure.
```shell
sh scripts/train_standalone.sh 0 results/standalone/
```
The script will run training in the background, you can view the results through the file `train_log[X].txt` as follows:
```text
loading parse...
batch size :128
loading dataset from /data/coco2017/train2017
loaded 149813 records from coco dataset.
loading pretrained model ./models/resnet50.ckpt
start training, epoch size = 140
epoch: 1 step: 1170, loss is 0.000699
Epoch time: 492271.194, per step time: 420.745
epoch: 2 step: 1170, loss is 0.000586
Epoch time: 456265.617, per step time: 389.971
...
```
The model checkpoint will be saved into `[ckpt_path_to_save]`.
### [Distributed Training](#contents)
#### Running on Ascend
Run `scripts/train_distributed.sh` to train the model distributed. The usage of the script is:
```shell
sh scripts/train_distributed.sh [rank_table] [ckpt_path_to_save] [device_number]
```
For example, you can run the shell command below to launch the distributed training procedure.
```shell
sh scripts/train_distributed.sh /home/rank_table.json results/distributed/ 4
```
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt` as follows:
```text
loading parse...
batch size :64
loading dataset from /data/coco2017/train2017
loaded 149813 records from coco dataset.
loading pretrained model ./models/resnet50.ckpt
start training, epoch size = 140
epoch: 1 step: 585, loss is 0.0007944
Epoch time: 236219.684, per step time: 403.794
epoch: 2 step: 585, loss is 0.000617
Epoch time: 164792.001, per step time: 281.696
...
```
The model checkpoint will be saved into `[ckpt_path_to_save]`.
## [Evaluation Process](#contents)
### Running on Ascend
Change the settings in `src/config.py` to the path of the model you want to validate. For example:
```python
config.TEST.MODEL_FILE='results/xxxx.ckpt'
```
Then, run `scripts/eval.sh` to evaluate the model with one Ascend processor. The usage of the script is:
```shell
sh scripts/eval.sh [device_id]
```
For example, you can run the shell command below to launch the validation procedure.
```shell
sh scripts/eval.sh 0
```
The above shell command will run validation procedure in the background. You can view the results through the file `eval_log[X].txt`. The result will be achieved as follows:
```text
use flip test: True
loading model ckpt from results/distributed/sim-140_1170.ckpt
loading dataset from /data/coco2017/val2017
loading bbox file from experiments/COCO_val2017_detections_AP_H_56_person.json
Total boxes: 104125
1024 samples validated in 18.133189916610718 seconds
2048 samples validated in 4.724390745162964 seconds
...
```
# [Model Description](#contents)
## [Performance](#contents)
### SimplePoseNet on COCO2017 with detector
#### Performance parameters
| Parameters | Standalone | Distributed |
| ------------------- | --------------------------- | --------------------------- |
| Model Version | SimplePoseNet | SimplePoseNet |
| Resource | Ascend 910 | 4 Ascend 910 cards |
| Uploaded Date | 12/18/2020 (month/day/year) | 12/18/2020 (month/day/year) |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | COCO2017 | COCO2017 |
| Training Parameters | epoch=140, batch_size=128 | epoch=140, batch_size=64 |
| Optimizer | Adam | Adam |
| Loss Function | Mean Squared Error | Mean Squared Error |
| Outputs | heatmap | heatmap |
| Train Performance | mAP: 70.4 | mAP: 70.4 |
| Speed | 1pc: 389.915 ms/step | 4pc: 281.356 ms/step |
#### Note
- Flip test is used.
- Person detector has person AP of 56.4 on COCO val2017 dataset.
- The dataset preprocessing and general training configurations are shown in [Script Parameters](#script-parameters) section.
# [Description of Random Situation](#contents)
In `src/dataset.py`, we set the seed inside “create_dataset" function. We also use random seed in `src/model.py` to initial network weights.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,180 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import time
import numpy as np
from mindspore import Tensor, float32, context
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config
from src.dataset import flip_pairs, keypoint_dataset
from src.evaluate.coco_eval import evaluate
from src.model import get_pose_net
from src.utils.transform import flip_back
from src.predict import get_final_preds
def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
parser.add_argument("--train_url", type=str, default="", help="")
parser.add_argument("--data_url", type=str, default="", help="data")
# output
parser.add_argument('--output-url',
help='output dir',
type=str)
# training
parser.add_argument('--workers',
help='num of dataloader workers',
default=8,
type=int)
parser.add_argument('--model-file',
help='model state file',
type=str)
parser.add_argument('--use-detect-bbox',
help='use detect bbox',
action='store_true')
parser.add_argument('--flip-test',
help='use flip test',
default=True,
action='store_true')
parser.add_argument('--post-process',
help='use post process',
action='store_true')
parser.add_argument('--shift-heatmap',
help='shift heatmap',
action='store_true')
parser.add_argument('--coco-bbox-file',
help='coco detection bbox file',
type=str)
args = parser.parse_args()
return args
def reset_config(cfg, args):
if args.use_detect_bbox:
cfg.TEST.USE_GT_BBOX = not args.use_detect_bbox
if args.flip_test:
cfg.TEST.FLIP_TEST = args.flip_test
print('use flip test:', cfg.TEST.FLIP_TEST)
if args.post_process:
cfg.TEST.POST_PROCESS = args.post_process
if args.shift_heatmap:
cfg.TEST.SHIFT_HEATMAP = args.shift_heatmap
if args.model_file:
cfg.TEST.MODEL_FILE = args.model_file
if args.coco_bbox_file:
cfg.TEST.COCO_BBOX_FILE = args.coco_bbox_file
def validate(cfg, val_dataset, model, output_dir):
# switch to evaluate mode
model.set_train(False)
# init record
num_samples = val_dataset.get_dataset_size() * cfg.TEST.BATCH_SIZE
all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3),
dtype=np.float32)
all_boxes = np.zeros((num_samples, 2))
image_id = []
idx = 0
# start eval
start = time.time()
for item in val_dataset.create_dict_iterator():
# input data
inputs = item['image'].asnumpy()
# compute output
output = model(Tensor(inputs, float32)).asnumpy()
if cfg.TEST.FLIP_TEST:
inputs_flipped = Tensor(inputs[:, :, :, ::-1], float32)
output_flipped = model(inputs_flipped)
output_flipped = flip_back(output_flipped.asnumpy(), flip_pairs)
# feature is not aligned, shift flipped heatmap for higher accuracy
if cfg.TEST.SHIFT_HEATMAP:
output_flipped[:, :, :, 1:] = \
output_flipped.copy()[:, :, :, 0:-1]
# output_flipped[:, :, :, 0] = 0
output = (output + output_flipped) * 0.5
# meta data
c = item['center'].asnumpy()
s = item['scale'].asnumpy()
score = item['score'].asnumpy()
file_id = list(item['id'].asnumpy())
# pred by heatmaps
preds, maxvals = get_final_preds(cfg, output.copy(), c, s)
num_images, _ = preds.shape[:2]
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
all_preds[idx:idx + num_images, :, 2:3] = maxvals
# double check this all_boxes parts
all_boxes[idx:idx + num_images, 0] = np.prod(s * 200, 1)
all_boxes[idx:idx + num_images, 1] = score
image_id.extend(file_id)
idx += num_images
if idx % 1024 == 0:
print('{} samples validated in {} seconds'.format(idx, time.time() - start))
start = time.time()
print(all_preds[:idx].shape, all_boxes[:idx].shape, len(image_id))
_, perf_indicator = evaluate(
cfg, all_preds[:idx], output_dir, all_boxes[:idx], image_id)
print("AP:", perf_indicator)
return perf_indicator
def main():
# init seed
set_seed(1)
# set context
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False, device_id=device_id)
args = parse_args()
# update config
reset_config(config, args)
# init model
model = get_pose_net(config, is_train=False)
# load parameters
ckpt_name = config.TEST.MODEL_FILE
print('loading model ckpt from {}'.format(ckpt_name))
load_param_into_net(model, load_checkpoint(ckpt_name))
# Data loading code
valid_dataset, _ = keypoint_dataset(
config,
bbox_file=config.TEST.COCO_BBOX_FILE,
train_mode=False,
num_parallel_workers=args.workers,
)
# evaluate on validation set
validate(config, valid_dataset, model, ckpt_name.split('.')[0])
if __name__ == '__main__':
main()

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
python eval.py > eval_log$1.txt 2>&1 &

View File

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Usage: sh train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [SAVE_CKPT_PATH] [RANK_SIZE]
export RANK_TABLE_FILE=$1
echo "RANK_TABLE_FILE=$RANK_TABLE_FILE"
export RANK_SIZE=$3
SAVE_PATH=$2
device=(0 1 2 3)
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=${device[$i]}
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
echo "start training for rank $i, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
cd ../
python train.py \
--run-distribute \
--ckpt-path=$SAVE_PATH > train_parallel$i/log.txt 2>&1 &
echo "python train.py \
--run-distribute \
--ckpt-path=$SAVE_PATH > train_parallel$i/log.txt 2>&1 &"
done

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Usage: train_standalone.sh [DEVICE_ID] [SAVE_CKPT_PATH]
export DEVICE_ID=$1
python train.py \
--ckpt-path=$2 --batch-size=128\
> train_log$1.txt 2>&1 &
echo " python train.py --ckpt-path=$2 --batch-size=128 > train_log$1.txt 2>&1 &"

View File

@ -0,0 +1,77 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from easydict import EasyDict as edict
config = edict()
# pose_resnet related params
POSE_RESNET = edict()
POSE_RESNET.NUM_LAYERS = 50
POSE_RESNET.DECONV_WITH_BIAS = False
POSE_RESNET.NUM_DECONV_LAYERS = 3
POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
POSE_RESNET.FINAL_CONV_KERNEL = 1
POSE_RESNET.TARGET_TYPE = 'gaussian'
POSE_RESNET.HEATMAP_SIZE = [48, 64] # width * height, ex: 24 * 32
POSE_RESNET.SIGMA = 2
MODEL_EXTRAS = {
'pose_resnet': POSE_RESNET,
}
# common params for NETWORK
config.MODEL = edict()
config.MODEL.NAME = 'pose_resnet'
config.MODEL.INIT_WEIGHTS = True
config.MODEL.PRETRAINED = './models/resnet50.ckpt'
config.MODEL.NUM_JOINTS = 17
config.MODEL.IMAGE_SIZE = [192, 256] # width * height, ex: 192 * 256
config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]
# dataset
config.DATASET = edict()
config.DATASET.ROOT = '/data/coco2017/'
config.DATASET.TEST_SET = 'val2017'
config.DATASET.TRAIN_SET = 'train2017'
# data augmentation
config.DATASET.FLIP = True
config.DATASET.ROT_FACTOR = 40
config.DATASET.SCALE_FACTOR = 0.3
# for train
config.TRAIN = edict()
config.TRAIN.BATCH_SIZE = 64
config.TRAIN.BEGIN_EPOCH = 0
config.TRAIN.END_EPOCH = 140
config.TRAIN.LR = 0.001
config.TRAIN.LR_FACTOR = 0.1
config.TRAIN.LR_STEP = [90, 120]
# test
config.TEST = edict()
config.TEST.BATCH_SIZE = 32
config.TEST.FLIP_TEST = True
config.TEST.POST_PROCESS = True
config.TEST.SHIFT_HEATMAP = True
config.TEST.USE_GT_BBOX = False
config.TEST.MODEL_FILE = ''
config.TEST.COCO_BBOX_FILE = 'experiments/COCO_val2017_detections_AP_H_56_person.json'
# nms
config.TEST.OKS_THRE = 0.9
config.TEST.IN_VIS_THRE = 0.2
config.TEST.BBOX_THRE = 1.0
config.TEST.IMAGE_THRE = 0.0
config.TEST.NMS_THRE = 1.0

View File

@ -0,0 +1,378 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import json
import os
from copy import deepcopy
import random
import numpy as np
import cv2
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as V_C
from src.utils.transform import fliplr_joints, get_affine_transform, affine_transform
de.config.set_seed(1)
flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16]]
class KeypointDatasetGenerator:
def __init__(self, cfg, is_train=False):
# config file
self.image_thre = cfg.TEST.IMAGE_THRE
self.image_size = np.array(cfg.MODEL.IMAGE_SIZE, dtype=np.int32)
self.image_width = cfg.MODEL.IMAGE_SIZE[0]
self.image_height = cfg.MODEL.IMAGE_SIZE[1]
self.aspect_ratio = self.image_width * 1.0 / self.image_height
self.heatmap_size = np.array(cfg.MODEL.EXTRA.HEATMAP_SIZE, dtype=np.int32)
self.sigma = cfg.MODEL.EXTRA.SIGMA
self.target_type = cfg.MODEL.EXTRA.TARGET_TYPE
# data argumentation
self.scale_factor = cfg.DATASET.SCALE_FACTOR
self.rotation_factor = cfg.DATASET.ROT_FACTOR
self.flip = cfg.DATASET.FLIP
# dataset informations
self.db = []
self.is_train = is_train
# for coco dataset
self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16]]
self.num_joints = 17
def load_gt_dataset(self, image_path, ann_file):
# reset db
self.db = []
# load json file and decode
with open(ann_file, "rb") as f:
lines = f.readlines()
json_dict = json.loads(lines[0].decode("utf-8"))
# traversal all the ann items
objs = {}
cnt = 0
for item in json_dict['annotations']:
# exclude iscrowd and no-keypoint record
if item['iscrowd'] != 0 or item['num_keypoints'] == 0:
continue
# assert the record is valid
assert item['iscrowd'] == 0, 'is crowd'
assert item['category_id'] == 1, 'is not people'
assert item['area'] > 0, 'area le 0'
assert item['num_keypoints'] > 0, 'has no keypoint'
assert max(item['keypoints']) > 0
image_id = item['image_id']
obj = [{'num_keypoints': item['num_keypoints'], 'keypoints': item['keypoints'], 'bbox': item['bbox']}]
objs[image_id] = obj if image_id not in objs else objs[image_id] + obj
cnt += 1
print('loaded %d records from coco dataset.' % cnt)
# traversal all the image items
for item in json_dict['images']:
image_id = item['id']
width = item['width']
height = item['height']
# exclude image not in records
if image_id not in objs:
continue
# sanitize bboxes
valid_objs = []
for obj in objs[image_id]:
x, y, w, h = obj['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(width - 1, x1 + max(0, w - 1))
y2 = min(height - 1, y1 + max(0, h - 1))
if x2 >= x1 and y2 >= y1:
tmp_obj = deepcopy(obj)
tmp_obj['bbox'] = np.array((x1, y1, x2, y2)) - np.array((0, 0, x1, y1))
valid_objs.append(tmp_obj)
else:
assert False, 'invalid bbox!'
# rewrite
objs[image_id] = valid_objs
for obj in objs[image_id]:
# keypoints
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float)
for ipt in range(self.num_joints):
joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
joints_3d[ipt, 2] = 0
t_vis = obj['keypoints'][ipt * 3 + 2]
if t_vis > 1:
t_vis = 1
joints_3d_vis[ipt, 0] = t_vis
joints_3d_vis[ipt, 1] = t_vis
joints_3d_vis[ipt, 2] = 0
scale, center = self._bbox2sc(obj['bbox'])
# reform and save
self.db.append({
'id': int(item['id']),
'image': os.path.join(image_path, item['file_name']),
'center': center,
'scale': scale,
'joints_3d': joints_3d,
'joints_3d_vis': joints_3d_vis,
})
def load_detect_dataset(self, image_path, ann_file, bbox_file):
# reset self.db
self.db = []
# open detect file
all_boxes = None
with open(bbox_file, 'r') as f:
all_boxes = json.load(f)
assert all_boxes, 'Loading %s fail!' % bbox_file
print('Total boxes: {}'.format(len(all_boxes)))
# load json file and decode
with open(ann_file, "rb") as f:
lines = f.readlines()
json_dict = json.loads(lines[0].decode("utf-8"))
# build a map from id to file name
index_to_filename = {}
for item in json_dict['images']:
index_to_filename[item['id']] = item['file_name']
# load each item into db
for det_res in all_boxes:
if det_res['category_id'] != 1:
continue
# load image
image = os.path.join(image_path,
index_to_filename[det_res['image_id']])
bbox = det_res['bbox']
score = det_res['score']
if score < self.image_thre:
continue
scale, center = self._bbox2sc(bbox)
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
joints_3d_vis = np.ones((self.num_joints, 3), dtype=np.float)
self.db.append({
'id': int(det_res['image_id']),
'image': image,
'center': center,
'scale': scale,
'score': score,
'joints_3d': joints_3d,
'joints_3d_vis': joints_3d_vis,
})
def _bbox2sc(self, bbox):
"""
reform xywh to meet the need of aspect ratio
"""
x, y, w, h = bbox[:4]
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array(
[w * 1.0 / 200, h * 1.0 / 200], dtype=np.float32)
if center[0] != -1:
scale = scale * 1.25
return scale, center
def __getitem__(self, idx):
db_rec = deepcopy(self.db[idx])
image_file = db_rec['image']
data_numpy = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
if data_numpy is None:
print('[ERROR] fail to read {}'.format(image_file))
raise ValueError('Fail to read {}'.format(image_file))
joints = db_rec['joints_3d']
joints_vis = db_rec['joints_3d_vis']
c = db_rec['center']
s = db_rec['scale']
score = db_rec['score'] if 'score' in db_rec else 1
r = 0
if self.is_train:
sf = self.scale_factor
rf = self.rotation_factor
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \
if random.random() <= 0.6 else 0
if self.flip and random.random() <= 0.5:
data_numpy = data_numpy[:, ::-1, :]
joints, joints_vis = fliplr_joints(
joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
c[0] = data_numpy.shape[1] - c[0] - 1
trans = get_affine_transform(c, s, r, self.image_size)
image = cv2.warpAffine(
data_numpy,
trans,
(int(self.image_size[0]), int(self.image_size[1])),
flags=cv2.INTER_LINEAR)
for i in range(self.num_joints):
if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
target, target_weight = self.generate_heatmap(joints, joints_vis)
return image, target, target_weight, s, c, score, db_rec['id']
def generate_heatmap(self, joints, joints_vis):
target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
target_weight[:, 0] = joints_vis[:, 0]
assert self.target_type == 'gaussian', \
'Only support gaussian map now!'
if self.target_type == 'gaussian':
target = np.zeros((self.num_joints,
self.heatmap_size[1],
self.heatmap_size[0]),
dtype=np.float32)
tmp_size = self.sigma * 3
for joint_id in range(self.num_joints):
feat_stride = self.image_size / self.heatmap_size
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
or br[0] < 0 or br[1] < 0:
# If not, just return the image as is
target_weight[joint_id] = 0
continue
# # Generate gaussian
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
v = target_weight[joint_id]
if v > 0.5:
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return target, target_weight
def __len__(self):
return len(self.db)
def keypoint_dataset(config,
ann_file=None,
image_path=None,
bbox_file=None,
rank=0,
group_size=1,
train_mode=True,
num_parallel_workers=8,
transform=None,
shuffle=None):
"""
A function that returns an imagenet dataset for classification. The mode of input dataset should be "folder" .
Args:
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided
into (default=None).
mode (str): "train" or others. Default: " train".
num_parallel_workers (int): Number of workers to read the data. Default: None.
"""
# config
per_batch_size = config.TRAIN.BATCH_SIZE if train_mode else config.TEST.BATCH_SIZE
image_path = image_path if image_path else os.path.join(config.DATASET.ROOT,
config.DATASET.TRAIN_SET
if train_mode else config.DATASET.TEST_SET)
print('loading dataset from {}'.format(image_path))
ann_file = ann_file if ann_file else os.path.join(config.DATASET.ROOT,
'annotations/person_keypoints_{}2017.json'.format(
'train' if train_mode else 'val'))
shuffle = shuffle if shuffle is not None else train_mode
# gen dataset db
dataset_generator = KeypointDatasetGenerator(config, is_train=train_mode)
if not train_mode and not config.TEST.USE_GT_BBOX:
print('loading bbox file from {}'.format(bbox_file))
dataset_generator.load_detect_dataset(image_path, ann_file, bbox_file)
else:
dataset_generator.load_gt_dataset(image_path, ann_file)
# construct dataset
de_dataset = de.GeneratorDataset(dataset_generator,
column_names=["image", "target", "weight", "scale", "center", "score", "id"],
num_parallel_workers=num_parallel_workers,
num_shards=group_size,
shard_id=rank,
shuffle=shuffle)
# inputs map functions
if transform is None:
transform_img = [
V_C.Rescale(1.0 / 255.0, 0.0),
V_C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
V_C.HWC2CHW()
]
else:
transform_img = transform
de_dataset = de_dataset.map(input_columns="image",
num_parallel_workers=num_parallel_workers,
operations=transform_img)
# batch
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=train_mode)
return de_dataset, dataset_generator

View File

@ -0,0 +1,132 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import json
import os
import pickle
from collections import defaultdict, OrderedDict
import numpy as np
try:
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
has_coco = True
except ImportError:
has_coco = False
from src.utils.nms import oks_nms
def _write_coco_keypoint_results(img_kpts, num_joints, res_file):
results = []
for img, items in img_kpts.items():
item_size = len(items)
if not items:
continue
# keypoints array at coco format
kpts = np.array([items[k]['keypoints']
for k in range(item_size)])
keypoints = np.zeros((item_size, num_joints * 3), dtype=np.float)
keypoints[:, 0::3] = kpts[:, :, 0]
keypoints[:, 1::3] = kpts[:, :, 1]
keypoints[:, 2::3] = kpts[:, :, 2]
result = [{'image_id': int(img),
'keypoints': list(keypoints[k]),
'score': items[k]['score'],
'category_id': 1,
} for k in range(item_size)]
results.extend(result)
with open(res_file, 'w') as f:
json.dump(results, f, sort_keys=True, indent=4)
def _do_python_keypoint_eval(res_file, res_folder, ann_path):
coco = COCO(ann_path)
coco_dt = coco.loadRes(res_file)
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
coco_eval.params.useSegm = None
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
info_str = []
for ind, name in enumerate(stats_names):
info_str.append((name, coco_eval.stats[ind]))
eval_file = os.path.join(
res_folder, 'keypoints_results.pkl')
with open(eval_file, 'wb') as f:
pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
print('coco eval results saved to %s' % eval_file)
return info_str
# need double check this API and classes field
def evaluate(cfg, preds, output_dir, all_boxes, img_id):
res_folder = os.path.join(output_dir, 'results')
if not os.path.exists(res_folder):
os.makedirs(res_folder)
res_file = os.path.join(res_folder, 'keypoints_results.json')
# image -> list(keypoints/area/score)
img_kpts_dict = defaultdict(list)
for idx, file_id in enumerate(img_id):
img_kpts_dict[file_id].append({
'keypoints': preds[idx],
'area': all_boxes[idx][0],
'score': all_boxes[idx][1],
})
# rescoring and oks nms
num_joints = cfg.MODEL.NUM_JOINTS
in_vis_thre = cfg.TEST.IN_VIS_THRE
oks_thre = cfg.TEST.OKS_THRE
oks_nmsed_kpts = {}
for img, items in img_kpts_dict.items():
for item in items:
kpt_score = 0
valid_num = 0
for n_jt in range(num_joints):
max_jt = item['keypoints'][n_jt][2]
if max_jt > in_vis_thre:
kpt_score = kpt_score + max_jt
valid_num = valid_num + 1
if valid_num != 0:
kpt_score = kpt_score / valid_num
# rescoring
item['score'] = kpt_score * item['score']
keep = oks_nms(items, oks_thre)
if not keep:
oks_nmsed_kpts[img] = items
else:
oks_nmsed_kpts[img] = [items[kep] for kep in keep]
# evaluate and save
image_set = cfg.DATASET.TEST_SET
_write_coco_keypoint_results(oks_nmsed_kpts, num_joints, res_file)
if 'test' not in image_set and has_coco:
ann_path = os.path.join(cfg.DATASET.ROOT, 'annotations',
'person_keypoints_' + image_set + '.json')
info_str = _do_python_keypoint_eval(
res_file, res_folder, ann_path)
name_value = OrderedDict(info_str)
return name_value, name_value['AP']
return {'Null': 0}, 0

View File

@ -0,0 +1,225 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from collections import OrderedDict
import mindspore.nn as nn
import mindspore.ops.operations as F
from mindspore.common.initializer import Normal
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import ParameterTuple
BN_MOMENTUM = 0.1
class MaxPool2dPytorch(nn.Cell):
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
super(MaxPool2dPytorch, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode)
self.reverse = F.ReverseV2(axis=[2, 3])
def construct(self, x):
x = self.reverse(x)
x = self.maxpool(x)
x = self.reverse(x)
return x
class Bottleneck(nn.Cell):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, has_bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
pad_mode='pad', padding=1, has_bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
has_bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU()
self.down_sample_layer = downsample
self.stride = stride
def construct(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample_layer is not None:
residual = self.down_sample_layer(x)
out += residual
out = self.relu(out)
return out
class PoseResNet(nn.Cell):
def __init__(self, block, layers, cfg, pytorch_mode=True):
self.inplanes = 64
extra = cfg.MODEL.EXTRA
self.deconv_with_bias = extra.DECONV_WITH_BIAS
super(PoseResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2,
pad_mode='pad', padding=3, has_bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU()
if pytorch_mode:
self.maxpool = MaxPool2dPytorch(kernel_size=3, stride=2, pad_mode='same')
print("use pytorch-style maxpool")
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
print("use mindspore-style maxpool")
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# used for deconv layers
self.deconv_layers = self._make_deconv_layer(
extra.NUM_DECONV_LAYERS,
extra.NUM_DECONV_FILTERS,
extra.NUM_DECONV_KERNELS,
)
self.final_layer = nn.Conv2d(
in_channels=extra.NUM_DECONV_FILTERS[-1],
out_channels=cfg.MODEL.NUM_JOINTS,
kernel_size=extra.FINAL_CONV_KERNEL,
stride=1,
pad_mode='pad',
padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0,
has_bias=True,
weight_init=Normal(0.001),
)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.SequentialCell(OrderedDict([
('0', nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, has_bias=False)),
('1', nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM)),
]))
layers = OrderedDict()
layers['0'] = block(self.inplanes, planes, stride, downsample)
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers['{}'.format(i)] = block(self.inplanes, planes)
return nn.SequentialCell(layers)
def _get_deconv_cfg(self, deconv_kernel):
assert deconv_kernel == 4, 'only support kernel_size = 4 for deconvolution layers'
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
return deconv_kernel, padding, output_padding
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
layers = OrderedDict()
for i in range(num_layers):
kernel, padding, _ = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers['deconv_{}'.format(i)] = nn.SequentialCell(OrderedDict([
('deconv', nn.Conv2dTranspose(
in_channels=self.inplanes,
out_channels=planes,
kernel_size=kernel,
stride=2,
pad_mode='pad',
padding=padding,
has_bias=self.deconv_with_bias,
weight_init=Normal(0.001),
)),
('bn', nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)),
('relu', nn.ReLU()),
]))
self.inplanes = planes
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.deconv_layers(x)
x = self.final_layer(x)
return x
def init_weights(self, pretrained=''):
if os.path.isfile(pretrained):
# load params from pretrained
param_dict = load_checkpoint(pretrained)
weight = ParameterTuple(self.trainable_params())
for w in weight:
if w.name.split('.')[0] not in ('deconv_layers', 'final_layer'):
assert w.name in param_dict, "parameter %s not in checkpoint" % w.name
load_param_into_net(self, param_dict)
print('loading pretrained model {}'.format(pretrained))
else:
assert False, '{} is not a file'.format(pretrained)
resnet_spec = {50: (Bottleneck, [3, 4, 6, 3]),
101: (Bottleneck, [3, 4, 23, 3]),
152: (Bottleneck, [3, 8, 36, 3])}
def get_pose_net(cfg, is_train, ckpt_path=None, pytorch_mode=False):
num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
block_class, layers = resnet_spec[num_layers]
model = PoseResNet(block_class, layers, cfg, pytorch_mode=pytorch_mode)
if is_train and cfg.MODEL.INIT_WEIGHTS:
model.init_weights(ckpt_path)
return model

View File

@ -0,0 +1,85 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.loss.loss import _Loss
from mindspore.common import dtype as mstype
class JointsMSELoss(_Loss):
def __init__(self, use_target_weight):
super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='mean')
self.use_target_weight = use_target_weight
self.reshape = P.Reshape()
self.squeeze = P.Squeeze(1)
self.mul = P.Mul()
def construct(self, output, target, target_weight):
batch_size = F.shape(output)[0]
num_joints = F.shape(output)[1]
split = P.Split(1, num_joints)
heatmaps_pred = self.reshape(output, (batch_size, num_joints, -1))
heatmaps_pred = split(heatmaps_pred)
heatmaps_gt = self.reshape(target, (batch_size, num_joints, -1))
heatmaps_gt = split(heatmaps_gt)
loss = 0
for idx in range(num_joints):
heatmap_pred = self.squeeze(heatmaps_pred[idx])
heatmap_gt = self.squeeze(heatmaps_gt[idx])
if self.use_target_weight:
loss += 0.5 * self.criterion(
self.mul(heatmap_pred, target_weight[:, idx]),
self.mul(heatmap_gt, target_weight[:, idx])
)
else:
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
return loss / num_joints
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, image, target, weight, scale=None,
center=None, score=None, idx=None):
out = self._backbone(image)
output = F.mixed_precision_cast(mstype.float32, out)
target = F.mixed_precision_cast(mstype.float32, target)
weight = F.mixed_precision_cast(mstype.float32, weight)
return self._loss_fn(output, target, weight)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone

View File

@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import numpy as np
from src.utils.transform import transform_preds
def get_max_preds(batch_heatmaps):
'''
get predictions from score maps
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
'''
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
def get_final_preds(config, batch_heatmaps, center, scale):
coords, maxvals = get_max_preds(batch_heatmaps)
heatmap_height = batch_heatmaps.shape[2]
heatmap_width = batch_heatmaps.shape[3]
# post-processing
if config.TEST.POST_PROCESS:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = batch_heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
diff = np.array([hm[py][px + 1] - hm[py][px - 1],
hm[py + 1][px] - hm[py - 1][px]])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy()
# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(coords[i], center[i], scale[i],
[heatmap_width, heatmap_height])
return preds, maxvals

View File

@ -0,0 +1,55 @@
import numpy as np
def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
if not isinstance(sigmas, np.ndarray):
sigmas = np.array(
[.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
vas = (sigmas * 2) ** 2
xg = g[0::3]
yg = g[1::3]
vg = g[2::3]
ious = np.zeros((d.shape[0]))
for n_d in range(0, d.shape[0]):
xd = d[n_d, 0::3]
yd = d[n_d, 1::3]
vd = d[n_d, 2::3]
dx = xd - xg
dy = yd - yg
e = (dx ** 2 + dy ** 2) / vas / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
if in_vis_thre is not None:
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
e = e[ind]
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
return ious
def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
"""
greedily select boxes with high confidence and overlap with current maximum <= thresh
rule out overlap >= thresh, overlap = oks
:param kpts_db
:param thresh: retain overlap < thresh
:return: indexes to keep
"""
kpts_size = len(kpts_db)
if kpts_size == 0:
return []
scores = np.array([kpts_db[i]['score'] for i in range(kpts_size)])
kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(kpts_size)])
areas = np.array([kpts_db[i]['area'] for i in range(kpts_size)])
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
inds = np.where(oks_ovr <= thresh)[0]
order = order[inds + 1]
return keep

View File

@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import cv2
def fliplr_joints(joints, joints_vis, width, matched_parts):
"""
flip coords
"""
# Flip horizontal
joints[:, 0] = width - joints[:, 0] - 1
# Change left-right parts
for pair in matched_parts:
joints[pair[0], :], joints[pair[1], :] = \
joints[pair[1], :], joints[pair[0], :].copy()
joints_vis[pair[0]], joints_vis[pair[1]] = \
joints_vis[pair[1]], joints_vis[pair[0]].copy()
return joints * joints_vis, joints_vis
def flip_back(output_flipped, matched_parts):
'''
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
'''
assert output_flipped.ndim == 4, \
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def transform_preds(coords, center, scale, output_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
print(scale)
scale = np.array([scale, scale])
scale_tmp = scale * 200.0
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = _get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
src[2:, :] = _get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def _get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def _get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result

View File

@ -0,0 +1,148 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.train import Model
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.nn.optim import Adam
from mindspore.common import set_seed
from src.config import config
from src.model import get_pose_net
from src.network_define import JointsMSELoss, WithLossCell
from src.dataset import keypoint_dataset
set_seed(1)
device_id = int(os.getenv('DEVICE_ID'))
def get_lr(begin_epoch,
total_epochs,
steps_per_epoch,
lr_init=0.1,
factor=0.1,
epoch_number_to_drop=(90, 120)
):
"""
Generate learning rate array.
Args:
begin_epoch (int): Initial epoch of training.
total_epochs (int): Total epoch of training.
steps_per_epoch (float): Steps of one epoch.
lr_init (float): Initial learning rate. Default: 0.316.
factor:Factor of lr to drop.
epoch_number_to_drop:Learing rate will drop after these epochs.
Returns:
np.array, learning rate array.
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
step_number_to_drop = [steps_per_epoch * x for x in epoch_number_to_drop]
for i in range(int(total_steps)):
if i in step_number_to_drop:
lr_init = lr_init * factor
lr_each_step.append(lr_init)
current_step = steps_per_epoch * begin_epoch
lr_each_step = np.array(lr_each_step, dtype=np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
def parse_args():
parser = argparse.ArgumentParser(description="Simpleposenet training")
parser.add_argument("--run-distribute",
help="Run distribute, default is false.",
action='store_true')
parser.add_argument('--ckpt-path', type=str, help='ckpt path to save')
parser.add_argument('--batch-size', type=int, help='training batch size')
args = parser.parse_args()
return args
def main():
# load parse and config
print("loading parse...")
args = parse_args()
if args.batch_size:
config.TRAIN.BATCH_SIZE = args.batch_size
print('batch size :{}'.format(config.TRAIN.BATCH_SIZE))
# distribution and context
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id)
if args.run_distribute:
init()
rank = get_rank()
device_num = get_group_size()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
rank = 0
device_num = 1
# only rank = 0 can write
rank_save_flag = False
if rank == 0 or device_num == 1:
rank_save_flag = True
# create dataset
dataset, _ = keypoint_dataset(config,
rank=rank,
group_size=device_num,
train_mode=True,
num_parallel_workers=8)
# network
net = get_pose_net(config, True, ckpt_path=config.MODEL.PRETRAINED)
loss = JointsMSELoss(use_target_weight=True)
net_with_loss = WithLossCell(net, loss)
# lr schedule and optim
dataset_size = dataset.get_dataset_size()
lr = Tensor(get_lr(config.TRAIN.BEGIN_EPOCH,
config.TRAIN.END_EPOCH,
dataset_size,
lr_init=config.TRAIN.LR,
factor=config.TRAIN.LR_FACTOR,
epoch_number_to_drop=config.TRAIN.LR_STEP))
opt = Adam(net.trainable_params(), learning_rate=lr)
# callback
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if args.ckpt_path and rank_save_flag:
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=20)
ckpoint_cb = ModelCheckpoint(prefix="simplepose", directory=args.ckpt_path, config=config_ck)
cb.append(ckpoint_cb)
# train model
model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
print('start training, epoch size = %d' % epoch_size)
model.train(epoch_size, dataset, callbacks=cb)
if __name__ == '__main__':
main()