forked from mindspore-Ecosystem/mindspore
!19379 Add Stacked Hourglass Network for Model Zoo
Merge pull request !19379 from cometeme/master
This commit is contained in:
commit
885bd46760
|
@ -0,0 +1,193 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [Stacked Hourglass 描述](#stacked-hourglass-描述)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本和样例代码](#脚本和样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [运行](#运行)
|
||||
- [结果](#结果)
|
||||
- [评估过程](#评估过程)
|
||||
- [运行](#运行-1)
|
||||
- [结果](#结果-1)
|
||||
- [导出](#导出)
|
||||
- [模型说明](#模型说明)
|
||||
- [训练性能(2HG)](#训练性能2hg)
|
||||
- [随机情况的描述](#随机情况的描述)
|
||||
- [ModelZoo](#modelzoo)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# Stacked Hourglass 描述
|
||||
|
||||
Stacked Hourglass 是一个用于人体姿态检测的模型,它采用堆叠的 hourglass 模块进行特征的提取,并在最终通过热力图输出模型对于每个特征点的预测位置。
|
||||
|
||||
[论文:Stacked Hourglass Networks for Human Pose Estimation](https://arxiv.org/abs/1603.06937v2)
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[MPII Human Pose Dataset](http://human-pose.mpi-inf.mpg.de/)
|
||||
|
||||
- 数据集大小:
|
||||
- 训练: 22246 张图片
|
||||
- 测试: 2958 张图片
|
||||
- 关键点数量:16 个(头部、颈部、肩部、肘部、手腕、胸部、骨盆、臀部、膝盖、脚踝)
|
||||
|
||||
> 注:MPII 数据集中原始的 annot 为 .mat 格式,处理困难,请下载使用另一个 annot:[https://github.com/princeton-vl/pytorch_stacked_hourglass/tree/master/data/MPII/annot](https://github.com/princeton-vl/pytorch_stacked_hourglass/tree/master/data/MPII/annot)
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用 Ascend 来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore 教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本和样例代码
|
||||
|
||||
```text
|
||||
├──scripts
|
||||
│ ├──run_distribute_train.sh # 分布式训练脚本
|
||||
│ ├──run_eval.sh # 评估脚本
|
||||
│ └──run_standalone_train.sh # 单卡训练脚本
|
||||
├──src
|
||||
│ ├──dataset
|
||||
│ │ ├──DatasetGenerator.py # 数据集定义及标注热力图生成
|
||||
│ │ └──MPIIDataLoader.py # MPII 数据的加载及预处理
|
||||
│ ├──models
|
||||
│ │ ├──layers.py # 网络子模块定义
|
||||
│ │ ├──loss.py # HeatMap Loss 定义
|
||||
│ │ └──StackedHourglassNet.py # 整体网络定义
|
||||
│ ├──utils
|
||||
│ │ ├──img.py # 通用的图像处理模块
|
||||
│ │ └──inference.py # 推理相关的函数,包含了推理的准确率计算等
|
||||
│ └── config.py # 参数配置
|
||||
├── eval.py # 评估脚本
|
||||
├── export.py # 导出脚本
|
||||
├── README_CN.md # 项目相关描述
|
||||
└── train.py # 训练脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
模型训练和评估过程中使用的参数可以在 config.py 中设置,也可以在运行时通过命令行参数给入。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 运行
|
||||
|
||||
单卡训练时,首先需要设置目标卡: `export DEVICE_ID=x` ,其中 `x` 为目标卡的 ID 。接下来启动训练:
|
||||
|
||||
```sh
|
||||
python train.py
|
||||
```
|
||||
|
||||
或者可以使用单卡训练脚本:
|
||||
|
||||
```sh
|
||||
./scripts/run_standalone_train.sh [设备 ID] [标注路径] [图像路径]
|
||||
```
|
||||
|
||||
多卡训练时可以使用多卡训练脚本:
|
||||
|
||||
```sh
|
||||
./scripts/run_distribute_train.sh [配置文件路径] [Ascend 卡数量] [标注路径] [图像路径]
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
ckpt 文件将存储在当前路径下,训练结果默认输出至 `loss.txt` 中,而错误和提示信息在 `err.txt` 中,示例如下:
|
||||
|
||||
```text
|
||||
loading data...
|
||||
Done (t=14.61s)
|
||||
train data size: 22246
|
||||
epoch: 1 step: 695, loss is 0.00068435294
|
||||
epoch time: 954584.373 ms, per step time: 1373.503 ms
|
||||
epoch: 2 step: 695, loss is 0.00067576126
|
||||
epoch time: 755549.341 ms, per step time: 1087.121 ms
|
||||
epoch: 3 step: 695, loss is 0.00057179347
|
||||
epoch time: 750856.373 ms, per step time: 1080.369 ms
|
||||
epoch: 4 step: 695, loss is 0.00055218843
|
||||
|
||||
[...]
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 运行
|
||||
|
||||
在运行评估前需要指定目标卡: `export DEVICE_ID=x` ,其中 `x` 为目标卡的 ID 。接下来使用 python 启动评估,需要指定 ckpt 文件的路径。
|
||||
|
||||
```sh
|
||||
python eval.py --ckpt_file <path to ckpt file>
|
||||
```
|
||||
|
||||
也可以使用验证脚本:
|
||||
|
||||
```sh
|
||||
./scripts/run_eval.sh [设备 ID] [ckpt 文件路径] [标注路径] [图像路径]
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
验证结果默认输出至 `result.txt` 中,而错误和提示信息在 `err.txt` 中。
|
||||
|
||||
```text
|
||||
all :
|
||||
Val PCK @, 0.5 , total : 0.882 , count: 44239
|
||||
Tra PCK @, 0.5 , total : 0.938 , count: 4443
|
||||
Val PCK @, 0.5 , ankle : 0.765 , count: 4234
|
||||
Tra PCK @, 0.5 , ankle : 0.847 , count: 392
|
||||
Val PCK @, 0.5 , knee : 0.819 , count: 4963
|
||||
Tra PCK @, 0.5 , knee : 0.91 , count: 499
|
||||
Val PCK @, 0.5 , hip : 0.871 , count: 5777
|
||||
Tra PCK @, 0.5 , hip : 0.918 , count: 587
|
||||
|
||||
[...]
|
||||
```
|
||||
|
||||
## 导出
|
||||
|
||||
可以使用 `export.py` 脚本进行模型导出,使用方法为:
|
||||
|
||||
```sh
|
||||
python export.py --ckpt_file [ckpt 文件路径]
|
||||
```
|
||||
|
||||
# 模型说明
|
||||
|
||||
## 训练性能(2HG)
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ---------------- | -------------------------- |
|
||||
| 模型名称 | Stacked Hourglass Networks |
|
||||
| 运行环境 | Ascend 910A |
|
||||
| 上传时间 | 2021-7-5 |
|
||||
| MindSpore 版本 | 1.2.0 |
|
||||
| 数据集 | MPII Human Pose Dataset |
|
||||
| 训练参数 | 详见 config.py |
|
||||
| 优化器 | Adam (带指数学习率衰减) |
|
||||
| 损失函数 | HeatMap Loss (类 MSE) |
|
||||
| 最终损失 | 0.00036373272 |
|
||||
| 精确度 | 88.2% |
|
||||
| 训练总时间(1p) | 20h |
|
||||
| 评估总时间(1p) | 21min |
|
||||
| 参数量 | 8429088 |
|
||||
|
||||
# 随机情况的描述
|
||||
|
||||
我们在 `train.py` 脚本中设置了随机种子。
|
||||
|
||||
# ModelZoo
|
||||
|
||||
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
run model eval
|
||||
"""
|
||||
import os
|
||||
|
||||
from mindspore import context, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import parse_args
|
||||
from src.models.StackedHourglassNet import StackedHourglassNet
|
||||
from src.utils.inference import MPIIEval, get_img, inference
|
||||
|
||||
args = parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.path.exists(args.ckpt_file):
|
||||
print("ckpt file not valid")
|
||||
exit()
|
||||
|
||||
if not os.path.exists(args.img_dir) or not os.path.exists(args.annot_dir):
|
||||
print("Dataset not found.")
|
||||
exit()
|
||||
|
||||
# Set context mode
|
||||
if args.context_mode == "GRAPH":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
|
||||
else:
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
|
||||
|
||||
# Import net
|
||||
net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim)
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
gts = []
|
||||
preds = []
|
||||
normalizing = []
|
||||
|
||||
num_eval = args.num_eval
|
||||
num_train = args.train_num_eval
|
||||
for anns, img, c, s, n in get_img(num_eval, num_train):
|
||||
gts.append(anns)
|
||||
ans = inference(img, net, c, s)
|
||||
if ans.size > 0:
|
||||
ans = ans[:, :, :3]
|
||||
|
||||
# (num preds, joints, x/y/visible)
|
||||
pred = []
|
||||
for i in range(ans.shape[0]):
|
||||
pred.append({"keypoints": ans[i, :, :]})
|
||||
preds.append(pred)
|
||||
normalizing.append(n)
|
||||
|
||||
mpii_eval = MPIIEval()
|
||||
mpii_eval.eval(preds, gts, normalizing, num_train)
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
export model
|
||||
"""
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from mindspore import Tensor, context, export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import parse_args
|
||||
from src.models.StackedHourglassNet import StackedHourglassNet
|
||||
|
||||
args = parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.path.exists(args.ckpt_file):
|
||||
print("ckpt file not valid")
|
||||
exit()
|
||||
|
||||
# Set context mode
|
||||
if args.context_mode == "GRAPH":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
|
||||
else:
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
|
||||
|
||||
# Import net
|
||||
net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim)
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args.batch_size, args.input_res, args.input_res, 3], np.float32))
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,23 @@
|
|||
asttokens==2.0.5
|
||||
astunparse==1.6.3
|
||||
attrs==21.2.0
|
||||
cached-property==1.5.2
|
||||
certifi==2018.10.15
|
||||
cffi==1.14.5
|
||||
decorator==5.0.9
|
||||
easydict==1.9
|
||||
h5py==3.3.0
|
||||
imageio==2.9.0
|
||||
mpmath==1.2.1
|
||||
numpy==1.21.0
|
||||
opencv-python==4.5.2.54
|
||||
packaging==21.0
|
||||
Pillow==8.3.1
|
||||
protobuf==3.17.3
|
||||
psutil==5.8.0
|
||||
pycparser==2.20
|
||||
pyparsing==2.4.7
|
||||
scipy==1.7.0
|
||||
six==1.16.0
|
||||
sympy==1.8
|
||||
tqdm==4.61.2
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
echo "Usage: ./scripts/run_distribute_train.sh [rank file path] [rank size] [annot path] [image path]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export RANK_TABLE_FILE="$1"
|
||||
export RANK_SIZE="$2"
|
||||
ANNOT_DIR=$3
|
||||
IMG_DIR=$4
|
||||
|
||||
CURDIR=$(pwd)
|
||||
|
||||
for ((i = 0; i < $2; i++))
|
||||
do
|
||||
export DEVICE_ID="$i"
|
||||
export RANK_ID="$i"
|
||||
|
||||
cd $CURDIR
|
||||
TARGET="./Train${DEVICE_ID}"
|
||||
rm -rf $TARGET
|
||||
mkdir $TARGET
|
||||
cp *.py $TARGET
|
||||
cp -r src $TARGET
|
||||
cd $TARGET
|
||||
|
||||
echo "training for rank $i"
|
||||
nohup python train.py \
|
||||
--annot_dir=$ANNOT_DIR \
|
||||
--img_dir=$IMG_DIR \
|
||||
--parallel True > loss.txt 2> err.txt &
|
||||
done
|
|
@ -0,0 +1,39 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
echo "Usage: ./scripts/run_eval.sh [device ID] [ckpt path] [annot path] [image path]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=$1
|
||||
PATH_CHECKPOINT=$2
|
||||
ANNOT_DIR=$3
|
||||
IMG_DIR=$4
|
||||
|
||||
TARGET="./Eval"
|
||||
|
||||
rm -rf $TARGET
|
||||
mkdir $TARGET
|
||||
cp *.py $TARGET
|
||||
cp -r src $TARGET
|
||||
|
||||
cd $TARGET
|
||||
|
||||
nohup python eval.py \
|
||||
--ckpt_file=$PATH_CHECKPOINT \
|
||||
--annot_dir=$ANNOT_DIR \
|
||||
--img_dir=$IMG_DIR > result.txt 2> err.txt &
|
|
@ -0,0 +1,37 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: ./scripts/run_standalone_train.sh [device ID] [annot path] [image path]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=$1
|
||||
ANNOT_DIR=$2
|
||||
IMG_DIR=$3
|
||||
|
||||
TARGET="./Train"
|
||||
|
||||
rm -rf $TARGET
|
||||
mkdir $TARGET
|
||||
cp *.py $TARGET
|
||||
cp -r src $TARGET
|
||||
|
||||
cd $TARGET
|
||||
|
||||
nohup python train.py \
|
||||
--annot_dir=$ANNOT_DIR \
|
||||
--img_dir=$IMG_DIR > loss.txt 2> err.txt &
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
model config
|
||||
"""
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="MindSpore Stacked Hourglass")
|
||||
|
||||
# Model
|
||||
parser.add_argument("--nstack", type=int, default=2)
|
||||
parser.add_argument("--inp_dim", type=int, default=256)
|
||||
parser.add_argument("--oup_dim", type=int, default=16)
|
||||
parser.add_argument("--input_res", type=int, default=256)
|
||||
parser.add_argument("--output_res", type=int, default=64)
|
||||
parser.add_argument("--annot_dir", type=str, default="./MPII/annot")
|
||||
parser.add_argument("--img_dir", type=str, default="./MPII/images")
|
||||
# Context
|
||||
parser.add_argument("--context_mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"])
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "CPU"])
|
||||
# Train
|
||||
parser.add_argument("--parallel", type=ast.literal_eval, default=False)
|
||||
parser.add_argument("--amp_level", type=str, default="O2", choices=["O0", "O1", "O2", "O3"])
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--num_epoch", type=int, default=100)
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=5)
|
||||
parser.add_argument("--keep_checkpoint_max", type=int, default=20)
|
||||
parser.add_argument("--loss_log_interval", type=int, default=1)
|
||||
parser.add_argument("--initial_lr", type=float, default=1e-3)
|
||||
parser.add_argument("--decay_rate", type=float, default=0.985)
|
||||
parser.add_argument("--decay_epoch", type=int, default=1)
|
||||
# Valid
|
||||
parser.add_argument("--num_eval", type=int, default=2958)
|
||||
parser.add_argument("--train_num_eval", type=int, default=300)
|
||||
parser.add_argument("--ckpt_file", type=str, default="")
|
||||
# Export
|
||||
parser.add_argument("--file_name", type=str, default="stackedhourglass")
|
||||
parser.add_argument("--file_format", type=str, default="MINDIR")
|
||||
|
||||
args = parser.parse_known_args()[0]
|
||||
|
||||
return args
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
dataset classes
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import src.utils.img
|
||||
from src.dataset.MPIIDataLoader import flipped_parts
|
||||
|
||||
|
||||
class GenerateHeatmap:
|
||||
"""
|
||||
get train target heatmap
|
||||
"""
|
||||
|
||||
def __init__(self, output_res, num_parts):
|
||||
self.output_res = output_res
|
||||
self.num_parts = num_parts
|
||||
sigma = self.output_res / 64
|
||||
self.sigma = sigma
|
||||
size = 6 * sigma + 3
|
||||
x = np.arange(0, size, 1, float)
|
||||
y = x[:, np.newaxis]
|
||||
x0, y0 = 3 * sigma + 1, 3 * sigma + 1
|
||||
self.g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
|
||||
def __call__(self, keypoints):
|
||||
hms = np.zeros(shape=(self.num_parts, self.output_res, self.output_res), dtype=np.float32)
|
||||
sigma = self.sigma
|
||||
for p in keypoints:
|
||||
for idx, pt in enumerate(p):
|
||||
if pt[0] > 0:
|
||||
x, y = int(pt[0]), int(pt[1])
|
||||
if x < 0 or y < 0 or x >= self.output_res or y >= self.output_res:
|
||||
continue
|
||||
ul = int(x - 3 * sigma - 1), int(y - 3 * sigma - 1)
|
||||
br = int(x + 3 * sigma + 2), int(y + 3 * sigma + 2)
|
||||
|
||||
c, d = max(0, -ul[0]), min(br[0], self.output_res) - ul[0]
|
||||
a, b = max(0, -ul[1]), min(br[1], self.output_res) - ul[1]
|
||||
|
||||
cc, dd = max(0, ul[0]), min(br[0], self.output_res)
|
||||
aa, bb = max(0, ul[1]), min(br[1], self.output_res)
|
||||
hms[idx, aa:bb, cc:dd] = np.maximum(hms[idx, aa:bb, cc:dd], self.g[a:b, c:d])
|
||||
return hms
|
||||
|
||||
|
||||
class DatasetGenerator:
|
||||
"""
|
||||
mindspore general dataset generator
|
||||
"""
|
||||
|
||||
def __init__(self, input_res, output_res, ds, index):
|
||||
self.input_res = input_res
|
||||
self.output_res = output_res
|
||||
self.generateHeatmap = GenerateHeatmap(self.output_res, 16)
|
||||
self.ds = ds
|
||||
self.index = index
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# print(f"loading...{idx}")
|
||||
return self.loadImage(self.index[idx])
|
||||
|
||||
def loadImage(self, idx):
|
||||
"""
|
||||
load and preprocess image
|
||||
"""
|
||||
ds = self.ds
|
||||
|
||||
# Load + Crop
|
||||
orig_img = ds.get_img(idx)
|
||||
orig_keypoints = ds.get_kps(idx)
|
||||
kptmp = orig_keypoints.copy()
|
||||
c = ds.get_center(idx)
|
||||
s = ds.get_scale(idx)
|
||||
|
||||
cropped = src.utils.img.crop(orig_img, c, s, (self.input_res, self.input_res))
|
||||
for i in range(np.shape(orig_keypoints)[1]):
|
||||
if orig_keypoints[0, i, 0] > 0:
|
||||
orig_keypoints[0, i, :2] = src.utils.img.transform(
|
||||
orig_keypoints[0, i, :2], c, s, (self.input_res, self.input_res)
|
||||
)
|
||||
keypoints = np.copy(orig_keypoints)
|
||||
|
||||
# Random Crop
|
||||
height, width = cropped.shape[0:2]
|
||||
center = np.array((width / 2, height / 2))
|
||||
scale = max(height, width) / 200
|
||||
|
||||
aug_rot = 0
|
||||
|
||||
aug_rot = (np.random.random() * 2 - 1) * 30.0
|
||||
aug_scale = np.random.random() * (1.25 - 0.75) + 0.75
|
||||
scale *= aug_scale
|
||||
|
||||
mat_mask = src.utils.img.get_transform(center, scale, (self.output_res, self.output_res), aug_rot)[:2]
|
||||
|
||||
mat = src.utils.img.get_transform(center, scale, (self.input_res, self.input_res), aug_rot)[:2]
|
||||
inp = cv2.warpAffine(cropped, mat, (self.input_res, self.input_res)).astype(np.float32) / 255
|
||||
keypoints[:, :, 0:2] = src.utils.img.kpt_affine(keypoints[:, :, 0:2], mat_mask)
|
||||
if np.random.randint(2) == 0:
|
||||
inp = self.preprocess(inp)
|
||||
inp = inp[:, ::-1]
|
||||
keypoints = keypoints[:, flipped_parts["mpii"]]
|
||||
keypoints[:, :, 0] = self.output_res - keypoints[:, :, 0]
|
||||
orig_keypoints = orig_keypoints[:, flipped_parts["mpii"]]
|
||||
orig_keypoints[:, :, 0] = self.input_res - orig_keypoints[:, :, 0]
|
||||
|
||||
# If keypoint is invisible, set to 0
|
||||
for i in range(np.shape(orig_keypoints)[1]):
|
||||
if kptmp[0, i, 0] == 0 and kptmp[0, i, 1] == 0:
|
||||
keypoints[0, i, 0] = 0
|
||||
keypoints[0, i, 1] = 0
|
||||
orig_keypoints[0, i, 0] = 0
|
||||
orig_keypoints[0, i, 1] = 0
|
||||
|
||||
# Generate target heatmap
|
||||
heatmaps = self.generateHeatmap(keypoints)
|
||||
|
||||
return inp.astype(np.float32), heatmaps.astype(np.float32)
|
||||
|
||||
def preprocess(self, data):
|
||||
"""
|
||||
preprocess images
|
||||
"""
|
||||
# Random hue and saturation
|
||||
data = cv2.cvtColor(data, cv2.COLOR_RGB2HSV)
|
||||
delta = (np.random.random() * 2 - 1) * 0.2
|
||||
data[:, :, 0] = np.mod(data[:, :, 0] + (delta * 360 + 360.0), 360.0)
|
||||
|
||||
delta_sature = np.random.random() + 0.5
|
||||
data[:, :, 1] *= delta_sature
|
||||
data[:, :, 1] = np.maximum(np.minimum(data[:, :, 1], 1), 0)
|
||||
data = cv2.cvtColor(data, cv2.COLOR_HSV2RGB)
|
||||
|
||||
# Random brightness
|
||||
delta = (np.random.random() * 2 - 1) * 0.3
|
||||
data += delta
|
||||
|
||||
# Random contrast
|
||||
mean = data.mean(axis=2, keepdims=True)
|
||||
data = (data - mean) * (np.random.random() + 0.5) + mean
|
||||
data = np.minimum(np.maximum(data, 0), 1)
|
||||
return data
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
MPII dataset loader
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
from imageio import imread
|
||||
|
||||
from src.config import parse_args
|
||||
|
||||
args = parse_args()
|
||||
|
||||
|
||||
class MPII:
|
||||
"""
|
||||
MPII dataset loader
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
print("loading data...")
|
||||
tic = time.time()
|
||||
|
||||
train_f = h5py.File(os.path.join(args.annot_dir, "train.h5"), "r")
|
||||
val_f = h5py.File(os.path.join(args.annot_dir, "valid.h5"), "r")
|
||||
|
||||
self.t_center = train_f["center"][()]
|
||||
t_scale = train_f["scale"][()]
|
||||
t_part = train_f["part"][()]
|
||||
t_visible = train_f["visible"][()]
|
||||
t_normalize = train_f["normalize"][()]
|
||||
t_imgname = [None] * len(self.t_center)
|
||||
for i in range(len(self.t_center)):
|
||||
t_imgname[i] = train_f["imgname"][i].decode("UTF-8")
|
||||
|
||||
self.v_center = val_f["center"][()]
|
||||
v_scale = val_f["scale"][()]
|
||||
v_part = val_f["part"][()]
|
||||
v_visible = val_f["visible"][()]
|
||||
v_normalize = val_f["normalize"][()]
|
||||
v_imgname = [None] * len(self.v_center)
|
||||
for i in range(len(self.v_center)):
|
||||
v_imgname[i] = val_f["imgname"][i].decode("UTF-8")
|
||||
|
||||
self.center = np.append(self.t_center, self.v_center, axis=0)
|
||||
self.scale = np.append(t_scale, v_scale)
|
||||
self.part = np.append(t_part, v_part, axis=0)
|
||||
self.visible = np.append(t_visible, v_visible, axis=0)
|
||||
self.normalize = np.append(t_normalize, v_normalize)
|
||||
self.imgname = t_imgname + v_imgname
|
||||
|
||||
print("Done (t={:0.2f}s)".format(time.time() - tic))
|
||||
|
||||
self.num_examples_train, self.num_examples_val = self.getLength()
|
||||
|
||||
def getLength(self):
|
||||
"""
|
||||
get dataset length
|
||||
"""
|
||||
return len(self.t_center), len(self.v_center)
|
||||
|
||||
def setup_val_split(self):
|
||||
"""
|
||||
get index for train and validation imgs
|
||||
index for validation images starts after that of train images
|
||||
so that loadImage can tell them apart
|
||||
"""
|
||||
valid = [i + self.num_examples_train for i in range(self.num_examples_val)]
|
||||
train = [i for i in range(self.num_examples_train)]
|
||||
return np.array(train), np.array(valid)
|
||||
|
||||
def get_img(self, idx):
|
||||
"""
|
||||
get image
|
||||
"""
|
||||
imgname = self.imgname[idx]
|
||||
path = os.path.join(args.img_dir, imgname)
|
||||
img = imread(path)
|
||||
return img
|
||||
|
||||
def get_path(self, idx):
|
||||
"""
|
||||
get image path
|
||||
"""
|
||||
imgname = self.imgname[idx]
|
||||
path = os.path.join(args.img_dir, imgname)
|
||||
return path
|
||||
|
||||
def get_kps(self, idx):
|
||||
"""
|
||||
get key points
|
||||
"""
|
||||
part = self.part[idx]
|
||||
visible = self.visible[idx]
|
||||
kp2 = np.insert(part, 2, visible, axis=1)
|
||||
kps = np.zeros((1, 16, 3))
|
||||
kps[0] = kp2
|
||||
return kps
|
||||
|
||||
def get_normalized(self, idx):
|
||||
"""
|
||||
get normalized value
|
||||
"""
|
||||
n = self.normalize[idx]
|
||||
return n
|
||||
|
||||
def get_center(self, idx):
|
||||
"""
|
||||
get center of the person
|
||||
"""
|
||||
c = self.center[idx]
|
||||
return c
|
||||
|
||||
def get_scale(self, idx):
|
||||
"""
|
||||
get scale of the person
|
||||
"""
|
||||
s = self.scale[idx]
|
||||
return s
|
||||
|
||||
|
||||
# Part reference
|
||||
parts = {
|
||||
"mpii": [
|
||||
"rank",
|
||||
"rkne",
|
||||
"rhip",
|
||||
"lhip",
|
||||
"lkne",
|
||||
"lank",
|
||||
"pelv",
|
||||
"thrx",
|
||||
"neck",
|
||||
"head",
|
||||
"rwri",
|
||||
"relb",
|
||||
"rsho",
|
||||
"lsho",
|
||||
"lelb",
|
||||
"lwri",
|
||||
]
|
||||
}
|
||||
|
||||
flipped_parts = {"mpii": [5, 4, 3, 2, 1, 0, 6, 7, 8, 9, 15, 14, 13, 12, 11, 10]}
|
||||
|
||||
part_pairs = {"mpii": [[0, 5], [1, 4], [2, 3], [6], [7], [8], [9], [10, 15], [11, 14], [12, 13]]}
|
||||
|
||||
pair_names = {"mpii": ["ankle", "knee", "hip", "pelvis", "thorax", "neck", "head", "wrist", "elbow", "shoulder"]}
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Stacked Hourglass Model
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
from src.models.layers import Conv, ConvBNReLU, Hourglass, Residual
|
||||
|
||||
|
||||
class StackedHourglassNet(nn.Cell):
|
||||
"""
|
||||
Stacked Hourglass Network
|
||||
"""
|
||||
|
||||
def __init__(self, nstack, inp_dim, oup_dim):
|
||||
super(StackedHourglassNet, self).__init__()
|
||||
|
||||
self.nstack = nstack
|
||||
|
||||
self.input_transpose = P.Transpose()
|
||||
|
||||
self.pre = nn.SequentialCell(
|
||||
[
|
||||
ConvBNReLU(3, 64, 7, 2),
|
||||
Residual(64, 128),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Residual(128, 128),
|
||||
Residual(128, inp_dim),
|
||||
]
|
||||
)
|
||||
|
||||
self.hgs = nn.CellList(
|
||||
[nn.SequentialCell([Hourglass(4, inp_dim),]) for i in range(nstack)]
|
||||
)
|
||||
|
||||
self.features = nn.CellList(
|
||||
[nn.SequentialCell([Residual(inp_dim, inp_dim), ConvBNReLU(inp_dim, inp_dim, 1)]) for i in range(nstack)]
|
||||
)
|
||||
|
||||
self.outs = nn.CellList([Conv(inp_dim, oup_dim, 1) for i in range(nstack)])
|
||||
self.merge_features = nn.CellList([Conv(inp_dim, inp_dim, 1) for i in range(nstack - 1)])
|
||||
self.merge_preds = nn.CellList([Conv(oup_dim, inp_dim, 1) for i in range(nstack - 1)])
|
||||
self.output_stack = ops.Stack(axis=1)
|
||||
|
||||
def construct(self, imgs):
|
||||
"""
|
||||
forward
|
||||
"""
|
||||
# x size (batch, 3, 256, 256)
|
||||
x = self.input_transpose(
|
||||
imgs,
|
||||
(
|
||||
0,
|
||||
3,
|
||||
1,
|
||||
2,
|
||||
),
|
||||
)
|
||||
x = self.pre(x)
|
||||
combined_hm_preds = []
|
||||
for i in range(self.nstack):
|
||||
hg = self.hgs[i](x)
|
||||
feature = self.features[i](hg)
|
||||
preds = self.outs[i](feature)
|
||||
combined_hm_preds.append(preds)
|
||||
if i < self.nstack - 1:
|
||||
x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)
|
||||
return self.output_stack(combined_hm_preds)
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
model layers
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Conv(nn.Cell):
|
||||
"""
|
||||
conv 2d
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1):
|
||||
super(Conv, self).__init__()
|
||||
self.inp_dim = inp_dim
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=inp_dim,
|
||||
out_channels=out_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
pad_mode="pad",
|
||||
padding=(kernel_size - 1) // 2,
|
||||
has_bias=True,
|
||||
)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
forward
|
||||
"""
|
||||
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Cell):
|
||||
"""
|
||||
conv 2d with batch normalize and relu
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.inp_dim = inp_dim
|
||||
self.conv = Conv(inp_dim, out_dim, kernel_size, stride)
|
||||
self.relu = nn.ReLU()
|
||||
self.bn = nn.BatchNorm2d(num_features=out_dim, momentum=0.9)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
forward
|
||||
"""
|
||||
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Residual(nn.Cell):
|
||||
"""
|
||||
residual block
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, out_dim):
|
||||
super(Residual, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.bn1 = nn.BatchNorm2d(num_features=inp_dim, momentum=0.9)
|
||||
self.conv1 = Conv(inp_dim, out_dim // 2, 1)
|
||||
self.bn2 = nn.BatchNorm2d(momentum=0.9, num_features=out_dim // 2)
|
||||
self.conv2 = Conv(out_dim // 2, out_dim // 2, 3)
|
||||
self.bn3 = nn.BatchNorm2d(momentum=0.9, num_features=out_dim // 2)
|
||||
self.conv3 = Conv(out_dim // 2, out_dim, 1)
|
||||
self.skip_layer = Conv(inp_dim, out_dim, 1)
|
||||
if inp_dim == out_dim:
|
||||
self.need_skip = False
|
||||
else:
|
||||
self.need_skip = True
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
forward
|
||||
"""
|
||||
if self.need_skip:
|
||||
residual = self.skip_layer(x)
|
||||
else:
|
||||
residual = x
|
||||
out = self.bn1(x)
|
||||
out = self.relu(out)
|
||||
out = self.conv1(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn3(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out += residual
|
||||
return out
|
||||
|
||||
|
||||
class Hourglass(nn.Cell):
|
||||
"""
|
||||
hourglass module
|
||||
"""
|
||||
|
||||
def __init__(self, n, f):
|
||||
super(Hourglass, self).__init__()
|
||||
self.up1 = Residual(f, f)
|
||||
# Down sampling
|
||||
self.pool1 = nn.MaxPool2d(2, 2)
|
||||
self.low1 = Residual(f, f)
|
||||
self.n = n
|
||||
# Use Hourglass recursively
|
||||
if self.n > 1:
|
||||
self.low2 = Hourglass(n - 1, f)
|
||||
else:
|
||||
self.low2 = Residual(f, f)
|
||||
self.low3 = Residual(f, f)
|
||||
|
||||
# Set upsample size
|
||||
sz = [0, 8, 16, 32, 64]
|
||||
self.up2 = ops.ResizeNearestNeighbor((sz[n], sz[n]))
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
forward
|
||||
"""
|
||||
|
||||
up1 = self.up1(x)
|
||||
pool1 = self.pool1(x)
|
||||
low1 = self.low1(pool1)
|
||||
low2 = self.low2(low1)
|
||||
low3 = self.low3(low2)
|
||||
up2 = self.up2(low3)
|
||||
return up1 + up2
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
define heatmap loss
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
|
||||
class HeatmapLoss(nn.Cell):
|
||||
"""
|
||||
loss for detection heatmap
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(HeatmapLoss, self).__init__()
|
||||
self.loss_function = nn.MSELoss()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, pred, gt):
|
||||
"""
|
||||
calculate loss
|
||||
"""
|
||||
# pred size (batch, 8, 16, 64, 64), gt size (batch, 16, 16, 64)
|
||||
# Use broadcast to calculate loss
|
||||
pred_t = self.transpose(pred, (1, 0, 2, 3, 4))
|
||||
loss = self.loss_function(pred_t, gt)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
general image utils
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_transform(center, scale, res, rot=0):
|
||||
"""
|
||||
generate trainsform matrix
|
||||
"""
|
||||
h = 200 * scale
|
||||
t = np.zeros((3, 3))
|
||||
t[0, 0] = float(res[1]) / h
|
||||
t[1, 1] = float(res[0]) / h
|
||||
t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
|
||||
t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
|
||||
t[2, 2] = 1
|
||||
if rot != 0:
|
||||
rot = -rot # To match direction of rotation from cropping
|
||||
rot_mat = np.zeros((3, 3))
|
||||
rot_rad = rot * np.pi / 180
|
||||
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||
rot_mat[0, :2] = [cs, -sn]
|
||||
rot_mat[1, :2] = [sn, cs]
|
||||
rot_mat[2, 2] = 1
|
||||
# Need to rotate around center
|
||||
t_mat = np.eye(3)
|
||||
t_mat[0, 2] = -res[1] / 2
|
||||
t_mat[1, 2] = -res[0] / 2
|
||||
t_inv = t_mat.copy()
|
||||
t_inv[:2, 2] *= -1
|
||||
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
||||
return t
|
||||
|
||||
|
||||
def transform(pt, center, scale, res, invert=0, rot=0):
|
||||
"""
|
||||
transform points
|
||||
"""
|
||||
t = get_transform(center, scale, res, rot=rot)
|
||||
if invert:
|
||||
t = np.linalg.inv(t)
|
||||
new_pt = np.array([pt[0], pt[1], 1.0]).T
|
||||
new_pt = np.dot(t, new_pt)
|
||||
return new_pt[:2].astype(int)
|
||||
|
||||
|
||||
def crop(img, center, scale, res):
|
||||
"""
|
||||
crop images
|
||||
"""
|
||||
# Left up
|
||||
ul = np.array(transform([0, 0], center, scale, res, invert=1))
|
||||
# Right down
|
||||
br = np.array(transform(res, center, scale, res, invert=1))
|
||||
|
||||
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||
if len(img.shape) > 2:
|
||||
new_shape += [img.shape[2]]
|
||||
new_img = np.zeros(new_shape)
|
||||
|
||||
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
||||
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
||||
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
||||
old_y = max(0, ul[1]), min(len(img), br[1])
|
||||
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
|
||||
|
||||
return cv2.resize(new_img, res)
|
||||
|
||||
|
||||
def inv_mat(mat):
|
||||
"""
|
||||
get invert matrix
|
||||
"""
|
||||
ans = np.linalg.pinv(np.array(mat).tolist() + [[0, 0, 1]])
|
||||
return ans[:2]
|
||||
|
||||
|
||||
def kpt_affine(kpt, mat):
|
||||
"""
|
||||
get key point affine
|
||||
"""
|
||||
kpt = np.array(kpt)
|
||||
shape = kpt.shape
|
||||
kpt = kpt.reshape(-1, 2)
|
||||
return np.dot(np.concatenate((kpt, kpt[:, 0:1] * 0 + 1), axis=1), mat.T).reshape(shape)
|
||||
|
||||
|
||||
def resize(im, res):
|
||||
"""
|
||||
resize image
|
||||
"""
|
||||
return np.array([cv2.resize(im[i], res) for i in range(im.shape[0])])
|
|
@ -0,0 +1,373 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
inference module
|
||||
"""
|
||||
import copy
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
import src.dataset.MPIIDataLoader as ds
|
||||
import src.utils.img
|
||||
from src.config import parse_args
|
||||
|
||||
args = parse_args()
|
||||
|
||||
|
||||
def match_format(dic):
|
||||
"""
|
||||
get match format
|
||||
"""
|
||||
loc = dic["loc_k"][0, :, 0, :]
|
||||
val = dic["val_k"][0, :, :]
|
||||
ans = np.hstack((loc, val))
|
||||
ans = np.expand_dims(ans, axis=0)
|
||||
ret = []
|
||||
ret.append(ans)
|
||||
return ret
|
||||
|
||||
|
||||
class MaxPool2dFilter(nn.Cell):
|
||||
"""
|
||||
maxpool 2d for filter
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MaxPool2dFilter, self).__init__()
|
||||
self.pool = nn.MaxPool2d(3, 1, "same")
|
||||
self.eq = ops.Equal()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
forward
|
||||
"""
|
||||
maxm = self.pool(x)
|
||||
return self.eq(maxm, x)
|
||||
|
||||
|
||||
class HeatmapParser:
|
||||
"""
|
||||
parse heatmap
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.maxpool2dfilter = MaxPool2dFilter().to_float(mindspore.float16) # avoid float error
|
||||
self.topk = ops.TopK(sorted=True)
|
||||
self.stack = ops.Stack(axis=3)
|
||||
|
||||
def nms(self, det):
|
||||
"""
|
||||
keep max in 3x3
|
||||
"""
|
||||
maxm = self.maxpool2dfilter(det)
|
||||
det = det * maxm
|
||||
return det
|
||||
|
||||
def calc(self, det):
|
||||
"""
|
||||
calc distance
|
||||
"""
|
||||
det = self.nms(det)
|
||||
w = det.shape[3]
|
||||
det = det.view(det.shape[0], det.shape[1], -1)
|
||||
val_k, ind = self.topk(det, 1)
|
||||
|
||||
x = ind % w
|
||||
y = (ind.astype(mindspore.float32) / w).astype(mindspore.int32)
|
||||
ind_k = self.stack((x, y))
|
||||
ans = {"loc_k": ind_k, "val_k": val_k}
|
||||
return {key: ans[key].asnumpy() for key in ans}
|
||||
|
||||
def adjust(self, ans, det):
|
||||
"""
|
||||
adjust for joint
|
||||
"""
|
||||
for people in ans:
|
||||
for i in people:
|
||||
for joint_id, joint in enumerate(i):
|
||||
if joint[2] > 0:
|
||||
y, x = joint[0:2]
|
||||
xx, yy = int(x), int(y)
|
||||
tmp = det[0][joint_id]
|
||||
if tmp[xx, min(yy + 1, tmp.shape[1] - 1)] > tmp[xx, max(yy - 1, 0)]:
|
||||
y += 0.25
|
||||
else:
|
||||
y -= 0.25
|
||||
|
||||
if tmp[min(xx + 1, tmp.shape[0] - 1), yy] > tmp[max(0, xx - 1), yy]:
|
||||
x += 0.25
|
||||
else:
|
||||
x -= 0.25
|
||||
ans[0][0, joint_id, 0:2] = (y + 0.5, x + 0.5)
|
||||
return ans
|
||||
|
||||
def parse(self, det, adjust=True):
|
||||
"""
|
||||
parse heatmap
|
||||
"""
|
||||
ans = match_format(self.calc(det))
|
||||
if adjust:
|
||||
ans = self.adjust(ans, det)
|
||||
return ans
|
||||
|
||||
|
||||
parser = HeatmapParser()
|
||||
|
||||
|
||||
def post_process(det, mat_, trainval, c=None, s=None, resolution=None):
|
||||
"""
|
||||
post process for parser
|
||||
"""
|
||||
mat = np.linalg.pinv(np.array(mat_).tolist() + [[0, 0, 1]])[:2]
|
||||
cropped_preds = parser.parse(mindspore.Tensor(np.float32([det])))[0]
|
||||
|
||||
if cropped_preds.size > 0:
|
||||
cropped_preds[:, :, :2] = src.utils.img.kpt_affine(cropped_preds[:, :, :2] * 4, mat) # size 1x16x3
|
||||
|
||||
preds = np.copy(cropped_preds)
|
||||
# revert to origin image
|
||||
if trainval != "cropped":
|
||||
for j in range(preds.shape[1]):
|
||||
preds[0, j, :2] = src.utils.img.transform(preds[0, j, :2], c, s, resolution, invert=1)
|
||||
return preds
|
||||
|
||||
|
||||
def inference(img, net, c, s):
|
||||
"""
|
||||
forward pass at test time
|
||||
calls post_process to post process results
|
||||
"""
|
||||
|
||||
height, width = img.shape[0:2]
|
||||
center = (width / 2, height / 2)
|
||||
scale = max(height, width) / 200
|
||||
res = (args.input_res, args.input_res)
|
||||
|
||||
mat_ = src.utils.img.get_transform(center, scale, res)[:2]
|
||||
inp = img / 255
|
||||
|
||||
tmp1 = net(mindspore.Tensor([inp], dtype=mindspore.float32)).asnumpy()
|
||||
tmp2 = net(mindspore.Tensor([inp[:, ::-1]], dtype=mindspore.float32)).asnumpy()
|
||||
|
||||
tmp = np.concatenate((tmp1, tmp2), axis=0)
|
||||
|
||||
det = tmp[0, -1] + tmp[1, -1, :, :, ::-1][ds.flipped_parts["mpii"]]
|
||||
|
||||
if det is None:
|
||||
return [], []
|
||||
det = det / 2
|
||||
|
||||
det = np.minimum(det, 1)
|
||||
|
||||
return post_process(det, mat_, "valid", c, s, res)
|
||||
|
||||
|
||||
class MPIIEval:
|
||||
"""
|
||||
eval for MPII dataset
|
||||
"""
|
||||
|
||||
template = {
|
||||
"all": {
|
||||
"total": 0,
|
||||
"ankle": 0,
|
||||
"knee": 0,
|
||||
"hip": 0,
|
||||
"pelvis": 0,
|
||||
"thorax": 0,
|
||||
"neck": 0,
|
||||
"head": 0,
|
||||
"wrist": 0,
|
||||
"elbow": 0,
|
||||
"shoulder": 0,
|
||||
},
|
||||
"visible": {
|
||||
"total": 0,
|
||||
"ankle": 0,
|
||||
"knee": 0,
|
||||
"hip": 0,
|
||||
"pelvis": 0,
|
||||
"thorax": 0,
|
||||
"neck": 0,
|
||||
"head": 0,
|
||||
"wrist": 0,
|
||||
"elbow": 0,
|
||||
"shoulder": 0,
|
||||
},
|
||||
"not visible": {
|
||||
"total": 0,
|
||||
"ankle": 0,
|
||||
"knee": 0,
|
||||
"hip": 0,
|
||||
"pelvis": 0,
|
||||
"thorax": 0,
|
||||
"neck": 0,
|
||||
"head": 0,
|
||||
"wrist": 0,
|
||||
"elbow": 0,
|
||||
"shoulder": 0,
|
||||
},
|
||||
}
|
||||
|
||||
joint_map = [
|
||||
"ankle",
|
||||
"knee",
|
||||
"hip",
|
||||
"hip",
|
||||
"knee",
|
||||
"ankle",
|
||||
"pelvis",
|
||||
"thorax",
|
||||
"neck",
|
||||
"head",
|
||||
"wrist",
|
||||
"elbow",
|
||||
"shoulder",
|
||||
"shoulder",
|
||||
"elbow",
|
||||
"wrist",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.correct = copy.deepcopy(self.template)
|
||||
self.count = copy.deepcopy(self.template)
|
||||
self.correct_train = copy.deepcopy(self.template)
|
||||
self.count_train = copy.deepcopy(self.template)
|
||||
|
||||
def eval(self, pred, gt, normalizing, num_train, bound=0.5):
|
||||
"""
|
||||
use PCK with threshold of .5 of normalized distance (presumably head size)
|
||||
"""
|
||||
idx = 0
|
||||
for p, g, normalize in zip(pred, gt, normalizing):
|
||||
for j in range(g.shape[1]):
|
||||
vis = "visible"
|
||||
if g[0, j, 0] == 0: # Not in image
|
||||
continue
|
||||
if g[0, j, 2] == 0:
|
||||
vis = "not visible"
|
||||
joint = self.joint_map[j]
|
||||
|
||||
if idx >= num_train:
|
||||
self.count["all"]["total"] += 1
|
||||
self.count["all"][joint] += 1
|
||||
self.count[vis]["total"] += 1
|
||||
self.count[vis][joint] += 1
|
||||
else:
|
||||
self.count_train["all"]["total"] += 1
|
||||
self.count_train["all"][joint] += 1
|
||||
self.count_train[vis]["total"] += 1
|
||||
self.count_train[vis][joint] += 1
|
||||
error = np.linalg.norm(p[0]["keypoints"][j, :2] - g[0, j, :2]) / normalize
|
||||
if idx >= num_train:
|
||||
if bound > error:
|
||||
self.correct["all"]["total"] += 1
|
||||
self.correct["all"][joint] += 1
|
||||
self.correct[vis]["total"] += 1
|
||||
self.correct[vis][joint] += 1
|
||||
else:
|
||||
if bound > error:
|
||||
self.correct_train["all"]["total"] += 1
|
||||
self.correct_train["all"][joint] += 1
|
||||
self.correct_train[vis]["total"] += 1
|
||||
self.correct_train[vis][joint] += 1
|
||||
idx += 1
|
||||
|
||||
self.output_result(bound)
|
||||
|
||||
def output_result(self, bound):
|
||||
"""
|
||||
output split via train/valid
|
||||
"""
|
||||
for k in self.correct:
|
||||
print(k, ":")
|
||||
for key in self.correct[k]:
|
||||
print(
|
||||
"Val PCK @,",
|
||||
bound,
|
||||
",",
|
||||
key,
|
||||
":",
|
||||
round(self.correct[k][key] / max(self.count[k][key], 1), 3),
|
||||
", count:",
|
||||
self.count[k][key],
|
||||
)
|
||||
print(
|
||||
"Tra PCK @,",
|
||||
bound,
|
||||
",",
|
||||
key,
|
||||
":",
|
||||
round(self.correct_train[k][key] / max(self.count_train[k][key], 1), 3),
|
||||
", count:",
|
||||
self.count_train[k][key],
|
||||
)
|
||||
print("\n")
|
||||
|
||||
|
||||
def get_img(num_eval=2958, num_train=300):
|
||||
"""
|
||||
load validation and training images
|
||||
"""
|
||||
input_res = args.input_res
|
||||
val_f = h5py.File(os.path.join(args.annot_dir, "valid.h5"), "r")
|
||||
|
||||
tr = tqdm.tqdm(range(0, num_train), total=num_train)
|
||||
# Train set
|
||||
train_f = h5py.File(os.path.join(args.annot_dir, "train.h5"), "r")
|
||||
for i in tr:
|
||||
path_t = "%s/%s" % (args.img_dir, train_f["imgname"][i].decode("UTF-8"))
|
||||
|
||||
orig_img = cv2.imread(path_t)[:, :, ::-1]
|
||||
c = train_f["center"][i]
|
||||
s = train_f["scale"][i]
|
||||
im = src.utils.img.crop(orig_img, c, s, (input_res, input_res))
|
||||
|
||||
kp = train_f["part"][i]
|
||||
vis = train_f["visible"][i]
|
||||
kp2 = np.insert(kp, 2, vis, axis=1)
|
||||
kps = np.zeros((1, 16, 3))
|
||||
kps[0] = kp2
|
||||
|
||||
n = train_f["normalize"][i]
|
||||
|
||||
yield kps, im, c, s, n
|
||||
|
||||
tr2 = tqdm.tqdm(range(0, num_eval), total=num_eval)
|
||||
# Valid
|
||||
for i in tr2:
|
||||
path_t = "%s/%s" % (args.img_dir, val_f["imgname"][i].decode("UTF-8"))
|
||||
|
||||
orig_img = cv2.imread(path_t)[:, :, ::-1]
|
||||
c = val_f["center"][i]
|
||||
s = val_f["scale"][i]
|
||||
im = src.utils.img.crop(orig_img, c, s, (input_res, input_res))
|
||||
|
||||
kp = val_f["part"][i]
|
||||
vis = val_f["visible"][i]
|
||||
kp2 = np.insert(kp, 2, vis, axis=1)
|
||||
kps = np.zeros((1, 16, 3))
|
||||
kps[0] = kp2
|
||||
|
||||
n = val_f["normalize"][i]
|
||||
|
||||
yield kps, im, c, s, n
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
run model train
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Model, context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import get_group_size, get_rank, init
|
||||
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
|
||||
ModelCheckpoint, TimeMonitor)
|
||||
|
||||
from src.config import parse_args
|
||||
from src.dataset.DatasetGenerator import DatasetGenerator
|
||||
from src.dataset.MPIIDataLoader import MPII
|
||||
from src.models.loss import HeatmapLoss
|
||||
from src.models.StackedHourglassNet import StackedHourglassNet
|
||||
|
||||
set_seed(1)
|
||||
|
||||
args = parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.path.exists(args.img_dir) or not os.path.exists(args.annot_dir):
|
||||
print("Dataset not found.")
|
||||
exit()
|
||||
|
||||
# Set context mode
|
||||
if args.context_mode == "GRAPH":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
|
||||
else:
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
|
||||
|
||||
if args.parallel:
|
||||
# Parallel mode
|
||||
context.reset_auto_parallel_context()
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=context.ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
args.rank_id = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
else:
|
||||
args.rank_id = 0
|
||||
args.group_size = 1
|
||||
|
||||
net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim)
|
||||
|
||||
# Process dataset
|
||||
mpii = MPII()
|
||||
train, valid = mpii.setup_val_split()
|
||||
|
||||
train_generator = DatasetGenerator(args.input_res, args.output_res, mpii, train)
|
||||
train_size = len(train_generator)
|
||||
train_sampler = ds.DistributedSampler(num_shards=args.group_size, shard_id=args.rank_id, shuffle=True)
|
||||
train_data = ds.GeneratorDataset(train_generator, ["data", "label"], sampler=train_sampler)
|
||||
train_data = train_data.batch(args.batch_size, True, args.group_size)
|
||||
|
||||
print("train data size:", train_size)
|
||||
step_per_epoch = math.ceil(train_size / args.batch_size / args.group_size)
|
||||
|
||||
# Define loss function
|
||||
loss_func = HeatmapLoss()
|
||||
# Define optimizer
|
||||
lr_decay = nn.exponential_decay_lr(
|
||||
args.initial_lr, args.decay_rate, args.num_epoch * step_per_epoch, step_per_epoch, args.decay_epoch
|
||||
)
|
||||
optimizer = nn.Adam(net.trainable_params(), lr_decay)
|
||||
|
||||
# Define model
|
||||
model = Model(net, loss_func, optimizer, amp_level=args.amp_level, keep_batchnorm_fp32=False)
|
||||
|
||||
# Define callback functions
|
||||
callbacks = []
|
||||
callbacks.append(LossMonitor(args.loss_log_interval))
|
||||
callbacks.append(TimeMonitor(train_size))
|
||||
|
||||
# Save checkpoint file
|
||||
if args.rank_id == 0:
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=args.save_checkpoint_epochs * step_per_epoch,
|
||||
keep_checkpoint_max=args.keep_checkpoint_max,
|
||||
)
|
||||
ckpoint = ModelCheckpoint("ckpt", config=config_ck)
|
||||
callbacks.append(ckpoint)
|
||||
|
||||
model.train(args.num_epoch, train_data, callbacks=callbacks, dataset_sink_mode=True)
|
Loading…
Reference in New Issue