diff --git a/model_zoo/research/cv/StackedHourglass/README_CN.md b/model_zoo/research/cv/StackedHourglass/README_CN.md new file mode 100644 index 00000000000..72a526ed260 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/README_CN.md @@ -0,0 +1,193 @@ +# 目录 + + + +- [目录](#目录) +- [Stacked Hourglass 描述](#stacked-hourglass-描述) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [脚本说明](#脚本说明) + - [脚本和样例代码](#脚本和样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [运行](#运行) + - [结果](#结果) + - [评估过程](#评估过程) + - [运行](#运行-1) + - [结果](#结果-1) + - [导出](#导出) +- [模型说明](#模型说明) + - [训练性能(2HG)](#训练性能2hg) +- [随机情况的描述](#随机情况的描述) +- [ModelZoo](#modelzoo) + + + +# Stacked Hourglass 描述 + +Stacked Hourglass 是一个用于人体姿态检测的模型,它采用堆叠的 hourglass 模块进行特征的提取,并在最终通过热力图输出模型对于每个特征点的预测位置。 + +[论文:Stacked Hourglass Networks for Human Pose Estimation](https://arxiv.org/abs/1603.06937v2) + +# 数据集 + +使用的数据集:[MPII Human Pose Dataset](http://human-pose.mpi-inf.mpg.de/) + +- 数据集大小: + - 训练: 22246 张图片 + - 测试: 2958 张图片 + - 关键点数量:16 个(头部、颈部、肩部、肘部、手腕、胸部、骨盆、臀部、膝盖、脚踝) + +> 注:MPII 数据集中原始的 annot 为 .mat 格式,处理困难,请下载使用另一个 annot:[https://github.com/princeton-vl/pytorch_stacked_hourglass/tree/master/data/MPII/annot](https://github.com/princeton-vl/pytorch_stacked_hourglass/tree/master/data/MPII/annot) + +# 环境要求 + +- 硬件(Ascend) + - 使用 Ascend 来搭建硬件环境。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install) +- 如需查看详情,请参见如下资源: + - [MindSpore 教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) + +# 脚本说明 + +## 脚本和样例代码 + +```text +├──scripts +│ ├──run_distribute_train.sh # 分布式训练脚本 +│ ├──run_eval.sh # 评估脚本 +│ └──run_standalone_train.sh # 单卡训练脚本 +├──src +│ ├──dataset +│ │ ├──DatasetGenerator.py # 数据集定义及标注热力图生成 +│ │ └──MPIIDataLoader.py # MPII 数据的加载及预处理 +│ ├──models +│ │ ├──layers.py # 网络子模块定义 +│ │ ├──loss.py # HeatMap Loss 定义 +│ │ └──StackedHourglassNet.py # 整体网络定义 +│ ├──utils +│ │ ├──img.py # 通用的图像处理模块 +│ │ └──inference.py # 推理相关的函数,包含了推理的准确率计算等 +│ └── config.py # 参数配置 +├── eval.py # 评估脚本 +├── export.py # 导出脚本 +├── README_CN.md # 项目相关描述 +└── train.py # 训练脚本 +``` + +## 脚本参数 + +模型训练和评估过程中使用的参数可以在 config.py 中设置,也可以在运行时通过命令行参数给入。 + +## 训练过程 + +### 运行 + +单卡训练时,首先需要设置目标卡: `export DEVICE_ID=x` ,其中 `x` 为目标卡的 ID 。接下来启动训练: + +```sh +python train.py +``` + +或者可以使用单卡训练脚本: + +```sh +./scripts/run_standalone_train.sh [设备 ID] [标注路径] [图像路径] +``` + +多卡训练时可以使用多卡训练脚本: + +```sh +./scripts/run_distribute_train.sh [配置文件路径] [Ascend 卡数量] [标注路径] [图像路径] +``` + +### 结果 + +ckpt 文件将存储在当前路径下,训练结果默认输出至 `loss.txt` 中,而错误和提示信息在 `err.txt` 中,示例如下: + +```text +loading data... +Done (t=14.61s) +train data size: 22246 +epoch: 1 step: 695, loss is 0.00068435294 +epoch time: 954584.373 ms, per step time: 1373.503 ms +epoch: 2 step: 695, loss is 0.00067576126 +epoch time: 755549.341 ms, per step time: 1087.121 ms +epoch: 3 step: 695, loss is 0.00057179347 +epoch time: 750856.373 ms, per step time: 1080.369 ms +epoch: 4 step: 695, loss is 0.00055218843 + +[...] +``` + +## 评估过程 + +### 运行 + +在运行评估前需要指定目标卡: `export DEVICE_ID=x` ,其中 `x` 为目标卡的 ID 。接下来使用 python 启动评估,需要指定 ckpt 文件的路径。 + +```sh +python eval.py --ckpt_file +``` + +也可以使用验证脚本: + +```sh +./scripts/run_eval.sh [设备 ID] [ckpt 文件路径] [标注路径] [图像路径] +``` + +### 结果 + +验证结果默认输出至 `result.txt` 中,而错误和提示信息在 `err.txt` 中。 + +```text +all : +Val PCK @, 0.5 , total : 0.882 , count: 44239 +Tra PCK @, 0.5 , total : 0.938 , count: 4443 +Val PCK @, 0.5 , ankle : 0.765 , count: 4234 +Tra PCK @, 0.5 , ankle : 0.847 , count: 392 +Val PCK @, 0.5 , knee : 0.819 , count: 4963 +Tra PCK @, 0.5 , knee : 0.91 , count: 499 +Val PCK @, 0.5 , hip : 0.871 , count: 5777 +Tra PCK @, 0.5 , hip : 0.918 , count: 587 + +[...] +``` + +## 导出 + +可以使用 `export.py` 脚本进行模型导出,使用方法为: + +```sh +python export.py --ckpt_file [ckpt 文件路径] +``` + +# 模型说明 + +## 训练性能(2HG) + +| 参数 | Ascend | +| ---------------- | -------------------------- | +| 模型名称 | Stacked Hourglass Networks | +| 运行环境 | Ascend 910A | +| 上传时间 | 2021-7-5 | +| MindSpore 版本 | 1.2.0 | +| 数据集 | MPII Human Pose Dataset | +| 训练参数 | 详见 config.py | +| 优化器 | Adam (带指数学习率衰减) | +| 损失函数 | HeatMap Loss (类 MSE) | +| 最终损失 | 0.00036373272 | +| 精确度 | 88.2% | +| 训练总时间(1p) | 20h | +| 评估总时间(1p) | 21min | +| 参数量 | 8429088 | + +# 随机情况的描述 + +我们在 `train.py` 脚本中设置了随机种子。 + +# ModelZoo + +请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。 \ No newline at end of file diff --git a/model_zoo/research/cv/StackedHourglass/eval.py b/model_zoo/research/cv/StackedHourglass/eval.py new file mode 100644 index 00000000000..0f75ce3a921 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/eval.py @@ -0,0 +1,68 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +run model eval +""" +import os + +from mindspore import context, load_checkpoint, load_param_into_net + +from src.config import parse_args +from src.models.StackedHourglassNet import StackedHourglassNet +from src.utils.inference import MPIIEval, get_img, inference + +args = parse_args() + +if __name__ == "__main__": + if not os.path.exists(args.ckpt_file): + print("ckpt file not valid") + exit() + + if not os.path.exists(args.img_dir) or not os.path.exists(args.annot_dir): + print("Dataset not found.") + exit() + + # Set context mode + if args.context_mode == "GRAPH": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False) + else: + context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target) + + # Import net + net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim) + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(net, param_dict) + + gts = [] + preds = [] + normalizing = [] + + num_eval = args.num_eval + num_train = args.train_num_eval + for anns, img, c, s, n in get_img(num_eval, num_train): + gts.append(anns) + ans = inference(img, net, c, s) + if ans.size > 0: + ans = ans[:, :, :3] + + # (num preds, joints, x/y/visible) + pred = [] + for i in range(ans.shape[0]): + pred.append({"keypoints": ans[i, :, :]}) + preds.append(pred) + normalizing.append(n) + + mpii_eval = MPIIEval() + mpii_eval.eval(preds, gts, normalizing, num_train) diff --git a/model_zoo/research/cv/StackedHourglass/export.py b/model_zoo/research/cv/StackedHourglass/export.py new file mode 100644 index 00000000000..0b91999ade8 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/export.py @@ -0,0 +1,45 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +export model +""" +import os + +import numpy as np +from mindspore import Tensor, context, export, load_checkpoint, load_param_into_net + +from src.config import parse_args +from src.models.StackedHourglassNet import StackedHourglassNet + +args = parse_args() + +if __name__ == "__main__": + if not os.path.exists(args.ckpt_file): + print("ckpt file not valid") + exit() + + # Set context mode + if args.context_mode == "GRAPH": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False) + else: + context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target) + + # Import net + net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim) + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.zeros([args.batch_size, args.input_res, args.input_res, 3], np.float32)) + export(net, input_arr, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/research/cv/StackedHourglass/requirements.txt b/model_zoo/research/cv/StackedHourglass/requirements.txt new file mode 100644 index 00000000000..0218ac509a0 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/requirements.txt @@ -0,0 +1,23 @@ +asttokens==2.0.5 +astunparse==1.6.3 +attrs==21.2.0 +cached-property==1.5.2 +certifi==2018.10.15 +cffi==1.14.5 +decorator==5.0.9 +easydict==1.9 +h5py==3.3.0 +imageio==2.9.0 +mpmath==1.2.1 +numpy==1.21.0 +opencv-python==4.5.2.54 +packaging==21.0 +Pillow==8.3.1 +protobuf==3.17.3 +psutil==5.8.0 +pycparser==2.20 +pyparsing==2.4.7 +scipy==1.7.0 +six==1.16.0 +sympy==1.8 +tqdm==4.61.2 diff --git a/model_zoo/research/cv/StackedHourglass/scripts/run_distribute_train.sh b/model_zoo/research/cv/StackedHourglass/scripts/run_distribute_train.sh new file mode 100755 index 00000000000..03b2ef21b7c --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/scripts/run_distribute_train.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 4 ]; then + echo "Usage: ./scripts/run_distribute_train.sh [rank file path] [rank size] [annot path] [image path]" + exit 1 +fi + +export RANK_TABLE_FILE="$1" +export RANK_SIZE="$2" +ANNOT_DIR=$3 +IMG_DIR=$4 + +CURDIR=$(pwd) + +for ((i = 0; i < $2; i++)) +do + export DEVICE_ID="$i" + export RANK_ID="$i" + + cd $CURDIR + TARGET="./Train${DEVICE_ID}" + rm -rf $TARGET + mkdir $TARGET + cp *.py $TARGET + cp -r src $TARGET + cd $TARGET + + echo "training for rank $i" + nohup python train.py \ + --annot_dir=$ANNOT_DIR \ + --img_dir=$IMG_DIR \ + --parallel True > loss.txt 2> err.txt & +done \ No newline at end of file diff --git a/model_zoo/research/cv/StackedHourglass/scripts/run_eval.sh b/model_zoo/research/cv/StackedHourglass/scripts/run_eval.sh new file mode 100755 index 00000000000..72c3b42f9d5 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/scripts/run_eval.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 4 ]; then + echo "Usage: ./scripts/run_eval.sh [device ID] [ckpt path] [annot path] [image path]" + exit 1 +fi + +export DEVICE_ID=$1 +PATH_CHECKPOINT=$2 +ANNOT_DIR=$3 +IMG_DIR=$4 + +TARGET="./Eval" + +rm -rf $TARGET +mkdir $TARGET +cp *.py $TARGET +cp -r src $TARGET + +cd $TARGET + +nohup python eval.py \ + --ckpt_file=$PATH_CHECKPOINT \ + --annot_dir=$ANNOT_DIR \ + --img_dir=$IMG_DIR > result.txt 2> err.txt & diff --git a/model_zoo/research/cv/StackedHourglass/scripts/run_standalone_train.sh b/model_zoo/research/cv/StackedHourglass/scripts/run_standalone_train.sh new file mode 100755 index 00000000000..17403f7eac1 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/scripts/run_standalone_train.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ]; then + echo "Usage: ./scripts/run_standalone_train.sh [device ID] [annot path] [image path]" + exit 1 +fi + +export DEVICE_ID=$1 +ANNOT_DIR=$2 +IMG_DIR=$3 + +TARGET="./Train" + +rm -rf $TARGET +mkdir $TARGET +cp *.py $TARGET +cp -r src $TARGET + +cd $TARGET + +nohup python train.py \ + --annot_dir=$ANNOT_DIR \ + --img_dir=$IMG_DIR > loss.txt 2> err.txt & diff --git a/model_zoo/research/cv/StackedHourglass/src/config.py b/model_zoo/research/cv/StackedHourglass/src/config.py new file mode 100644 index 00000000000..76720e16857 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/src/config.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +model config +""" +import argparse +import ast + + +def parse_args(): + """ + parse arguments + """ + parser = argparse.ArgumentParser(description="MindSpore Stacked Hourglass") + + # Model + parser.add_argument("--nstack", type=int, default=2) + parser.add_argument("--inp_dim", type=int, default=256) + parser.add_argument("--oup_dim", type=int, default=16) + parser.add_argument("--input_res", type=int, default=256) + parser.add_argument("--output_res", type=int, default=64) + parser.add_argument("--annot_dir", type=str, default="./MPII/annot") + parser.add_argument("--img_dir", type=str, default="./MPII/images") + # Context + parser.add_argument("--context_mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"]) + parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "CPU"]) + # Train + parser.add_argument("--parallel", type=ast.literal_eval, default=False) + parser.add_argument("--amp_level", type=str, default="O2", choices=["O0", "O1", "O2", "O3"]) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_epoch", type=int, default=100) + parser.add_argument("--save_checkpoint_epochs", type=int, default=5) + parser.add_argument("--keep_checkpoint_max", type=int, default=20) + parser.add_argument("--loss_log_interval", type=int, default=1) + parser.add_argument("--initial_lr", type=float, default=1e-3) + parser.add_argument("--decay_rate", type=float, default=0.985) + parser.add_argument("--decay_epoch", type=int, default=1) + # Valid + parser.add_argument("--num_eval", type=int, default=2958) + parser.add_argument("--train_num_eval", type=int, default=300) + parser.add_argument("--ckpt_file", type=str, default="") + # Export + parser.add_argument("--file_name", type=str, default="stackedhourglass") + parser.add_argument("--file_format", type=str, default="MINDIR") + + args = parser.parse_known_args()[0] + + return args diff --git a/model_zoo/research/cv/StackedHourglass/src/dataset/DatasetGenerator.py b/model_zoo/research/cv/StackedHourglass/src/dataset/DatasetGenerator.py new file mode 100644 index 00000000000..94228f07bd5 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/src/dataset/DatasetGenerator.py @@ -0,0 +1,161 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +dataset classes +""" +import cv2 +import numpy as np + +import src.utils.img +from src.dataset.MPIIDataLoader import flipped_parts + + +class GenerateHeatmap: + """ + get train target heatmap + """ + + def __init__(self, output_res, num_parts): + self.output_res = output_res + self.num_parts = num_parts + sigma = self.output_res / 64 + self.sigma = sigma + size = 6 * sigma + 3 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0, y0 = 3 * sigma + 1, 3 * sigma + 1 + self.g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + + def __call__(self, keypoints): + hms = np.zeros(shape=(self.num_parts, self.output_res, self.output_res), dtype=np.float32) + sigma = self.sigma + for p in keypoints: + for idx, pt in enumerate(p): + if pt[0] > 0: + x, y = int(pt[0]), int(pt[1]) + if x < 0 or y < 0 or x >= self.output_res or y >= self.output_res: + continue + ul = int(x - 3 * sigma - 1), int(y - 3 * sigma - 1) + br = int(x + 3 * sigma + 2), int(y + 3 * sigma + 2) + + c, d = max(0, -ul[0]), min(br[0], self.output_res) - ul[0] + a, b = max(0, -ul[1]), min(br[1], self.output_res) - ul[1] + + cc, dd = max(0, ul[0]), min(br[0], self.output_res) + aa, bb = max(0, ul[1]), min(br[1], self.output_res) + hms[idx, aa:bb, cc:dd] = np.maximum(hms[idx, aa:bb, cc:dd], self.g[a:b, c:d]) + return hms + + +class DatasetGenerator: + """ + mindspore general dataset generator + """ + + def __init__(self, input_res, output_res, ds, index): + self.input_res = input_res + self.output_res = output_res + self.generateHeatmap = GenerateHeatmap(self.output_res, 16) + self.ds = ds + self.index = index + + def __len__(self): + return len(self.index) + + def __getitem__(self, idx): + # print(f"loading...{idx}") + return self.loadImage(self.index[idx]) + + def loadImage(self, idx): + """ + load and preprocess image + """ + ds = self.ds + + # Load + Crop + orig_img = ds.get_img(idx) + orig_keypoints = ds.get_kps(idx) + kptmp = orig_keypoints.copy() + c = ds.get_center(idx) + s = ds.get_scale(idx) + + cropped = src.utils.img.crop(orig_img, c, s, (self.input_res, self.input_res)) + for i in range(np.shape(orig_keypoints)[1]): + if orig_keypoints[0, i, 0] > 0: + orig_keypoints[0, i, :2] = src.utils.img.transform( + orig_keypoints[0, i, :2], c, s, (self.input_res, self.input_res) + ) + keypoints = np.copy(orig_keypoints) + + # Random Crop + height, width = cropped.shape[0:2] + center = np.array((width / 2, height / 2)) + scale = max(height, width) / 200 + + aug_rot = 0 + + aug_rot = (np.random.random() * 2 - 1) * 30.0 + aug_scale = np.random.random() * (1.25 - 0.75) + 0.75 + scale *= aug_scale + + mat_mask = src.utils.img.get_transform(center, scale, (self.output_res, self.output_res), aug_rot)[:2] + + mat = src.utils.img.get_transform(center, scale, (self.input_res, self.input_res), aug_rot)[:2] + inp = cv2.warpAffine(cropped, mat, (self.input_res, self.input_res)).astype(np.float32) / 255 + keypoints[:, :, 0:2] = src.utils.img.kpt_affine(keypoints[:, :, 0:2], mat_mask) + if np.random.randint(2) == 0: + inp = self.preprocess(inp) + inp = inp[:, ::-1] + keypoints = keypoints[:, flipped_parts["mpii"]] + keypoints[:, :, 0] = self.output_res - keypoints[:, :, 0] + orig_keypoints = orig_keypoints[:, flipped_parts["mpii"]] + orig_keypoints[:, :, 0] = self.input_res - orig_keypoints[:, :, 0] + + # If keypoint is invisible, set to 0 + for i in range(np.shape(orig_keypoints)[1]): + if kptmp[0, i, 0] == 0 and kptmp[0, i, 1] == 0: + keypoints[0, i, 0] = 0 + keypoints[0, i, 1] = 0 + orig_keypoints[0, i, 0] = 0 + orig_keypoints[0, i, 1] = 0 + + # Generate target heatmap + heatmaps = self.generateHeatmap(keypoints) + + return inp.astype(np.float32), heatmaps.astype(np.float32) + + def preprocess(self, data): + """ + preprocess images + """ + # Random hue and saturation + data = cv2.cvtColor(data, cv2.COLOR_RGB2HSV) + delta = (np.random.random() * 2 - 1) * 0.2 + data[:, :, 0] = np.mod(data[:, :, 0] + (delta * 360 + 360.0), 360.0) + + delta_sature = np.random.random() + 0.5 + data[:, :, 1] *= delta_sature + data[:, :, 1] = np.maximum(np.minimum(data[:, :, 1], 1), 0) + data = cv2.cvtColor(data, cv2.COLOR_HSV2RGB) + + # Random brightness + delta = (np.random.random() * 2 - 1) * 0.3 + data += delta + + # Random contrast + mean = data.mean(axis=2, keepdims=True) + data = (data - mean) * (np.random.random() + 0.5) + mean + data = np.minimum(np.maximum(data, 0), 1) + return data diff --git a/model_zoo/research/cv/StackedHourglass/src/dataset/MPIIDataLoader.py b/model_zoo/research/cv/StackedHourglass/src/dataset/MPIIDataLoader.py new file mode 100644 index 00000000000..c5d5029e4a9 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/src/dataset/MPIIDataLoader.py @@ -0,0 +1,163 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +MPII dataset loader +""" +import os +import time + +import h5py +import numpy as np +from imageio import imread + +from src.config import parse_args + +args = parse_args() + + +class MPII: + """ + MPII dataset loader + """ + + def __init__(self): + print("loading data...") + tic = time.time() + + train_f = h5py.File(os.path.join(args.annot_dir, "train.h5"), "r") + val_f = h5py.File(os.path.join(args.annot_dir, "valid.h5"), "r") + + self.t_center = train_f["center"][()] + t_scale = train_f["scale"][()] + t_part = train_f["part"][()] + t_visible = train_f["visible"][()] + t_normalize = train_f["normalize"][()] + t_imgname = [None] * len(self.t_center) + for i in range(len(self.t_center)): + t_imgname[i] = train_f["imgname"][i].decode("UTF-8") + + self.v_center = val_f["center"][()] + v_scale = val_f["scale"][()] + v_part = val_f["part"][()] + v_visible = val_f["visible"][()] + v_normalize = val_f["normalize"][()] + v_imgname = [None] * len(self.v_center) + for i in range(len(self.v_center)): + v_imgname[i] = val_f["imgname"][i].decode("UTF-8") + + self.center = np.append(self.t_center, self.v_center, axis=0) + self.scale = np.append(t_scale, v_scale) + self.part = np.append(t_part, v_part, axis=0) + self.visible = np.append(t_visible, v_visible, axis=0) + self.normalize = np.append(t_normalize, v_normalize) + self.imgname = t_imgname + v_imgname + + print("Done (t={:0.2f}s)".format(time.time() - tic)) + + self.num_examples_train, self.num_examples_val = self.getLength() + + def getLength(self): + """ + get dataset length + """ + return len(self.t_center), len(self.v_center) + + def setup_val_split(self): + """ + get index for train and validation imgs + index for validation images starts after that of train images + so that loadImage can tell them apart + """ + valid = [i + self.num_examples_train for i in range(self.num_examples_val)] + train = [i for i in range(self.num_examples_train)] + return np.array(train), np.array(valid) + + def get_img(self, idx): + """ + get image + """ + imgname = self.imgname[idx] + path = os.path.join(args.img_dir, imgname) + img = imread(path) + return img + + def get_path(self, idx): + """ + get image path + """ + imgname = self.imgname[idx] + path = os.path.join(args.img_dir, imgname) + return path + + def get_kps(self, idx): + """ + get key points + """ + part = self.part[idx] + visible = self.visible[idx] + kp2 = np.insert(part, 2, visible, axis=1) + kps = np.zeros((1, 16, 3)) + kps[0] = kp2 + return kps + + def get_normalized(self, idx): + """ + get normalized value + """ + n = self.normalize[idx] + return n + + def get_center(self, idx): + """ + get center of the person + """ + c = self.center[idx] + return c + + def get_scale(self, idx): + """ + get scale of the person + """ + s = self.scale[idx] + return s + + +# Part reference +parts = { + "mpii": [ + "rank", + "rkne", + "rhip", + "lhip", + "lkne", + "lank", + "pelv", + "thrx", + "neck", + "head", + "rwri", + "relb", + "rsho", + "lsho", + "lelb", + "lwri", + ] +} + +flipped_parts = {"mpii": [5, 4, 3, 2, 1, 0, 6, 7, 8, 9, 15, 14, 13, 12, 11, 10]} + +part_pairs = {"mpii": [[0, 5], [1, 4], [2, 3], [6], [7], [8], [9], [10, 15], [11, 14], [12, 13]]} + +pair_names = {"mpii": ["ankle", "knee", "hip", "pelvis", "thorax", "neck", "head", "wrist", "elbow", "shoulder"]} diff --git a/model_zoo/research/cv/StackedHourglass/src/models/StackedHourglassNet.py b/model_zoo/research/cv/StackedHourglass/src/models/StackedHourglassNet.py new file mode 100644 index 00000000000..6b5e5f7686e --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/src/models/StackedHourglassNet.py @@ -0,0 +1,83 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Stacked Hourglass Model +""" +import mindspore.nn as nn +import mindspore.ops as ops +import mindspore.ops.operations as P + +from src.models.layers import Conv, ConvBNReLU, Hourglass, Residual + + +class StackedHourglassNet(nn.Cell): + """ + Stacked Hourglass Network + """ + + def __init__(self, nstack, inp_dim, oup_dim): + super(StackedHourglassNet, self).__init__() + + self.nstack = nstack + + self.input_transpose = P.Transpose() + + self.pre = nn.SequentialCell( + [ + ConvBNReLU(3, 64, 7, 2), + Residual(64, 128), + nn.MaxPool2d(2, 2), + Residual(128, 128), + Residual(128, inp_dim), + ] + ) + + self.hgs = nn.CellList( + [nn.SequentialCell([Hourglass(4, inp_dim),]) for i in range(nstack)] + ) + + self.features = nn.CellList( + [nn.SequentialCell([Residual(inp_dim, inp_dim), ConvBNReLU(inp_dim, inp_dim, 1)]) for i in range(nstack)] + ) + + self.outs = nn.CellList([Conv(inp_dim, oup_dim, 1) for i in range(nstack)]) + self.merge_features = nn.CellList([Conv(inp_dim, inp_dim, 1) for i in range(nstack - 1)]) + self.merge_preds = nn.CellList([Conv(oup_dim, inp_dim, 1) for i in range(nstack - 1)]) + self.output_stack = ops.Stack(axis=1) + + def construct(self, imgs): + """ + forward + """ + # x size (batch, 3, 256, 256) + x = self.input_transpose( + imgs, + ( + 0, + 3, + 1, + 2, + ), + ) + x = self.pre(x) + combined_hm_preds = [] + for i in range(self.nstack): + hg = self.hgs[i](x) + feature = self.features[i](hg) + preds = self.outs[i](feature) + combined_hm_preds.append(preds) + if i < self.nstack - 1: + x = x + self.merge_preds[i](preds) + self.merge_features[i](feature) + return self.output_stack(combined_hm_preds) diff --git a/model_zoo/research/cv/StackedHourglass/src/models/layers.py b/model_zoo/research/cv/StackedHourglass/src/models/layers.py new file mode 100644 index 00000000000..1cec0cf143d --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/src/models/layers.py @@ -0,0 +1,147 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +model layers +""" +import mindspore.nn as nn +import mindspore.ops as ops + + +class Conv(nn.Cell): + """ + conv 2d + """ + + def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1): + super(Conv, self).__init__() + self.inp_dim = inp_dim + self.conv = nn.Conv2d( + in_channels=inp_dim, + out_channels=out_dim, + kernel_size=kernel_size, + stride=stride, + pad_mode="pad", + padding=(kernel_size - 1) // 2, + has_bias=True, + ) + + def construct(self, x): + """ + forward + """ + + x = self.conv(x) + return x + + +class ConvBNReLU(nn.Cell): + """ + conv 2d with batch normalize and relu + """ + + def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1): + super(ConvBNReLU, self).__init__() + self.inp_dim = inp_dim + self.conv = Conv(inp_dim, out_dim, kernel_size, stride) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(num_features=out_dim, momentum=0.9) + + def construct(self, x): + """ + forward + """ + + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Residual(nn.Cell): + """ + residual block + """ + + def __init__(self, inp_dim, out_dim): + super(Residual, self).__init__() + self.relu = nn.ReLU() + self.bn1 = nn.BatchNorm2d(num_features=inp_dim, momentum=0.9) + self.conv1 = Conv(inp_dim, out_dim // 2, 1) + self.bn2 = nn.BatchNorm2d(momentum=0.9, num_features=out_dim // 2) + self.conv2 = Conv(out_dim // 2, out_dim // 2, 3) + self.bn3 = nn.BatchNorm2d(momentum=0.9, num_features=out_dim // 2) + self.conv3 = Conv(out_dim // 2, out_dim, 1) + self.skip_layer = Conv(inp_dim, out_dim, 1) + if inp_dim == out_dim: + self.need_skip = False + else: + self.need_skip = True + + def construct(self, x): + """ + forward + """ + if self.need_skip: + residual = self.skip_layer(x) + else: + residual = x + out = self.bn1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn3(out) + out = self.relu(out) + out = self.conv3(out) + out += residual + return out + + +class Hourglass(nn.Cell): + """ + hourglass module + """ + + def __init__(self, n, f): + super(Hourglass, self).__init__() + self.up1 = Residual(f, f) + # Down sampling + self.pool1 = nn.MaxPool2d(2, 2) + self.low1 = Residual(f, f) + self.n = n + # Use Hourglass recursively + if self.n > 1: + self.low2 = Hourglass(n - 1, f) + else: + self.low2 = Residual(f, f) + self.low3 = Residual(f, f) + + # Set upsample size + sz = [0, 8, 16, 32, 64] + self.up2 = ops.ResizeNearestNeighbor((sz[n], sz[n])) + + def construct(self, x): + """ + forward + """ + + up1 = self.up1(x) + pool1 = self.pool1(x) + low1 = self.low1(pool1) + low2 = self.low2(low1) + low3 = self.low3(low2) + up2 = self.up2(low3) + return up1 + up2 diff --git a/model_zoo/research/cv/StackedHourglass/src/models/loss.py b/model_zoo/research/cv/StackedHourglass/src/models/loss.py new file mode 100644 index 00000000000..505fedac2ad --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/src/models/loss.py @@ -0,0 +1,41 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +define heatmap loss +""" +import mindspore.nn as nn +import mindspore.ops.operations as P + + +class HeatmapLoss(nn.Cell): + """ + loss for detection heatmap + """ + + def __init__(self): + super(HeatmapLoss, self).__init__() + self.loss_function = nn.MSELoss() + self.transpose = P.Transpose() + + def construct(self, pred, gt): + """ + calculate loss + """ + # pred size (batch, 8, 16, 64, 64), gt size (batch, 16, 16, 64) + # Use broadcast to calculate loss + pred_t = self.transpose(pred, (1, 0, 2, 3, 4)) + loss = self.loss_function(pred_t, gt) + + return loss diff --git a/model_zoo/research/cv/StackedHourglass/src/utils/img.py b/model_zoo/research/cv/StackedHourglass/src/utils/img.py new file mode 100644 index 00000000000..31fa5e8baa8 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/src/utils/img.py @@ -0,0 +1,108 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +general image utils +""" +import cv2 +import numpy as np + + +def get_transform(center, scale, res, rot=0): + """ + generate trainsform matrix + """ + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + 0.5) + t[1, 2] = res[0] * (-float(center[1]) / h + 0.5) + t[2, 2] = 1 + if rot != 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[1] / 2 + t_mat[1, 2] = -res[0] / 2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) + return t + + +def transform(pt, center, scale, res, invert=0, rot=0): + """ + transform points + """ + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0], pt[1], 1.0]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2].astype(int) + + +def crop(img, center, scale, res): + """ + crop images + """ + # Left up + ul = np.array(transform([0, 0], center, scale, res, invert=1)) + # Right down + br = np.array(transform(res, center, scale, res, invert=1)) + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] + + return cv2.resize(new_img, res) + + +def inv_mat(mat): + """ + get invert matrix + """ + ans = np.linalg.pinv(np.array(mat).tolist() + [[0, 0, 1]]) + return ans[:2] + + +def kpt_affine(kpt, mat): + """ + get key point affine + """ + kpt = np.array(kpt) + shape = kpt.shape + kpt = kpt.reshape(-1, 2) + return np.dot(np.concatenate((kpt, kpt[:, 0:1] * 0 + 1), axis=1), mat.T).reshape(shape) + + +def resize(im, res): + """ + resize image + """ + return np.array([cv2.resize(im[i], res) for i in range(im.shape[0])]) diff --git a/model_zoo/research/cv/StackedHourglass/src/utils/inference.py b/model_zoo/research/cv/StackedHourglass/src/utils/inference.py new file mode 100644 index 00000000000..d419b098402 --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/src/utils/inference.py @@ -0,0 +1,373 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +inference module +""" +import copy +import os + +import cv2 +import h5py +import mindspore +import mindspore.nn as nn +import mindspore.ops as ops +import numpy as np +import tqdm + +import src.dataset.MPIIDataLoader as ds +import src.utils.img +from src.config import parse_args + +args = parse_args() + + +def match_format(dic): + """ + get match format + """ + loc = dic["loc_k"][0, :, 0, :] + val = dic["val_k"][0, :, :] + ans = np.hstack((loc, val)) + ans = np.expand_dims(ans, axis=0) + ret = [] + ret.append(ans) + return ret + + +class MaxPool2dFilter(nn.Cell): + """ + maxpool 2d for filter + """ + + def __init__(self): + super(MaxPool2dFilter, self).__init__() + self.pool = nn.MaxPool2d(3, 1, "same") + self.eq = ops.Equal() + + def construct(self, x): + """ + forward + """ + maxm = self.pool(x) + return self.eq(maxm, x) + + +class HeatmapParser: + """ + parse heatmap + """ + + def __init__(self): + self.maxpool2dfilter = MaxPool2dFilter().to_float(mindspore.float16) # avoid float error + self.topk = ops.TopK(sorted=True) + self.stack = ops.Stack(axis=3) + + def nms(self, det): + """ + keep max in 3x3 + """ + maxm = self.maxpool2dfilter(det) + det = det * maxm + return det + + def calc(self, det): + """ + calc distance + """ + det = self.nms(det) + w = det.shape[3] + det = det.view(det.shape[0], det.shape[1], -1) + val_k, ind = self.topk(det, 1) + + x = ind % w + y = (ind.astype(mindspore.float32) / w).astype(mindspore.int32) + ind_k = self.stack((x, y)) + ans = {"loc_k": ind_k, "val_k": val_k} + return {key: ans[key].asnumpy() for key in ans} + + def adjust(self, ans, det): + """ + adjust for joint + """ + for people in ans: + for i in people: + for joint_id, joint in enumerate(i): + if joint[2] > 0: + y, x = joint[0:2] + xx, yy = int(x), int(y) + tmp = det[0][joint_id] + if tmp[xx, min(yy + 1, tmp.shape[1] - 1)] > tmp[xx, max(yy - 1, 0)]: + y += 0.25 + else: + y -= 0.25 + + if tmp[min(xx + 1, tmp.shape[0] - 1), yy] > tmp[max(0, xx - 1), yy]: + x += 0.25 + else: + x -= 0.25 + ans[0][0, joint_id, 0:2] = (y + 0.5, x + 0.5) + return ans + + def parse(self, det, adjust=True): + """ + parse heatmap + """ + ans = match_format(self.calc(det)) + if adjust: + ans = self.adjust(ans, det) + return ans + + +parser = HeatmapParser() + + +def post_process(det, mat_, trainval, c=None, s=None, resolution=None): + """ + post process for parser + """ + mat = np.linalg.pinv(np.array(mat_).tolist() + [[0, 0, 1]])[:2] + cropped_preds = parser.parse(mindspore.Tensor(np.float32([det])))[0] + + if cropped_preds.size > 0: + cropped_preds[:, :, :2] = src.utils.img.kpt_affine(cropped_preds[:, :, :2] * 4, mat) # size 1x16x3 + + preds = np.copy(cropped_preds) + # revert to origin image + if trainval != "cropped": + for j in range(preds.shape[1]): + preds[0, j, :2] = src.utils.img.transform(preds[0, j, :2], c, s, resolution, invert=1) + return preds + + +def inference(img, net, c, s): + """ + forward pass at test time + calls post_process to post process results + """ + + height, width = img.shape[0:2] + center = (width / 2, height / 2) + scale = max(height, width) / 200 + res = (args.input_res, args.input_res) + + mat_ = src.utils.img.get_transform(center, scale, res)[:2] + inp = img / 255 + + tmp1 = net(mindspore.Tensor([inp], dtype=mindspore.float32)).asnumpy() + tmp2 = net(mindspore.Tensor([inp[:, ::-1]], dtype=mindspore.float32)).asnumpy() + + tmp = np.concatenate((tmp1, tmp2), axis=0) + + det = tmp[0, -1] + tmp[1, -1, :, :, ::-1][ds.flipped_parts["mpii"]] + + if det is None: + return [], [] + det = det / 2 + + det = np.minimum(det, 1) + + return post_process(det, mat_, "valid", c, s, res) + + +class MPIIEval: + """ + eval for MPII dataset + """ + + template = { + "all": { + "total": 0, + "ankle": 0, + "knee": 0, + "hip": 0, + "pelvis": 0, + "thorax": 0, + "neck": 0, + "head": 0, + "wrist": 0, + "elbow": 0, + "shoulder": 0, + }, + "visible": { + "total": 0, + "ankle": 0, + "knee": 0, + "hip": 0, + "pelvis": 0, + "thorax": 0, + "neck": 0, + "head": 0, + "wrist": 0, + "elbow": 0, + "shoulder": 0, + }, + "not visible": { + "total": 0, + "ankle": 0, + "knee": 0, + "hip": 0, + "pelvis": 0, + "thorax": 0, + "neck": 0, + "head": 0, + "wrist": 0, + "elbow": 0, + "shoulder": 0, + }, + } + + joint_map = [ + "ankle", + "knee", + "hip", + "hip", + "knee", + "ankle", + "pelvis", + "thorax", + "neck", + "head", + "wrist", + "elbow", + "shoulder", + "shoulder", + "elbow", + "wrist", + ] + + def __init__(self): + self.correct = copy.deepcopy(self.template) + self.count = copy.deepcopy(self.template) + self.correct_train = copy.deepcopy(self.template) + self.count_train = copy.deepcopy(self.template) + + def eval(self, pred, gt, normalizing, num_train, bound=0.5): + """ + use PCK with threshold of .5 of normalized distance (presumably head size) + """ + idx = 0 + for p, g, normalize in zip(pred, gt, normalizing): + for j in range(g.shape[1]): + vis = "visible" + if g[0, j, 0] == 0: # Not in image + continue + if g[0, j, 2] == 0: + vis = "not visible" + joint = self.joint_map[j] + + if idx >= num_train: + self.count["all"]["total"] += 1 + self.count["all"][joint] += 1 + self.count[vis]["total"] += 1 + self.count[vis][joint] += 1 + else: + self.count_train["all"]["total"] += 1 + self.count_train["all"][joint] += 1 + self.count_train[vis]["total"] += 1 + self.count_train[vis][joint] += 1 + error = np.linalg.norm(p[0]["keypoints"][j, :2] - g[0, j, :2]) / normalize + if idx >= num_train: + if bound > error: + self.correct["all"]["total"] += 1 + self.correct["all"][joint] += 1 + self.correct[vis]["total"] += 1 + self.correct[vis][joint] += 1 + else: + if bound > error: + self.correct_train["all"]["total"] += 1 + self.correct_train["all"][joint] += 1 + self.correct_train[vis]["total"] += 1 + self.correct_train[vis][joint] += 1 + idx += 1 + + self.output_result(bound) + + def output_result(self, bound): + """ + output split via train/valid + """ + for k in self.correct: + print(k, ":") + for key in self.correct[k]: + print( + "Val PCK @,", + bound, + ",", + key, + ":", + round(self.correct[k][key] / max(self.count[k][key], 1), 3), + ", count:", + self.count[k][key], + ) + print( + "Tra PCK @,", + bound, + ",", + key, + ":", + round(self.correct_train[k][key] / max(self.count_train[k][key], 1), 3), + ", count:", + self.count_train[k][key], + ) + print("\n") + + +def get_img(num_eval=2958, num_train=300): + """ + load validation and training images + """ + input_res = args.input_res + val_f = h5py.File(os.path.join(args.annot_dir, "valid.h5"), "r") + + tr = tqdm.tqdm(range(0, num_train), total=num_train) + # Train set + train_f = h5py.File(os.path.join(args.annot_dir, "train.h5"), "r") + for i in tr: + path_t = "%s/%s" % (args.img_dir, train_f["imgname"][i].decode("UTF-8")) + + orig_img = cv2.imread(path_t)[:, :, ::-1] + c = train_f["center"][i] + s = train_f["scale"][i] + im = src.utils.img.crop(orig_img, c, s, (input_res, input_res)) + + kp = train_f["part"][i] + vis = train_f["visible"][i] + kp2 = np.insert(kp, 2, vis, axis=1) + kps = np.zeros((1, 16, 3)) + kps[0] = kp2 + + n = train_f["normalize"][i] + + yield kps, im, c, s, n + + tr2 = tqdm.tqdm(range(0, num_eval), total=num_eval) + # Valid + for i in tr2: + path_t = "%s/%s" % (args.img_dir, val_f["imgname"][i].decode("UTF-8")) + + orig_img = cv2.imread(path_t)[:, :, ::-1] + c = val_f["center"][i] + s = val_f["scale"][i] + im = src.utils.img.crop(orig_img, c, s, (input_res, input_res)) + + kp = val_f["part"][i] + vis = val_f["visible"][i] + kp2 = np.insert(kp, 2, vis, axis=1) + kps = np.zeros((1, 16, 3)) + kps[0] = kp2 + + n = val_f["normalize"][i] + + yield kps, im, c, s, n diff --git a/model_zoo/research/cv/StackedHourglass/train.py b/model_zoo/research/cv/StackedHourglass/train.py new file mode 100644 index 00000000000..12c3052413e --- /dev/null +++ b/model_zoo/research/cv/StackedHourglass/train.py @@ -0,0 +1,101 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +run model train +""" +import math +import os + +import mindspore.dataset as ds +import mindspore.nn as nn +from mindspore import Model, context +from mindspore.common import set_seed +from mindspore.communication.management import get_group_size, get_rank, init +from mindspore.train.callback import (CheckpointConfig, LossMonitor, + ModelCheckpoint, TimeMonitor) + +from src.config import parse_args +from src.dataset.DatasetGenerator import DatasetGenerator +from src.dataset.MPIIDataLoader import MPII +from src.models.loss import HeatmapLoss +from src.models.StackedHourglassNet import StackedHourglassNet + +set_seed(1) + +args = parse_args() + +if __name__ == "__main__": + if not os.path.exists(args.img_dir) or not os.path.exists(args.annot_dir): + print("Dataset not found.") + exit() + + # Set context mode + if args.context_mode == "GRAPH": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False) + else: + context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target) + + if args.parallel: + # Parallel mode + context.reset_auto_parallel_context() + init() + context.set_auto_parallel_context(parallel_mode=context.ParallelMode.DATA_PARALLEL, gradients_mean=True) + args.rank_id = get_rank() + args.group_size = get_group_size() + else: + args.rank_id = 0 + args.group_size = 1 + + net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim) + + # Process dataset + mpii = MPII() + train, valid = mpii.setup_val_split() + + train_generator = DatasetGenerator(args.input_res, args.output_res, mpii, train) + train_size = len(train_generator) + train_sampler = ds.DistributedSampler(num_shards=args.group_size, shard_id=args.rank_id, shuffle=True) + train_data = ds.GeneratorDataset(train_generator, ["data", "label"], sampler=train_sampler) + train_data = train_data.batch(args.batch_size, True, args.group_size) + + print("train data size:", train_size) + step_per_epoch = math.ceil(train_size / args.batch_size / args.group_size) + + # Define loss function + loss_func = HeatmapLoss() + # Define optimizer + lr_decay = nn.exponential_decay_lr( + args.initial_lr, args.decay_rate, args.num_epoch * step_per_epoch, step_per_epoch, args.decay_epoch + ) + optimizer = nn.Adam(net.trainable_params(), lr_decay) + + # Define model + model = Model(net, loss_func, optimizer, amp_level=args.amp_level, keep_batchnorm_fp32=False) + + # Define callback functions + callbacks = [] + callbacks.append(LossMonitor(args.loss_log_interval)) + callbacks.append(TimeMonitor(train_size)) + + # Save checkpoint file + if args.rank_id == 0: + config_ck = CheckpointConfig( + save_checkpoint_steps=args.save_checkpoint_epochs * step_per_epoch, + keep_checkpoint_max=args.keep_checkpoint_max, + ) + ckpoint = ModelCheckpoint("ckpt", config=config_ck) + callbacks.append(ckpoint) + + model.train(args.num_epoch, train_data, callbacks=callbacks, dataset_sink_mode=True)