!19379 Add Stacked Hourglass Network for Model Zoo

Merge pull request !19379 from cometeme/master
This commit is contained in:
i-robot 2021-07-13 04:09:20 +00:00 committed by Gitee
commit 885bd46760
16 changed files with 1689 additions and 0 deletions

View File

@ -0,0 +1,193 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [Stacked Hourglass 描述](#stacked-hourglass-描述)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [脚本和样例代码](#脚本和样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [运行](#运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [运行](#运行-1)
- [结果](#结果-1)
- [导出](#导出)
- [模型说明](#模型说明)
- [训练性能2HG](#训练性能2hg)
- [随机情况的描述](#随机情况的描述)
- [ModelZoo](#modelzoo)
<!-- /TOC -->
# Stacked Hourglass 描述
Stacked Hourglass 是一个用于人体姿态检测的模型,它采用堆叠的 hourglass 模块进行特征的提取,并在最终通过热力图输出模型对于每个特征点的预测位置。
[论文Stacked Hourglass Networks for Human Pose Estimation](https://arxiv.org/abs/1603.06937v2)
# 数据集
使用的数据集:[MPII Human Pose Dataset](http://human-pose.mpi-inf.mpg.de/)
- 数据集大小:
- 训练: 22246 张图片
- 测试: 2958 张图片
- 关键点数量16 个(头部、颈部、肩部、肘部、手腕、胸部、骨盆、臀部、膝盖、脚踝)
> 注MPII 数据集中原始的 annot 为 .mat 格式,处理困难,请下载使用另一个 annot[https://github.com/princeton-vl/pytorch_stacked_hourglass/tree/master/data/MPII/annot](https://github.com/princeton-vl/pytorch_stacked_hourglass/tree/master/data/MPII/annot)
# 环境要求
- 硬件Ascend
- 使用 Ascend 来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore 教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 脚本说明
## 脚本和样例代码
```text
├──scripts
│ ├──run_distribute_train.sh # 分布式训练脚本
│ ├──run_eval.sh # 评估脚本
│ └──run_standalone_train.sh # 单卡训练脚本
├──src
│ ├──dataset
│ │ ├──DatasetGenerator.py # 数据集定义及标注热力图生成
│ │ └──MPIIDataLoader.py # MPII 数据的加载及预处理
│ ├──models
│ │ ├──layers.py # 网络子模块定义
│ │ ├──loss.py # HeatMap Loss 定义
│ │ └──StackedHourglassNet.py # 整体网络定义
│ ├──utils
│ │ ├──img.py # 通用的图像处理模块
│ │ └──inference.py # 推理相关的函数,包含了推理的准确率计算等
│ └── config.py # 参数配置
├── eval.py # 评估脚本
├── export.py # 导出脚本
├── README_CN.md # 项目相关描述
└── train.py # 训练脚本
```
## 脚本参数
模型训练和评估过程中使用的参数可以在 config.py 中设置,也可以在运行时通过命令行参数给入。
## 训练过程
### 运行
单卡训练时,首先需要设置目标卡: `export DEVICE_ID=x` ,其中 `x` 为目标卡的 ID 。接下来启动训练:
```sh
python train.py
```
或者可以使用单卡训练脚本:
```sh
./scripts/run_standalone_train.sh [设备 ID] [标注路径] [图像路径]
```
多卡训练时可以使用多卡训练脚本:
```sh
./scripts/run_distribute_train.sh [配置文件路径] [Ascend 卡数量] [标注路径] [图像路径]
```
### 结果
ckpt 文件将存储在当前路径下,训练结果默认输出至 `loss.txt` 中,而错误和提示信息在 `err.txt` 中,示例如下:
```text
loading data...
Done (t=14.61s)
train data size: 22246
epoch: 1 step: 695, loss is 0.00068435294
epoch time: 954584.373 ms, per step time: 1373.503 ms
epoch: 2 step: 695, loss is 0.00067576126
epoch time: 755549.341 ms, per step time: 1087.121 ms
epoch: 3 step: 695, loss is 0.00057179347
epoch time: 750856.373 ms, per step time: 1080.369 ms
epoch: 4 step: 695, loss is 0.00055218843
[...]
```
## 评估过程
### 运行
在运行评估前需要指定目标卡: `export DEVICE_ID=x` ,其中 `x` 为目标卡的 ID 。接下来使用 python 启动评估,需要指定 ckpt 文件的路径。
```sh
python eval.py --ckpt_file <path to ckpt file>
```
也可以使用验证脚本:
```sh
./scripts/run_eval.sh [设备 ID] [ckpt 文件路径] [标注路径] [图像路径]
```
### 结果
验证结果默认输出至 `result.txt` 中,而错误和提示信息在 `err.txt` 中。
```text
all :
Val PCK @, 0.5 , total : 0.882 , count: 44239
Tra PCK @, 0.5 , total : 0.938 , count: 4443
Val PCK @, 0.5 , ankle : 0.765 , count: 4234
Tra PCK @, 0.5 , ankle : 0.847 , count: 392
Val PCK @, 0.5 , knee : 0.819 , count: 4963
Tra PCK @, 0.5 , knee : 0.91 , count: 499
Val PCK @, 0.5 , hip : 0.871 , count: 5777
Tra PCK @, 0.5 , hip : 0.918 , count: 587
[...]
```
## 导出
可以使用 `export.py` 脚本进行模型导出,使用方法为:
```sh
python export.py --ckpt_file [ckpt 文件路径]
```
# 模型说明
## 训练性能2HG
| 参数 | Ascend |
| ---------------- | -------------------------- |
| 模型名称 | Stacked Hourglass Networks |
| 运行环境 | Ascend 910A |
| 上传时间 | 2021-7-5 |
| MindSpore 版本 | 1.2.0 |
| 数据集 | MPII Human Pose Dataset |
| 训练参数 | 详见 config.py |
| 优化器 | Adam (带指数学习率衰减) |
| 损失函数 | HeatMap Loss (类 MSE) |
| 最终损失 | 0.00036373272 |
| 精确度 | 88.2% |
| 训练总时间1p | 20h |
| 评估总时间1p | 21min |
| 参数量 | 8429088 |
# 随机情况的描述
我们在 `train.py` 脚本中设置了随机种子。
# ModelZoo
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,68 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
run model eval
"""
import os
from mindspore import context, load_checkpoint, load_param_into_net
from src.config import parse_args
from src.models.StackedHourglassNet import StackedHourglassNet
from src.utils.inference import MPIIEval, get_img, inference
args = parse_args()
if __name__ == "__main__":
if not os.path.exists(args.ckpt_file):
print("ckpt file not valid")
exit()
if not os.path.exists(args.img_dir) or not os.path.exists(args.annot_dir):
print("Dataset not found.")
exit()
# Set context mode
if args.context_mode == "GRAPH":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
# Import net
net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
gts = []
preds = []
normalizing = []
num_eval = args.num_eval
num_train = args.train_num_eval
for anns, img, c, s, n in get_img(num_eval, num_train):
gts.append(anns)
ans = inference(img, net, c, s)
if ans.size > 0:
ans = ans[:, :, :3]
# (num preds, joints, x/y/visible)
pred = []
for i in range(ans.shape[0]):
pred.append({"keypoints": ans[i, :, :]})
preds.append(pred)
normalizing.append(n)
mpii_eval = MPIIEval()
mpii_eval.eval(preds, gts, normalizing, num_train)

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
export model
"""
import os
import numpy as np
from mindspore import Tensor, context, export, load_checkpoint, load_param_into_net
from src.config import parse_args
from src.models.StackedHourglassNet import StackedHourglassNet
args = parse_args()
if __name__ == "__main__":
if not os.path.exists(args.ckpt_file):
print("ckpt file not valid")
exit()
# Set context mode
if args.context_mode == "GRAPH":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
# Import net
net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([args.batch_size, args.input_res, args.input_res, 3], np.float32))
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,23 @@
asttokens==2.0.5
astunparse==1.6.3
attrs==21.2.0
cached-property==1.5.2
certifi==2018.10.15
cffi==1.14.5
decorator==5.0.9
easydict==1.9
h5py==3.3.0
imageio==2.9.0
mpmath==1.2.1
numpy==1.21.0
opencv-python==4.5.2.54
packaging==21.0
Pillow==8.3.1
protobuf==3.17.3
psutil==5.8.0
pycparser==2.20
pyparsing==2.4.7
scipy==1.7.0
six==1.16.0
sympy==1.8
tqdm==4.61.2

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]; then
echo "Usage: ./scripts/run_distribute_train.sh [rank file path] [rank size] [annot path] [image path]"
exit 1
fi
export RANK_TABLE_FILE="$1"
export RANK_SIZE="$2"
ANNOT_DIR=$3
IMG_DIR=$4
CURDIR=$(pwd)
for ((i = 0; i < $2; i++))
do
export DEVICE_ID="$i"
export RANK_ID="$i"
cd $CURDIR
TARGET="./Train${DEVICE_ID}"
rm -rf $TARGET
mkdir $TARGET
cp *.py $TARGET
cp -r src $TARGET
cd $TARGET
echo "training for rank $i"
nohup python train.py \
--annot_dir=$ANNOT_DIR \
--img_dir=$IMG_DIR \
--parallel True > loss.txt 2> err.txt &
done

View File

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]; then
echo "Usage: ./scripts/run_eval.sh [device ID] [ckpt path] [annot path] [image path]"
exit 1
fi
export DEVICE_ID=$1
PATH_CHECKPOINT=$2
ANNOT_DIR=$3
IMG_DIR=$4
TARGET="./Eval"
rm -rf $TARGET
mkdir $TARGET
cp *.py $TARGET
cp -r src $TARGET
cd $TARGET
nohup python eval.py \
--ckpt_file=$PATH_CHECKPOINT \
--annot_dir=$ANNOT_DIR \
--img_dir=$IMG_DIR > result.txt 2> err.txt &

View File

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: ./scripts/run_standalone_train.sh [device ID] [annot path] [image path]"
exit 1
fi
export DEVICE_ID=$1
ANNOT_DIR=$2
IMG_DIR=$3
TARGET="./Train"
rm -rf $TARGET
mkdir $TARGET
cp *.py $TARGET
cp -r src $TARGET
cd $TARGET
nohup python train.py \
--annot_dir=$ANNOT_DIR \
--img_dir=$IMG_DIR > loss.txt 2> err.txt &

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
model config
"""
import argparse
import ast
def parse_args():
"""
parse arguments
"""
parser = argparse.ArgumentParser(description="MindSpore Stacked Hourglass")
# Model
parser.add_argument("--nstack", type=int, default=2)
parser.add_argument("--inp_dim", type=int, default=256)
parser.add_argument("--oup_dim", type=int, default=16)
parser.add_argument("--input_res", type=int, default=256)
parser.add_argument("--output_res", type=int, default=64)
parser.add_argument("--annot_dir", type=str, default="./MPII/annot")
parser.add_argument("--img_dir", type=str, default="./MPII/images")
# Context
parser.add_argument("--context_mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"])
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "CPU"])
# Train
parser.add_argument("--parallel", type=ast.literal_eval, default=False)
parser.add_argument("--amp_level", type=str, default="O2", choices=["O0", "O1", "O2", "O3"])
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_epoch", type=int, default=100)
parser.add_argument("--save_checkpoint_epochs", type=int, default=5)
parser.add_argument("--keep_checkpoint_max", type=int, default=20)
parser.add_argument("--loss_log_interval", type=int, default=1)
parser.add_argument("--initial_lr", type=float, default=1e-3)
parser.add_argument("--decay_rate", type=float, default=0.985)
parser.add_argument("--decay_epoch", type=int, default=1)
# Valid
parser.add_argument("--num_eval", type=int, default=2958)
parser.add_argument("--train_num_eval", type=int, default=300)
parser.add_argument("--ckpt_file", type=str, default="")
# Export
parser.add_argument("--file_name", type=str, default="stackedhourglass")
parser.add_argument("--file_format", type=str, default="MINDIR")
args = parser.parse_known_args()[0]
return args

View File

@ -0,0 +1,161 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset classes
"""
import cv2
import numpy as np
import src.utils.img
from src.dataset.MPIIDataLoader import flipped_parts
class GenerateHeatmap:
"""
get train target heatmap
"""
def __init__(self, output_res, num_parts):
self.output_res = output_res
self.num_parts = num_parts
sigma = self.output_res / 64
self.sigma = sigma
size = 6 * sigma + 3
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0, y0 = 3 * sigma + 1, 3 * sigma + 1
self.g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
def __call__(self, keypoints):
hms = np.zeros(shape=(self.num_parts, self.output_res, self.output_res), dtype=np.float32)
sigma = self.sigma
for p in keypoints:
for idx, pt in enumerate(p):
if pt[0] > 0:
x, y = int(pt[0]), int(pt[1])
if x < 0 or y < 0 or x >= self.output_res or y >= self.output_res:
continue
ul = int(x - 3 * sigma - 1), int(y - 3 * sigma - 1)
br = int(x + 3 * sigma + 2), int(y + 3 * sigma + 2)
c, d = max(0, -ul[0]), min(br[0], self.output_res) - ul[0]
a, b = max(0, -ul[1]), min(br[1], self.output_res) - ul[1]
cc, dd = max(0, ul[0]), min(br[0], self.output_res)
aa, bb = max(0, ul[1]), min(br[1], self.output_res)
hms[idx, aa:bb, cc:dd] = np.maximum(hms[idx, aa:bb, cc:dd], self.g[a:b, c:d])
return hms
class DatasetGenerator:
"""
mindspore general dataset generator
"""
def __init__(self, input_res, output_res, ds, index):
self.input_res = input_res
self.output_res = output_res
self.generateHeatmap = GenerateHeatmap(self.output_res, 16)
self.ds = ds
self.index = index
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
# print(f"loading...{idx}")
return self.loadImage(self.index[idx])
def loadImage(self, idx):
"""
load and preprocess image
"""
ds = self.ds
# Load + Crop
orig_img = ds.get_img(idx)
orig_keypoints = ds.get_kps(idx)
kptmp = orig_keypoints.copy()
c = ds.get_center(idx)
s = ds.get_scale(idx)
cropped = src.utils.img.crop(orig_img, c, s, (self.input_res, self.input_res))
for i in range(np.shape(orig_keypoints)[1]):
if orig_keypoints[0, i, 0] > 0:
orig_keypoints[0, i, :2] = src.utils.img.transform(
orig_keypoints[0, i, :2], c, s, (self.input_res, self.input_res)
)
keypoints = np.copy(orig_keypoints)
# Random Crop
height, width = cropped.shape[0:2]
center = np.array((width / 2, height / 2))
scale = max(height, width) / 200
aug_rot = 0
aug_rot = (np.random.random() * 2 - 1) * 30.0
aug_scale = np.random.random() * (1.25 - 0.75) + 0.75
scale *= aug_scale
mat_mask = src.utils.img.get_transform(center, scale, (self.output_res, self.output_res), aug_rot)[:2]
mat = src.utils.img.get_transform(center, scale, (self.input_res, self.input_res), aug_rot)[:2]
inp = cv2.warpAffine(cropped, mat, (self.input_res, self.input_res)).astype(np.float32) / 255
keypoints[:, :, 0:2] = src.utils.img.kpt_affine(keypoints[:, :, 0:2], mat_mask)
if np.random.randint(2) == 0:
inp = self.preprocess(inp)
inp = inp[:, ::-1]
keypoints = keypoints[:, flipped_parts["mpii"]]
keypoints[:, :, 0] = self.output_res - keypoints[:, :, 0]
orig_keypoints = orig_keypoints[:, flipped_parts["mpii"]]
orig_keypoints[:, :, 0] = self.input_res - orig_keypoints[:, :, 0]
# If keypoint is invisible, set to 0
for i in range(np.shape(orig_keypoints)[1]):
if kptmp[0, i, 0] == 0 and kptmp[0, i, 1] == 0:
keypoints[0, i, 0] = 0
keypoints[0, i, 1] = 0
orig_keypoints[0, i, 0] = 0
orig_keypoints[0, i, 1] = 0
# Generate target heatmap
heatmaps = self.generateHeatmap(keypoints)
return inp.astype(np.float32), heatmaps.astype(np.float32)
def preprocess(self, data):
"""
preprocess images
"""
# Random hue and saturation
data = cv2.cvtColor(data, cv2.COLOR_RGB2HSV)
delta = (np.random.random() * 2 - 1) * 0.2
data[:, :, 0] = np.mod(data[:, :, 0] + (delta * 360 + 360.0), 360.0)
delta_sature = np.random.random() + 0.5
data[:, :, 1] *= delta_sature
data[:, :, 1] = np.maximum(np.minimum(data[:, :, 1], 1), 0)
data = cv2.cvtColor(data, cv2.COLOR_HSV2RGB)
# Random brightness
delta = (np.random.random() * 2 - 1) * 0.3
data += delta
# Random contrast
mean = data.mean(axis=2, keepdims=True)
data = (data - mean) * (np.random.random() + 0.5) + mean
data = np.minimum(np.maximum(data, 0), 1)
return data

View File

@ -0,0 +1,163 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MPII dataset loader
"""
import os
import time
import h5py
import numpy as np
from imageio import imread
from src.config import parse_args
args = parse_args()
class MPII:
"""
MPII dataset loader
"""
def __init__(self):
print("loading data...")
tic = time.time()
train_f = h5py.File(os.path.join(args.annot_dir, "train.h5"), "r")
val_f = h5py.File(os.path.join(args.annot_dir, "valid.h5"), "r")
self.t_center = train_f["center"][()]
t_scale = train_f["scale"][()]
t_part = train_f["part"][()]
t_visible = train_f["visible"][()]
t_normalize = train_f["normalize"][()]
t_imgname = [None] * len(self.t_center)
for i in range(len(self.t_center)):
t_imgname[i] = train_f["imgname"][i].decode("UTF-8")
self.v_center = val_f["center"][()]
v_scale = val_f["scale"][()]
v_part = val_f["part"][()]
v_visible = val_f["visible"][()]
v_normalize = val_f["normalize"][()]
v_imgname = [None] * len(self.v_center)
for i in range(len(self.v_center)):
v_imgname[i] = val_f["imgname"][i].decode("UTF-8")
self.center = np.append(self.t_center, self.v_center, axis=0)
self.scale = np.append(t_scale, v_scale)
self.part = np.append(t_part, v_part, axis=0)
self.visible = np.append(t_visible, v_visible, axis=0)
self.normalize = np.append(t_normalize, v_normalize)
self.imgname = t_imgname + v_imgname
print("Done (t={:0.2f}s)".format(time.time() - tic))
self.num_examples_train, self.num_examples_val = self.getLength()
def getLength(self):
"""
get dataset length
"""
return len(self.t_center), len(self.v_center)
def setup_val_split(self):
"""
get index for train and validation imgs
index for validation images starts after that of train images
so that loadImage can tell them apart
"""
valid = [i + self.num_examples_train for i in range(self.num_examples_val)]
train = [i for i in range(self.num_examples_train)]
return np.array(train), np.array(valid)
def get_img(self, idx):
"""
get image
"""
imgname = self.imgname[idx]
path = os.path.join(args.img_dir, imgname)
img = imread(path)
return img
def get_path(self, idx):
"""
get image path
"""
imgname = self.imgname[idx]
path = os.path.join(args.img_dir, imgname)
return path
def get_kps(self, idx):
"""
get key points
"""
part = self.part[idx]
visible = self.visible[idx]
kp2 = np.insert(part, 2, visible, axis=1)
kps = np.zeros((1, 16, 3))
kps[0] = kp2
return kps
def get_normalized(self, idx):
"""
get normalized value
"""
n = self.normalize[idx]
return n
def get_center(self, idx):
"""
get center of the person
"""
c = self.center[idx]
return c
def get_scale(self, idx):
"""
get scale of the person
"""
s = self.scale[idx]
return s
# Part reference
parts = {
"mpii": [
"rank",
"rkne",
"rhip",
"lhip",
"lkne",
"lank",
"pelv",
"thrx",
"neck",
"head",
"rwri",
"relb",
"rsho",
"lsho",
"lelb",
"lwri",
]
}
flipped_parts = {"mpii": [5, 4, 3, 2, 1, 0, 6, 7, 8, 9, 15, 14, 13, 12, 11, 10]}
part_pairs = {"mpii": [[0, 5], [1, 4], [2, 3], [6], [7], [8], [9], [10, 15], [11, 14], [12, 13]]}
pair_names = {"mpii": ["ankle", "knee", "hip", "pelvis", "thorax", "neck", "head", "wrist", "elbow", "shoulder"]}

View File

@ -0,0 +1,83 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Stacked Hourglass Model
"""
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.ops.operations as P
from src.models.layers import Conv, ConvBNReLU, Hourglass, Residual
class StackedHourglassNet(nn.Cell):
"""
Stacked Hourglass Network
"""
def __init__(self, nstack, inp_dim, oup_dim):
super(StackedHourglassNet, self).__init__()
self.nstack = nstack
self.input_transpose = P.Transpose()
self.pre = nn.SequentialCell(
[
ConvBNReLU(3, 64, 7, 2),
Residual(64, 128),
nn.MaxPool2d(2, 2),
Residual(128, 128),
Residual(128, inp_dim),
]
)
self.hgs = nn.CellList(
[nn.SequentialCell([Hourglass(4, inp_dim),]) for i in range(nstack)]
)
self.features = nn.CellList(
[nn.SequentialCell([Residual(inp_dim, inp_dim), ConvBNReLU(inp_dim, inp_dim, 1)]) for i in range(nstack)]
)
self.outs = nn.CellList([Conv(inp_dim, oup_dim, 1) for i in range(nstack)])
self.merge_features = nn.CellList([Conv(inp_dim, inp_dim, 1) for i in range(nstack - 1)])
self.merge_preds = nn.CellList([Conv(oup_dim, inp_dim, 1) for i in range(nstack - 1)])
self.output_stack = ops.Stack(axis=1)
def construct(self, imgs):
"""
forward
"""
# x size (batch, 3, 256, 256)
x = self.input_transpose(
imgs,
(
0,
3,
1,
2,
),
)
x = self.pre(x)
combined_hm_preds = []
for i in range(self.nstack):
hg = self.hgs[i](x)
feature = self.features[i](hg)
preds = self.outs[i](feature)
combined_hm_preds.append(preds)
if i < self.nstack - 1:
x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)
return self.output_stack(combined_hm_preds)

View File

@ -0,0 +1,147 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
model layers
"""
import mindspore.nn as nn
import mindspore.ops as ops
class Conv(nn.Cell):
"""
conv 2d
"""
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1):
super(Conv, self).__init__()
self.inp_dim = inp_dim
self.conv = nn.Conv2d(
in_channels=inp_dim,
out_channels=out_dim,
kernel_size=kernel_size,
stride=stride,
pad_mode="pad",
padding=(kernel_size - 1) // 2,
has_bias=True,
)
def construct(self, x):
"""
forward
"""
x = self.conv(x)
return x
class ConvBNReLU(nn.Cell):
"""
conv 2d with batch normalize and relu
"""
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1):
super(ConvBNReLU, self).__init__()
self.inp_dim = inp_dim
self.conv = Conv(inp_dim, out_dim, kernel_size, stride)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(num_features=out_dim, momentum=0.9)
def construct(self, x):
"""
forward
"""
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Residual(nn.Cell):
"""
residual block
"""
def __init__(self, inp_dim, out_dim):
super(Residual, self).__init__()
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2d(num_features=inp_dim, momentum=0.9)
self.conv1 = Conv(inp_dim, out_dim // 2, 1)
self.bn2 = nn.BatchNorm2d(momentum=0.9, num_features=out_dim // 2)
self.conv2 = Conv(out_dim // 2, out_dim // 2, 3)
self.bn3 = nn.BatchNorm2d(momentum=0.9, num_features=out_dim // 2)
self.conv3 = Conv(out_dim // 2, out_dim, 1)
self.skip_layer = Conv(inp_dim, out_dim, 1)
if inp_dim == out_dim:
self.need_skip = False
else:
self.need_skip = True
def construct(self, x):
"""
forward
"""
if self.need_skip:
residual = self.skip_layer(x)
else:
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
out += residual
return out
class Hourglass(nn.Cell):
"""
hourglass module
"""
def __init__(self, n, f):
super(Hourglass, self).__init__()
self.up1 = Residual(f, f)
# Down sampling
self.pool1 = nn.MaxPool2d(2, 2)
self.low1 = Residual(f, f)
self.n = n
# Use Hourglass recursively
if self.n > 1:
self.low2 = Hourglass(n - 1, f)
else:
self.low2 = Residual(f, f)
self.low3 = Residual(f, f)
# Set upsample size
sz = [0, 8, 16, 32, 64]
self.up2 = ops.ResizeNearestNeighbor((sz[n], sz[n]))
def construct(self, x):
"""
forward
"""
up1 = self.up1(x)
pool1 = self.pool1(x)
low1 = self.low1(pool1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return up1 + up2

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
define heatmap loss
"""
import mindspore.nn as nn
import mindspore.ops.operations as P
class HeatmapLoss(nn.Cell):
"""
loss for detection heatmap
"""
def __init__(self):
super(HeatmapLoss, self).__init__()
self.loss_function = nn.MSELoss()
self.transpose = P.Transpose()
def construct(self, pred, gt):
"""
calculate loss
"""
# pred size (batch, 8, 16, 64, 64), gt size (batch, 16, 16, 64)
# Use broadcast to calculate loss
pred_t = self.transpose(pred, (1, 0, 2, 3, 4))
loss = self.loss_function(pred_t, gt)
return loss

View File

@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
general image utils
"""
import cv2
import numpy as np
def get_transform(center, scale, res, rot=0):
"""
generate trainsform matrix
"""
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
t[2, 2] = 1
if rot != 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0, :2] = [cs, -sn]
rot_mat[1, :2] = [sn, cs]
rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[1] / 2
t_mat[1, 2] = -res[0] / 2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
"""
transform points
"""
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0], pt[1], 1.0]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int)
def crop(img, center, scale, res):
"""
crop images
"""
# Left up
ul = np.array(transform([0, 0], center, scale, res, invert=1))
# Right down
br = np.array(transform(res, center, scale, res, invert=1))
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
return cv2.resize(new_img, res)
def inv_mat(mat):
"""
get invert matrix
"""
ans = np.linalg.pinv(np.array(mat).tolist() + [[0, 0, 1]])
return ans[:2]
def kpt_affine(kpt, mat):
"""
get key point affine
"""
kpt = np.array(kpt)
shape = kpt.shape
kpt = kpt.reshape(-1, 2)
return np.dot(np.concatenate((kpt, kpt[:, 0:1] * 0 + 1), axis=1), mat.T).reshape(shape)
def resize(im, res):
"""
resize image
"""
return np.array([cv2.resize(im[i], res) for i in range(im.shape[0])])

View File

@ -0,0 +1,373 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
inference module
"""
import copy
import os
import cv2
import h5py
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
import numpy as np
import tqdm
import src.dataset.MPIIDataLoader as ds
import src.utils.img
from src.config import parse_args
args = parse_args()
def match_format(dic):
"""
get match format
"""
loc = dic["loc_k"][0, :, 0, :]
val = dic["val_k"][0, :, :]
ans = np.hstack((loc, val))
ans = np.expand_dims(ans, axis=0)
ret = []
ret.append(ans)
return ret
class MaxPool2dFilter(nn.Cell):
"""
maxpool 2d for filter
"""
def __init__(self):
super(MaxPool2dFilter, self).__init__()
self.pool = nn.MaxPool2d(3, 1, "same")
self.eq = ops.Equal()
def construct(self, x):
"""
forward
"""
maxm = self.pool(x)
return self.eq(maxm, x)
class HeatmapParser:
"""
parse heatmap
"""
def __init__(self):
self.maxpool2dfilter = MaxPool2dFilter().to_float(mindspore.float16) # avoid float error
self.topk = ops.TopK(sorted=True)
self.stack = ops.Stack(axis=3)
def nms(self, det):
"""
keep max in 3x3
"""
maxm = self.maxpool2dfilter(det)
det = det * maxm
return det
def calc(self, det):
"""
calc distance
"""
det = self.nms(det)
w = det.shape[3]
det = det.view(det.shape[0], det.shape[1], -1)
val_k, ind = self.topk(det, 1)
x = ind % w
y = (ind.astype(mindspore.float32) / w).astype(mindspore.int32)
ind_k = self.stack((x, y))
ans = {"loc_k": ind_k, "val_k": val_k}
return {key: ans[key].asnumpy() for key in ans}
def adjust(self, ans, det):
"""
adjust for joint
"""
for people in ans:
for i in people:
for joint_id, joint in enumerate(i):
if joint[2] > 0:
y, x = joint[0:2]
xx, yy = int(x), int(y)
tmp = det[0][joint_id]
if tmp[xx, min(yy + 1, tmp.shape[1] - 1)] > tmp[xx, max(yy - 1, 0)]:
y += 0.25
else:
y -= 0.25
if tmp[min(xx + 1, tmp.shape[0] - 1), yy] > tmp[max(0, xx - 1), yy]:
x += 0.25
else:
x -= 0.25
ans[0][0, joint_id, 0:2] = (y + 0.5, x + 0.5)
return ans
def parse(self, det, adjust=True):
"""
parse heatmap
"""
ans = match_format(self.calc(det))
if adjust:
ans = self.adjust(ans, det)
return ans
parser = HeatmapParser()
def post_process(det, mat_, trainval, c=None, s=None, resolution=None):
"""
post process for parser
"""
mat = np.linalg.pinv(np.array(mat_).tolist() + [[0, 0, 1]])[:2]
cropped_preds = parser.parse(mindspore.Tensor(np.float32([det])))[0]
if cropped_preds.size > 0:
cropped_preds[:, :, :2] = src.utils.img.kpt_affine(cropped_preds[:, :, :2] * 4, mat) # size 1x16x3
preds = np.copy(cropped_preds)
# revert to origin image
if trainval != "cropped":
for j in range(preds.shape[1]):
preds[0, j, :2] = src.utils.img.transform(preds[0, j, :2], c, s, resolution, invert=1)
return preds
def inference(img, net, c, s):
"""
forward pass at test time
calls post_process to post process results
"""
height, width = img.shape[0:2]
center = (width / 2, height / 2)
scale = max(height, width) / 200
res = (args.input_res, args.input_res)
mat_ = src.utils.img.get_transform(center, scale, res)[:2]
inp = img / 255
tmp1 = net(mindspore.Tensor([inp], dtype=mindspore.float32)).asnumpy()
tmp2 = net(mindspore.Tensor([inp[:, ::-1]], dtype=mindspore.float32)).asnumpy()
tmp = np.concatenate((tmp1, tmp2), axis=0)
det = tmp[0, -1] + tmp[1, -1, :, :, ::-1][ds.flipped_parts["mpii"]]
if det is None:
return [], []
det = det / 2
det = np.minimum(det, 1)
return post_process(det, mat_, "valid", c, s, res)
class MPIIEval:
"""
eval for MPII dataset
"""
template = {
"all": {
"total": 0,
"ankle": 0,
"knee": 0,
"hip": 0,
"pelvis": 0,
"thorax": 0,
"neck": 0,
"head": 0,
"wrist": 0,
"elbow": 0,
"shoulder": 0,
},
"visible": {
"total": 0,
"ankle": 0,
"knee": 0,
"hip": 0,
"pelvis": 0,
"thorax": 0,
"neck": 0,
"head": 0,
"wrist": 0,
"elbow": 0,
"shoulder": 0,
},
"not visible": {
"total": 0,
"ankle": 0,
"knee": 0,
"hip": 0,
"pelvis": 0,
"thorax": 0,
"neck": 0,
"head": 0,
"wrist": 0,
"elbow": 0,
"shoulder": 0,
},
}
joint_map = [
"ankle",
"knee",
"hip",
"hip",
"knee",
"ankle",
"pelvis",
"thorax",
"neck",
"head",
"wrist",
"elbow",
"shoulder",
"shoulder",
"elbow",
"wrist",
]
def __init__(self):
self.correct = copy.deepcopy(self.template)
self.count = copy.deepcopy(self.template)
self.correct_train = copy.deepcopy(self.template)
self.count_train = copy.deepcopy(self.template)
def eval(self, pred, gt, normalizing, num_train, bound=0.5):
"""
use PCK with threshold of .5 of normalized distance (presumably head size)
"""
idx = 0
for p, g, normalize in zip(pred, gt, normalizing):
for j in range(g.shape[1]):
vis = "visible"
if g[0, j, 0] == 0: # Not in image
continue
if g[0, j, 2] == 0:
vis = "not visible"
joint = self.joint_map[j]
if idx >= num_train:
self.count["all"]["total"] += 1
self.count["all"][joint] += 1
self.count[vis]["total"] += 1
self.count[vis][joint] += 1
else:
self.count_train["all"]["total"] += 1
self.count_train["all"][joint] += 1
self.count_train[vis]["total"] += 1
self.count_train[vis][joint] += 1
error = np.linalg.norm(p[0]["keypoints"][j, :2] - g[0, j, :2]) / normalize
if idx >= num_train:
if bound > error:
self.correct["all"]["total"] += 1
self.correct["all"][joint] += 1
self.correct[vis]["total"] += 1
self.correct[vis][joint] += 1
else:
if bound > error:
self.correct_train["all"]["total"] += 1
self.correct_train["all"][joint] += 1
self.correct_train[vis]["total"] += 1
self.correct_train[vis][joint] += 1
idx += 1
self.output_result(bound)
def output_result(self, bound):
"""
output split via train/valid
"""
for k in self.correct:
print(k, ":")
for key in self.correct[k]:
print(
"Val PCK @,",
bound,
",",
key,
":",
round(self.correct[k][key] / max(self.count[k][key], 1), 3),
", count:",
self.count[k][key],
)
print(
"Tra PCK @,",
bound,
",",
key,
":",
round(self.correct_train[k][key] / max(self.count_train[k][key], 1), 3),
", count:",
self.count_train[k][key],
)
print("\n")
def get_img(num_eval=2958, num_train=300):
"""
load validation and training images
"""
input_res = args.input_res
val_f = h5py.File(os.path.join(args.annot_dir, "valid.h5"), "r")
tr = tqdm.tqdm(range(0, num_train), total=num_train)
# Train set
train_f = h5py.File(os.path.join(args.annot_dir, "train.h5"), "r")
for i in tr:
path_t = "%s/%s" % (args.img_dir, train_f["imgname"][i].decode("UTF-8"))
orig_img = cv2.imread(path_t)[:, :, ::-1]
c = train_f["center"][i]
s = train_f["scale"][i]
im = src.utils.img.crop(orig_img, c, s, (input_res, input_res))
kp = train_f["part"][i]
vis = train_f["visible"][i]
kp2 = np.insert(kp, 2, vis, axis=1)
kps = np.zeros((1, 16, 3))
kps[0] = kp2
n = train_f["normalize"][i]
yield kps, im, c, s, n
tr2 = tqdm.tqdm(range(0, num_eval), total=num_eval)
# Valid
for i in tr2:
path_t = "%s/%s" % (args.img_dir, val_f["imgname"][i].decode("UTF-8"))
orig_img = cv2.imread(path_t)[:, :, ::-1]
c = val_f["center"][i]
s = val_f["scale"][i]
im = src.utils.img.crop(orig_img, c, s, (input_res, input_res))
kp = val_f["part"][i]
vis = val_f["visible"][i]
kp2 = np.insert(kp, 2, vis, axis=1)
kps = np.zeros((1, 16, 3))
kps[0] = kp2
n = val_f["normalize"][i]
yield kps, im, c, s, n

View File

@ -0,0 +1,101 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
run model train
"""
import math
import os
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import Model, context
from mindspore.common import set_seed
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
ModelCheckpoint, TimeMonitor)
from src.config import parse_args
from src.dataset.DatasetGenerator import DatasetGenerator
from src.dataset.MPIIDataLoader import MPII
from src.models.loss import HeatmapLoss
from src.models.StackedHourglassNet import StackedHourglassNet
set_seed(1)
args = parse_args()
if __name__ == "__main__":
if not os.path.exists(args.img_dir) or not os.path.exists(args.annot_dir):
print("Dataset not found.")
exit()
# Set context mode
if args.context_mode == "GRAPH":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
if args.parallel:
# Parallel mode
context.reset_auto_parallel_context()
init()
context.set_auto_parallel_context(parallel_mode=context.ParallelMode.DATA_PARALLEL, gradients_mean=True)
args.rank_id = get_rank()
args.group_size = get_group_size()
else:
args.rank_id = 0
args.group_size = 1
net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim)
# Process dataset
mpii = MPII()
train, valid = mpii.setup_val_split()
train_generator = DatasetGenerator(args.input_res, args.output_res, mpii, train)
train_size = len(train_generator)
train_sampler = ds.DistributedSampler(num_shards=args.group_size, shard_id=args.rank_id, shuffle=True)
train_data = ds.GeneratorDataset(train_generator, ["data", "label"], sampler=train_sampler)
train_data = train_data.batch(args.batch_size, True, args.group_size)
print("train data size:", train_size)
step_per_epoch = math.ceil(train_size / args.batch_size / args.group_size)
# Define loss function
loss_func = HeatmapLoss()
# Define optimizer
lr_decay = nn.exponential_decay_lr(
args.initial_lr, args.decay_rate, args.num_epoch * step_per_epoch, step_per_epoch, args.decay_epoch
)
optimizer = nn.Adam(net.trainable_params(), lr_decay)
# Define model
model = Model(net, loss_func, optimizer, amp_level=args.amp_level, keep_batchnorm_fp32=False)
# Define callback functions
callbacks = []
callbacks.append(LossMonitor(args.loss_log_interval))
callbacks.append(TimeMonitor(train_size))
# Save checkpoint file
if args.rank_id == 0:
config_ck = CheckpointConfig(
save_checkpoint_steps=args.save_checkpoint_epochs * step_per_epoch,
keep_checkpoint_max=args.keep_checkpoint_max,
)
ckpoint = ModelCheckpoint("ckpt", config=config_ck)
callbacks.append(ckpoint)
model.train(args.num_epoch, train_data, callbacks=callbacks, dataset_sink_mode=True)