forked from mindspore-Ecosystem/mindspore
commit
438cbefaf4
0
mindspore/lite/micro/example/mnist_stm32f746/mnist_stm32f746/operator_library/kernels/nnacl/fp32/softmax_fp32.c
Executable file → Normal file
0
mindspore/lite/micro/example/mnist_stm32f746/mnist_stm32f746/operator_library/kernels/nnacl/fp32/softmax_fp32.c
Executable file → Normal file
|
@ -0,0 +1,263 @@
|
||||||
|
# 目录
|
||||||
|
|
||||||
|
<!-- TOC -->
|
||||||
|
|
||||||
|
- [simple_baselines描述](#simple_baselines描述)
|
||||||
|
- [模型架构](#模型架构)
|
||||||
|
- [数据集](#数据集)
|
||||||
|
- [特性](#特性)
|
||||||
|
- [混合精度](#混合精度)
|
||||||
|
- [环境要求](#环境要求)
|
||||||
|
- [快速入门](#快速入门)
|
||||||
|
- [脚本说明](#脚本说明)
|
||||||
|
- [脚本及样例代码](#脚本及样例代码)
|
||||||
|
- [脚本参数](#脚本参数)
|
||||||
|
- [训练过程](#训练过程)
|
||||||
|
- [评估过程](#评估过程)
|
||||||
|
- [模型描述](#模型描述)
|
||||||
|
- [性能](#性能)
|
||||||
|
- [评估性能](#评估性能)
|
||||||
|
- [随机情况说明](#随机情况说明)
|
||||||
|
- [ModelZoo主页](#ModelZoo主页)
|
||||||
|
|
||||||
|
<!-- /TOC -->
|
||||||
|
|
||||||
|
# simple baselines描述
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
simple_baselines模型网络由微软亚洲研究院Bin Xiao等人提出,作者认为当前流行的人体姿态估计和追踪方法都过于复杂,已有的关于人体姿势估计和姿势追踪模型在结构上看似差异较大,但在性能方面确又接近。作者提出了一种简单有效的基线方法,通过在主干网络ResNet上添加反卷积层,这恰恰是从高和低分辨率特征图中估计热图的最简单方法,从而有助于激发和评估该领域的新想法。
|
||||||
|
|
||||||
|
simple_baselines模型网络具体细节可参考[论文1](https://arxiv.org/pdf/1804.06208.pdf),simple_baselines模型网络Mindspore实现基于原微软亚洲研究院发布的Pytorch版本实现,具体可参考(<https://github.com/microsoft/human-pose-estimation.pytorch>)。
|
||||||
|
|
||||||
|
## 论文
|
||||||
|
|
||||||
|
1. [论文](https://arxiv.org/pdf/1804.06208.pdf):Bin Xiao, Haiping Wu, Yichen Wei."Simple baselines for human pose estimation and tracking"
|
||||||
|
|
||||||
|
# 模型架构
|
||||||
|
|
||||||
|
simple_baselines的总体网络架构如下:
|
||||||
|
[链接](https://arxiv.org/pdf/1804.06208.pdf)
|
||||||
|
|
||||||
|
# 数据集
|
||||||
|
|
||||||
|
使用的数据集:[COCO2017]
|
||||||
|
|
||||||
|
- 数据集大小:
|
||||||
|
- 训练集:19.56G, 118,287个图像
|
||||||
|
- 测试集:825MB, 5,000个图像
|
||||||
|
- 数据格式:JPG文件
|
||||||
|
- 注:数据在src/dataset.py中处理
|
||||||
|
|
||||||
|
# 特性
|
||||||
|
|
||||||
|
## 混合精度
|
||||||
|
|
||||||
|
采用[混合精度](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||||
|
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
|
||||||
|
|
||||||
|
# 环境要求
|
||||||
|
|
||||||
|
- 硬件(Ascend)
|
||||||
|
- 准备Ascend处理器搭建硬件环境。
|
||||||
|
- 框架
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- 如需查看详情,请参见如下资源:
|
||||||
|
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
# 快速入门
|
||||||
|
|
||||||
|
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||||
|
|
||||||
|
- 预训练模型
|
||||||
|
|
||||||
|
当开始训练之前需要获取mindspore图像网络预训练模型,可通过在[official model zoo](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet)中运行Resnet训练脚本来获取模型权重文件,预训练文件名称为resnet50.ckpt。
|
||||||
|
|
||||||
|
- 数据集准备
|
||||||
|
|
||||||
|
simple_baselines网络模型使用COCO2017数据集用于训练和推理,数据集可通过[official website](https://cocodataset.org/)官方网站下载使用。
|
||||||
|
|
||||||
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
|
```text
|
||||||
|
# 分布式训练
|
||||||
|
用法:sh run_distribute_train.sh --is_model_arts False --run_distribute True
|
||||||
|
|
||||||
|
# 单机训练
|
||||||
|
用法:sh run_standalone_train.sh --device_id 0 --is_model_arts False --run_distribute False
|
||||||
|
|
||||||
|
# 运行评估示例
|
||||||
|
用法:sh run_eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
# 脚本说明
|
||||||
|
|
||||||
|
## 脚本及样例代码
|
||||||
|
|
||||||
|
```shell
|
||||||
|
|
||||||
|
└──simple_baselines
|
||||||
|
├── README.md
|
||||||
|
├── scripts
|
||||||
|
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||||
|
├── run_eval.sh # 启动Ascend评估
|
||||||
|
├── run_standalone_train.sh # 启动Ascend单机训练(单卡)
|
||||||
|
├── src
|
||||||
|
├── utils
|
||||||
|
├── coco.py # COCO数据集评估结果
|
||||||
|
├── inference.py # 热图关键点预测
|
||||||
|
├── nms.py # nms
|
||||||
|
├── transforms.py # 图像处理转换
|
||||||
|
├── config.py # 参数配置
|
||||||
|
├── dataset.py # 数据预处理
|
||||||
|
├── network_with_loss.py # 损失函数定义
|
||||||
|
└── pose_resnet.py # 主干网络定义
|
||||||
|
├── eval.py # 评估网络
|
||||||
|
└── train.py # 训练网络
|
||||||
|
```
|
||||||
|
|
||||||
|
## 脚本参数
|
||||||
|
|
||||||
|
在src/config.py中配置相关参数。
|
||||||
|
|
||||||
|
- 配置模型相关参数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config.MODEL.INIT_WEIGHTS = True # 初始化模型权重
|
||||||
|
config.MODEL.PRETRAINED = 'resnet50.ckpt' # 预训练模型
|
||||||
|
config.MODEL.NUM_JOINTS = 17 # 关键点数量
|
||||||
|
config.MODEL.IMAGE_SIZE = [192, 256] # 图像大小
|
||||||
|
```
|
||||||
|
|
||||||
|
- 配置网络相关参数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config.NETWORK.NUM_LAYERS = 50 # resnet主干网络层数
|
||||||
|
config.NETWORK.DECONV_WITH_BIAS = False # 网络反卷积偏差
|
||||||
|
config.NETWORK.NUM_DECONV_LAYERS = 3 # 网络反卷积层数
|
||||||
|
config.NETWORK.NUM_DECONV_FILTERS = [256, 256, 256] # 反卷积层过滤器尺寸
|
||||||
|
config.NETWORK.NUM_DECONV_KERNELS = [4, 4, 4] # 反卷积层内核大小
|
||||||
|
config.NETWORK.FINAL_CONV_KERNEL = 1 # 最终卷积层内核大小
|
||||||
|
config.NETWORK.HEATMAP_SIZE = [48, 64] # 热图尺寸
|
||||||
|
```
|
||||||
|
|
||||||
|
- 配置训练相关参数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config.TRAIN.SHUFFLE = True # 训练数据随机排序
|
||||||
|
config.TRAIN.BATCH_SIZE = 64 # 训练批次大小
|
||||||
|
config.TRAIN.BEGIN_EPOCH = 0 # 测试数据集文件名
|
||||||
|
config.DATASET.FLIP = True # 数据集随机翻转
|
||||||
|
config.DATASET.SCALE_FACTOR = 0.3 # 数据集随机规模因数
|
||||||
|
config.DATASET.ROT_FACTOR = 40 # 数据集随机旋转因数
|
||||||
|
config.TRAIN.BEGIN_EPOCH = 0 # 初始周期数
|
||||||
|
config.TRAIN.END_EPOCH = 140 # 最终周期数
|
||||||
|
config.TRAIN.LR = 0.001 # 初始学习率
|
||||||
|
config.TRAIN.LR_FACTOR = 0.1 # 学习率降低因子
|
||||||
|
```
|
||||||
|
|
||||||
|
- 配置验证相关参数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config.TEST.BATCH_SIZE = 32 # 验证批次大小
|
||||||
|
config.TEST.FLIP_TEST = True # 翻转验证
|
||||||
|
config.TEST.USE_GT_BBOX = False # 使用标注框
|
||||||
|
```
|
||||||
|
|
||||||
|
- 配置nms相关参数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config.TEST.OKS_THRE = 0.9 # OKS阈值
|
||||||
|
config.TEST.IN_VIS_THRE = 0.2 # 可视化阈值
|
||||||
|
config.TEST.BBOX_THRE = 1.0 # 候选框阈值
|
||||||
|
config.TEST.IMAGE_THRE = 0.0 # 图像阈值
|
||||||
|
config.TEST.NMS_THRE = 1.0 # nms阈值
|
||||||
|
```
|
||||||
|
|
||||||
|
## 训练过程
|
||||||
|
|
||||||
|
### 用法
|
||||||
|
|
||||||
|
#### Ascend处理器环境运行
|
||||||
|
|
||||||
|
```text
|
||||||
|
# 分布式训练
|
||||||
|
用法:sh run_distribute_train.sh --is_model_arts False --run_distribute True
|
||||||
|
|
||||||
|
# 单机训练
|
||||||
|
用法:sh run_standalone_train.sh --device_id 0 --is_model_arts False --run_distribute False
|
||||||
|
|
||||||
|
# 运行评估示例
|
||||||
|
用法:sh run_eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 结果
|
||||||
|
|
||||||
|
- 使用COCO2017数据集训练simple_baselines
|
||||||
|
|
||||||
|
```text
|
||||||
|
分布式训练结果(8P)
|
||||||
|
epoch:1 step:2340, loss is 0.0008106
|
||||||
|
epoch:2 step:2340, loss is 0.0006160
|
||||||
|
epoch:3 step:2340, loss is 0.0006480
|
||||||
|
epoch:4 step:2340, loss is 0.0005620
|
||||||
|
epoch:5 step:2340, loss is 0.0005207
|
||||||
|
...
|
||||||
|
epoch:138 step:2340, loss is 0.0003183
|
||||||
|
epoch:139 step:2340, loss is 0.0002866
|
||||||
|
epoch:140 step:2340, loss is 0.0003393
|
||||||
|
```
|
||||||
|
|
||||||
|
## 评估过程
|
||||||
|
|
||||||
|
### 用法
|
||||||
|
|
||||||
|
#### Ascend处理器环境运行
|
||||||
|
|
||||||
|
可通过改变config.py文件中的"config.TEST.MODEL_FILE"文件进行相应模型推理。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 评估
|
||||||
|
sh eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 结果
|
||||||
|
|
||||||
|
使用COCO2017数据集文件夹中val2017进行评估simple_baselines,如下所示:
|
||||||
|
|
||||||
|
```text
|
||||||
|
coco eval results saved to /cache/train_output/multi_train_poseresnet_v5_2-140_2340/keypoints_results.pkl
|
||||||
|
AP: 0.704
|
||||||
|
```
|
||||||
|
|
||||||
|
# 模型描述
|
||||||
|
|
||||||
|
## 性能
|
||||||
|
|
||||||
|
### 评估性能
|
||||||
|
|
||||||
|
#### COCO2017上性能参数
|
||||||
|
|
||||||
|
| Parameters | Ascend 910 |
|
||||||
|
| ------------------- | --------------------------- |
|
||||||
|
| 模型版本 | simple_baselines |
|
||||||
|
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G |
|
||||||
|
| 上传日期 | 2021-03-29 |
|
||||||
|
| MindSpore版本 | 1.1.0 |
|
||||||
|
| 数据集 | COCO2017 |
|
||||||
|
| 训练参数 | epoch=140, batch_size=64 |
|
||||||
|
| 优化器 | Adam |
|
||||||
|
| 损失函数 | Mean Squared Error |
|
||||||
|
| 输出 | heatmap |
|
||||||
|
| 输出 | heatmap |
|
||||||
|
| 速度 | 1pc: 251.4 ms/step |
|
||||||
|
| 训练性能 | AP: 0.704 |
|
||||||
|
|
||||||
|
# 随机情况说明
|
||||||
|
|
||||||
|
dataset.py中设置了“create_dataset”函数内的种子,同时在model.py中使用了初始化网络权重。
|
||||||
|
|
||||||
|
# ModelZoo主页
|
||||||
|
|
||||||
|
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
This file evaluates the model used.
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor, float32, context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
from src.pose_resnet import GetPoseResNet
|
||||||
|
from src.dataset import flip_pairs
|
||||||
|
from src.dataset import CreateDatasetCoco
|
||||||
|
from src.utils.coco import evaluate
|
||||||
|
from src.utils.transforms import flip_back
|
||||||
|
from src.utils.inference import get_final_preds
|
||||||
|
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
import moxing as mox
|
||||||
|
|
||||||
|
set_seed(config.GENERAL.EVAL_SEED)
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Evaluate')
|
||||||
|
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
|
||||||
|
parser.add_argument('--train_url', required=True, default=None, help='Location of evaluate outputs.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def validate(cfg, val_dataset, model, output_dir, ann_path):
|
||||||
|
'''
|
||||||
|
validate
|
||||||
|
'''
|
||||||
|
model.set_train(False)
|
||||||
|
num_samples = val_dataset.get_dataset_size() * cfg.TEST.BATCH_SIZE
|
||||||
|
all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3),
|
||||||
|
dtype=np.float32)
|
||||||
|
all_boxes = np.zeros((num_samples, 2))
|
||||||
|
image_id = []
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for item in val_dataset.create_dict_iterator():
|
||||||
|
inputs = item['image'].asnumpy()
|
||||||
|
output = model(Tensor(inputs, float32)).asnumpy()
|
||||||
|
if cfg.TEST.FLIP_TEST:
|
||||||
|
inputs_flipped = Tensor(inputs[:, :, :, ::-1], float32)
|
||||||
|
output_flipped = model(inputs_flipped)
|
||||||
|
output_flipped = flip_back(output_flipped.asnumpy(), flip_pairs)
|
||||||
|
|
||||||
|
if cfg.TEST.SHIFT_HEATMAP:
|
||||||
|
output_flipped[:, :, :, 1:] = \
|
||||||
|
output_flipped.copy()[:, :, :, 0:-1]
|
||||||
|
|
||||||
|
output = (output + output_flipped) * 0.5
|
||||||
|
|
||||||
|
c = item['center'].asnumpy()
|
||||||
|
s = item['scale'].asnumpy()
|
||||||
|
score = item['score'].asnumpy()
|
||||||
|
file_id = list(item['id'].asnumpy())
|
||||||
|
|
||||||
|
preds, maxvals = get_final_preds(cfg, output.copy(), c, s)
|
||||||
|
num_images, _ = preds.shape[:2]
|
||||||
|
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
|
||||||
|
all_preds[idx:idx + num_images, :, 2:3] = maxvals
|
||||||
|
all_boxes[idx:idx + num_images, 0] = np.prod(s * 200, 1)
|
||||||
|
all_boxes[idx:idx + num_images, 1] = score
|
||||||
|
image_id.extend(file_id)
|
||||||
|
idx += num_images
|
||||||
|
if idx % 1024 == 0:
|
||||||
|
print('{} samples validated in {} seconds'.format(idx, time.time() - start))
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
print(all_preds[:idx].shape, all_boxes[:idx].shape, len(image_id))
|
||||||
|
_, perf_indicator = evaluate(cfg, all_preds[:idx], output_dir, all_boxes[:idx], image_id, ann_path)
|
||||||
|
print("AP:", perf_indicator)
|
||||||
|
return perf_indicator
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
mox.file.copy_parallel(src_url=args.data_url, dst_url=config.MODELARTS.CACHE_INPUT)
|
||||||
|
|
||||||
|
model = GetPoseResNet(config)
|
||||||
|
|
||||||
|
ckpt_name = ''
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
ckpt_name = config.MODELARTS.CACHE_INPUT
|
||||||
|
else:
|
||||||
|
ckpt_name = config.DATASET.ROOT
|
||||||
|
ckpt_name = ckpt_name + config.TEST.MODEL_FILE
|
||||||
|
print('loading model ckpt from {}'.format(ckpt_name))
|
||||||
|
load_param_into_net(model, load_checkpoint(ckpt_name))
|
||||||
|
|
||||||
|
valid_dataset = CreateDatasetCoco(
|
||||||
|
train_mode=False,
|
||||||
|
num_parallel_workers=config.TEST.NUM_PARALLEL_WORKERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_name = ckpt_name.split('/')
|
||||||
|
ckpt_name = ckpt_name[len(ckpt_name) - 1]
|
||||||
|
ckpt_name = ckpt_name.split('.')[0]
|
||||||
|
output_dir = ''
|
||||||
|
ann_path = ''
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
output_dir = config.MODELARTS.CACHE_OUTPUT
|
||||||
|
ann_path = config.MODELARTS.CACHE_INPUT
|
||||||
|
else:
|
||||||
|
output_dir = config.TEST.OUTPUT_DIR
|
||||||
|
ann_path = config.DATASET.ROOT
|
||||||
|
output_dir = output_dir + ckpt_name
|
||||||
|
ann_path = ann_path + config.DATASET.TEST_JSON
|
||||||
|
validate(config, valid_dataset, model, output_dir, ann_path)
|
||||||
|
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
mox.file.copy_parallel(src_url=config.MODELARTS.CACHE_OUTPUT, dst_url=args.train_url)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
##############export checkpoint file into air, onnx, mindir models#################
|
||||||
|
python export.py
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.common.dtype as ms
|
||||||
|
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
from src.pose_resnet import GetPoseResNet
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='simple_baselines')
|
||||||
|
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||||
|
help="device target")
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||||
|
parser.add_argument("--ckpt_url", type=str, required=True, help="Checkpoint file path.")
|
||||||
|
parser.add_argument("--file_name", type=str, default="simple_baselines", help="output file name.")
|
||||||
|
parser.add_argument('--file_format', type=str, choices=["MINDIR"], default='MINDIR', help='file format')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cfg = config
|
||||||
|
|
||||||
|
net = GetPoseResNet(config)
|
||||||
|
|
||||||
|
assert cfg.checkpoint_dir is not None, "cfg.checkpoint_dir is None."
|
||||||
|
param_dict = load_checkpoint(args.ckpt_url)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
input_arr = Tensor(np.ones([64, 3, 192, 224]), ms.float32)
|
||||||
|
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,80 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "========================================================================"
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run.sh RANK_SIZE"
|
||||||
|
echo "For example: bash run_distribute.sh 8"
|
||||||
|
echo "It is better to use the absolute path."
|
||||||
|
echo "========================================================================"
|
||||||
|
set -e
|
||||||
|
|
||||||
|
RANK_SIZE=$1
|
||||||
|
export RANK_SIZE
|
||||||
|
|
||||||
|
EXEC_PATH=$(pwd)
|
||||||
|
echo "$EXEC_PATH"
|
||||||
|
|
||||||
|
test_dist_8pcs()
|
||||||
|
{
|
||||||
|
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||||
|
export RANK_SIZE=8
|
||||||
|
}
|
||||||
|
|
||||||
|
test_dist_2pcs()
|
||||||
|
{
|
||||||
|
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||||
|
export RANK_SIZE=2
|
||||||
|
}
|
||||||
|
|
||||||
|
test_dist_${RANK_SIZE}pcs
|
||||||
|
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
cd ../
|
||||||
|
rm -rf distribute_train
|
||||||
|
mkdir distribute_train
|
||||||
|
cd distribute_train
|
||||||
|
for((i=0;i<${RANK_SIZE};i++))
|
||||||
|
do
|
||||||
|
rm -rf device$i
|
||||||
|
mkdir device$i
|
||||||
|
cd ./device$i
|
||||||
|
mkdir src
|
||||||
|
cd src
|
||||||
|
mkdir utils
|
||||||
|
cd ../../../
|
||||||
|
cp ./train.py ./distribute_train/device$i
|
||||||
|
cp ./src/*.py ./distribute_train/device$i/src
|
||||||
|
cp ./src/utils/*.py ./distribute_train/device$i/src/utils
|
||||||
|
cd ./distribute_train/device$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
echo "start training for device $i"
|
||||||
|
env > env$i.log
|
||||||
|
python train.py --is_model_arts False --run_distribute True > train$i.log 2>&1 &
|
||||||
|
echo "$i finish"
|
||||||
|
cd ../
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ $? -eq 0 ];then
|
||||||
|
echo "training success"
|
||||||
|
else
|
||||||
|
echo "training failed"
|
||||||
|
exit 2
|
||||||
|
fi
|
||||||
|
echo "finish"
|
||||||
|
cd ../
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
|
||||||
|
python eval.py > eval_log$1.txt 2>&1 &
|
|
@ -0,0 +1,25 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "========================================================================"
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run_standalone_train.sh"
|
||||||
|
echo "For example: bash run_standalone_train.sh"
|
||||||
|
echo "It is better to use the absolute path."
|
||||||
|
echo "========================================================================"
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
python -u ../train.py --device_id ${DEVICE_ID} --is_model_arts False --run_distribute False > train${DEVICE_ID}.log 2>&1 &
|
||||||
|
echo "finish"
|
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
config
|
||||||
|
'''
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
config = edict()
|
||||||
|
|
||||||
|
#general
|
||||||
|
config.GENERAL = edict()
|
||||||
|
config.GENERAL.VERSION = 'commit'
|
||||||
|
config.GENERAL.TRAIN_SEED = 1
|
||||||
|
config.GENERAL.EVAL_SEED = 1
|
||||||
|
config.GENERAL.DATASET_SEED = 1
|
||||||
|
config.GENERAL.RUN_DISTRIBUTE = True
|
||||||
|
|
||||||
|
#model arts
|
||||||
|
config.MODELARTS = edict()
|
||||||
|
config.MODELARTS.IS_MODEL_ARTS = False
|
||||||
|
config.MODELARTS.CACHE_INPUT = '/cache/data_tzh/'
|
||||||
|
config.MODELARTS.CACHE_OUTPUT = '/cache/train_out/'
|
||||||
|
|
||||||
|
# model
|
||||||
|
config.MODEL = edict()
|
||||||
|
config.MODEL.IS_TRAINED = True
|
||||||
|
config.MODEL.INIT_WEIGHTS = True
|
||||||
|
config.MODEL.PRETRAINED = 'resnet50.ckpt'
|
||||||
|
config.MODEL.NUM_JOINTS = 17
|
||||||
|
config.MODEL.IMAGE_SIZE = [192, 256]
|
||||||
|
|
||||||
|
# network
|
||||||
|
config.NETWORK = edict()
|
||||||
|
config.NETWORK.NUM_LAYERS = 50
|
||||||
|
config.NETWORK.DECONV_WITH_BIAS = False
|
||||||
|
config.NETWORK.NUM_DECONV_LAYERS = 3
|
||||||
|
config.NETWORK.NUM_DECONV_FILTERS = [256, 256, 256]
|
||||||
|
config.NETWORK.NUM_DECONV_KERNELS = [4, 4, 4]
|
||||||
|
config.NETWORK.FINAL_CONV_KERNEL = 1
|
||||||
|
config.NETWORK.REVERSE = True
|
||||||
|
|
||||||
|
config.NETWORK.TARGET_TYPE = 'gaussian'
|
||||||
|
config.NETWORK.HEATMAP_SIZE = [48, 64]
|
||||||
|
config.NETWORK.SIGMA = 2
|
||||||
|
|
||||||
|
# loss
|
||||||
|
config.LOSS = edict()
|
||||||
|
config.LOSS.USE_TARGET_WEIGHT = True
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
config.DATASET = edict()
|
||||||
|
config.DATASET.TYPE = 'COCO'
|
||||||
|
config.DATASET.ROOT = '/opt_data/xidian_wks/zhao/simple_baselines/coco2017/'
|
||||||
|
config.DATASET.TRAIN_SET = 'train2017'
|
||||||
|
config.DATASET.TRAIN_JSON = 'annotations/person_keypoints_train2017.json'
|
||||||
|
config.DATASET.TEST_SET = 'val2017'
|
||||||
|
config.DATASET.TEST_JSON = 'annotations/person_keypoints_val2017.json'
|
||||||
|
|
||||||
|
# training data augmentation
|
||||||
|
config.DATASET.FLIP = True
|
||||||
|
config.DATASET.SCALE_FACTOR = 0.3
|
||||||
|
config.DATASET.ROT_FACTOR = 40
|
||||||
|
|
||||||
|
# train
|
||||||
|
config.TRAIN = edict()
|
||||||
|
config.TRAIN.SHUFFLE = True
|
||||||
|
config.TRAIN.BATCH_SIZE = 64
|
||||||
|
config.TRAIN.BEGIN_EPOCH = 0
|
||||||
|
config.TRAIN.END_EPOCH = 140
|
||||||
|
config.TRAIN.LR = 0.001
|
||||||
|
config.TRAIN.LR_FACTOR = 0.1
|
||||||
|
config.TRAIN.LR_STEP = [90, 120]
|
||||||
|
config.TRAIN.NUM_PARALLEL_WORKERS = 8
|
||||||
|
config.TRAIN.SAVE_CKPT = True
|
||||||
|
config.TRAIN.CKPT_PATH = "/opt_data/xidian_wks/zhao/simple_baselines/"
|
||||||
|
|
||||||
|
# valid
|
||||||
|
config.TEST = edict()
|
||||||
|
config.TEST.BATCH_SIZE = 32
|
||||||
|
config.TEST.FLIP_TEST = True
|
||||||
|
config.TEST.POST_PROCESS = True
|
||||||
|
config.TEST.SHIFT_HEATMAP = True
|
||||||
|
config.TEST.USE_GT_BBOX = False
|
||||||
|
config.TEST.NUM_PARALLEL_WORKERS = 2
|
||||||
|
config.TEST.MODEL_FILE = 'multi_train_poseresnet_commit_0-140_292.ckpt'
|
||||||
|
config.TEST.COCO_BBOX_FILE = 'annotations/COCO_val2017_detections_AP_H_56_person.json'
|
||||||
|
config.TEST.OUTPUT_DIR = 'results/'
|
||||||
|
|
||||||
|
# nms
|
||||||
|
config.TEST.OKS_THRE = 0.9
|
||||||
|
config.TEST.IN_VIS_THRE = 0.2
|
||||||
|
config.TEST.BBOX_THRE = 1.0
|
||||||
|
config.TEST.IMAGE_THRE = 0.0
|
||||||
|
config.TEST.NMS_THRE = 1.0
|
|
@ -0,0 +1,365 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
dataset processing
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
|
from src.utils.transforms import fliplr_joints, get_affine_transform, affine_transform
|
||||||
|
from src.config import config
|
||||||
|
|
||||||
|
ds.config.set_seed(config.GENERAL.DATASET_SEED) # Set Random Seed
|
||||||
|
flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
|
||||||
|
[9, 10], [11, 12], [13, 14], [15, 16]]
|
||||||
|
|
||||||
|
class CocoDatasetGenerator:
|
||||||
|
'''
|
||||||
|
About the specific operations of coco2017 data set processing
|
||||||
|
'''
|
||||||
|
def __init__(self, cfg, is_train=False):
|
||||||
|
self.image_thre = cfg.TEST.IMAGE_THRE
|
||||||
|
self.image_size = np.array(cfg.MODEL.IMAGE_SIZE, dtype=np.int32)
|
||||||
|
self.image_width = cfg.MODEL.IMAGE_SIZE[0]
|
||||||
|
self.image_height = cfg.MODEL.IMAGE_SIZE[1]
|
||||||
|
self.aspect_ratio = self.image_width * 1.0 / self.image_height
|
||||||
|
self.heatmap_size = np.array(cfg.NETWORK.HEATMAP_SIZE, dtype=np.int32)
|
||||||
|
self.sigma = cfg.NETWORK.SIGMA
|
||||||
|
self.target_type = cfg.NETWORK.TARGET_TYPE
|
||||||
|
self.scale_factor = cfg.DATASET.SCALE_FACTOR
|
||||||
|
self.rotation_factor = cfg.DATASET.ROT_FACTOR
|
||||||
|
self.flip = cfg.DATASET.FLIP
|
||||||
|
self.db = []
|
||||||
|
self.is_train = is_train
|
||||||
|
self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
|
||||||
|
[9, 10], [11, 12], [13, 14], [15, 16]]
|
||||||
|
self.num_joints = 17
|
||||||
|
|
||||||
|
def load_gt_dataset(self, image_path, ann_file):
|
||||||
|
'''
|
||||||
|
load_gt_dataset
|
||||||
|
'''
|
||||||
|
self.db = []
|
||||||
|
|
||||||
|
with open(ann_file, "rb") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
json_dict = json.loads(lines[0].decode("utf-8"))
|
||||||
|
|
||||||
|
objs = {}
|
||||||
|
cnt = 0
|
||||||
|
for item in json_dict['annotations']:
|
||||||
|
# exclude iscrowd and no-keypoint record
|
||||||
|
if item['iscrowd'] != 0 or item['num_keypoints'] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# assert the record is valid
|
||||||
|
assert item['iscrowd'] == 0, 'is crowd'
|
||||||
|
assert item['category_id'] == 1, 'is not people'
|
||||||
|
assert item['area'] > 0, 'area le 0'
|
||||||
|
assert item['num_keypoints'] > 0, 'has no keypoint'
|
||||||
|
assert max(item['keypoints']) > 0
|
||||||
|
|
||||||
|
image_id = item['image_id']
|
||||||
|
obj = [{'num_keypoints': item['num_keypoints'], 'keypoints': item['keypoints'], 'bbox': item['bbox']}]
|
||||||
|
objs[image_id] = obj if image_id not in objs else objs[image_id] + obj
|
||||||
|
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
print('loaded %d records from coco dataset.' % cnt)
|
||||||
|
|
||||||
|
for item in json_dict['images']:
|
||||||
|
image_id = item['id']
|
||||||
|
width = item['width']
|
||||||
|
height = item['height']
|
||||||
|
if image_id not in objs:
|
||||||
|
continue
|
||||||
|
valid_objs = []
|
||||||
|
for obj in objs[image_id]:
|
||||||
|
x, y, w, h = obj['bbox']
|
||||||
|
x1 = max(0, x)
|
||||||
|
y1 = max(0, y)
|
||||||
|
x2 = min(width - 1, x1 + max(0, w - 1))
|
||||||
|
y2 = min(height - 1, y1 + max(0, h - 1))
|
||||||
|
if x2 >= x1 and y2 >= y1:
|
||||||
|
tmp_obj = deepcopy(obj)
|
||||||
|
tmp_obj['bbox'] = np.array((x1, y1, x2, y2)) - np.array((0, 0, x1, y1))
|
||||||
|
valid_objs.append(tmp_obj)
|
||||||
|
else:
|
||||||
|
assert False, 'invalid bbox!'
|
||||||
|
objs[image_id] = valid_objs
|
||||||
|
|
||||||
|
for obj in objs[image_id]:
|
||||||
|
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
|
||||||
|
joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float)
|
||||||
|
for ipt in range(self.num_joints):
|
||||||
|
joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
|
||||||
|
joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
|
||||||
|
joints_3d[ipt, 2] = 0
|
||||||
|
t_vis = obj['keypoints'][ipt * 3 + 2]
|
||||||
|
if t_vis > 1:
|
||||||
|
t_vis = 1
|
||||||
|
joints_3d_vis[ipt, 0] = t_vis
|
||||||
|
joints_3d_vis[ipt, 1] = t_vis
|
||||||
|
joints_3d_vis[ipt, 2] = 0
|
||||||
|
|
||||||
|
scale, center = self._bbox2sc(obj['bbox'])
|
||||||
|
|
||||||
|
self.db.append({
|
||||||
|
'id': int(item['id']),
|
||||||
|
'image': os.path.join(image_path, item['file_name']),
|
||||||
|
'center': center,
|
||||||
|
'scale': scale,
|
||||||
|
'joints_3d': joints_3d,
|
||||||
|
'joints_3d_vis': joints_3d_vis,
|
||||||
|
})
|
||||||
|
|
||||||
|
def load_detect_dataset(self, image_path, ann_file, bbox_file):
|
||||||
|
'''
|
||||||
|
load_detect_dataset
|
||||||
|
'''
|
||||||
|
self.db = []
|
||||||
|
all_boxes = None
|
||||||
|
with open(bbox_file, 'r') as f:
|
||||||
|
all_boxes = json.load(f)
|
||||||
|
|
||||||
|
assert all_boxes, 'Loading %s fail!' % bbox_file
|
||||||
|
print('Total boxes: {}'.format(len(all_boxes)))
|
||||||
|
|
||||||
|
with open(ann_file, "rb") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
json_dict = json.loads(lines[0].decode("utf-8"))
|
||||||
|
index_to_filename = {}
|
||||||
|
for item in json_dict['images']:
|
||||||
|
index_to_filename[item['id']] = item['file_name']
|
||||||
|
for det_res in all_boxes:
|
||||||
|
if det_res['category_id'] != 1:
|
||||||
|
continue
|
||||||
|
image = os.path.join(image_path,
|
||||||
|
index_to_filename[det_res['image_id']])
|
||||||
|
|
||||||
|
bbox = det_res['bbox']
|
||||||
|
score = det_res['score']
|
||||||
|
if score < self.image_thre:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scale, center = self._bbox2sc(bbox)
|
||||||
|
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
|
||||||
|
joints_3d_vis = np.ones((self.num_joints, 3), dtype=np.float)
|
||||||
|
|
||||||
|
self.db.append({
|
||||||
|
'id': int(det_res['image_id']),
|
||||||
|
'image': image,
|
||||||
|
'center': center,
|
||||||
|
'scale': scale,
|
||||||
|
'score': score,
|
||||||
|
'joints_3d': joints_3d,
|
||||||
|
'joints_3d_vis': joints_3d_vis,
|
||||||
|
})
|
||||||
|
|
||||||
|
def _bbox2sc(self, bbox):
|
||||||
|
"""
|
||||||
|
reform xywh to meet the need of aspect ratio
|
||||||
|
"""
|
||||||
|
x, y, w, h = bbox[:4]
|
||||||
|
center = np.zeros((2), dtype=np.float32)
|
||||||
|
center[0] = x + w * 0.5
|
||||||
|
center[1] = y + h * 0.5
|
||||||
|
|
||||||
|
if w > self.aspect_ratio * h:
|
||||||
|
h = w * 1.0 / self.aspect_ratio
|
||||||
|
elif w < self.aspect_ratio * h:
|
||||||
|
w = h * self.aspect_ratio
|
||||||
|
scale = np.array(
|
||||||
|
[w * 1.0 / 200, h * 1.0 / 200], dtype=np.float32)
|
||||||
|
if center[0] != -1:
|
||||||
|
scale = scale * 1.25
|
||||||
|
|
||||||
|
return scale, center
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
db_rec = deepcopy(self.db[idx])
|
||||||
|
|
||||||
|
image_file = db_rec['image']
|
||||||
|
|
||||||
|
data_numpy = cv2.imread(
|
||||||
|
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
|
||||||
|
|
||||||
|
if data_numpy is None:
|
||||||
|
print('[ERROR] fail to read {}'.format(image_file))
|
||||||
|
raise ValueError('Fail to read {}'.format(image_file))
|
||||||
|
|
||||||
|
joints = db_rec['joints_3d']
|
||||||
|
joints_vis = db_rec['joints_3d_vis']
|
||||||
|
|
||||||
|
c = db_rec['center']
|
||||||
|
s = db_rec['scale']
|
||||||
|
score = db_rec['score'] if 'score' in db_rec else 1
|
||||||
|
r = 0
|
||||||
|
|
||||||
|
if self.is_train:
|
||||||
|
sf = self.scale_factor
|
||||||
|
rf = self.rotation_factor
|
||||||
|
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
|
||||||
|
r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \
|
||||||
|
if random.random() <= 0.6 else 0
|
||||||
|
|
||||||
|
if self.flip and random.random() <= 0.5:
|
||||||
|
data_numpy = data_numpy[:, ::-1, :]
|
||||||
|
joints, joints_vis = fliplr_joints(
|
||||||
|
joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
|
||||||
|
c[0] = data_numpy.shape[1] - c[0] - 1
|
||||||
|
|
||||||
|
trans = get_affine_transform(c, s, r, self.image_size)
|
||||||
|
image = cv2.warpAffine(
|
||||||
|
data_numpy,
|
||||||
|
trans,
|
||||||
|
(int(self.image_size[0]), int(self.image_size[1])),
|
||||||
|
flags=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
for i in range(self.num_joints):
|
||||||
|
if joints_vis[i, 0] > 0.0:
|
||||||
|
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
|
||||||
|
|
||||||
|
target, target_weight = self.generate_heatmap(joints, joints_vis)
|
||||||
|
|
||||||
|
return image, target, target_weight, s, c, score, db_rec['id']
|
||||||
|
|
||||||
|
def generate_heatmap(self, joints, joints_vis):
|
||||||
|
'''
|
||||||
|
generate_heatmap
|
||||||
|
'''
|
||||||
|
target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
|
||||||
|
target_weight[:, 0] = joints_vis[:, 0]
|
||||||
|
|
||||||
|
assert self.target_type == 'gaussian', \
|
||||||
|
'Only support gaussian map now!'
|
||||||
|
|
||||||
|
if self.target_type == 'gaussian':
|
||||||
|
target = np.zeros((self.num_joints,
|
||||||
|
self.heatmap_size[1],
|
||||||
|
self.heatmap_size[0]),
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
tmp_size = self.sigma * 3
|
||||||
|
|
||||||
|
for joint_id in range(self.num_joints):
|
||||||
|
feat_stride = self.image_size / self.heatmap_size
|
||||||
|
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
|
||||||
|
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
|
||||||
|
# Check that any part of the gaussian is in-bounds
|
||||||
|
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
||||||
|
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
||||||
|
if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
|
||||||
|
or br[0] < 0 or br[1] < 0:
|
||||||
|
# If not, just return the image as is
|
||||||
|
target_weight[joint_id] = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# # Generate gaussian
|
||||||
|
size = 2 * tmp_size + 1
|
||||||
|
x = np.arange(0, size, 1, np.float32)
|
||||||
|
y = x[:, np.newaxis]
|
||||||
|
x0 = y0 = size // 2
|
||||||
|
# The gaussian is not normalized, we want the center value to equal 1
|
||||||
|
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))
|
||||||
|
|
||||||
|
# Usable gaussian range
|
||||||
|
g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
|
||||||
|
g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
|
||||||
|
# Image range
|
||||||
|
img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
|
||||||
|
img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
|
||||||
|
|
||||||
|
v = target_weight[joint_id]
|
||||||
|
if v > 0.5:
|
||||||
|
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
||||||
|
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
||||||
|
|
||||||
|
return target, target_weight
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.db)
|
||||||
|
|
||||||
|
def CreateDatasetCoco(rank=0,
|
||||||
|
group_size=1,
|
||||||
|
train_mode=True,
|
||||||
|
num_parallel_workers=8,
|
||||||
|
transform=None,
|
||||||
|
shuffle=None):
|
||||||
|
'''
|
||||||
|
CreateDatasetCoco
|
||||||
|
'''
|
||||||
|
per_batch_size = config.TRAIN.BATCH_SIZE if train_mode else config.TEST.BATCH_SIZE
|
||||||
|
|
||||||
|
image_path = ''
|
||||||
|
ann_file = ''
|
||||||
|
bbox_file = ''
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
image_path = config.MODELARTS.CACHE_INPUT
|
||||||
|
ann_file = config.MODELARTS.CACHE_INPUT
|
||||||
|
bbox_file = config.MODELARTS.CACHE_INPUT
|
||||||
|
else:
|
||||||
|
image_path = config.DATASET.ROOT
|
||||||
|
ann_file = config.DATASET.ROOT
|
||||||
|
bbox_file = config.DATASET.ROOT
|
||||||
|
|
||||||
|
if train_mode:
|
||||||
|
image_path = image_path + config.DATASET.TRAIN_SET
|
||||||
|
ann_file = ann_file + config.DATASET.TRAIN_JSON
|
||||||
|
else:
|
||||||
|
image_path = image_path + config.DATASET.TEST_SET
|
||||||
|
ann_file = ann_file + config.DATASET.TEST_JSON
|
||||||
|
bbox_file = bbox_file + config.TEST.COCO_BBOX_FILE
|
||||||
|
|
||||||
|
print('loading dataset from {}'.format(image_path))
|
||||||
|
|
||||||
|
shuffle = shuffle if shuffle is not None else train_mode
|
||||||
|
dataset_generator = CocoDatasetGenerator(config, is_train=train_mode)
|
||||||
|
|
||||||
|
if not train_mode and config.TEST.USE_GT_BBOX:
|
||||||
|
print('loading bbox file from {}'.format(bbox_file))
|
||||||
|
dataset_generator.load_detect_dataset(image_path, ann_file, bbox_file)
|
||||||
|
else:
|
||||||
|
dataset_generator.load_gt_dataset(image_path, ann_file)
|
||||||
|
|
||||||
|
coco_dataset = ds.GeneratorDataset(dataset_generator,
|
||||||
|
column_names=["image", "target", "weight", "scale", "center", "score", "id"],
|
||||||
|
num_parallel_workers=num_parallel_workers,
|
||||||
|
num_shards=group_size,
|
||||||
|
shard_id=rank,
|
||||||
|
shuffle=shuffle)
|
||||||
|
if transform is None:
|
||||||
|
transform_img = [
|
||||||
|
C.Rescale(1.0 / 255.0, 0.0),
|
||||||
|
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
transform_img = transform
|
||||||
|
coco_dataset = coco_dataset.map(input_columns="image",
|
||||||
|
num_parallel_workers=num_parallel_workers,
|
||||||
|
operations=transform_img)
|
||||||
|
coco_dataset = coco_dataset.batch(per_batch_size, drop_remainder=train_mode)
|
||||||
|
|
||||||
|
return coco_dataset
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
network_with_loss
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
class JointsMSELoss(_Loss):
|
||||||
|
'''
|
||||||
|
JointsMSELoss
|
||||||
|
'''
|
||||||
|
def __init__(self, use_target_weight):
|
||||||
|
super(JointsMSELoss, self).__init__()
|
||||||
|
self.criterion = nn.MSELoss(reduction='mean')
|
||||||
|
self.use_target_weight = use_target_weight
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.squeeze = P.Squeeze(1)
|
||||||
|
self.mul = P.Mul()
|
||||||
|
|
||||||
|
def construct(self, output, target, target_weight):
|
||||||
|
'''
|
||||||
|
construct
|
||||||
|
'''
|
||||||
|
total_shape = self.shape(output)
|
||||||
|
batch_size = total_shape[0]
|
||||||
|
num_joints = total_shape[1]
|
||||||
|
remained_size = 1
|
||||||
|
for i in range(2, len(total_shape)):
|
||||||
|
remained_size *= total_shape[i]
|
||||||
|
|
||||||
|
split = P.Split(1, num_joints)
|
||||||
|
new_shape = (batch_size, num_joints, remained_size)
|
||||||
|
heatmaps_pred = split(self.reshape(output, new_shape))
|
||||||
|
heatmaps_gt = split(self.reshape(target, new_shape))
|
||||||
|
loss = 0
|
||||||
|
|
||||||
|
for idx in range(num_joints):
|
||||||
|
heatmap_pred_squeezed = self.squeeze(heatmaps_pred[idx])
|
||||||
|
heatmap_gt_squeezed = self.squeeze(heatmaps_gt[idx])
|
||||||
|
if self.use_target_weight:
|
||||||
|
loss += 0.5 * self.criterion(self.mul(heatmap_pred_squeezed, target_weight[:, idx]),
|
||||||
|
self.mul(heatmap_gt_squeezed, target_weight[:, idx]))
|
||||||
|
else:
|
||||||
|
loss += 0.5 * self.criterion(heatmap_pred_squeezed, heatmap_gt_squeezed)
|
||||||
|
|
||||||
|
return loss / num_joints
|
||||||
|
|
||||||
|
class PoseResNetWithLoss(nn.Cell):
|
||||||
|
"""
|
||||||
|
Pack the model network and loss function together to calculate the loss value.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, loss):
|
||||||
|
super(PoseResNetWithLoss, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.loss = loss
|
||||||
|
|
||||||
|
def construct(self, image, target, weight, scale=None, center=None, score=None, idx=None):
|
||||||
|
output = self.network(image)
|
||||||
|
output = F.mixed_precision_cast(mstype.float32, output)
|
||||||
|
target = F.mixed_precision_cast(mstype.float32, target)
|
||||||
|
weight = F.mixed_precision_cast(mstype.float32, weight)
|
||||||
|
return self.loss(output, target, weight)
|
|
@ -0,0 +1,222 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
simple_baselines network
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
import os
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.initializer as init
|
||||||
|
import mindspore.ops.operations as F
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
BN_MOMENTUM = 0.1
|
||||||
|
|
||||||
|
class MPReverse(nn.Cell):
|
||||||
|
'''
|
||||||
|
MPReverse
|
||||||
|
'''
|
||||||
|
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
|
||||||
|
super(MPReverse, self).__init__()
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode)
|
||||||
|
self.reverse = F.ReverseV2(axis=[2, 3])
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.reverse(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
x = self.reverse(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Bottleneck(nn.Cell):
|
||||||
|
'''
|
||||||
|
model part of network
|
||||||
|
'''
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, has_bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
||||||
|
stride=stride, padding=1, has_bias=False, pad_mode='pad')
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, has_bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
'''
|
||||||
|
construct
|
||||||
|
'''
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PoseResNet(nn.Cell):
|
||||||
|
'''
|
||||||
|
PoseResNet
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, block, layers, cfg):
|
||||||
|
self.inplanes = 64
|
||||||
|
self.deconv_with_bias = cfg.NETWORK.DECONV_WITH_BIAS
|
||||||
|
|
||||||
|
super(PoseResNet, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, has_bias=False, pad_mode='pad')
|
||||||
|
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.maxpool = MPReverse(kernel_size=3, stride=2, pad_mode='same')
|
||||||
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||||
|
|
||||||
|
# used for deconv layers
|
||||||
|
self.deconv_layers = self._make_deconv_layer(
|
||||||
|
cfg.NETWORK.NUM_DECONV_LAYERS,
|
||||||
|
cfg.NETWORK.NUM_DECONV_FILTERS,
|
||||||
|
cfg.NETWORK.NUM_DECONV_KERNELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = nn.Conv2d(
|
||||||
|
in_channels=cfg.NETWORK.NUM_DECONV_FILTERS[-1],
|
||||||
|
out_channels=cfg.MODEL.NUM_JOINTS,
|
||||||
|
kernel_size=cfg.NETWORK.FINAL_CONV_KERNEL,
|
||||||
|
stride=1,
|
||||||
|
padding=1 if cfg.NETWORK.FINAL_CONV_KERNEL == 3 else 0,
|
||||||
|
pad_mode='pad',
|
||||||
|
has_bias=True,
|
||||||
|
weight_init=init.Normal(0.001)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
'''
|
||||||
|
_make_layer
|
||||||
|
'''
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.SequentialCell([nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, has_bias=False),
|
||||||
|
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM)])
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for i in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes))
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
||||||
|
'''
|
||||||
|
_make_deconv_layer
|
||||||
|
'''
|
||||||
|
assert num_layers == len(num_filters), \
|
||||||
|
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
||||||
|
assert num_layers == len(num_kernels), \
|
||||||
|
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
||||||
|
layers = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
kernel = num_kernels[i]
|
||||||
|
padding = 1
|
||||||
|
planes = num_filters[i]
|
||||||
|
|
||||||
|
layers.append(nn.Conv2dTranspose(
|
||||||
|
in_channels=self.inplanes,
|
||||||
|
out_channels=planes,
|
||||||
|
kernel_size=kernel,
|
||||||
|
stride=2,
|
||||||
|
padding=padding,
|
||||||
|
has_bias=self.deconv_with_bias,
|
||||||
|
pad_mode='pad',
|
||||||
|
weight_init=init.Normal(0.001)
|
||||||
|
))
|
||||||
|
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
self.inplanes = planes
|
||||||
|
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
'''
|
||||||
|
construct
|
||||||
|
'''
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.deconv_layers(x)
|
||||||
|
x = self.final_layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=''):
|
||||||
|
if os.path.isfile(pretrained):
|
||||||
|
# load params from pretrained
|
||||||
|
param_dict = load_checkpoint(pretrained)
|
||||||
|
load_param_into_net(self, param_dict)
|
||||||
|
print('=> loading pretrained model {}'.format(pretrained))
|
||||||
|
else:
|
||||||
|
print('=> imagenet pretrained model dose not exist')
|
||||||
|
raise ValueError('{} is not a file'.format(pretrained))
|
||||||
|
|
||||||
|
|
||||||
|
resnet_spec = {50: (Bottleneck, [3, 4, 6, 3]),
|
||||||
|
101: (Bottleneck, [3, 4, 23, 3]),
|
||||||
|
152: (Bottleneck, [3, 8, 36, 3])}
|
||||||
|
|
||||||
|
|
||||||
|
def GetPoseResNet(cfg):
|
||||||
|
'''
|
||||||
|
GetPoseResNet
|
||||||
|
'''
|
||||||
|
num_layers = cfg.NETWORK.NUM_LAYERS
|
||||||
|
block_class, layers = resnet_spec[num_layers]
|
||||||
|
network = PoseResNet(block_class, layers, cfg)
|
||||||
|
|
||||||
|
if cfg.MODEL.IS_TRAINED and cfg.MODEL.INIT_WEIGHTS:
|
||||||
|
pretrained = ''
|
||||||
|
if cfg.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
pretrained = cfg.MODELARTS.CACHE_INPUT + cfg.MODEL.PRETRAINED
|
||||||
|
else:
|
||||||
|
pretrained = cfg.TRAIN.CKPT_PATH + cfg.MODEL.PRETRAINED
|
||||||
|
network.init_weights(pretrained)
|
||||||
|
return network
|
|
@ -0,0 +1,137 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
coco
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from collections import defaultdict, OrderedDict
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
|
||||||
|
has_coco = True
|
||||||
|
except ImportError:
|
||||||
|
has_coco = False
|
||||||
|
|
||||||
|
from src.utils.nms import oks_nms
|
||||||
|
|
||||||
|
def _write_coco_keypoint_results(img_kpts, num_joints, res_file):
|
||||||
|
'''
|
||||||
|
_write_coco_keypoint_results
|
||||||
|
'''
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for img, items in img_kpts.items():
|
||||||
|
item_size = len(items)
|
||||||
|
if not items:
|
||||||
|
continue
|
||||||
|
kpts = np.array([items[k]['keypoints']
|
||||||
|
for k in range(item_size)])
|
||||||
|
keypoints = np.zeros((item_size, num_joints * 3), dtype=np.float)
|
||||||
|
keypoints[:, 0::3] = kpts[:, :, 0]
|
||||||
|
keypoints[:, 1::3] = kpts[:, :, 1]
|
||||||
|
keypoints[:, 2::3] = kpts[:, :, 2]
|
||||||
|
|
||||||
|
result = [{'image_id': int(img),
|
||||||
|
'keypoints': list(keypoints[k]),
|
||||||
|
'score': items[k]['score'],
|
||||||
|
'category_id': 1,
|
||||||
|
} for k in range(item_size)]
|
||||||
|
results.extend(result)
|
||||||
|
|
||||||
|
with open(res_file, 'w') as f:
|
||||||
|
json.dump(results, f, sort_keys=True, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def _do_python_keypoint_eval(res_file, res_folder, ann_path):
|
||||||
|
'''
|
||||||
|
_do_python_keypoint_eval
|
||||||
|
'''
|
||||||
|
coco = COCO(ann_path)
|
||||||
|
coco_dt = coco.loadRes(res_file)
|
||||||
|
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
|
||||||
|
coco_eval.params.useSegm = None
|
||||||
|
coco_eval.evaluate()
|
||||||
|
coco_eval.accumulate()
|
||||||
|
coco_eval.summarize()
|
||||||
|
stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
|
||||||
|
|
||||||
|
info_str = []
|
||||||
|
for ind, name in enumerate(stats_names):
|
||||||
|
info_str.append((name, coco_eval.stats[ind]))
|
||||||
|
|
||||||
|
eval_file = os.path.join(
|
||||||
|
res_folder, 'keypoints_results.pkl')
|
||||||
|
|
||||||
|
with open(eval_file, 'wb') as f:
|
||||||
|
pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
print('coco eval results saved to %s' % eval_file)
|
||||||
|
|
||||||
|
return info_str
|
||||||
|
|
||||||
|
def evaluate(cfg, preds, output_dir, all_boxes, img_id, ann_path):
|
||||||
|
'''
|
||||||
|
evaluate
|
||||||
|
'''
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
res_file = os.path.join(output_dir, 'keypoints_results.json')
|
||||||
|
img_kpts_dict = defaultdict(list)
|
||||||
|
for idx, file_id in enumerate(img_id):
|
||||||
|
img_kpts_dict[file_id].append({
|
||||||
|
'keypoints': preds[idx],
|
||||||
|
'area': all_boxes[idx][0],
|
||||||
|
'score': all_boxes[idx][1],
|
||||||
|
})
|
||||||
|
|
||||||
|
# rescoring and oks nms
|
||||||
|
num_joints = cfg.MODEL.NUM_JOINTS
|
||||||
|
in_vis_thre = cfg.TEST.IN_VIS_THRE
|
||||||
|
oks_thre = cfg.TEST.OKS_THRE
|
||||||
|
oks_nmsed_kpts = {}
|
||||||
|
for img, items in img_kpts_dict.items():
|
||||||
|
for item in items:
|
||||||
|
kpt_score = 0
|
||||||
|
valid_num = 0
|
||||||
|
for n_jt in range(num_joints):
|
||||||
|
max_jt = item['keypoints'][n_jt][2]
|
||||||
|
if max_jt > in_vis_thre:
|
||||||
|
kpt_score = kpt_score + max_jt
|
||||||
|
valid_num = valid_num + 1
|
||||||
|
if valid_num != 0:
|
||||||
|
kpt_score = kpt_score / valid_num
|
||||||
|
item['score'] = kpt_score * item['score']
|
||||||
|
keep = oks_nms(items, oks_thre)
|
||||||
|
if not keep:
|
||||||
|
oks_nmsed_kpts[img] = items
|
||||||
|
else:
|
||||||
|
oks_nmsed_kpts[img] = [items[kep] for kep in keep]
|
||||||
|
|
||||||
|
# evaluate and save
|
||||||
|
image_set = cfg.DATASET.TEST_SET
|
||||||
|
_write_coco_keypoint_results(oks_nmsed_kpts, num_joints, res_file)
|
||||||
|
if 'test' not in image_set and has_coco:
|
||||||
|
ann_path = ann_path if ann_path else os.path.join(cfg.DATASET.ROOT, 'annotations',
|
||||||
|
'person_keypoints_' + image_set + '.json')
|
||||||
|
info_str = _do_python_keypoint_eval(res_file, output_dir, ann_path)
|
||||||
|
name_value = OrderedDict(info_str)
|
||||||
|
return name_value, name_value['AP']
|
||||||
|
return {'Null': 0}, 0
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
inference
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from src.utils.transforms import transform_preds
|
||||||
|
|
||||||
|
def get_max_preds(batch_heatmaps):
|
||||||
|
'''
|
||||||
|
get predictions from score maps
|
||||||
|
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
|
||||||
|
'''
|
||||||
|
assert isinstance(batch_heatmaps, np.ndarray), \
|
||||||
|
'batch_heatmaps should be numpy.ndarray'
|
||||||
|
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
|
||||||
|
|
||||||
|
batch_size = batch_heatmaps.shape[0]
|
||||||
|
num_joints = batch_heatmaps.shape[1]
|
||||||
|
width = batch_heatmaps.shape[3]
|
||||||
|
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
|
||||||
|
idx = np.argmax(heatmaps_reshaped, 2)
|
||||||
|
maxvals = np.amax(heatmaps_reshaped, 2)
|
||||||
|
|
||||||
|
maxvals = maxvals.reshape((batch_size, num_joints, 1))
|
||||||
|
idx = idx.reshape((batch_size, num_joints, 1))
|
||||||
|
|
||||||
|
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
|
||||||
|
|
||||||
|
preds[:, :, 0] = (preds[:, :, 0]) % width
|
||||||
|
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
|
||||||
|
|
||||||
|
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
|
||||||
|
pred_mask = pred_mask.astype(np.float32)
|
||||||
|
|
||||||
|
preds *= pred_mask
|
||||||
|
return preds, maxvals
|
||||||
|
|
||||||
|
|
||||||
|
def get_final_preds(config, batch_heatmaps, center, scale):
|
||||||
|
'''
|
||||||
|
get_final_preds
|
||||||
|
'''
|
||||||
|
coords, maxvals = get_max_preds(batch_heatmaps)
|
||||||
|
|
||||||
|
heatmap_height = batch_heatmaps.shape[2]
|
||||||
|
heatmap_width = batch_heatmaps.shape[3]
|
||||||
|
# post-processing
|
||||||
|
if config.TEST.POST_PROCESS:
|
||||||
|
for n in range(coords.shape[0]):
|
||||||
|
for p in range(coords.shape[1]):
|
||||||
|
hm = batch_heatmaps[n][p]
|
||||||
|
px = int(math.floor(coords[n][p][0] + 0.5))
|
||||||
|
py = int(math.floor(coords[n][p][1] + 0.5))
|
||||||
|
if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
|
||||||
|
diff = np.array([hm[py][px+1] - hm[py][px-1],
|
||||||
|
hm[py+1][px]-hm[py-1][px]])
|
||||||
|
coords[n][p] += np.sign(diff) * .25
|
||||||
|
|
||||||
|
preds = coords.copy()
|
||||||
|
|
||||||
|
# Transform back
|
||||||
|
for i in range(coords.shape[0]):
|
||||||
|
preds[i] = transform_preds(coords[i], center[i], scale[i],
|
||||||
|
[heatmap_width, heatmap_height])
|
||||||
|
|
||||||
|
return preds, maxvals
|
|
@ -0,0 +1,74 @@
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
nms operation
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
|
||||||
|
'''
|
||||||
|
oks_iou
|
||||||
|
'''
|
||||||
|
if not isinstance(sigmas, np.ndarray):
|
||||||
|
sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72,
|
||||||
|
.62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||||
|
var = (sigmas * 2) ** 2
|
||||||
|
xg = g[0::3]
|
||||||
|
yg = g[1::3]
|
||||||
|
vg = g[2::3]
|
||||||
|
ious = np.zeros((d.shape[0]))
|
||||||
|
for n_d in range(0, d.shape[0]):
|
||||||
|
xd = d[n_d, 0::3]
|
||||||
|
yd = d[n_d, 1::3]
|
||||||
|
vd = d[n_d, 2::3]
|
||||||
|
dx = xd - xg
|
||||||
|
dy = yd - yg
|
||||||
|
e = (dx ** 2 + dy ** 2) / var / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
|
||||||
|
if in_vis_thre is not None:
|
||||||
|
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
|
||||||
|
e = e[ind]
|
||||||
|
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
|
||||||
|
return ious
|
||||||
|
|
||||||
|
def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
|
||||||
|
"""
|
||||||
|
greedily select boxes with high confidence and overlap with current maximum <= thresh
|
||||||
|
rule out overlap >= thresh, overlap = oks
|
||||||
|
:param kpts_db
|
||||||
|
:param thresh: retain overlap < thresh
|
||||||
|
:return: indexes to keep
|
||||||
|
"""
|
||||||
|
kpts = len(kpts_db)
|
||||||
|
if kpts == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
|
||||||
|
kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
|
||||||
|
areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])
|
||||||
|
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
|
||||||
|
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
|
||||||
|
|
||||||
|
inds = np.where(oks_ovr <= thresh)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
return keep
|
|
@ -0,0 +1,137 @@
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
transforms
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def flip_back(output_flipped, matched_parts):
|
||||||
|
'''
|
||||||
|
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
|
||||||
|
'''
|
||||||
|
assert output_flipped.ndim == 4,\
|
||||||
|
'output_flipped should be [batch_size, num_joints, height, width]'
|
||||||
|
|
||||||
|
output_flipped = output_flipped[:, :, :, ::-1]
|
||||||
|
|
||||||
|
for pair in matched_parts:
|
||||||
|
tmp = output_flipped[:, pair[0], :, :].copy()
|
||||||
|
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
|
||||||
|
output_flipped[:, pair[1], :, :] = tmp
|
||||||
|
|
||||||
|
return output_flipped
|
||||||
|
|
||||||
|
|
||||||
|
def fliplr_joints(joints, joints_vis, width, matched_parts):
|
||||||
|
"""
|
||||||
|
flip coords
|
||||||
|
"""
|
||||||
|
# Flip horizontal
|
||||||
|
joints[:, 0] = width - joints[:, 0] - 1
|
||||||
|
|
||||||
|
# Change left-right parts
|
||||||
|
for pair in matched_parts:
|
||||||
|
joints[pair[0], :], joints[pair[1], :] = \
|
||||||
|
joints[pair[1], :], joints[pair[0], :].copy()
|
||||||
|
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
|
||||||
|
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
|
||||||
|
|
||||||
|
return joints*joints_vis, joints_vis
|
||||||
|
|
||||||
|
|
||||||
|
def transform_preds(coords, center, scale, output_size):
|
||||||
|
'''
|
||||||
|
transform_preds
|
||||||
|
'''
|
||||||
|
target_coords = np.zeros(coords.shape)
|
||||||
|
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
|
||||||
|
for p in range(coords.shape[0]):
|
||||||
|
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
||||||
|
return target_coords
|
||||||
|
|
||||||
|
|
||||||
|
def get_affine_transform(center,
|
||||||
|
scale,
|
||||||
|
rot,
|
||||||
|
output_size,
|
||||||
|
shift=np.array([0, 0], dtype=np.float32),
|
||||||
|
inv=0):
|
||||||
|
'''
|
||||||
|
get_affine_transform
|
||||||
|
'''
|
||||||
|
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
||||||
|
print(scale)
|
||||||
|
scale = np.array([scale, scale])
|
||||||
|
|
||||||
|
scale_tmp = scale * 200.0
|
||||||
|
src_w = scale_tmp[0]
|
||||||
|
dst_w = output_size[0]
|
||||||
|
dst_h = output_size[1]
|
||||||
|
|
||||||
|
rot_rad = np.pi * rot / 180
|
||||||
|
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
||||||
|
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
||||||
|
|
||||||
|
src = np.zeros((3, 2), dtype=np.float32)
|
||||||
|
dst = np.zeros((3, 2), dtype=np.float32)
|
||||||
|
src[0, :] = center + scale_tmp * shift
|
||||||
|
src[1, :] = center + src_dir + scale_tmp * shift
|
||||||
|
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||||
|
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||||
|
|
||||||
|
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
||||||
|
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
||||||
|
|
||||||
|
if inv:
|
||||||
|
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||||
|
else:
|
||||||
|
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||||
|
|
||||||
|
return trans
|
||||||
|
|
||||||
|
|
||||||
|
def affine_transform(pt, t):
|
||||||
|
new_pt = np.array([pt[0], pt[1], 1.]).T
|
||||||
|
new_pt = np.dot(t, new_pt)
|
||||||
|
return new_pt[:2]
|
||||||
|
|
||||||
|
|
||||||
|
def get_3rd_point(a, b):
|
||||||
|
direct = a - b
|
||||||
|
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dir(src_point, rot_rad):
|
||||||
|
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||||
|
|
||||||
|
src_result = [0, 0]
|
||||||
|
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
||||||
|
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
||||||
|
|
||||||
|
return src_result
|
||||||
|
|
||||||
|
|
||||||
|
def crop(img, center, scale, output_size, rot=0):
|
||||||
|
trans = get_affine_transform(center, scale, rot, output_size)
|
||||||
|
|
||||||
|
dst_img = cv2.warpAffine(img,
|
||||||
|
trans,
|
||||||
|
(int(output_size[0]), int(output_size[1])),
|
||||||
|
flags=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
return dst_img
|
|
@ -0,0 +1,147 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
train
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.train import Model
|
||||||
|
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
|
||||||
|
from mindspore.nn.optim import Adam
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
from src.pose_resnet import GetPoseResNet
|
||||||
|
from src.network_with_loss import JointsMSELoss, PoseResNetWithLoss
|
||||||
|
from src.dataset import CreateDatasetCoco
|
||||||
|
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
import moxing as mox
|
||||||
|
|
||||||
|
set_seed(config.GENERAL.TRAIN_SEED)
|
||||||
|
def get_lr(begin_epoch,
|
||||||
|
total_epochs,
|
||||||
|
steps_per_epoch,
|
||||||
|
lr_init=0.1,
|
||||||
|
factor=0.1,
|
||||||
|
epoch_number_to_drop=(90, 120)
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
get_lr
|
||||||
|
'''
|
||||||
|
lr_each_step = []
|
||||||
|
total_steps = steps_per_epoch * total_epochs
|
||||||
|
step_number_to_drop = [steps_per_epoch * x for x in epoch_number_to_drop]
|
||||||
|
for i in range(int(total_steps)):
|
||||||
|
if i in step_number_to_drop:
|
||||||
|
lr_init = lr_init * factor
|
||||||
|
lr_each_step.append(lr_init)
|
||||||
|
current_step = steps_per_epoch * begin_epoch
|
||||||
|
lr_each_step = np.array(lr_each_step, dtype=np.float32)
|
||||||
|
learning_rate = lr_each_step[current_step:]
|
||||||
|
return learning_rate
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simpleposenet training")
|
||||||
|
parser.add_argument('--data_url', required=False, default=None, help='Location of data.')
|
||||||
|
parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
|
||||||
|
parser.add_argument('--device_id', required=False, default=None, type=int, help='Location of training outputs.')
|
||||||
|
parser.add_argument('--run_distribute', required=False, default=False, help='Location of training outputs.')
|
||||||
|
parser.add_argument('--is_model_arts', required=False, default=False, help='Location of training outputs.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("loading parse...")
|
||||||
|
args = parse_args()
|
||||||
|
device_id = args.device_id
|
||||||
|
config.GENERAL.RUN_DISTRIBUTE = args.run_distribute
|
||||||
|
config.MODELARTS.IS_MODEL_ARTS = args.is_model_arts
|
||||||
|
if config.GENERAL.RUN_DISTRIBUTE or config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Ascend",
|
||||||
|
save_graphs=False,
|
||||||
|
device_id=device_id)
|
||||||
|
|
||||||
|
if config.GENERAL.RUN_DISTRIBUTE:
|
||||||
|
init()
|
||||||
|
rank = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
context.set_auto_parallel_context(device_num=device_num,
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
else:
|
||||||
|
rank = 0
|
||||||
|
device_num = 1
|
||||||
|
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
mox.file.copy_parallel(src_url=args.data_url, dst_url=config.MODELARTS.CACHE_INPUT)
|
||||||
|
|
||||||
|
dataset = CreateDatasetCoco(rank=rank,
|
||||||
|
group_size=device_num,
|
||||||
|
train_mode=True,
|
||||||
|
num_parallel_workers=config.TRAIN.NUM_PARALLEL_WORKERS,
|
||||||
|
)
|
||||||
|
net = GetPoseResNet(config)
|
||||||
|
loss = JointsMSELoss(config.LOSS.USE_TARGET_WEIGHT)
|
||||||
|
net_with_loss = PoseResNetWithLoss(net, loss)
|
||||||
|
dataset_size = dataset.get_dataset_size()
|
||||||
|
lr = Tensor(get_lr(config.TRAIN.BEGIN_EPOCH,
|
||||||
|
config.TRAIN.END_EPOCH,
|
||||||
|
dataset_size,
|
||||||
|
lr_init=config.TRAIN.LR,
|
||||||
|
factor=config.TRAIN.LR_FACTOR,
|
||||||
|
epoch_number_to_drop=config.TRAIN.LR_STEP))
|
||||||
|
opt = Adam(net.trainable_params(), learning_rate=lr)
|
||||||
|
time_cb = TimeMonitor(data_size=dataset_size)
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
cb = [time_cb, loss_cb]
|
||||||
|
if config.TRAIN.SAVE_CKPT:
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=20)
|
||||||
|
prefix = ''
|
||||||
|
if config.GENERAL.RUN_DISTRIBUTE:
|
||||||
|
prefix = 'multi_' + 'train_poseresnet_' + config.GENERAL.VERSION + '_' + os.getenv('DEVICE_ID')
|
||||||
|
else:
|
||||||
|
prefix = 'single_' + 'train_poseresnet_' + config.GENERAL.VERSION
|
||||||
|
|
||||||
|
directory = ''
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
directory = config.MODELARTS.CACHE_OUTPUT + 'device_'+ os.getenv('DEVICE_ID')
|
||||||
|
elif config.GENERAL.RUN_DISTRIBUTE:
|
||||||
|
directory = config.TRAIN.CKPT_PATH + 'device_'+ os.getenv('DEVICE_ID')
|
||||||
|
else:
|
||||||
|
directory = config.TRAIN.CKPT_PATH + 'device'
|
||||||
|
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix=prefix, directory=directory, config=config_ck)
|
||||||
|
cb.append(ckpoint_cb)
|
||||||
|
model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
|
||||||
|
epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
|
||||||
|
print("************ Start training now ************")
|
||||||
|
print('start training, epoch size = %d' % epoch_size)
|
||||||
|
model.train(epoch_size, dataset, callbacks=cb)
|
||||||
|
|
||||||
|
if config.MODELARTS.IS_MODEL_ARTS:
|
||||||
|
mox.file.copy_parallel(src_url=config.MODELARTS.CACHE_OUTPUT, dst_url=args.train_url)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue