add simple_baaselines

This commit is contained in:
tongzh 2021-04-01 23:33:23 +08:00 committed by root
parent f0016f5574
commit 99e9bd13bb
15 changed files with 1933 additions and 0 deletions

View File

@ -0,0 +1,263 @@
# 目录
<!-- TOC -->
- [simple_baselines描述](#simple_baselines描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [评估过程](#评估过程)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#ModelZoo主页)
<!-- /TOC -->
# simple baselines描述
## 概述
simple_baselines模型网络由微软亚洲研究院Bin Xiao等人提出作者认为当前流行的人体姿态估计和追踪方法都过于复杂已有的关于人体姿势估计和姿势追踪模型在结构上看似差异较大但在性能方面确又接近。作者提出了一种简单有效的基线方法通过在主干网络ResNet上添加反卷积层这恰恰是从高和低分辨率特征图中估计热图的最简单方法从而有助于激发和评估该领域的新想法。
simple_baselines模型网络具体细节可参考[论文1](https://arxiv.org/pdf/1804.06208.pdf)simple_baselines模型网络Mindspore实现基于原微软亚洲研究院发布的Pytorch版本实现具体可参考(<https://github.com/microsoft/human-pose-estimation.pytorch>)。
## 论文
1. [论文](https://arxiv.org/pdf/1804.06208.pdf)Bin Xiao, Haiping Wu, Yichen Wei."Simple baselines for human pose estimation and tracking"
# 模型架构
simple_baselines的总体网络架构如下
[链接](https://arxiv.org/pdf/1804.06208.pdf)
# 数据集
使用的数据集:[COCO2017]
- 数据集大小:
- 训练集19.56G, 118,287个图像
- 测试集825MB, 5,000个图像
- 数据格式JPG文件
- 注数据在src/dataset.py中处理
# 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# 环境要求
- 硬件(Ascend)
- 准备Ascend处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- 预训练模型
当开始训练之前需要获取mindspore图像网络预训练模型可通过在[official model zoo](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet)中运行Resnet训练脚本来获取模型权重文件预训练文件名称为resnet50.ckpt。
- 数据集准备
simple_baselines网络模型使用COCO2017数据集用于训练和推理数据集可通过[official website](https://cocodataset.org/)官方网站下载使用。
- Ascend处理器环境运行
```text
# 分布式训练
用法sh run_distribute_train.sh --is_model_arts False --run_distribute True
# 单机训练
用法sh run_standalone_train.sh --device_id 0 --is_model_arts False --run_distribute False
# 运行评估示例
用法sh run_eval.sh
```
# 脚本说明
## 脚本及样例代码
```shell
└──simple_baselines
├── README.md
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练8卡
├── run_eval.sh # 启动Ascend评估
├── run_standalone_train.sh # 启动Ascend单机训练单卡
├── src
├── utils
├── coco.py # COCO数据集评估结果
├── inference.py # 热图关键点预测
├── nms.py # nms
├── transforms.py # 图像处理转换
├── config.py # 参数配置
├── dataset.py # 数据预处理
├── network_with_loss.py # 损失函数定义
└── pose_resnet.py # 主干网络定义
├── eval.py # 评估网络
└── train.py # 训练网络
```
## 脚本参数
在src/config.py中配置相关参数。
- 配置模型相关参数:
```python
config.MODEL.INIT_WEIGHTS = True # 初始化模型权重
config.MODEL.PRETRAINED = 'resnet50.ckpt' # 预训练模型
config.MODEL.NUM_JOINTS = 17 # 关键点数量
config.MODEL.IMAGE_SIZE = [192, 256] # 图像大小
```
- 配置网络相关参数:
```python
config.NETWORK.NUM_LAYERS = 50 # resnet主干网络层数
config.NETWORK.DECONV_WITH_BIAS = False # 网络反卷积偏差
config.NETWORK.NUM_DECONV_LAYERS = 3 # 网络反卷积层数
config.NETWORK.NUM_DECONV_FILTERS = [256, 256, 256] # 反卷积层过滤器尺寸
config.NETWORK.NUM_DECONV_KERNELS = [4, 4, 4] # 反卷积层内核大小
config.NETWORK.FINAL_CONV_KERNEL = 1 # 最终卷积层内核大小
config.NETWORK.HEATMAP_SIZE = [48, 64] # 热图尺寸
```
- 配置训练相关参数:
```python
config.TRAIN.SHUFFLE = True # 训练数据随机排序
config.TRAIN.BATCH_SIZE = 64 # 训练批次大小
config.TRAIN.BEGIN_EPOCH = 0 # 测试数据集文件名
config.DATASET.FLIP = True # 数据集随机翻转
config.DATASET.SCALE_FACTOR = 0.3 # 数据集随机规模因数
config.DATASET.ROT_FACTOR = 40 # 数据集随机旋转因数
config.TRAIN.BEGIN_EPOCH = 0 # 初始周期数
config.TRAIN.END_EPOCH = 140 # 最终周期数
config.TRAIN.LR = 0.001 # 初始学习率
config.TRAIN.LR_FACTOR = 0.1 # 学习率降低因子
```
- 配置验证相关参数:
```python
config.TEST.BATCH_SIZE = 32 # 验证批次大小
config.TEST.FLIP_TEST = True # 翻转验证
config.TEST.USE_GT_BBOX = False # 使用标注框
```
- 配置nms相关参数
```python
config.TEST.OKS_THRE = 0.9 # OKS阈值
config.TEST.IN_VIS_THRE = 0.2 # 可视化阈值
config.TEST.BBOX_THRE = 1.0 # 候选框阈值
config.TEST.IMAGE_THRE = 0.0 # 图像阈值
config.TEST.NMS_THRE = 1.0 # nms阈值
```
## 训练过程
### 用法
#### Ascend处理器环境运行
```text
# 分布式训练
用法sh run_distribute_train.sh --is_model_arts False --run_distribute True
# 单机训练
用法sh run_standalone_train.sh --device_id 0 --is_model_arts False --run_distribute False
# 运行评估示例
用法sh run_eval.sh
```
### 结果
- 使用COCO2017数据集训练simple_baselines
```text
分布式训练结果8P
epoch:1 step:2340, loss is 0.0008106
epoch:2 step:2340, loss is 0.0006160
epoch:3 step:2340, loss is 0.0006480
epoch:4 step:2340, loss is 0.0005620
epoch:5 step:2340, loss is 0.0005207
...
epoch:138 step:2340, loss is 0.0003183
epoch:139 step:2340, loss is 0.0002866
epoch:140 step:2340, loss is 0.0003393
```
## 评估过程
### 用法
#### Ascend处理器环境运行
可通过改变config.py文件中的"config.TEST.MODEL_FILE"文件进行相应模型推理。
```bash
# 评估
sh eval.sh
```
### 结果
使用COCO2017数据集文件夹中val2017进行评估simple_baselines,如下所示:
```text
coco eval results saved to /cache/train_output/multi_train_poseresnet_v5_2-140_2340/keypoints_results.pkl
AP: 0.704
```
# 模型描述
## 性能
### 评估性能
#### COCO2017上性能参数
| Parameters | Ascend 910 |
| ------------------- | --------------------------- |
| 模型版本 | simple_baselines |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 | 2021-03-29 |
| MindSpore版本 | 1.1.0 |
| 数据集 | COCO2017 |
| 训练参数 | epoch=140, batch_size=64 |
| 优化器 | Adam |
| 损失函数 | Mean Squared Error |
| 输出 | heatmap |
| 输出 | heatmap |
| 速度 | 1pc: 251.4 ms/step |
| 训练性能 | AP: 0.704 |
# 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子同时在model.py中使用了初始化网络权重。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
This file evaluates the model used.
'''
from __future__ import division
import argparse
import os
import time
import numpy as np
from mindspore import Tensor, float32, context
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config
from src.pose_resnet import GetPoseResNet
from src.dataset import flip_pairs
from src.dataset import CreateDatasetCoco
from src.utils.coco import evaluate
from src.utils.transforms import flip_back
from src.utils.inference import get_final_preds
if config.MODELARTS.IS_MODEL_ARTS:
import moxing as mox
set_seed(config.GENERAL.EVAL_SEED)
device_id = int(os.getenv('DEVICE_ID'))
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate')
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
parser.add_argument('--train_url', required=True, default=None, help='Location of evaluate outputs.')
args = parser.parse_args()
return args
def validate(cfg, val_dataset, model, output_dir, ann_path):
'''
validate
'''
model.set_train(False)
num_samples = val_dataset.get_dataset_size() * cfg.TEST.BATCH_SIZE
all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3),
dtype=np.float32)
all_boxes = np.zeros((num_samples, 2))
image_id = []
idx = 0
start = time.time()
for item in val_dataset.create_dict_iterator():
inputs = item['image'].asnumpy()
output = model(Tensor(inputs, float32)).asnumpy()
if cfg.TEST.FLIP_TEST:
inputs_flipped = Tensor(inputs[:, :, :, ::-1], float32)
output_flipped = model(inputs_flipped)
output_flipped = flip_back(output_flipped.asnumpy(), flip_pairs)
if cfg.TEST.SHIFT_HEATMAP:
output_flipped[:, :, :, 1:] = \
output_flipped.copy()[:, :, :, 0:-1]
output = (output + output_flipped) * 0.5
c = item['center'].asnumpy()
s = item['scale'].asnumpy()
score = item['score'].asnumpy()
file_id = list(item['id'].asnumpy())
preds, maxvals = get_final_preds(cfg, output.copy(), c, s)
num_images, _ = preds.shape[:2]
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
all_preds[idx:idx + num_images, :, 2:3] = maxvals
all_boxes[idx:idx + num_images, 0] = np.prod(s * 200, 1)
all_boxes[idx:idx + num_images, 1] = score
image_id.extend(file_id)
idx += num_images
if idx % 1024 == 0:
print('{} samples validated in {} seconds'.format(idx, time.time() - start))
start = time.time()
print(all_preds[:idx].shape, all_boxes[:idx].shape, len(image_id))
_, perf_indicator = evaluate(cfg, all_preds[:idx], output_dir, all_boxes[:idx], image_id, ann_path)
print("AP:", perf_indicator)
return perf_indicator
def main():
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False, device_id=device_id)
args = parse_args()
if config.MODELARTS.IS_MODEL_ARTS:
mox.file.copy_parallel(src_url=args.data_url, dst_url=config.MODELARTS.CACHE_INPUT)
model = GetPoseResNet(config)
ckpt_name = ''
if config.MODELARTS.IS_MODEL_ARTS:
ckpt_name = config.MODELARTS.CACHE_INPUT
else:
ckpt_name = config.DATASET.ROOT
ckpt_name = ckpt_name + config.TEST.MODEL_FILE
print('loading model ckpt from {}'.format(ckpt_name))
load_param_into_net(model, load_checkpoint(ckpt_name))
valid_dataset = CreateDatasetCoco(
train_mode=False,
num_parallel_workers=config.TEST.NUM_PARALLEL_WORKERS,
)
ckpt_name = ckpt_name.split('/')
ckpt_name = ckpt_name[len(ckpt_name) - 1]
ckpt_name = ckpt_name.split('.')[0]
output_dir = ''
ann_path = ''
if config.MODELARTS.IS_MODEL_ARTS:
output_dir = config.MODELARTS.CACHE_OUTPUT
ann_path = config.MODELARTS.CACHE_INPUT
else:
output_dir = config.TEST.OUTPUT_DIR
ann_path = config.DATASET.ROOT
output_dir = output_dir + ckpt_name
ann_path = ann_path + config.DATASET.TEST_JSON
validate(config, valid_dataset, model, output_dir, ann_path)
if config.MODELARTS.IS_MODEL_ARTS:
mox.file.copy_parallel(src_url=config.MODELARTS.CACHE_OUTPUT, dst_url=args.train_url)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np
import mindspore.common.dtype as ms
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.config import config
from src.pose_resnet import GetPoseResNet
parser = argparse.ArgumentParser(description='simple_baselines')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_url", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="simple_baselines", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["MINDIR"], default='MINDIR', help='file format')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
cfg = config
net = GetPoseResNet(config)
assert cfg.checkpoint_dir is not None, "cfg.checkpoint_dir is None."
param_dict = load_checkpoint(args.ckpt_url)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.ones([64, 3, 192, 224]), ms.float32)
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "========================================================================"
echo "Please run the script as: "
echo "bash run.sh RANK_SIZE"
echo "For example: bash run_distribute.sh 8"
echo "It is better to use the absolute path."
echo "========================================================================"
set -e
RANK_SIZE=$1
export RANK_SIZE
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
test_dist_8pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
export RANK_SIZE=8
}
test_dist_2pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
export RANK_SIZE=2
}
test_dist_${RANK_SIZE}pcs
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf distribute_train
mkdir distribute_train
cd distribute_train
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cd ./device$i
mkdir src
cd src
mkdir utils
cd ../../../
cp ./train.py ./distribute_train/device$i
cp ./src/*.py ./distribute_train/device$i/src
cp ./src/utils/*.py ./distribute_train/device$i/src/utils
cd ./distribute_train/device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
python train.py --is_model_arts False --run_distribute True > train$i.log 2>&1 &
echo "$i finish"
cd ../
done
if [ $? -eq 0 ];then
echo "training success"
else
echo "training failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
python eval.py > eval_log$1.txt 2>&1 &

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "========================================================================"
echo "Please run the script as: "
echo "bash run_standalone_train.sh"
echo "For example: bash run_standalone_train.sh"
echo "It is better to use the absolute path."
echo "========================================================================"
echo "start training for device $DEVICE_ID"
export DEVICE_ID=$1
python -u ../train.py --device_id ${DEVICE_ID} --is_model_arts False --run_distribute False > train${DEVICE_ID}.log 2>&1 &
echo "finish"

View File

@ -0,0 +1,106 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
config
'''
from easydict import EasyDict as edict
config = edict()
#general
config.GENERAL = edict()
config.GENERAL.VERSION = 'commit'
config.GENERAL.TRAIN_SEED = 1
config.GENERAL.EVAL_SEED = 1
config.GENERAL.DATASET_SEED = 1
config.GENERAL.RUN_DISTRIBUTE = True
#model arts
config.MODELARTS = edict()
config.MODELARTS.IS_MODEL_ARTS = False
config.MODELARTS.CACHE_INPUT = '/cache/data_tzh/'
config.MODELARTS.CACHE_OUTPUT = '/cache/train_out/'
# model
config.MODEL = edict()
config.MODEL.IS_TRAINED = True
config.MODEL.INIT_WEIGHTS = True
config.MODEL.PRETRAINED = 'resnet50.ckpt'
config.MODEL.NUM_JOINTS = 17
config.MODEL.IMAGE_SIZE = [192, 256]
# network
config.NETWORK = edict()
config.NETWORK.NUM_LAYERS = 50
config.NETWORK.DECONV_WITH_BIAS = False
config.NETWORK.NUM_DECONV_LAYERS = 3
config.NETWORK.NUM_DECONV_FILTERS = [256, 256, 256]
config.NETWORK.NUM_DECONV_KERNELS = [4, 4, 4]
config.NETWORK.FINAL_CONV_KERNEL = 1
config.NETWORK.REVERSE = True
config.NETWORK.TARGET_TYPE = 'gaussian'
config.NETWORK.HEATMAP_SIZE = [48, 64]
config.NETWORK.SIGMA = 2
# loss
config.LOSS = edict()
config.LOSS.USE_TARGET_WEIGHT = True
# dataset
config.DATASET = edict()
config.DATASET.TYPE = 'COCO'
config.DATASET.ROOT = '/opt_data/xidian_wks/zhao/simple_baselines/coco2017/'
config.DATASET.TRAIN_SET = 'train2017'
config.DATASET.TRAIN_JSON = 'annotations/person_keypoints_train2017.json'
config.DATASET.TEST_SET = 'val2017'
config.DATASET.TEST_JSON = 'annotations/person_keypoints_val2017.json'
# training data augmentation
config.DATASET.FLIP = True
config.DATASET.SCALE_FACTOR = 0.3
config.DATASET.ROT_FACTOR = 40
# train
config.TRAIN = edict()
config.TRAIN.SHUFFLE = True
config.TRAIN.BATCH_SIZE = 64
config.TRAIN.BEGIN_EPOCH = 0
config.TRAIN.END_EPOCH = 140
config.TRAIN.LR = 0.001
config.TRAIN.LR_FACTOR = 0.1
config.TRAIN.LR_STEP = [90, 120]
config.TRAIN.NUM_PARALLEL_WORKERS = 8
config.TRAIN.SAVE_CKPT = True
config.TRAIN.CKPT_PATH = "/opt_data/xidian_wks/zhao/simple_baselines/"
# valid
config.TEST = edict()
config.TEST.BATCH_SIZE = 32
config.TEST.FLIP_TEST = True
config.TEST.POST_PROCESS = True
config.TEST.SHIFT_HEATMAP = True
config.TEST.USE_GT_BBOX = False
config.TEST.NUM_PARALLEL_WORKERS = 2
config.TEST.MODEL_FILE = 'multi_train_poseresnet_commit_0-140_292.ckpt'
config.TEST.COCO_BBOX_FILE = 'annotations/COCO_val2017_detections_AP_H_56_person.json'
config.TEST.OUTPUT_DIR = 'results/'
# nms
config.TEST.OKS_THRE = 0.9
config.TEST.IN_VIS_THRE = 0.2
config.TEST.BBOX_THRE = 1.0
config.TEST.IMAGE_THRE = 0.0
config.TEST.NMS_THRE = 1.0

View File

@ -0,0 +1,365 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
dataset processing
'''
from __future__ import division
import json
import os
from copy import deepcopy
import random
import numpy as np
import cv2
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
from src.utils.transforms import fliplr_joints, get_affine_transform, affine_transform
from src.config import config
ds.config.set_seed(config.GENERAL.DATASET_SEED) # Set Random Seed
flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16]]
class CocoDatasetGenerator:
'''
About the specific operations of coco2017 data set processing
'''
def __init__(self, cfg, is_train=False):
self.image_thre = cfg.TEST.IMAGE_THRE
self.image_size = np.array(cfg.MODEL.IMAGE_SIZE, dtype=np.int32)
self.image_width = cfg.MODEL.IMAGE_SIZE[0]
self.image_height = cfg.MODEL.IMAGE_SIZE[1]
self.aspect_ratio = self.image_width * 1.0 / self.image_height
self.heatmap_size = np.array(cfg.NETWORK.HEATMAP_SIZE, dtype=np.int32)
self.sigma = cfg.NETWORK.SIGMA
self.target_type = cfg.NETWORK.TARGET_TYPE
self.scale_factor = cfg.DATASET.SCALE_FACTOR
self.rotation_factor = cfg.DATASET.ROT_FACTOR
self.flip = cfg.DATASET.FLIP
self.db = []
self.is_train = is_train
self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16]]
self.num_joints = 17
def load_gt_dataset(self, image_path, ann_file):
'''
load_gt_dataset
'''
self.db = []
with open(ann_file, "rb") as f:
lines = f.readlines()
json_dict = json.loads(lines[0].decode("utf-8"))
objs = {}
cnt = 0
for item in json_dict['annotations']:
# exclude iscrowd and no-keypoint record
if item['iscrowd'] != 0 or item['num_keypoints'] == 0:
continue
# assert the record is valid
assert item['iscrowd'] == 0, 'is crowd'
assert item['category_id'] == 1, 'is not people'
assert item['area'] > 0, 'area le 0'
assert item['num_keypoints'] > 0, 'has no keypoint'
assert max(item['keypoints']) > 0
image_id = item['image_id']
obj = [{'num_keypoints': item['num_keypoints'], 'keypoints': item['keypoints'], 'bbox': item['bbox']}]
objs[image_id] = obj if image_id not in objs else objs[image_id] + obj
cnt += 1
print('loaded %d records from coco dataset.' % cnt)
for item in json_dict['images']:
image_id = item['id']
width = item['width']
height = item['height']
if image_id not in objs:
continue
valid_objs = []
for obj in objs[image_id]:
x, y, w, h = obj['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(width - 1, x1 + max(0, w - 1))
y2 = min(height - 1, y1 + max(0, h - 1))
if x2 >= x1 and y2 >= y1:
tmp_obj = deepcopy(obj)
tmp_obj['bbox'] = np.array((x1, y1, x2, y2)) - np.array((0, 0, x1, y1))
valid_objs.append(tmp_obj)
else:
assert False, 'invalid bbox!'
objs[image_id] = valid_objs
for obj in objs[image_id]:
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float)
for ipt in range(self.num_joints):
joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
joints_3d[ipt, 2] = 0
t_vis = obj['keypoints'][ipt * 3 + 2]
if t_vis > 1:
t_vis = 1
joints_3d_vis[ipt, 0] = t_vis
joints_3d_vis[ipt, 1] = t_vis
joints_3d_vis[ipt, 2] = 0
scale, center = self._bbox2sc(obj['bbox'])
self.db.append({
'id': int(item['id']),
'image': os.path.join(image_path, item['file_name']),
'center': center,
'scale': scale,
'joints_3d': joints_3d,
'joints_3d_vis': joints_3d_vis,
})
def load_detect_dataset(self, image_path, ann_file, bbox_file):
'''
load_detect_dataset
'''
self.db = []
all_boxes = None
with open(bbox_file, 'r') as f:
all_boxes = json.load(f)
assert all_boxes, 'Loading %s fail!' % bbox_file
print('Total boxes: {}'.format(len(all_boxes)))
with open(ann_file, "rb") as f:
lines = f.readlines()
json_dict = json.loads(lines[0].decode("utf-8"))
index_to_filename = {}
for item in json_dict['images']:
index_to_filename[item['id']] = item['file_name']
for det_res in all_boxes:
if det_res['category_id'] != 1:
continue
image = os.path.join(image_path,
index_to_filename[det_res['image_id']])
bbox = det_res['bbox']
score = det_res['score']
if score < self.image_thre:
continue
scale, center = self._bbox2sc(bbox)
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
joints_3d_vis = np.ones((self.num_joints, 3), dtype=np.float)
self.db.append({
'id': int(det_res['image_id']),
'image': image,
'center': center,
'scale': scale,
'score': score,
'joints_3d': joints_3d,
'joints_3d_vis': joints_3d_vis,
})
def _bbox2sc(self, bbox):
"""
reform xywh to meet the need of aspect ratio
"""
x, y, w, h = bbox[:4]
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array(
[w * 1.0 / 200, h * 1.0 / 200], dtype=np.float32)
if center[0] != -1:
scale = scale * 1.25
return scale, center
def __getitem__(self, idx):
db_rec = deepcopy(self.db[idx])
image_file = db_rec['image']
data_numpy = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
if data_numpy is None:
print('[ERROR] fail to read {}'.format(image_file))
raise ValueError('Fail to read {}'.format(image_file))
joints = db_rec['joints_3d']
joints_vis = db_rec['joints_3d_vis']
c = db_rec['center']
s = db_rec['scale']
score = db_rec['score'] if 'score' in db_rec else 1
r = 0
if self.is_train:
sf = self.scale_factor
rf = self.rotation_factor
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \
if random.random() <= 0.6 else 0
if self.flip and random.random() <= 0.5:
data_numpy = data_numpy[:, ::-1, :]
joints, joints_vis = fliplr_joints(
joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
c[0] = data_numpy.shape[1] - c[0] - 1
trans = get_affine_transform(c, s, r, self.image_size)
image = cv2.warpAffine(
data_numpy,
trans,
(int(self.image_size[0]), int(self.image_size[1])),
flags=cv2.INTER_LINEAR)
for i in range(self.num_joints):
if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
target, target_weight = self.generate_heatmap(joints, joints_vis)
return image, target, target_weight, s, c, score, db_rec['id']
def generate_heatmap(self, joints, joints_vis):
'''
generate_heatmap
'''
target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
target_weight[:, 0] = joints_vis[:, 0]
assert self.target_type == 'gaussian', \
'Only support gaussian map now!'
if self.target_type == 'gaussian':
target = np.zeros((self.num_joints,
self.heatmap_size[1],
self.heatmap_size[0]),
dtype=np.float32)
tmp_size = self.sigma * 3
for joint_id in range(self.num_joints):
feat_stride = self.image_size / self.heatmap_size
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
or br[0] < 0 or br[1] < 0:
# If not, just return the image as is
target_weight[joint_id] = 0
continue
# # Generate gaussian
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
v = target_weight[joint_id]
if v > 0.5:
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return target, target_weight
def __len__(self):
return len(self.db)
def CreateDatasetCoco(rank=0,
group_size=1,
train_mode=True,
num_parallel_workers=8,
transform=None,
shuffle=None):
'''
CreateDatasetCoco
'''
per_batch_size = config.TRAIN.BATCH_SIZE if train_mode else config.TEST.BATCH_SIZE
image_path = ''
ann_file = ''
bbox_file = ''
if config.MODELARTS.IS_MODEL_ARTS:
image_path = config.MODELARTS.CACHE_INPUT
ann_file = config.MODELARTS.CACHE_INPUT
bbox_file = config.MODELARTS.CACHE_INPUT
else:
image_path = config.DATASET.ROOT
ann_file = config.DATASET.ROOT
bbox_file = config.DATASET.ROOT
if train_mode:
image_path = image_path + config.DATASET.TRAIN_SET
ann_file = ann_file + config.DATASET.TRAIN_JSON
else:
image_path = image_path + config.DATASET.TEST_SET
ann_file = ann_file + config.DATASET.TEST_JSON
bbox_file = bbox_file + config.TEST.COCO_BBOX_FILE
print('loading dataset from {}'.format(image_path))
shuffle = shuffle if shuffle is not None else train_mode
dataset_generator = CocoDatasetGenerator(config, is_train=train_mode)
if not train_mode and config.TEST.USE_GT_BBOX:
print('loading bbox file from {}'.format(bbox_file))
dataset_generator.load_detect_dataset(image_path, ann_file, bbox_file)
else:
dataset_generator.load_gt_dataset(image_path, ann_file)
coco_dataset = ds.GeneratorDataset(dataset_generator,
column_names=["image", "target", "weight", "scale", "center", "score", "id"],
num_parallel_workers=num_parallel_workers,
num_shards=group_size,
shard_id=rank,
shuffle=shuffle)
if transform is None:
transform_img = [
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
C.HWC2CHW()
]
else:
transform_img = transform
coco_dataset = coco_dataset.map(input_columns="image",
num_parallel_workers=num_parallel_workers,
operations=transform_img)
coco_dataset = coco_dataset.batch(per_batch_size, drop_remainder=train_mode)
return coco_dataset

View File

@ -0,0 +1,81 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
network_with_loss
'''
from __future__ import division
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.loss.loss import _Loss
from mindspore.common import dtype as mstype
class JointsMSELoss(_Loss):
'''
JointsMSELoss
'''
def __init__(self, use_target_weight):
super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='mean')
self.use_target_weight = use_target_weight
self.shape = P.Shape()
self.reshape = P.Reshape()
self.squeeze = P.Squeeze(1)
self.mul = P.Mul()
def construct(self, output, target, target_weight):
'''
construct
'''
total_shape = self.shape(output)
batch_size = total_shape[0]
num_joints = total_shape[1]
remained_size = 1
for i in range(2, len(total_shape)):
remained_size *= total_shape[i]
split = P.Split(1, num_joints)
new_shape = (batch_size, num_joints, remained_size)
heatmaps_pred = split(self.reshape(output, new_shape))
heatmaps_gt = split(self.reshape(target, new_shape))
loss = 0
for idx in range(num_joints):
heatmap_pred_squeezed = self.squeeze(heatmaps_pred[idx])
heatmap_gt_squeezed = self.squeeze(heatmaps_gt[idx])
if self.use_target_weight:
loss += 0.5 * self.criterion(self.mul(heatmap_pred_squeezed, target_weight[:, idx]),
self.mul(heatmap_gt_squeezed, target_weight[:, idx]))
else:
loss += 0.5 * self.criterion(heatmap_pred_squeezed, heatmap_gt_squeezed)
return loss / num_joints
class PoseResNetWithLoss(nn.Cell):
"""
Pack the model network and loss function together to calculate the loss value.
"""
def __init__(self, network, loss):
super(PoseResNetWithLoss, self).__init__()
self.network = network
self.loss = loss
def construct(self, image, target, weight, scale=None, center=None, score=None, idx=None):
output = self.network(image)
output = F.mixed_precision_cast(mstype.float32, output)
target = F.mixed_precision_cast(mstype.float32, target)
weight = F.mixed_precision_cast(mstype.float32, weight)
return self.loss(output, target, weight)

View File

@ -0,0 +1,222 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
simple_baselines network
'''
from __future__ import division
import os
import mindspore.nn as nn
import mindspore.common.initializer as init
import mindspore.ops.operations as F
from mindspore.train.serialization import load_checkpoint, load_param_into_net
BN_MOMENTUM = 0.1
class MPReverse(nn.Cell):
'''
MPReverse
'''
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
super(MPReverse, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode)
self.reverse = F.ReverseV2(axis=[2, 3])
def construct(self, x):
x = self.reverse(x)
x = self.maxpool(x)
x = self.reverse(x)
return x
class Bottleneck(nn.Cell):
'''
model part of network
'''
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, has_bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=stride, padding=1, has_bias=False, pad_mode='pad')
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, has_bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def construct(self, x):
'''
construct
'''
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class PoseResNet(nn.Cell):
'''
PoseResNet
'''
def __init__(self, block, layers, cfg):
self.inplanes = 64
self.deconv_with_bias = cfg.NETWORK.DECONV_WITH_BIAS
super(PoseResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, has_bias=False, pad_mode='pad')
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU()
self.maxpool = MPReverse(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# used for deconv layers
self.deconv_layers = self._make_deconv_layer(
cfg.NETWORK.NUM_DECONV_LAYERS,
cfg.NETWORK.NUM_DECONV_FILTERS,
cfg.NETWORK.NUM_DECONV_KERNELS,
)
self.final_layer = nn.Conv2d(
in_channels=cfg.NETWORK.NUM_DECONV_FILTERS[-1],
out_channels=cfg.MODEL.NUM_JOINTS,
kernel_size=cfg.NETWORK.FINAL_CONV_KERNEL,
stride=1,
padding=1 if cfg.NETWORK.FINAL_CONV_KERNEL == 3 else 0,
pad_mode='pad',
has_bias=True,
weight_init=init.Normal(0.001)
)
def _make_layer(self, block, planes, blocks, stride=1):
'''
_make_layer
'''
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.SequentialCell([nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, has_bias=False),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM)])
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
print(i)
return nn.SequentialCell(layers)
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
'''
_make_deconv_layer
'''
assert num_layers == len(num_filters), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
layers = []
for i in range(num_layers):
kernel = num_kernels[i]
padding = 1
planes = num_filters[i]
layers.append(nn.Conv2dTranspose(
in_channels=self.inplanes,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
has_bias=self.deconv_with_bias,
pad_mode='pad',
weight_init=init.Normal(0.001)
))
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU())
self.inplanes = planes
return nn.SequentialCell(layers)
def construct(self, x):
'''
construct
'''
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.deconv_layers(x)
x = self.final_layer(x)
return x
def init_weights(self, pretrained=''):
if os.path.isfile(pretrained):
# load params from pretrained
param_dict = load_checkpoint(pretrained)
load_param_into_net(self, param_dict)
print('=> loading pretrained model {}'.format(pretrained))
else:
print('=> imagenet pretrained model dose not exist')
raise ValueError('{} is not a file'.format(pretrained))
resnet_spec = {50: (Bottleneck, [3, 4, 6, 3]),
101: (Bottleneck, [3, 4, 23, 3]),
152: (Bottleneck, [3, 8, 36, 3])}
def GetPoseResNet(cfg):
'''
GetPoseResNet
'''
num_layers = cfg.NETWORK.NUM_LAYERS
block_class, layers = resnet_spec[num_layers]
network = PoseResNet(block_class, layers, cfg)
if cfg.MODEL.IS_TRAINED and cfg.MODEL.INIT_WEIGHTS:
pretrained = ''
if cfg.MODELARTS.IS_MODEL_ARTS:
pretrained = cfg.MODELARTS.CACHE_INPUT + cfg.MODEL.PRETRAINED
else:
pretrained = cfg.TRAIN.CKPT_PATH + cfg.MODEL.PRETRAINED
network.init_weights(pretrained)
return network

View File

@ -0,0 +1,137 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
coco
'''
from __future__ import division
import json
import os
import pickle
from collections import defaultdict, OrderedDict
import numpy as np
try:
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
has_coco = True
except ImportError:
has_coco = False
from src.utils.nms import oks_nms
def _write_coco_keypoint_results(img_kpts, num_joints, res_file):
'''
_write_coco_keypoint_results
'''
results = []
for img, items in img_kpts.items():
item_size = len(items)
if not items:
continue
kpts = np.array([items[k]['keypoints']
for k in range(item_size)])
keypoints = np.zeros((item_size, num_joints * 3), dtype=np.float)
keypoints[:, 0::3] = kpts[:, :, 0]
keypoints[:, 1::3] = kpts[:, :, 1]
keypoints[:, 2::3] = kpts[:, :, 2]
result = [{'image_id': int(img),
'keypoints': list(keypoints[k]),
'score': items[k]['score'],
'category_id': 1,
} for k in range(item_size)]
results.extend(result)
with open(res_file, 'w') as f:
json.dump(results, f, sort_keys=True, indent=4)
def _do_python_keypoint_eval(res_file, res_folder, ann_path):
'''
_do_python_keypoint_eval
'''
coco = COCO(ann_path)
coco_dt = coco.loadRes(res_file)
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
coco_eval.params.useSegm = None
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
info_str = []
for ind, name in enumerate(stats_names):
info_str.append((name, coco_eval.stats[ind]))
eval_file = os.path.join(
res_folder, 'keypoints_results.pkl')
with open(eval_file, 'wb') as f:
pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
print('coco eval results saved to %s' % eval_file)
return info_str
def evaluate(cfg, preds, output_dir, all_boxes, img_id, ann_path):
'''
evaluate
'''
if not os.path.exists(output_dir):
os.makedirs(output_dir)
res_file = os.path.join(output_dir, 'keypoints_results.json')
img_kpts_dict = defaultdict(list)
for idx, file_id in enumerate(img_id):
img_kpts_dict[file_id].append({
'keypoints': preds[idx],
'area': all_boxes[idx][0],
'score': all_boxes[idx][1],
})
# rescoring and oks nms
num_joints = cfg.MODEL.NUM_JOINTS
in_vis_thre = cfg.TEST.IN_VIS_THRE
oks_thre = cfg.TEST.OKS_THRE
oks_nmsed_kpts = {}
for img, items in img_kpts_dict.items():
for item in items:
kpt_score = 0
valid_num = 0
for n_jt in range(num_joints):
max_jt = item['keypoints'][n_jt][2]
if max_jt > in_vis_thre:
kpt_score = kpt_score + max_jt
valid_num = valid_num + 1
if valid_num != 0:
kpt_score = kpt_score / valid_num
item['score'] = kpt_score * item['score']
keep = oks_nms(items, oks_thre)
if not keep:
oks_nmsed_kpts[img] = items
else:
oks_nmsed_kpts[img] = [items[kep] for kep in keep]
# evaluate and save
image_set = cfg.DATASET.TEST_SET
_write_coco_keypoint_results(oks_nmsed_kpts, num_joints, res_file)
if 'test' not in image_set and has_coco:
ann_path = ann_path if ann_path else os.path.join(cfg.DATASET.ROOT, 'annotations',
'person_keypoints_' + image_set + '.json')
info_str = _do_python_keypoint_eval(res_file, output_dir, ann_path)
name_value = OrderedDict(info_str)
return name_value, name_value['AP']
return {'Null': 0}, 0

View File

@ -0,0 +1,83 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
inference
'''
from __future__ import division
import math
import numpy as np
from src.utils.transforms import transform_preds
def get_max_preds(batch_heatmaps):
'''
get predictions from score maps
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
'''
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
def get_final_preds(config, batch_heatmaps, center, scale):
'''
get_final_preds
'''
coords, maxvals = get_max_preds(batch_heatmaps)
heatmap_height = batch_heatmaps.shape[2]
heatmap_width = batch_heatmaps.shape[3]
# post-processing
if config.TEST.POST_PROCESS:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = batch_heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
diff = np.array([hm[py][px+1] - hm[py][px-1],
hm[py+1][px]-hm[py-1][px]])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy()
# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(coords[i], center[i], scale[i],
[heatmap_width, heatmap_height])
return preds, maxvals

View File

@ -0,0 +1,74 @@
# ------------------------------------------------------------------------------
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
nms operation
'''
from __future__ import division
import numpy as np
def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
'''
oks_iou
'''
if not isinstance(sigmas, np.ndarray):
sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72,
.62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
var = (sigmas * 2) ** 2
xg = g[0::3]
yg = g[1::3]
vg = g[2::3]
ious = np.zeros((d.shape[0]))
for n_d in range(0, d.shape[0]):
xd = d[n_d, 0::3]
yd = d[n_d, 1::3]
vd = d[n_d, 2::3]
dx = xd - xg
dy = yd - yg
e = (dx ** 2 + dy ** 2) / var / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
if in_vis_thre is not None:
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
e = e[ind]
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
return ious
def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
"""
greedily select boxes with high confidence and overlap with current maximum <= thresh
rule out overlap >= thresh, overlap = oks
:param kpts_db
:param thresh: retain overlap < thresh
:return: indexes to keep
"""
kpts = len(kpts_db)
if kpts == 0:
return []
scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
inds = np.where(oks_ovr <= thresh)[0]
order = order[inds + 1]
return keep

View File

@ -0,0 +1,137 @@
# ------------------------------------------------------------------------------
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
transforms
'''
from __future__ import division
import numpy as np
import cv2
def flip_back(output_flipped, matched_parts):
'''
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
'''
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def fliplr_joints(joints, joints_vis, width, matched_parts):
"""
flip coords
"""
# Flip horizontal
joints[:, 0] = width - joints[:, 0] - 1
# Change left-right parts
for pair in matched_parts:
joints[pair[0], :], joints[pair[1], :] = \
joints[pair[1], :], joints[pair[0], :].copy()
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
return joints*joints_vis, joints_vis
def transform_preds(coords, center, scale, output_size):
'''
transform_preds
'''
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
'''
get_affine_transform
'''
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
print(scale)
scale = np.array([scale, scale])
scale_tmp = scale * 200.0
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
def crop(img, center, scale, output_size, rot=0):
trans = get_affine_transform(center, scale, rot, output_size)
dst_img = cv2.warpAffine(img,
trans,
(int(output_size[0]), int(output_size[1])),
flags=cv2.INTER_LINEAR)
return dst_img

View File

@ -0,0 +1,147 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
train
'''
from __future__ import division
import os
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.context import ParallelMode
from mindspore.communication.management import init
from mindspore.train import Model
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.nn.optim import Adam
from mindspore.common import set_seed
from src.config import config
from src.pose_resnet import GetPoseResNet
from src.network_with_loss import JointsMSELoss, PoseResNetWithLoss
from src.dataset import CreateDatasetCoco
if config.MODELARTS.IS_MODEL_ARTS:
import moxing as mox
set_seed(config.GENERAL.TRAIN_SEED)
def get_lr(begin_epoch,
total_epochs,
steps_per_epoch,
lr_init=0.1,
factor=0.1,
epoch_number_to_drop=(90, 120)
):
'''
get_lr
'''
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
step_number_to_drop = [steps_per_epoch * x for x in epoch_number_to_drop]
for i in range(int(total_steps)):
if i in step_number_to_drop:
lr_init = lr_init * factor
lr_each_step.append(lr_init)
current_step = steps_per_epoch * begin_epoch
lr_each_step = np.array(lr_each_step, dtype=np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
def parse_args():
parser = argparse.ArgumentParser(description="Simpleposenet training")
parser.add_argument('--data_url', required=False, default=None, help='Location of data.')
parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
parser.add_argument('--device_id', required=False, default=None, type=int, help='Location of training outputs.')
parser.add_argument('--run_distribute', required=False, default=False, help='Location of training outputs.')
parser.add_argument('--is_model_arts', required=False, default=False, help='Location of training outputs.')
args = parser.parse_args()
return args
def main():
print("loading parse...")
args = parse_args()
device_id = args.device_id
config.GENERAL.RUN_DISTRIBUTE = args.run_distribute
config.MODELARTS.IS_MODEL_ARTS = args.is_model_arts
if config.GENERAL.RUN_DISTRIBUTE or config.MODELARTS.IS_MODEL_ARTS:
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id)
if config.GENERAL.RUN_DISTRIBUTE:
init()
rank = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
rank = 0
device_num = 1
if config.MODELARTS.IS_MODEL_ARTS:
mox.file.copy_parallel(src_url=args.data_url, dst_url=config.MODELARTS.CACHE_INPUT)
dataset = CreateDatasetCoco(rank=rank,
group_size=device_num,
train_mode=True,
num_parallel_workers=config.TRAIN.NUM_PARALLEL_WORKERS,
)
net = GetPoseResNet(config)
loss = JointsMSELoss(config.LOSS.USE_TARGET_WEIGHT)
net_with_loss = PoseResNetWithLoss(net, loss)
dataset_size = dataset.get_dataset_size()
lr = Tensor(get_lr(config.TRAIN.BEGIN_EPOCH,
config.TRAIN.END_EPOCH,
dataset_size,
lr_init=config.TRAIN.LR,
factor=config.TRAIN.LR_FACTOR,
epoch_number_to_drop=config.TRAIN.LR_STEP))
opt = Adam(net.trainable_params(), learning_rate=lr)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.TRAIN.SAVE_CKPT:
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=20)
prefix = ''
if config.GENERAL.RUN_DISTRIBUTE:
prefix = 'multi_' + 'train_poseresnet_' + config.GENERAL.VERSION + '_' + os.getenv('DEVICE_ID')
else:
prefix = 'single_' + 'train_poseresnet_' + config.GENERAL.VERSION
directory = ''
if config.MODELARTS.IS_MODEL_ARTS:
directory = config.MODELARTS.CACHE_OUTPUT + 'device_'+ os.getenv('DEVICE_ID')
elif config.GENERAL.RUN_DISTRIBUTE:
directory = config.TRAIN.CKPT_PATH + 'device_'+ os.getenv('DEVICE_ID')
else:
directory = config.TRAIN.CKPT_PATH + 'device'
ckpoint_cb = ModelCheckpoint(prefix=prefix, directory=directory, config=config_ck)
cb.append(ckpoint_cb)
model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
print("************ Start training now ************")
print('start training, epoch size = %d' % epoch_size)
model.train(epoch_size, dataset, callbacks=cb)
if config.MODELARTS.IS_MODEL_ARTS:
mox.file.copy_parallel(src_url=config.MODELARTS.CACHE_OUTPUT, dst_url=args.train_url)
if __name__ == '__main__':
main()