forked from mindspore-Ecosystem/mindspore
add simplepose implementation
This commit is contained in:
parent
ef0b483eb4
commit
15d8cccd7b
|
@ -0,0 +1,340 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [Description](#description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Mixed Precision](#mixed-precision)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Dataset Preparation](#dataset-preparation)
|
||||||
|
- [Model Checkpoints](#model-checkpoints)
|
||||||
|
- [Running](#running)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Distributed Training](#distributed-training)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
# [Description](#contents)
|
||||||
|
|
||||||
|
SimplePoseNet is a convolution-based neural network for the task of human pose estimation and tracking. It provides baseline methods that are surprisingly simple and effective, thus helpful for inspiring and evaluating new ideas for the field. State-of-the-art results are achieved on challenging benchmarks. More detail about this model can be found in:
|
||||||
|
|
||||||
|
B. Xiao, H. Wu, and Y. Wei, “Simple baselines for human pose estimation and tracking,” in Proc. Eur. Conf. Comput. Vis., 2018, pp. 472–487.
|
||||||
|
|
||||||
|
This repository contains a Mindspore implementation of SimplePoseNet based upon Microsoft's original Pytorch implementation (<https://github.com/microsoft/human-pose-estimation.pytorch>). The training and validating scripts are also included, and the evaluation results are shown in the [Performance](#performance) section.
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
The overall network architecture of SimplePoseNet is shown below:
|
||||||
|
|
||||||
|
[Link](https://arxiv.org/pdf/1804.06208.pdf)
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||||
|
|
||||||
|
Dataset used: COCO2017
|
||||||
|
|
||||||
|
- Dataset size:
|
||||||
|
- Train: 19G, 118,287 images
|
||||||
|
- Test: 788MB, 5,000 images
|
||||||
|
- Data format: JPG images
|
||||||
|
- Note: Data will be processed in `src/dataset.py`
|
||||||
|
- Person detection result for validation: Detection result provided by author in the [repository](https://github.com/microsoft/human-pose-estimation.pytorch)
|
||||||
|
|
||||||
|
# [Features](#contents)
|
||||||
|
|
||||||
|
## [Mixed Precision](#contents)
|
||||||
|
|
||||||
|
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
To run the python scripts in the repository, you need to prepare the environment as follow:
|
||||||
|
|
||||||
|
- Hardware
|
||||||
|
- Prepare hardware environment with Ascend. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to [ascend@huawei.com](mailto:ascend@huawei.com). Once approved, you can get the resources.
|
||||||
|
- Python and dependencies
|
||||||
|
- python 3.7
|
||||||
|
- mindspore 1.0.1
|
||||||
|
- easydict 1.9
|
||||||
|
- opencv-python 4.3.0.36
|
||||||
|
- pycocotools 2.0
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||||
|
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
## [Dataset Preparation](#contents)
|
||||||
|
|
||||||
|
SimplePoseNet use COCO2017 dataset to train and validate in this repository. Download the dataset from [official website](https://cocodataset.org/). You can place the dataset anywhere and tell the scripts where it is by modifying the `DATASET.ROOT` setting in configuration file `src/config.py`. For more information about the configuration file, please refer to [Script Parameters](#script-parameters).
|
||||||
|
|
||||||
|
You also need the person detection result of COCO val2017 to reproduce the multi-person pose estimation results, as mentioned in [Dataset](#dataset). Please checkout the author's repository, download and extract them under `<ROOT>/experiments/`, and make them look like this:
|
||||||
|
|
||||||
|
```text
|
||||||
|
└─ <ROOT>
|
||||||
|
└─ experiments
|
||||||
|
└─ COCO_val2017_detections_AP_H_56_person.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Model Checkpoints](#contents)
|
||||||
|
|
||||||
|
Before you start your training process, you need to obtain mindspore imagenet pretrained models. The model weight file can be obtained by running the Resnet training script in [official model zoo](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet). We also provide a pretrained model that can be used to train SimplePoseNet directly in [GoogleDrive](https://drive.google.com/file/d/1r3Hs0QNys0HyNtsQhSvx6IKdyRkC-3Hh/view?usp=sharing). The model file should be placed under `<ROOT>/models/` like this:
|
||||||
|
|
||||||
|
```text
|
||||||
|
└─ <ROOT>
|
||||||
|
└─ models
|
||||||
|
└─resnet50.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Running](#contents)
|
||||||
|
|
||||||
|
To train the model, run the shell script `scripts/train_standalone.sh` with the format below:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh scripts/train_standalone.sh [device_id] [ckpt_path_to_save]
|
||||||
|
```
|
||||||
|
|
||||||
|
To validate the model, change the settings in `src/config.py` to the path of the model you want to validate. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config.TEST.MODEL_FILE='results/xxxx.ckpt'
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, run the shell script `scripts/eval.sh` with the format below:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh scripts/eval.sh [device_id]
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
The structure of the files in this repository is shown below.
|
||||||
|
|
||||||
|
```text
|
||||||
|
└─ mindspore-simpleposenet
|
||||||
|
├─ scripts
|
||||||
|
│ ├─ eval.sh // launch ascend standalone evaluation
|
||||||
|
│ ├─ train_distributed.sh // launch ascend distributed training
|
||||||
|
│ └─ train_standalone.sh // launch ascend standalone training
|
||||||
|
├─ src
|
||||||
|
│ ├─utils
|
||||||
|
│ │ ├─ transform.py // utils about image transformation
|
||||||
|
│ │ └─ nms.py // utils about nms
|
||||||
|
│ ├─evaluate
|
||||||
|
│ │ └─ coco_eval.py // evaluate result by coco
|
||||||
|
│ ├─ config.py // network and running config
|
||||||
|
│ ├─ dataset.py // dataset processor and provider
|
||||||
|
│ ├─ model.py // SimplePoseNet implementation
|
||||||
|
│ ├─ network_define.py // define loss
|
||||||
|
│ └─ predict.py // predict keypoints from heatmaps
|
||||||
|
├─ eval.py // evaluation script
|
||||||
|
├─ param_convert.py // model parameters convertion script
|
||||||
|
├─ train.py // training script
|
||||||
|
└─ README.md // descriptions about this repository
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
Configurations for both training and evaluation are set in `src/config.py`. All the settings are shown following.
|
||||||
|
|
||||||
|
- config for SimplePoseNet on COCO2017 dataset:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# pose_resnet related params
|
||||||
|
POSE_RESNET.HEATMAP_SIZE = [48, 64] # heatmap size
|
||||||
|
POSE_RESNET.SIGMA = 2 # Gaussian hyperparameter in heatmap generation
|
||||||
|
POSE_RESNET.FINAL_CONV_KERNEL = 1 # final convolution kernel size
|
||||||
|
POSE_RESNET.DECONV_WITH_BIAS = False # deconvolution bias
|
||||||
|
POSE_RESNET.NUM_DECONV_LAYERS = 3 # the number of deconvolution layers
|
||||||
|
POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256] # the filter size of deconvolution layers
|
||||||
|
POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4] # kernel size of deconvolution layers
|
||||||
|
POSE_RESNET.NUM_LAYERS = 50 # number of layers(for resnet)
|
||||||
|
# common params for NETWORK
|
||||||
|
config.MODEL.NAME = 'pose_resnet' # model name
|
||||||
|
config.MODEL.INIT_WEIGHTS = True # init model weights by resnet
|
||||||
|
config.MODEL.PRETRAINED = './models/resnet50.ckpt' # pretrained model
|
||||||
|
config.MODEL.NUM_JOINTS = 17 # the number of keypoints
|
||||||
|
config.MODEL.IMAGE_SIZE = [192, 256] # image size
|
||||||
|
# dataset
|
||||||
|
config.DATASET.ROOT = '/data/coco2017/' # coco2017 dataset root
|
||||||
|
config.DATASET.TEST_SET = 'val2017' # folder name of test set
|
||||||
|
config.DATASET.TRAIN_SET = 'train2017' # folder name of train set
|
||||||
|
# data augmentation
|
||||||
|
config.DATASET.FLIP = True # random flip
|
||||||
|
config.DATASET.ROT_FACTOR = 40 # random rotation
|
||||||
|
config.DATASET.SCALE_FACTOR = 0.3 # random scale
|
||||||
|
# for train
|
||||||
|
config.TRAIN.BATCH_SIZE = 64 # batch size
|
||||||
|
config.TRAIN.BEGIN_EPOCH = 0 # begin epoch
|
||||||
|
config.TRAIN.END_EPOCH = 140 # end epoch
|
||||||
|
config.TRAIN.LR = 0.001 # initial learning rate
|
||||||
|
config.TRAIN.LR_FACTOR = 0.1 # learning rate reduce factor
|
||||||
|
config.TRAIN.LR_STEP = [90,120] # step to reduce lr
|
||||||
|
# test
|
||||||
|
config.TEST.BATCH_SIZE = 32 # batch size
|
||||||
|
config.TEST.FLIP_TEST = True # flip test
|
||||||
|
config.TEST.POST_PROCESS = True # post process
|
||||||
|
config.TEST.SHIFT_HEATMAP = True # shift heatmap
|
||||||
|
config.TEST.USE_GT_BBOX = False # use groundtruth bbox
|
||||||
|
config.TEST.MODEL_FILE = '' # model file to test
|
||||||
|
# detect bbox file
|
||||||
|
config.TEST.COCO_BBOX_FILE = 'experiments/COCO_val2017_detections_AP_H_56_person.json'
|
||||||
|
# nms
|
||||||
|
config.TEST.OKS_THRE = 0.9 # oks threshold
|
||||||
|
config.TEST.IN_VIS_THRE = 0.2 # visible threshold
|
||||||
|
config.TEST.BBOX_THRE = 1.0 # bbox threshold
|
||||||
|
config.TEST.IMAGE_THRE = 0.0 # image threshold
|
||||||
|
config.TEST.NMS_THRE = 1.0 # nms threshold
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Training Process](#contents)
|
||||||
|
|
||||||
|
### [Training](#contents)
|
||||||
|
|
||||||
|
#### Running on Ascend
|
||||||
|
|
||||||
|
Run `scripts/train_standalone.sh` to train the model standalone. The usage of the script is:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh scripts/train_standalone.sh [device_id] [ckpt_path_to_save]
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, you can run the shell command below to launch the training procedure.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh scripts/train_standalone.sh 0 results/standalone/
|
||||||
|
```
|
||||||
|
|
||||||
|
The script will run training in the background, you can view the results through the file `train_log[X].txt` as follows:
|
||||||
|
|
||||||
|
```text
|
||||||
|
loading parse...
|
||||||
|
batch size :128
|
||||||
|
loading dataset from /data/coco2017/train2017
|
||||||
|
loaded 149813 records from coco dataset.
|
||||||
|
loading pretrained model ./models/resnet50.ckpt
|
||||||
|
start training, epoch size = 140
|
||||||
|
epoch: 1 step: 1170, loss is 0.000699
|
||||||
|
Epoch time: 492271.194, per step time: 420.745
|
||||||
|
epoch: 2 step: 1170, loss is 0.000586
|
||||||
|
Epoch time: 456265.617, per step time: 389.971
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
||||||
|
|
||||||
|
### [Distributed Training](#contents)
|
||||||
|
|
||||||
|
#### Running on Ascend
|
||||||
|
|
||||||
|
Run `scripts/train_distributed.sh` to train the model distributed. The usage of the script is:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh scripts/train_distributed.sh [rank_table] [ckpt_path_to_save] [device_number]
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, you can run the shell command below to launch the distributed training procedure.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh scripts/train_distributed.sh /home/rank_table.json results/distributed/ 4
|
||||||
|
```
|
||||||
|
|
||||||
|
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt` as follows:
|
||||||
|
|
||||||
|
```text
|
||||||
|
loading parse...
|
||||||
|
batch size :64
|
||||||
|
loading dataset from /data/coco2017/train2017
|
||||||
|
loaded 149813 records from coco dataset.
|
||||||
|
loading pretrained model ./models/resnet50.ckpt
|
||||||
|
start training, epoch size = 140
|
||||||
|
epoch: 1 step: 585, loss is 0.0007944
|
||||||
|
Epoch time: 236219.684, per step time: 403.794
|
||||||
|
epoch: 2 step: 585, loss is 0.000617
|
||||||
|
Epoch time: 164792.001, per step time: 281.696
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Running on Ascend
|
||||||
|
|
||||||
|
Change the settings in `src/config.py` to the path of the model you want to validate. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config.TEST.MODEL_FILE='results/xxxx.ckpt'
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, run `scripts/eval.sh` to evaluate the model with one Ascend processor. The usage of the script is:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh scripts/eval.sh [device_id]
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, you can run the shell command below to launch the validation procedure.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh scripts/eval.sh 0
|
||||||
|
```
|
||||||
|
|
||||||
|
The above shell command will run validation procedure in the background. You can view the results through the file `eval_log[X].txt`. The result will be achieved as follows:
|
||||||
|
|
||||||
|
```text
|
||||||
|
use flip test: True
|
||||||
|
loading model ckpt from results/distributed/sim-140_1170.ckpt
|
||||||
|
loading dataset from /data/coco2017/val2017
|
||||||
|
loading bbox file from experiments/COCO_val2017_detections_AP_H_56_person.json
|
||||||
|
Total boxes: 104125
|
||||||
|
1024 samples validated in 18.133189916610718 seconds
|
||||||
|
2048 samples validated in 4.724390745162964 seconds
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### SimplePoseNet on COCO2017 with detector
|
||||||
|
|
||||||
|
#### Performance parameters
|
||||||
|
|
||||||
|
| Parameters | Standalone | Distributed |
|
||||||
|
| ------------------- | --------------------------- | --------------------------- |
|
||||||
|
| Model Version | SimplePoseNet | SimplePoseNet |
|
||||||
|
| Resource | Ascend 910 | 4 Ascend 910 cards |
|
||||||
|
| Uploaded Date | 12/18/2020 (month/day/year) | 12/18/2020 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.1.0 | 1.1.0 |
|
||||||
|
| Dataset | COCO2017 | COCO2017 |
|
||||||
|
| Training Parameters | epoch=140, batch_size=128 | epoch=140, batch_size=64 |
|
||||||
|
| Optimizer | Adam | Adam |
|
||||||
|
| Loss Function | Mean Squared Error | Mean Squared Error |
|
||||||
|
| Outputs | heatmap | heatmap |
|
||||||
|
| Train Performance | mAP: 70.4 | mAP: 70.4 |
|
||||||
|
| Speed | 1pc: 389.915 ms/step | 4pc: 281.356 ms/step |
|
||||||
|
|
||||||
|
#### Note
|
||||||
|
|
||||||
|
- Flip test is used.
|
||||||
|
- Person detector has person AP of 56.4 on COCO val2017 dataset.
|
||||||
|
- The dataset preprocessing and general training configurations are shown in [Script Parameters](#script-parameters) section.
|
||||||
|
|
||||||
|
# [Description of Random Situation](#contents)
|
||||||
|
|
||||||
|
In `src/dataset.py`, we set the seed inside “create_dataset" function. We also use random seed in `src/model.py` to initial network weights.
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,180 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from mindspore import Tensor, float32, context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
from src.dataset import flip_pairs, keypoint_dataset
|
||||||
|
from src.evaluate.coco_eval import evaluate
|
||||||
|
from src.model import get_pose_net
|
||||||
|
from src.utils.transform import flip_back
|
||||||
|
from src.predict import get_final_preds
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Train keypoints network')
|
||||||
|
parser.add_argument("--train_url", type=str, default="", help="")
|
||||||
|
parser.add_argument("--data_url", type=str, default="", help="data")
|
||||||
|
# output
|
||||||
|
parser.add_argument('--output-url',
|
||||||
|
help='output dir',
|
||||||
|
type=str)
|
||||||
|
# training
|
||||||
|
parser.add_argument('--workers',
|
||||||
|
help='num of dataloader workers',
|
||||||
|
default=8,
|
||||||
|
type=int)
|
||||||
|
parser.add_argument('--model-file',
|
||||||
|
help='model state file',
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--use-detect-bbox',
|
||||||
|
help='use detect bbox',
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--flip-test',
|
||||||
|
help='use flip test',
|
||||||
|
default=True,
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--post-process',
|
||||||
|
help='use post process',
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--shift-heatmap',
|
||||||
|
help='shift heatmap',
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--coco-bbox-file',
|
||||||
|
help='coco detection bbox file',
|
||||||
|
type=str)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def reset_config(cfg, args):
|
||||||
|
if args.use_detect_bbox:
|
||||||
|
cfg.TEST.USE_GT_BBOX = not args.use_detect_bbox
|
||||||
|
if args.flip_test:
|
||||||
|
cfg.TEST.FLIP_TEST = args.flip_test
|
||||||
|
print('use flip test:', cfg.TEST.FLIP_TEST)
|
||||||
|
if args.post_process:
|
||||||
|
cfg.TEST.POST_PROCESS = args.post_process
|
||||||
|
if args.shift_heatmap:
|
||||||
|
cfg.TEST.SHIFT_HEATMAP = args.shift_heatmap
|
||||||
|
if args.model_file:
|
||||||
|
cfg.TEST.MODEL_FILE = args.model_file
|
||||||
|
if args.coco_bbox_file:
|
||||||
|
cfg.TEST.COCO_BBOX_FILE = args.coco_bbox_file
|
||||||
|
|
||||||
|
|
||||||
|
def validate(cfg, val_dataset, model, output_dir):
|
||||||
|
# switch to evaluate mode
|
||||||
|
model.set_train(False)
|
||||||
|
|
||||||
|
# init record
|
||||||
|
num_samples = val_dataset.get_dataset_size() * cfg.TEST.BATCH_SIZE
|
||||||
|
all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3),
|
||||||
|
dtype=np.float32)
|
||||||
|
all_boxes = np.zeros((num_samples, 2))
|
||||||
|
image_id = []
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
# start eval
|
||||||
|
start = time.time()
|
||||||
|
for item in val_dataset.create_dict_iterator():
|
||||||
|
# input data
|
||||||
|
inputs = item['image'].asnumpy()
|
||||||
|
# compute output
|
||||||
|
output = model(Tensor(inputs, float32)).asnumpy()
|
||||||
|
if cfg.TEST.FLIP_TEST:
|
||||||
|
inputs_flipped = Tensor(inputs[:, :, :, ::-1], float32)
|
||||||
|
output_flipped = model(inputs_flipped)
|
||||||
|
output_flipped = flip_back(output_flipped.asnumpy(), flip_pairs)
|
||||||
|
|
||||||
|
# feature is not aligned, shift flipped heatmap for higher accuracy
|
||||||
|
if cfg.TEST.SHIFT_HEATMAP:
|
||||||
|
output_flipped[:, :, :, 1:] = \
|
||||||
|
output_flipped.copy()[:, :, :, 0:-1]
|
||||||
|
# output_flipped[:, :, :, 0] = 0
|
||||||
|
|
||||||
|
output = (output + output_flipped) * 0.5
|
||||||
|
|
||||||
|
# meta data
|
||||||
|
c = item['center'].asnumpy()
|
||||||
|
s = item['scale'].asnumpy()
|
||||||
|
score = item['score'].asnumpy()
|
||||||
|
file_id = list(item['id'].asnumpy())
|
||||||
|
|
||||||
|
# pred by heatmaps
|
||||||
|
preds, maxvals = get_final_preds(cfg, output.copy(), c, s)
|
||||||
|
num_images, _ = preds.shape[:2]
|
||||||
|
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
|
||||||
|
all_preds[idx:idx + num_images, :, 2:3] = maxvals
|
||||||
|
# double check this all_boxes parts
|
||||||
|
all_boxes[idx:idx + num_images, 0] = np.prod(s * 200, 1)
|
||||||
|
all_boxes[idx:idx + num_images, 1] = score
|
||||||
|
image_id.extend(file_id)
|
||||||
|
idx += num_images
|
||||||
|
if idx % 1024 == 0:
|
||||||
|
print('{} samples validated in {} seconds'.format(idx, time.time() - start))
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
print(all_preds[:idx].shape, all_boxes[:idx].shape, len(image_id))
|
||||||
|
_, perf_indicator = evaluate(
|
||||||
|
cfg, all_preds[:idx], output_dir, all_boxes[:idx], image_id)
|
||||||
|
print("AP:", perf_indicator)
|
||||||
|
return perf_indicator
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# init seed
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
# set context
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
# update config
|
||||||
|
reset_config(config, args)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = get_pose_net(config, is_train=False)
|
||||||
|
|
||||||
|
# load parameters
|
||||||
|
ckpt_name = config.TEST.MODEL_FILE
|
||||||
|
print('loading model ckpt from {}'.format(ckpt_name))
|
||||||
|
load_param_into_net(model, load_checkpoint(ckpt_name))
|
||||||
|
|
||||||
|
# Data loading code
|
||||||
|
valid_dataset, _ = keypoint_dataset(
|
||||||
|
config,
|
||||||
|
bbox_file=config.TEST.COCO_BBOX_FILE,
|
||||||
|
train_mode=False,
|
||||||
|
num_parallel_workers=args.workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluate on validation set
|
||||||
|
validate(config, valid_dataset, model, ckpt_name.split('.')[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
|
||||||
|
python eval.py > eval_log$1.txt 2>&1 &
|
|
@ -0,0 +1,44 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
# Usage: sh train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [SAVE_CKPT_PATH] [RANK_SIZE]
|
||||||
|
|
||||||
|
export RANK_TABLE_FILE=$1
|
||||||
|
echo "RANK_TABLE_FILE=$RANK_TABLE_FILE"
|
||||||
|
export RANK_SIZE=$3
|
||||||
|
SAVE_PATH=$2
|
||||||
|
|
||||||
|
device=(0 1 2 3)
|
||||||
|
for((i=0;i<RANK_SIZE;i++))
|
||||||
|
do
|
||||||
|
export DEVICE_ID=${device[$i]}
|
||||||
|
export RANK_ID=$i
|
||||||
|
|
||||||
|
rm -rf ./train_parallel$i
|
||||||
|
mkdir ./train_parallel$i
|
||||||
|
echo "start training for rank $i, device $DEVICE_ID"
|
||||||
|
|
||||||
|
cd ./train_parallel$i ||exit
|
||||||
|
env > env.log
|
||||||
|
cd ../
|
||||||
|
python train.py \
|
||||||
|
--run-distribute \
|
||||||
|
--ckpt-path=$SAVE_PATH > train_parallel$i/log.txt 2>&1 &
|
||||||
|
|
||||||
|
echo "python train.py \
|
||||||
|
--run-distribute \
|
||||||
|
--ckpt-path=$SAVE_PATH > train_parallel$i/log.txt 2>&1 &"
|
||||||
|
|
||||||
|
done
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
# Usage: train_standalone.sh [DEVICE_ID] [SAVE_CKPT_PATH]
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
|
||||||
|
python train.py \
|
||||||
|
--ckpt-path=$2 --batch-size=128\
|
||||||
|
> train_log$1.txt 2>&1 &
|
||||||
|
echo " python train.py --ckpt-path=$2 --batch-size=128 > train_log$1.txt 2>&1 &"
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
config = edict()
|
||||||
|
|
||||||
|
# pose_resnet related params
|
||||||
|
POSE_RESNET = edict()
|
||||||
|
POSE_RESNET.NUM_LAYERS = 50
|
||||||
|
POSE_RESNET.DECONV_WITH_BIAS = False
|
||||||
|
POSE_RESNET.NUM_DECONV_LAYERS = 3
|
||||||
|
POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
|
||||||
|
POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
|
||||||
|
POSE_RESNET.FINAL_CONV_KERNEL = 1
|
||||||
|
POSE_RESNET.TARGET_TYPE = 'gaussian'
|
||||||
|
POSE_RESNET.HEATMAP_SIZE = [48, 64] # width * height, ex: 24 * 32
|
||||||
|
POSE_RESNET.SIGMA = 2
|
||||||
|
|
||||||
|
MODEL_EXTRAS = {
|
||||||
|
'pose_resnet': POSE_RESNET,
|
||||||
|
}
|
||||||
|
|
||||||
|
# common params for NETWORK
|
||||||
|
config.MODEL = edict()
|
||||||
|
config.MODEL.NAME = 'pose_resnet'
|
||||||
|
config.MODEL.INIT_WEIGHTS = True
|
||||||
|
config.MODEL.PRETRAINED = './models/resnet50.ckpt'
|
||||||
|
config.MODEL.NUM_JOINTS = 17
|
||||||
|
config.MODEL.IMAGE_SIZE = [192, 256] # width * height, ex: 192 * 256
|
||||||
|
config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
config.DATASET = edict()
|
||||||
|
config.DATASET.ROOT = '/data/coco2017/'
|
||||||
|
config.DATASET.TEST_SET = 'val2017'
|
||||||
|
config.DATASET.TRAIN_SET = 'train2017'
|
||||||
|
# data augmentation
|
||||||
|
config.DATASET.FLIP = True
|
||||||
|
config.DATASET.ROT_FACTOR = 40
|
||||||
|
config.DATASET.SCALE_FACTOR = 0.3
|
||||||
|
|
||||||
|
# for train
|
||||||
|
config.TRAIN = edict()
|
||||||
|
config.TRAIN.BATCH_SIZE = 64
|
||||||
|
config.TRAIN.BEGIN_EPOCH = 0
|
||||||
|
config.TRAIN.END_EPOCH = 140
|
||||||
|
config.TRAIN.LR = 0.001
|
||||||
|
config.TRAIN.LR_FACTOR = 0.1
|
||||||
|
config.TRAIN.LR_STEP = [90, 120]
|
||||||
|
|
||||||
|
# test
|
||||||
|
config.TEST = edict()
|
||||||
|
config.TEST.BATCH_SIZE = 32
|
||||||
|
config.TEST.FLIP_TEST = True
|
||||||
|
config.TEST.POST_PROCESS = True
|
||||||
|
config.TEST.SHIFT_HEATMAP = True
|
||||||
|
config.TEST.USE_GT_BBOX = False
|
||||||
|
config.TEST.MODEL_FILE = ''
|
||||||
|
config.TEST.COCO_BBOX_FILE = 'experiments/COCO_val2017_detections_AP_H_56_person.json'
|
||||||
|
# nms
|
||||||
|
config.TEST.OKS_THRE = 0.9
|
||||||
|
config.TEST.IN_VIS_THRE = 0.2
|
||||||
|
config.TEST.BBOX_THRE = 1.0
|
||||||
|
config.TEST.IMAGE_THRE = 0.0
|
||||||
|
config.TEST.NMS_THRE = 1.0
|
|
@ -0,0 +1,378 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
import mindspore.dataset as de
|
||||||
|
import mindspore.dataset.vision.c_transforms as V_C
|
||||||
|
|
||||||
|
from src.utils.transform import fliplr_joints, get_affine_transform, affine_transform
|
||||||
|
|
||||||
|
de.config.set_seed(1)
|
||||||
|
flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
|
||||||
|
[9, 10], [11, 12], [13, 14], [15, 16]]
|
||||||
|
|
||||||
|
|
||||||
|
class KeypointDatasetGenerator:
|
||||||
|
def __init__(self, cfg, is_train=False):
|
||||||
|
# config file
|
||||||
|
self.image_thre = cfg.TEST.IMAGE_THRE
|
||||||
|
self.image_size = np.array(cfg.MODEL.IMAGE_SIZE, dtype=np.int32)
|
||||||
|
self.image_width = cfg.MODEL.IMAGE_SIZE[0]
|
||||||
|
self.image_height = cfg.MODEL.IMAGE_SIZE[1]
|
||||||
|
self.aspect_ratio = self.image_width * 1.0 / self.image_height
|
||||||
|
self.heatmap_size = np.array(cfg.MODEL.EXTRA.HEATMAP_SIZE, dtype=np.int32)
|
||||||
|
self.sigma = cfg.MODEL.EXTRA.SIGMA
|
||||||
|
self.target_type = cfg.MODEL.EXTRA.TARGET_TYPE
|
||||||
|
|
||||||
|
# data argumentation
|
||||||
|
self.scale_factor = cfg.DATASET.SCALE_FACTOR
|
||||||
|
self.rotation_factor = cfg.DATASET.ROT_FACTOR
|
||||||
|
self.flip = cfg.DATASET.FLIP
|
||||||
|
|
||||||
|
# dataset informations
|
||||||
|
self.db = []
|
||||||
|
self.is_train = is_train
|
||||||
|
|
||||||
|
# for coco dataset
|
||||||
|
self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
|
||||||
|
[9, 10], [11, 12], [13, 14], [15, 16]]
|
||||||
|
self.num_joints = 17
|
||||||
|
|
||||||
|
def load_gt_dataset(self, image_path, ann_file):
|
||||||
|
# reset db
|
||||||
|
self.db = []
|
||||||
|
|
||||||
|
# load json file and decode
|
||||||
|
with open(ann_file, "rb") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
json_dict = json.loads(lines[0].decode("utf-8"))
|
||||||
|
|
||||||
|
# traversal all the ann items
|
||||||
|
objs = {}
|
||||||
|
cnt = 0
|
||||||
|
for item in json_dict['annotations']:
|
||||||
|
# exclude iscrowd and no-keypoint record
|
||||||
|
if item['iscrowd'] != 0 or item['num_keypoints'] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# assert the record is valid
|
||||||
|
assert item['iscrowd'] == 0, 'is crowd'
|
||||||
|
assert item['category_id'] == 1, 'is not people'
|
||||||
|
assert item['area'] > 0, 'area le 0'
|
||||||
|
assert item['num_keypoints'] > 0, 'has no keypoint'
|
||||||
|
assert max(item['keypoints']) > 0
|
||||||
|
|
||||||
|
image_id = item['image_id']
|
||||||
|
obj = [{'num_keypoints': item['num_keypoints'], 'keypoints': item['keypoints'], 'bbox': item['bbox']}]
|
||||||
|
objs[image_id] = obj if image_id not in objs else objs[image_id] + obj
|
||||||
|
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
print('loaded %d records from coco dataset.' % cnt)
|
||||||
|
|
||||||
|
# traversal all the image items
|
||||||
|
for item in json_dict['images']:
|
||||||
|
image_id = item['id']
|
||||||
|
width = item['width']
|
||||||
|
height = item['height']
|
||||||
|
# exclude image not in records
|
||||||
|
if image_id not in objs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# sanitize bboxes
|
||||||
|
valid_objs = []
|
||||||
|
for obj in objs[image_id]:
|
||||||
|
x, y, w, h = obj['bbox']
|
||||||
|
x1 = max(0, x)
|
||||||
|
y1 = max(0, y)
|
||||||
|
x2 = min(width - 1, x1 + max(0, w - 1))
|
||||||
|
y2 = min(height - 1, y1 + max(0, h - 1))
|
||||||
|
if x2 >= x1 and y2 >= y1:
|
||||||
|
tmp_obj = deepcopy(obj)
|
||||||
|
tmp_obj['bbox'] = np.array((x1, y1, x2, y2)) - np.array((0, 0, x1, y1))
|
||||||
|
valid_objs.append(tmp_obj)
|
||||||
|
else:
|
||||||
|
assert False, 'invalid bbox!'
|
||||||
|
# rewrite
|
||||||
|
objs[image_id] = valid_objs
|
||||||
|
|
||||||
|
for obj in objs[image_id]:
|
||||||
|
# keypoints
|
||||||
|
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
|
||||||
|
joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float)
|
||||||
|
for ipt in range(self.num_joints):
|
||||||
|
joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
|
||||||
|
joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
|
||||||
|
joints_3d[ipt, 2] = 0
|
||||||
|
t_vis = obj['keypoints'][ipt * 3 + 2]
|
||||||
|
if t_vis > 1:
|
||||||
|
t_vis = 1
|
||||||
|
joints_3d_vis[ipt, 0] = t_vis
|
||||||
|
joints_3d_vis[ipt, 1] = t_vis
|
||||||
|
joints_3d_vis[ipt, 2] = 0
|
||||||
|
|
||||||
|
scale, center = self._bbox2sc(obj['bbox'])
|
||||||
|
|
||||||
|
# reform and save
|
||||||
|
self.db.append({
|
||||||
|
'id': int(item['id']),
|
||||||
|
'image': os.path.join(image_path, item['file_name']),
|
||||||
|
'center': center,
|
||||||
|
'scale': scale,
|
||||||
|
'joints_3d': joints_3d,
|
||||||
|
'joints_3d_vis': joints_3d_vis,
|
||||||
|
})
|
||||||
|
|
||||||
|
def load_detect_dataset(self, image_path, ann_file, bbox_file):
|
||||||
|
# reset self.db
|
||||||
|
self.db = []
|
||||||
|
|
||||||
|
# open detect file
|
||||||
|
all_boxes = None
|
||||||
|
with open(bbox_file, 'r') as f:
|
||||||
|
all_boxes = json.load(f)
|
||||||
|
|
||||||
|
assert all_boxes, 'Loading %s fail!' % bbox_file
|
||||||
|
print('Total boxes: {}'.format(len(all_boxes)))
|
||||||
|
|
||||||
|
# load json file and decode
|
||||||
|
with open(ann_file, "rb") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
json_dict = json.loads(lines[0].decode("utf-8"))
|
||||||
|
|
||||||
|
# build a map from id to file name
|
||||||
|
index_to_filename = {}
|
||||||
|
for item in json_dict['images']:
|
||||||
|
index_to_filename[item['id']] = item['file_name']
|
||||||
|
|
||||||
|
# load each item into db
|
||||||
|
for det_res in all_boxes:
|
||||||
|
if det_res['category_id'] != 1:
|
||||||
|
continue
|
||||||
|
# load image
|
||||||
|
image = os.path.join(image_path,
|
||||||
|
index_to_filename[det_res['image_id']])
|
||||||
|
|
||||||
|
bbox = det_res['bbox']
|
||||||
|
score = det_res['score']
|
||||||
|
if score < self.image_thre:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scale, center = self._bbox2sc(bbox)
|
||||||
|
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
|
||||||
|
joints_3d_vis = np.ones((self.num_joints, 3), dtype=np.float)
|
||||||
|
|
||||||
|
self.db.append({
|
||||||
|
'id': int(det_res['image_id']),
|
||||||
|
'image': image,
|
||||||
|
'center': center,
|
||||||
|
'scale': scale,
|
||||||
|
'score': score,
|
||||||
|
'joints_3d': joints_3d,
|
||||||
|
'joints_3d_vis': joints_3d_vis,
|
||||||
|
})
|
||||||
|
|
||||||
|
def _bbox2sc(self, bbox):
|
||||||
|
"""
|
||||||
|
reform xywh to meet the need of aspect ratio
|
||||||
|
"""
|
||||||
|
x, y, w, h = bbox[:4]
|
||||||
|
center = np.zeros((2), dtype=np.float32)
|
||||||
|
center[0] = x + w * 0.5
|
||||||
|
center[1] = y + h * 0.5
|
||||||
|
|
||||||
|
if w > self.aspect_ratio * h:
|
||||||
|
h = w * 1.0 / self.aspect_ratio
|
||||||
|
elif w < self.aspect_ratio * h:
|
||||||
|
w = h * self.aspect_ratio
|
||||||
|
scale = np.array(
|
||||||
|
[w * 1.0 / 200, h * 1.0 / 200], dtype=np.float32)
|
||||||
|
if center[0] != -1:
|
||||||
|
scale = scale * 1.25
|
||||||
|
|
||||||
|
return scale, center
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
db_rec = deepcopy(self.db[idx])
|
||||||
|
|
||||||
|
image_file = db_rec['image']
|
||||||
|
|
||||||
|
data_numpy = cv2.imread(
|
||||||
|
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
|
||||||
|
|
||||||
|
if data_numpy is None:
|
||||||
|
print('[ERROR] fail to read {}'.format(image_file))
|
||||||
|
raise ValueError('Fail to read {}'.format(image_file))
|
||||||
|
|
||||||
|
joints = db_rec['joints_3d']
|
||||||
|
joints_vis = db_rec['joints_3d_vis']
|
||||||
|
|
||||||
|
c = db_rec['center']
|
||||||
|
s = db_rec['scale']
|
||||||
|
score = db_rec['score'] if 'score' in db_rec else 1
|
||||||
|
r = 0
|
||||||
|
|
||||||
|
if self.is_train:
|
||||||
|
sf = self.scale_factor
|
||||||
|
rf = self.rotation_factor
|
||||||
|
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
|
||||||
|
r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \
|
||||||
|
if random.random() <= 0.6 else 0
|
||||||
|
|
||||||
|
if self.flip and random.random() <= 0.5:
|
||||||
|
data_numpy = data_numpy[:, ::-1, :]
|
||||||
|
joints, joints_vis = fliplr_joints(
|
||||||
|
joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
|
||||||
|
c[0] = data_numpy.shape[1] - c[0] - 1
|
||||||
|
|
||||||
|
trans = get_affine_transform(c, s, r, self.image_size)
|
||||||
|
image = cv2.warpAffine(
|
||||||
|
data_numpy,
|
||||||
|
trans,
|
||||||
|
(int(self.image_size[0]), int(self.image_size[1])),
|
||||||
|
flags=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
for i in range(self.num_joints):
|
||||||
|
if joints_vis[i, 0] > 0.0:
|
||||||
|
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
|
||||||
|
|
||||||
|
target, target_weight = self.generate_heatmap(joints, joints_vis)
|
||||||
|
|
||||||
|
return image, target, target_weight, s, c, score, db_rec['id']
|
||||||
|
|
||||||
|
def generate_heatmap(self, joints, joints_vis):
|
||||||
|
target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
|
||||||
|
target_weight[:, 0] = joints_vis[:, 0]
|
||||||
|
|
||||||
|
assert self.target_type == 'gaussian', \
|
||||||
|
'Only support gaussian map now!'
|
||||||
|
|
||||||
|
if self.target_type == 'gaussian':
|
||||||
|
target = np.zeros((self.num_joints,
|
||||||
|
self.heatmap_size[1],
|
||||||
|
self.heatmap_size[0]),
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
tmp_size = self.sigma * 3
|
||||||
|
|
||||||
|
for joint_id in range(self.num_joints):
|
||||||
|
feat_stride = self.image_size / self.heatmap_size
|
||||||
|
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
|
||||||
|
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
|
||||||
|
# Check that any part of the gaussian is in-bounds
|
||||||
|
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
||||||
|
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
||||||
|
if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
|
||||||
|
or br[0] < 0 or br[1] < 0:
|
||||||
|
# If not, just return the image as is
|
||||||
|
target_weight[joint_id] = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# # Generate gaussian
|
||||||
|
size = 2 * tmp_size + 1
|
||||||
|
x = np.arange(0, size, 1, np.float32)
|
||||||
|
y = x[:, np.newaxis]
|
||||||
|
x0 = y0 = size // 2
|
||||||
|
# The gaussian is not normalized, we want the center value to equal 1
|
||||||
|
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))
|
||||||
|
|
||||||
|
# Usable gaussian range
|
||||||
|
g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
|
||||||
|
g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
|
||||||
|
# Image range
|
||||||
|
img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
|
||||||
|
img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
|
||||||
|
|
||||||
|
v = target_weight[joint_id]
|
||||||
|
if v > 0.5:
|
||||||
|
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
||||||
|
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
||||||
|
|
||||||
|
return target, target_weight
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.db)
|
||||||
|
|
||||||
|
|
||||||
|
def keypoint_dataset(config,
|
||||||
|
ann_file=None,
|
||||||
|
image_path=None,
|
||||||
|
bbox_file=None,
|
||||||
|
rank=0,
|
||||||
|
group_size=1,
|
||||||
|
train_mode=True,
|
||||||
|
num_parallel_workers=8,
|
||||||
|
transform=None,
|
||||||
|
shuffle=None):
|
||||||
|
"""
|
||||||
|
A function that returns an imagenet dataset for classification. The mode of input dataset should be "folder" .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rank (int): The shard ID within num_shards (default=None).
|
||||||
|
group_size (int): Number of shards that the dataset should be divided
|
||||||
|
into (default=None).
|
||||||
|
mode (str): "train" or others. Default: " train".
|
||||||
|
num_parallel_workers (int): Number of workers to read the data. Default: None.
|
||||||
|
"""
|
||||||
|
# config
|
||||||
|
per_batch_size = config.TRAIN.BATCH_SIZE if train_mode else config.TEST.BATCH_SIZE
|
||||||
|
image_path = image_path if image_path else os.path.join(config.DATASET.ROOT,
|
||||||
|
config.DATASET.TRAIN_SET
|
||||||
|
if train_mode else config.DATASET.TEST_SET)
|
||||||
|
print('loading dataset from {}'.format(image_path))
|
||||||
|
ann_file = ann_file if ann_file else os.path.join(config.DATASET.ROOT,
|
||||||
|
'annotations/person_keypoints_{}2017.json'.format(
|
||||||
|
'train' if train_mode else 'val'))
|
||||||
|
shuffle = shuffle if shuffle is not None else train_mode
|
||||||
|
# gen dataset db
|
||||||
|
dataset_generator = KeypointDatasetGenerator(config, is_train=train_mode)
|
||||||
|
|
||||||
|
if not train_mode and not config.TEST.USE_GT_BBOX:
|
||||||
|
print('loading bbox file from {}'.format(bbox_file))
|
||||||
|
dataset_generator.load_detect_dataset(image_path, ann_file, bbox_file)
|
||||||
|
else:
|
||||||
|
dataset_generator.load_gt_dataset(image_path, ann_file)
|
||||||
|
|
||||||
|
# construct dataset
|
||||||
|
de_dataset = de.GeneratorDataset(dataset_generator,
|
||||||
|
column_names=["image", "target", "weight", "scale", "center", "score", "id"],
|
||||||
|
num_parallel_workers=num_parallel_workers,
|
||||||
|
num_shards=group_size,
|
||||||
|
shard_id=rank,
|
||||||
|
shuffle=shuffle)
|
||||||
|
|
||||||
|
# inputs map functions
|
||||||
|
if transform is None:
|
||||||
|
transform_img = [
|
||||||
|
V_C.Rescale(1.0 / 255.0, 0.0),
|
||||||
|
V_C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
V_C.HWC2CHW()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
transform_img = transform
|
||||||
|
de_dataset = de_dataset.map(input_columns="image",
|
||||||
|
num_parallel_workers=num_parallel_workers,
|
||||||
|
operations=transform_img)
|
||||||
|
|
||||||
|
# batch
|
||||||
|
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=train_mode)
|
||||||
|
|
||||||
|
return de_dataset, dataset_generator
|
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from collections import defaultdict, OrderedDict
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
|
||||||
|
has_coco = True
|
||||||
|
except ImportError:
|
||||||
|
has_coco = False
|
||||||
|
|
||||||
|
from src.utils.nms import oks_nms
|
||||||
|
|
||||||
|
|
||||||
|
def _write_coco_keypoint_results(img_kpts, num_joints, res_file):
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for img, items in img_kpts.items():
|
||||||
|
item_size = len(items)
|
||||||
|
if not items:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# keypoints array at coco format
|
||||||
|
kpts = np.array([items[k]['keypoints']
|
||||||
|
for k in range(item_size)])
|
||||||
|
keypoints = np.zeros((item_size, num_joints * 3), dtype=np.float)
|
||||||
|
keypoints[:, 0::3] = kpts[:, :, 0]
|
||||||
|
keypoints[:, 1::3] = kpts[:, :, 1]
|
||||||
|
keypoints[:, 2::3] = kpts[:, :, 2]
|
||||||
|
|
||||||
|
result = [{'image_id': int(img),
|
||||||
|
'keypoints': list(keypoints[k]),
|
||||||
|
'score': items[k]['score'],
|
||||||
|
'category_id': 1,
|
||||||
|
} for k in range(item_size)]
|
||||||
|
results.extend(result)
|
||||||
|
|
||||||
|
with open(res_file, 'w') as f:
|
||||||
|
json.dump(results, f, sort_keys=True, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def _do_python_keypoint_eval(res_file, res_folder, ann_path):
|
||||||
|
coco = COCO(ann_path)
|
||||||
|
coco_dt = coco.loadRes(res_file)
|
||||||
|
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
|
||||||
|
coco_eval.params.useSegm = None
|
||||||
|
coco_eval.evaluate()
|
||||||
|
coco_eval.accumulate()
|
||||||
|
coco_eval.summarize()
|
||||||
|
stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
|
||||||
|
|
||||||
|
info_str = []
|
||||||
|
for ind, name in enumerate(stats_names):
|
||||||
|
info_str.append((name, coco_eval.stats[ind]))
|
||||||
|
|
||||||
|
eval_file = os.path.join(
|
||||||
|
res_folder, 'keypoints_results.pkl')
|
||||||
|
|
||||||
|
with open(eval_file, 'wb') as f:
|
||||||
|
pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
print('coco eval results saved to %s' % eval_file)
|
||||||
|
|
||||||
|
return info_str
|
||||||
|
|
||||||
|
|
||||||
|
# need double check this API and classes field
|
||||||
|
def evaluate(cfg, preds, output_dir, all_boxes, img_id):
|
||||||
|
res_folder = os.path.join(output_dir, 'results')
|
||||||
|
if not os.path.exists(res_folder):
|
||||||
|
os.makedirs(res_folder)
|
||||||
|
res_file = os.path.join(res_folder, 'keypoints_results.json')
|
||||||
|
# image -> list(keypoints/area/score)
|
||||||
|
img_kpts_dict = defaultdict(list)
|
||||||
|
for idx, file_id in enumerate(img_id):
|
||||||
|
img_kpts_dict[file_id].append({
|
||||||
|
'keypoints': preds[idx],
|
||||||
|
'area': all_boxes[idx][0],
|
||||||
|
'score': all_boxes[idx][1],
|
||||||
|
})
|
||||||
|
|
||||||
|
# rescoring and oks nms
|
||||||
|
num_joints = cfg.MODEL.NUM_JOINTS
|
||||||
|
in_vis_thre = cfg.TEST.IN_VIS_THRE
|
||||||
|
oks_thre = cfg.TEST.OKS_THRE
|
||||||
|
oks_nmsed_kpts = {}
|
||||||
|
for img, items in img_kpts_dict.items():
|
||||||
|
for item in items:
|
||||||
|
kpt_score = 0
|
||||||
|
valid_num = 0
|
||||||
|
for n_jt in range(num_joints):
|
||||||
|
max_jt = item['keypoints'][n_jt][2]
|
||||||
|
if max_jt > in_vis_thre:
|
||||||
|
kpt_score = kpt_score + max_jt
|
||||||
|
valid_num = valid_num + 1
|
||||||
|
if valid_num != 0:
|
||||||
|
kpt_score = kpt_score / valid_num
|
||||||
|
# rescoring
|
||||||
|
item['score'] = kpt_score * item['score']
|
||||||
|
keep = oks_nms(items, oks_thre)
|
||||||
|
if not keep:
|
||||||
|
oks_nmsed_kpts[img] = items
|
||||||
|
else:
|
||||||
|
oks_nmsed_kpts[img] = [items[kep] for kep in keep]
|
||||||
|
|
||||||
|
# evaluate and save
|
||||||
|
image_set = cfg.DATASET.TEST_SET
|
||||||
|
_write_coco_keypoint_results(oks_nmsed_kpts, num_joints, res_file)
|
||||||
|
if 'test' not in image_set and has_coco:
|
||||||
|
ann_path = os.path.join(cfg.DATASET.ROOT, 'annotations',
|
||||||
|
'person_keypoints_' + image_set + '.json')
|
||||||
|
info_str = _do_python_keypoint_eval(
|
||||||
|
res_file, res_folder, ann_path)
|
||||||
|
name_value = OrderedDict(info_str)
|
||||||
|
return name_value, name_value['AP']
|
||||||
|
return {'Null': 0}, 0
|
|
@ -0,0 +1,225 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops.operations as F
|
||||||
|
from mindspore.common.initializer import Normal
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore import ParameterTuple
|
||||||
|
|
||||||
|
BN_MOMENTUM = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool2dPytorch(nn.Cell):
|
||||||
|
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
|
||||||
|
super(MaxPool2dPytorch, self).__init__()
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode)
|
||||||
|
self.reverse = F.ReverseV2(axis=[2, 3])
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.reverse(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
x = self.reverse(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Cell):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, has_bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||||
|
pad_mode='pad', padding=1, has_bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
||||||
|
has_bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
||||||
|
momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.down_sample_layer = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.down_sample_layer is not None:
|
||||||
|
residual = self.down_sample_layer(x)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PoseResNet(nn.Cell):
|
||||||
|
|
||||||
|
def __init__(self, block, layers, cfg, pytorch_mode=True):
|
||||||
|
self.inplanes = 64
|
||||||
|
extra = cfg.MODEL.EXTRA
|
||||||
|
self.deconv_with_bias = extra.DECONV_WITH_BIAS
|
||||||
|
|
||||||
|
super(PoseResNet, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2,
|
||||||
|
pad_mode='pad', padding=3, has_bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
if pytorch_mode:
|
||||||
|
self.maxpool = MaxPool2dPytorch(kernel_size=3, stride=2, pad_mode='same')
|
||||||
|
print("use pytorch-style maxpool")
|
||||||
|
else:
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||||
|
print("use mindspore-style maxpool")
|
||||||
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||||
|
|
||||||
|
# used for deconv layers
|
||||||
|
self.deconv_layers = self._make_deconv_layer(
|
||||||
|
extra.NUM_DECONV_LAYERS,
|
||||||
|
extra.NUM_DECONV_FILTERS,
|
||||||
|
extra.NUM_DECONV_KERNELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = nn.Conv2d(
|
||||||
|
in_channels=extra.NUM_DECONV_FILTERS[-1],
|
||||||
|
out_channels=cfg.MODEL.NUM_JOINTS,
|
||||||
|
kernel_size=extra.FINAL_CONV_KERNEL,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='pad',
|
||||||
|
padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0,
|
||||||
|
has_bias=True,
|
||||||
|
weight_init=Normal(0.001),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.SequentialCell(OrderedDict([
|
||||||
|
('0', nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, has_bias=False)),
|
||||||
|
('1', nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM)),
|
||||||
|
]))
|
||||||
|
|
||||||
|
layers = OrderedDict()
|
||||||
|
layers['0'] = block(self.inplanes, planes, stride, downsample)
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for i in range(1, blocks):
|
||||||
|
layers['{}'.format(i)] = block(self.inplanes, planes)
|
||||||
|
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
def _get_deconv_cfg(self, deconv_kernel):
|
||||||
|
assert deconv_kernel == 4, 'only support kernel_size = 4 for deconvolution layers'
|
||||||
|
if deconv_kernel == 4:
|
||||||
|
padding = 1
|
||||||
|
output_padding = 0
|
||||||
|
elif deconv_kernel == 3:
|
||||||
|
padding = 1
|
||||||
|
output_padding = 1
|
||||||
|
elif deconv_kernel == 2:
|
||||||
|
padding = 0
|
||||||
|
output_padding = 0
|
||||||
|
|
||||||
|
return deconv_kernel, padding, output_padding
|
||||||
|
|
||||||
|
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
||||||
|
assert num_layers == len(num_filters), \
|
||||||
|
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
||||||
|
assert num_layers == len(num_kernels), \
|
||||||
|
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
||||||
|
|
||||||
|
layers = OrderedDict()
|
||||||
|
for i in range(num_layers):
|
||||||
|
kernel, padding, _ = \
|
||||||
|
self._get_deconv_cfg(num_kernels[i])
|
||||||
|
|
||||||
|
planes = num_filters[i]
|
||||||
|
layers['deconv_{}'.format(i)] = nn.SequentialCell(OrderedDict([
|
||||||
|
('deconv', nn.Conv2dTranspose(
|
||||||
|
in_channels=self.inplanes,
|
||||||
|
out_channels=planes,
|
||||||
|
kernel_size=kernel,
|
||||||
|
stride=2,
|
||||||
|
pad_mode='pad',
|
||||||
|
padding=padding,
|
||||||
|
has_bias=self.deconv_with_bias,
|
||||||
|
weight_init=Normal(0.001),
|
||||||
|
)),
|
||||||
|
('bn', nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)),
|
||||||
|
('relu', nn.ReLU()),
|
||||||
|
]))
|
||||||
|
self.inplanes = planes
|
||||||
|
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.deconv_layers(x)
|
||||||
|
x = self.final_layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=''):
|
||||||
|
if os.path.isfile(pretrained):
|
||||||
|
# load params from pretrained
|
||||||
|
param_dict = load_checkpoint(pretrained)
|
||||||
|
weight = ParameterTuple(self.trainable_params())
|
||||||
|
for w in weight:
|
||||||
|
if w.name.split('.')[0] not in ('deconv_layers', 'final_layer'):
|
||||||
|
assert w.name in param_dict, "parameter %s not in checkpoint" % w.name
|
||||||
|
load_param_into_net(self, param_dict)
|
||||||
|
print('loading pretrained model {}'.format(pretrained))
|
||||||
|
else:
|
||||||
|
assert False, '{} is not a file'.format(pretrained)
|
||||||
|
|
||||||
|
|
||||||
|
resnet_spec = {50: (Bottleneck, [3, 4, 6, 3]),
|
||||||
|
101: (Bottleneck, [3, 4, 23, 3]),
|
||||||
|
152: (Bottleneck, [3, 8, 36, 3])}
|
||||||
|
|
||||||
|
|
||||||
|
def get_pose_net(cfg, is_train, ckpt_path=None, pytorch_mode=False):
|
||||||
|
num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
|
||||||
|
|
||||||
|
block_class, layers = resnet_spec[num_layers]
|
||||||
|
model = PoseResNet(block_class, layers, cfg, pytorch_mode=pytorch_mode)
|
||||||
|
|
||||||
|
if is_train and cfg.MODEL.INIT_WEIGHTS:
|
||||||
|
model.init_weights(ckpt_path)
|
||||||
|
|
||||||
|
return model
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
class JointsMSELoss(_Loss):
|
||||||
|
def __init__(self, use_target_weight):
|
||||||
|
super(JointsMSELoss, self).__init__()
|
||||||
|
self.criterion = nn.MSELoss(reduction='mean')
|
||||||
|
self.use_target_weight = use_target_weight
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.squeeze = P.Squeeze(1)
|
||||||
|
self.mul = P.Mul()
|
||||||
|
|
||||||
|
def construct(self, output, target, target_weight):
|
||||||
|
batch_size = F.shape(output)[0]
|
||||||
|
num_joints = F.shape(output)[1]
|
||||||
|
|
||||||
|
split = P.Split(1, num_joints)
|
||||||
|
heatmaps_pred = self.reshape(output, (batch_size, num_joints, -1))
|
||||||
|
heatmaps_pred = split(heatmaps_pred)
|
||||||
|
|
||||||
|
heatmaps_gt = self.reshape(target, (batch_size, num_joints, -1))
|
||||||
|
heatmaps_gt = split(heatmaps_gt)
|
||||||
|
loss = 0
|
||||||
|
for idx in range(num_joints):
|
||||||
|
heatmap_pred = self.squeeze(heatmaps_pred[idx])
|
||||||
|
heatmap_gt = self.squeeze(heatmaps_gt[idx])
|
||||||
|
if self.use_target_weight:
|
||||||
|
loss += 0.5 * self.criterion(
|
||||||
|
self.mul(heatmap_pred, target_weight[:, idx]),
|
||||||
|
self.mul(heatmap_gt, target_weight[:, idx])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
|
||||||
|
return loss / num_joints
|
||||||
|
|
||||||
|
|
||||||
|
class WithLossCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Wrap the network with loss function to compute loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backbone (Cell): The target network to wrap.
|
||||||
|
loss_fn (Cell): The loss function used to compute loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backbone, loss_fn):
|
||||||
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||||
|
self._backbone = backbone
|
||||||
|
self._loss_fn = loss_fn
|
||||||
|
|
||||||
|
def construct(self, image, target, weight, scale=None,
|
||||||
|
center=None, score=None, idx=None):
|
||||||
|
out = self._backbone(image)
|
||||||
|
output = F.mixed_precision_cast(mstype.float32, out)
|
||||||
|
target = F.mixed_precision_cast(mstype.float32, target)
|
||||||
|
weight = F.mixed_precision_cast(mstype.float32, weight)
|
||||||
|
return self._loss_fn(output, target, weight)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backbone_network(self):
|
||||||
|
"""
|
||||||
|
Get the backbone network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cell, return backbone network.
|
||||||
|
"""
|
||||||
|
return self._backbone
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.utils.transform import transform_preds
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_preds(batch_heatmaps):
|
||||||
|
'''
|
||||||
|
get predictions from score maps
|
||||||
|
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
|
||||||
|
'''
|
||||||
|
assert isinstance(batch_heatmaps, np.ndarray), \
|
||||||
|
'batch_heatmaps should be numpy.ndarray'
|
||||||
|
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
|
||||||
|
|
||||||
|
batch_size = batch_heatmaps.shape[0]
|
||||||
|
num_joints = batch_heatmaps.shape[1]
|
||||||
|
width = batch_heatmaps.shape[3]
|
||||||
|
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
|
||||||
|
idx = np.argmax(heatmaps_reshaped, 2)
|
||||||
|
maxvals = np.amax(heatmaps_reshaped, 2)
|
||||||
|
|
||||||
|
maxvals = maxvals.reshape((batch_size, num_joints, 1))
|
||||||
|
idx = idx.reshape((batch_size, num_joints, 1))
|
||||||
|
|
||||||
|
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
|
||||||
|
|
||||||
|
preds[:, :, 0] = (preds[:, :, 0]) % width
|
||||||
|
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
|
||||||
|
|
||||||
|
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
|
||||||
|
pred_mask = pred_mask.astype(np.float32)
|
||||||
|
|
||||||
|
preds *= pred_mask
|
||||||
|
return preds, maxvals
|
||||||
|
|
||||||
|
|
||||||
|
def get_final_preds(config, batch_heatmaps, center, scale):
|
||||||
|
coords, maxvals = get_max_preds(batch_heatmaps)
|
||||||
|
|
||||||
|
heatmap_height = batch_heatmaps.shape[2]
|
||||||
|
heatmap_width = batch_heatmaps.shape[3]
|
||||||
|
|
||||||
|
# post-processing
|
||||||
|
if config.TEST.POST_PROCESS:
|
||||||
|
for n in range(coords.shape[0]):
|
||||||
|
for p in range(coords.shape[1]):
|
||||||
|
hm = batch_heatmaps[n][p]
|
||||||
|
px = int(math.floor(coords[n][p][0] + 0.5))
|
||||||
|
py = int(math.floor(coords[n][p][1] + 0.5))
|
||||||
|
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
|
||||||
|
diff = np.array([hm[py][px + 1] - hm[py][px - 1],
|
||||||
|
hm[py + 1][px] - hm[py - 1][px]])
|
||||||
|
coords[n][p] += np.sign(diff) * .25
|
||||||
|
|
||||||
|
preds = coords.copy()
|
||||||
|
|
||||||
|
# Transform back
|
||||||
|
for i in range(coords.shape[0]):
|
||||||
|
preds[i] = transform_preds(coords[i], center[i], scale[i],
|
||||||
|
[heatmap_width, heatmap_height])
|
||||||
|
|
||||||
|
return preds, maxvals
|
|
@ -0,0 +1,55 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
|
||||||
|
if not isinstance(sigmas, np.ndarray):
|
||||||
|
sigmas = np.array(
|
||||||
|
[.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||||
|
vas = (sigmas * 2) ** 2
|
||||||
|
xg = g[0::3]
|
||||||
|
yg = g[1::3]
|
||||||
|
vg = g[2::3]
|
||||||
|
ious = np.zeros((d.shape[0]))
|
||||||
|
for n_d in range(0, d.shape[0]):
|
||||||
|
xd = d[n_d, 0::3]
|
||||||
|
yd = d[n_d, 1::3]
|
||||||
|
vd = d[n_d, 2::3]
|
||||||
|
dx = xd - xg
|
||||||
|
dy = yd - yg
|
||||||
|
e = (dx ** 2 + dy ** 2) / vas / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
|
||||||
|
if in_vis_thre is not None:
|
||||||
|
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
|
||||||
|
e = e[ind]
|
||||||
|
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
|
||||||
|
return ious
|
||||||
|
|
||||||
|
|
||||||
|
def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
|
||||||
|
"""
|
||||||
|
greedily select boxes with high confidence and overlap with current maximum <= thresh
|
||||||
|
rule out overlap >= thresh, overlap = oks
|
||||||
|
:param kpts_db
|
||||||
|
:param thresh: retain overlap < thresh
|
||||||
|
:return: indexes to keep
|
||||||
|
"""
|
||||||
|
kpts_size = len(kpts_db)
|
||||||
|
if kpts_size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scores = np.array([kpts_db[i]['score'] for i in range(kpts_size)])
|
||||||
|
kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(kpts_size)])
|
||||||
|
areas = np.array([kpts_db[i]['area'] for i in range(kpts_size)])
|
||||||
|
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
|
||||||
|
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
|
||||||
|
|
||||||
|
inds = np.where(oks_ovr <= thresh)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
|
||||||
|
return keep
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
def fliplr_joints(joints, joints_vis, width, matched_parts):
|
||||||
|
"""
|
||||||
|
flip coords
|
||||||
|
"""
|
||||||
|
# Flip horizontal
|
||||||
|
joints[:, 0] = width - joints[:, 0] - 1
|
||||||
|
|
||||||
|
# Change left-right parts
|
||||||
|
for pair in matched_parts:
|
||||||
|
joints[pair[0], :], joints[pair[1], :] = \
|
||||||
|
joints[pair[1], :], joints[pair[0], :].copy()
|
||||||
|
joints_vis[pair[0]], joints_vis[pair[1]] = \
|
||||||
|
joints_vis[pair[1]], joints_vis[pair[0]].copy()
|
||||||
|
|
||||||
|
return joints * joints_vis, joints_vis
|
||||||
|
|
||||||
|
|
||||||
|
def flip_back(output_flipped, matched_parts):
|
||||||
|
'''
|
||||||
|
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
|
||||||
|
'''
|
||||||
|
assert output_flipped.ndim == 4, \
|
||||||
|
'output_flipped should be [batch_size, num_joints, height, width]'
|
||||||
|
|
||||||
|
output_flipped = output_flipped[:, :, :, ::-1]
|
||||||
|
|
||||||
|
for pair in matched_parts:
|
||||||
|
tmp = output_flipped[:, pair[0], :, :].copy()
|
||||||
|
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
|
||||||
|
output_flipped[:, pair[1], :, :] = tmp
|
||||||
|
|
||||||
|
return output_flipped
|
||||||
|
|
||||||
|
|
||||||
|
def transform_preds(coords, center, scale, output_size):
|
||||||
|
target_coords = np.zeros(coords.shape)
|
||||||
|
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
|
||||||
|
for p in range(coords.shape[0]):
|
||||||
|
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
||||||
|
return target_coords
|
||||||
|
|
||||||
|
|
||||||
|
def get_affine_transform(center,
|
||||||
|
scale,
|
||||||
|
rot,
|
||||||
|
output_size,
|
||||||
|
shift=np.array([0, 0], dtype=np.float32),
|
||||||
|
inv=0):
|
||||||
|
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
||||||
|
print(scale)
|
||||||
|
scale = np.array([scale, scale])
|
||||||
|
|
||||||
|
scale_tmp = scale * 200.0
|
||||||
|
src_w = scale_tmp[0]
|
||||||
|
dst_w = output_size[0]
|
||||||
|
dst_h = output_size[1]
|
||||||
|
|
||||||
|
rot_rad = np.pi * rot / 180
|
||||||
|
src_dir = _get_dir([0, src_w * -0.5], rot_rad)
|
||||||
|
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
||||||
|
|
||||||
|
src = np.zeros((3, 2), dtype=np.float32)
|
||||||
|
dst = np.zeros((3, 2), dtype=np.float32)
|
||||||
|
src[0, :] = center + scale_tmp * shift
|
||||||
|
src[1, :] = center + src_dir + scale_tmp * shift
|
||||||
|
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||||
|
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||||
|
|
||||||
|
src[2:, :] = _get_3rd_point(src[0, :], src[1, :])
|
||||||
|
dst[2:, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
||||||
|
|
||||||
|
if inv:
|
||||||
|
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||||
|
else:
|
||||||
|
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||||
|
|
||||||
|
return trans
|
||||||
|
|
||||||
|
|
||||||
|
def affine_transform(pt, t):
|
||||||
|
new_pt = np.array([pt[0], pt[1], 1.]).T
|
||||||
|
new_pt = np.dot(t, new_pt)
|
||||||
|
return new_pt[:2]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_3rd_point(a, b):
|
||||||
|
direct = a - b
|
||||||
|
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dir(src_point, rot_rad):
|
||||||
|
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||||
|
|
||||||
|
src_result = [0, 0]
|
||||||
|
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
||||||
|
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
||||||
|
|
||||||
|
return src_result
|
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.communication.management import init, get_group_size, get_rank
|
||||||
|
from mindspore.train import Model
|
||||||
|
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
|
||||||
|
from mindspore.nn.optim import Adam
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
from src.model import get_pose_net
|
||||||
|
from src.network_define import JointsMSELoss, WithLossCell
|
||||||
|
from src.dataset import keypoint_dataset
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(begin_epoch,
|
||||||
|
total_epochs,
|
||||||
|
steps_per_epoch,
|
||||||
|
lr_init=0.1,
|
||||||
|
factor=0.1,
|
||||||
|
epoch_number_to_drop=(90, 120)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate learning rate array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_epoch (int): Initial epoch of training.
|
||||||
|
total_epochs (int): Total epoch of training.
|
||||||
|
steps_per_epoch (float): Steps of one epoch.
|
||||||
|
lr_init (float): Initial learning rate. Default: 0.316.
|
||||||
|
factor:Factor of lr to drop.
|
||||||
|
epoch_number_to_drop:Learing rate will drop after these epochs.
|
||||||
|
Returns:
|
||||||
|
np.array, learning rate array.
|
||||||
|
"""
|
||||||
|
lr_each_step = []
|
||||||
|
total_steps = steps_per_epoch * total_epochs
|
||||||
|
step_number_to_drop = [steps_per_epoch * x for x in epoch_number_to_drop]
|
||||||
|
for i in range(int(total_steps)):
|
||||||
|
if i in step_number_to_drop:
|
||||||
|
lr_init = lr_init * factor
|
||||||
|
lr_each_step.append(lr_init)
|
||||||
|
current_step = steps_per_epoch * begin_epoch
|
||||||
|
lr_each_step = np.array(lr_each_step, dtype=np.float32)
|
||||||
|
learning_rate = lr_each_step[current_step:]
|
||||||
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simpleposenet training")
|
||||||
|
parser.add_argument("--run-distribute",
|
||||||
|
help="Run distribute, default is false.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--ckpt-path', type=str, help='ckpt path to save')
|
||||||
|
parser.add_argument('--batch-size', type=int, help='training batch size')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# load parse and config
|
||||||
|
print("loading parse...")
|
||||||
|
args = parse_args()
|
||||||
|
if args.batch_size:
|
||||||
|
config.TRAIN.BATCH_SIZE = args.batch_size
|
||||||
|
print('batch size :{}'.format(config.TRAIN.BATCH_SIZE))
|
||||||
|
|
||||||
|
# distribution and context
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Ascend",
|
||||||
|
save_graphs=False,
|
||||||
|
device_id=device_id)
|
||||||
|
|
||||||
|
if args.run_distribute:
|
||||||
|
init()
|
||||||
|
rank = get_rank()
|
||||||
|
device_num = get_group_size()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num,
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
else:
|
||||||
|
rank = 0
|
||||||
|
device_num = 1
|
||||||
|
|
||||||
|
# only rank = 0 can write
|
||||||
|
rank_save_flag = False
|
||||||
|
if rank == 0 or device_num == 1:
|
||||||
|
rank_save_flag = True
|
||||||
|
|
||||||
|
# create dataset
|
||||||
|
dataset, _ = keypoint_dataset(config,
|
||||||
|
rank=rank,
|
||||||
|
group_size=device_num,
|
||||||
|
train_mode=True,
|
||||||
|
num_parallel_workers=8)
|
||||||
|
|
||||||
|
# network
|
||||||
|
net = get_pose_net(config, True, ckpt_path=config.MODEL.PRETRAINED)
|
||||||
|
loss = JointsMSELoss(use_target_weight=True)
|
||||||
|
net_with_loss = WithLossCell(net, loss)
|
||||||
|
|
||||||
|
# lr schedule and optim
|
||||||
|
dataset_size = dataset.get_dataset_size()
|
||||||
|
lr = Tensor(get_lr(config.TRAIN.BEGIN_EPOCH,
|
||||||
|
config.TRAIN.END_EPOCH,
|
||||||
|
dataset_size,
|
||||||
|
lr_init=config.TRAIN.LR,
|
||||||
|
factor=config.TRAIN.LR_FACTOR,
|
||||||
|
epoch_number_to_drop=config.TRAIN.LR_STEP))
|
||||||
|
opt = Adam(net.trainable_params(), learning_rate=lr)
|
||||||
|
|
||||||
|
# callback
|
||||||
|
time_cb = TimeMonitor(data_size=dataset_size)
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
cb = [time_cb, loss_cb]
|
||||||
|
if args.ckpt_path and rank_save_flag:
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=20)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix="simplepose", directory=args.ckpt_path, config=config_ck)
|
||||||
|
cb.append(ckpoint_cb)
|
||||||
|
# train model
|
||||||
|
model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
|
||||||
|
epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
|
||||||
|
print('start training, epoch size = %d' % epoch_size)
|
||||||
|
model.train(epoch_size, dataset, callbacks=cb)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue