!16674 [crowdfunding] Add new model: posenet(GPU)

Merge pull request !16674 from dlliu123/posenet_gpu-master
This commit is contained in:
i-robot 2021-07-01 03:29:27 +00:00 committed by Gitee
commit 94443f8e8c
14 changed files with 1450 additions and 0 deletions

View File

@ -0,0 +1,366 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [PoseNet描述](#posenet描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [单机训练](#单机训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [KingsCollege上的PoseNet](#KingsCollege上的PoseNet)
- [StMarysChurch上的PoseNet](#StMarysChurch上的PoseNet)
- [推理性能](#推理性能)
- [KingsCollege上的PoseNet](#KingsCollege上的PoseNet)
- [StMarysChurch上的PoseNet](#StMarysChurch上的PoseNet)
- [使用流程](#使用流程)
- [推理](#推理)
- [继续训练预训练模型](#继续训练预训练模型)
- [迁移学习](#迁移学习)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# PoseNet描述
PoseNet是剑桥大学提出的一种鲁棒、实时的6DOF单目六自由度重定位系统。该系统训练一个卷积神经网络可以端到端的从RGB图像中回归出6DOF姿态而不需要其它额外的处理。该算法以每帧5ms的处理速度实时进行室内外场景的位姿估计。该网络模型包含23个卷积层可以用来解决复杂的图像平面回归问题。这可以通过利用从大规模分类数据中进行的迁移学习来实现。PoseNet利用高层特征进行图像定位作者证明对于光照变化、运动模糊以及传统SIFT注册失败的案例具有较好的鲁棒性。此外作者展示了模型推广到其他场景的扩展性以及小样本上进行姿态回归的能力。
[论文](https://arxiv.org/abs/1505.07427)Kendall A, Grimes M, Cipolla R. "PoseNet: A convolutional network for real-time 6-dof camera relocalization."*In IEEE International Conference on Computer Vision (pp. 29382946), 2015.
# 模型架构
基本骨架模型采用GoogLeNet该模型包括22个卷积层和3个分类分支其中2个分类分支在测试时将进行丢弃。改进包括3个小点移除softmax层并新增具有7个神经元的全连接回归层用于回归位姿在全连接回归层前插入神经元数为2048的特征向量层测试时回归出的四元数需进行单位化。输入数据均resize到224x224。
# 数据集
[KingsCollege](<http://mi.eng.cam.ac.uk/projects/relocalisation/#dataset>)
- 数据集大小5.73G含视频videos
- 训练集2.9G共1220张图像(seq1, seq4, seq5, seq6, seq8)
- 测试集852M共342张图像(seq2, seq3, seq7)
- 数据格式txt文件(image_url + label)
- 注数据将在src/dataset.py中处理。
[StMarysChurch](<http://mi.eng.cam.ac.uk/projects/relocalisation/#dataset>)
- 数据集大小5.04G含视频videos
- 训练集3.5G共1487张图像(seq1, seq2, seq4, seq5, seq7, seq8, seq9, seq10, seq11, seq12, seq14)
- 测试集1.34G共530张图像(seq3, seq5, seq13)
- 数据格式txt文件(image_url + label)
- 注数据将在src/dataset.py中处理。
# 环境要求
- 硬件Ascend/GPU
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```python
# 运行单机训练示例
sh run_standalone_train.sh [DATASET_NAME] [DEVICE_ID]
# 运行分布式训练示例
sh run_distribute_train.sh [DATASET_NAME] [RANK_SIZE]
# 运行评估示例
sh run_eval.sh [DEVICE_ID] [DATASET_NAME] [CKPT_PATH]
```
对于分布式训练需要提前创建JSON格式的hccl配置文件。
请遵循以下链接中的说明:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
- GPU处理器环境运行
为了在GPU处理器环境运行请将配置文件src/config.py中的`device_target`从`Ascend`改为`GPU`
```python
# 运行单机训练示例
sh run_standalone_train_gpu.sh [DATASET_NAME] [DEVICE_ID]
# 运行分布式训练示例
sh run_distribute_train_gpu.sh [DATASET_NAME] [RANK_SIZE]
# 运行评估示例
sh run_eval_gpu.sh [DEVICE_ID] [DATASET_NAME] [CKPT_PATH]
```
默认使用KingsCollege数据集。您也可以将`$dataset_name`传入脚本,以便选择其他数据集。如需查看更多详情,请参考指定脚本。
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── README.md // 所有模型相关说明
├── posenet
├── README.md // posenet相关说明
├── scripts
│ ├──run_standalone_train.sh // 单机到Ascend处理器的shell脚本
│ ├──run_distribute_train.sh // 分布式到Ascend处理器的shell脚本
│ ├──run_eval.sh // Ascend评估的shell脚本
│ ├──run_standalone_train_gpu.sh // 单机到GPU的shell脚本
│ ├──run_distribute_train_gpu.sh // 分布式到GPU的shell脚本
│ ├──run_eval_gpu.sh // GPU评估的shell脚本
├── src
│ ├──dataset.py // 数据集转换成mindrecord格式创建数据集及数据预处理
│ ├──posenet.py // posenet架构
│ ├──loss.py // posenet的损失函数定义
│ ├──config.py // 参数配置
├── train.py // 训练脚本
├── eval.py // 评估脚本
├── export.py // 将checkpoint文件导出到mindir下
```
## 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置PoseNet和2种数据集。
```python
# common_config
'device_target': 'GPU', # 运行设备Ascend/GPU
'device_id': 0, # 用于训练或评估数据集的设备ID使用run_distribute_train.sh进行分布式训练时可以忽略
'pre_trained': True, # 是否基于预训练模型训练
'max_steps': 30000, # 最大迭代次数
'save_checkpoint': True, # 是否保存检查点文件
'pre_trained_file': '../pre_trained_googlenet_imagenet.ckpt', # checkpoint文件保存的路径
'checkpoint_dir': '../checkpoint', # checkpoint文件夹路径
'save_checkpoint_epochs': 5, # 保存检查点间隔epoch数
'keep_checkpoint_max': 10 # 保存的最大checkpoint文件数
# dataset_config
'batch_size': 75, # 批处理大小
'lr_init': 0.001, # 初始学习率
'weight_decay': 0.5, # 权重衰减率
'name': 'KingsCollege', # 数据集名字
'dataset_path': '../KingsCollege/', # 数据集路径
'mindrecord_dir': '../MindrecordKingsCollege' # 数据集mindrecord文件路径
```
预训练checkpoint文件'pre_trained_file'在ModelArts环境下需调整为对应的绝对路径
比如"/home/work/user-job-dir/posenet/pre_trained_googlenet_imagenet.ckpt"
更多配置细节请参考脚本`config.py`。
## 训练过程
### 单机训练
- Ascend处理器环境运行
```bash
sh run_standalone_train.sh [DATASET_NAME] [DEVICE_ID]
```
上述python命令将在后台运行您可以通过train.log文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式得到损失值:
```bash
epoch:1 step:38, loss is 1722.1506
epcoh:2 step:38, loss is 1671.5763
...
```
模型检查点保存在checkpoint文件夹下。
- GPU处理器环境运行
```bash
sh run_standalone_train_gpu.sh [DATASET_NAME] [DEVICE_ID]
```
上述python命令将在后台运行您可以通过train.log文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式得到损失值:
```bash
epoch:1 step:38, loss is 1722.1506
epcoh:2 step:38, loss is 1671.5763
...
```
模型检查点保存在checkpoint文件夹下。
### 分布式训练
- Ascend处理器环境运行
```bash
sh run_distribute_train.sh [DATASET_NAME] [RANK_SIZE]
```
上述shell脚本将在后台运行分布训练。您可以通过device[X]/log文件查看结果。采用以下方式达到损失值
```bash
device0/log:epoch:1 step:38, loss is 1722.1506
device0/log:epcoh:2 step:38, loss is 1671.5763
...
device1/log:epoch:1 step:38, loss is 1722.1506
device1/log:epcoh:2 step:38, loss is 1671.5763
...
```
- GPU处理器环境运行
```bash
sh run_distribute_train_gpu.sh [DATASET_NAME] [RANK_SIZE]
```
上述shell脚本将在后台运行分布训练。您可以通过device[X]/log文件查看结果。采用以下方式达到损失值
```bash
device0/log:epoch:1 step:38, loss is 1722.1506
device0/log:epcoh:2 step:38, loss is 1671.5763
...
device1/log:epoch:1 step:38, loss is 1722.1506
device1/log:epcoh:2 step:38, loss is 1671.5763
...
```
## 评估过程
### 评估
- 在Ascend环境运行时评估KingsCollege数据集
在运行以下命令之前,请检查用于评估的检查点路径。
请将检查点路径设置为相对路径,例如“../checkpoint/train_posenet_KingsCollege-790_38.ckpt”。
```bash
sh run_eval.sh [DEVICE_ID] [DATASET_NAME] [CKPT_PATH]
```
上述python命令将在后台运行您可以通过eval/eval.log文件查看结果。测试数据集的准确性如下
```bash
Median error 3.56644630432129 m and 3.07089155413442 degrees
```
- 在GPU环境运行时评估KingsCollege数据集
在运行以下命令之前,请检查用于评估的检查点路径。
请将检查点路径设置为相对路径,例如“../checkpoint/train_posenet_KingsCollege-1875_2.ckpt”。
```bash
sh run_eval_gpu.sh [DEVICE_ID] [DATASET_NAME] [CKPT_PATH]
```
上述python命令将在后台运行您可以通过eval/eval.log文件查看结果。测试数据集的准确性如下
```bash
Median error 3.56644630432129 m and 3.07089155413442 degrees
```
# 模型描述
## 性能
### 评估性能
#### KingsCollege上的PoseNet
| 参数 | Ascend | GPU |
| -------------------------- | ----------------------------------------------------------- | ---------------------- |
| 资源 | Ascend 910 CPU 2.60GHz192核内存755G | NV SMX2 V100-32G |
| 上传日期 | 2021-03-26 | 2021-05-20 |
| MindSpore版本 | 1.1.1-alpha | 1.2.1-alpha |
| 数据集 | KingsCollege | KingsCollege |
| 训练参数 | max_steps=30000, batch_size=75, lr_init=0.001 | max_steps=30000, batch_size=75, lr_init=0.001 |
| 优化器 | Adagrad | Adagrad |
| 损失函数 | 自定义损失函数 | 自定义损失函数 |
| 输出 | 距离、角度 | 距离、角度 |
| 损失 | 1110.86 | 1110.86 |
| 速度 | 单卡750毫秒/步; 8卡856毫秒/步 | 8卡675毫秒/步(不稳定) |
| 总时长 | 8卡75分钟 | 8卡60分钟 |
| 参数(M) | 10.7 | 10.7 |
| 微调检查点 | 82.91M (.ckpt文件) | 82.91M (.ckpt文件) |
| 推理模型 | 41.66M (.mindir文件) | 41.66M (.mindir文件) |
| 脚本 | [posenet脚本](https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/research/cv/posenet) | [posenet脚本](https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/master/cv/posenet) |
#### StMarysChurch上的PoseNet
| 参数 | Ascend | GPU |
| -------------------------- | ----------------------------------------------------------- | ---------------------- |
| 资源 | Ascend 910 CPU 2.60GHz192核内存755G | NV SMX2 V100-32G |
| 上传日期 | 2021-03-26 | 2021-05-20 |
| MindSpore版本 | 1.1.1-alpha | 1.2.1-alpha |
| 数据集 | StMarysChurch | StMarysChurch |
| 训练参数 | max_steps=30000, batch_size=75, lr_init=0.001 | max_steps=30000, batch_size=75, lr_init=0.001 |
| 优化器 | Adagrad | Adagrad |
| 损失函数 | 自定义损失函数 | 自定义损失函数 |
| 输出 | 距离、角度 | 距离、角度 |
| 损失 | 1077.86 | 1023.67 |
| 速度 | 单卡800毫秒/步; 8卡1122毫秒/步 | 8卡850毫秒/步(不稳定) |
| 总时长 | 单卡6小时40分钟; 8卡85分钟 | 8卡80分钟 |
| 参数(M) | 10.7 | 10.7 |
| 微调检查点 | 82.91M (.ckpt文件) | 82.91M (.ckpt文件) |
| 推理模型 | 41.66M (.mindir文件) | 41.66M (.mindir文件) |
| 脚本 | [posenet脚本](https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/research/cv/posenet) | [posenet脚本](https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/master/cv/posenet) |
### 推理性能
#### KingsCollege上的PoseNet
| 参数 | Ascend | GPU |
| ------------------- | --------------------------- | --------------------------- |
| 资源 | Ascend 910 | GPU |
| 上传日期 | 2021-03-26 | 2021-05-20 |
| MindSpore 版本 | 1.1.1-alpha | 1.2.1-alpha |
| 数据集 | KingsCollege | KingsCollege |
| batch_size | 1 | 1 |
| 输出 | 距离、角度 |距离、角度 |
| 准确性 | 单卡: 1.928米 4.24度; 8卡1.89米 4.31度 | 8卡1.80米 3.68度 |
| 推理模型 | 41.66M (.mindir文件) | 41.66M (.mindir文件) |
#### StMarysChurch上的PoseNet
| 参数 | Ascend | GPU |
| ------------------- | --------------------------- | --------------------------- |
| 资源 | Ascend 910 | GPU |
| 上传日期 | 2021-03-26 | 2021-05-20 |
| MindSpore 版本 | 1.1.1-alpha | 1.2.1-alpha |
| 数据集 | StMarysChurch | StMarysChurch |
| batch_size | 1 | 1 |
| 输出 | 距离、角度 | 距离、角度 |
| 准确性 | 单卡: 1.884米 7.20度; 8卡1.90米 6.23度 | 8卡1.89米 6.24度 |
| 推理模型 | 41.66M (.mindir文件) | 41.66M (.mindir文件) |
## 迁移学习
在Imagenet数据集上预训练GoogLeNet迁移至PoseNet。
# 随机情况说明
在train.py中我们设置了随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,111 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test posenet"""
import ast
import os
import time
import argparse
import math
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.config import common_config, KingsCollege, StMarysChurch
from src.posenet import PoseNet
from src.dataset import data_to_mindrecord, create_posenet_dataset
set_seed(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='posenet eval')
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
parser.add_argument('--dataset', type=str, default='KingsCollege',
choices=['KingsCollege', 'StMarysChurch'],
help='dataset name.')
parser.add_argument('--ckpt_url', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=False, help='Train in Modelarts.')
parser.add_argument('--data_url', default=None, help='Location of data.')
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
args_opt = parser.parse_args()
cfg = common_config
if args_opt.dataset == "KingsCollege":
dataset_cfg = KingsCollege
elif args_opt.dataset == "StMarysChurch":
dataset_cfg = StMarysChurch
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
else:
context.set_context(device_id=cfg.device_id)
eval_dataset_path = dataset_cfg.dataset_path
if args_opt.is_modelarts:
import moxing as mox
mox.file.copy_parallel(src_url=args_opt.data_url,
dst_url='/cache/dataset_eval/device_' + os.getenv('DEVICE_ID'))
eval_dataset_path = '/cache/dataset_eval/device_' + os.getenv('DEVICE_ID') + '/'
# It will generate eval mindrecord file in cfg.mindrecord_dir,
# and the file name is "dataset_cfg.name + _posenet_eval.mindrecord".
prefix = "_posenet_eval.mindrecord"
mindrecord_dir = dataset_cfg.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, dataset_cfg.name + prefix)
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
print("Create mindrecord for eval.")
data_to_mindrecord(eval_dataset_path, False, mindrecord_file)
print("Create mindrecord done, at {}".format(mindrecord_dir))
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
dataset = create_posenet_dataset(mindrecord_file, batch_size=1, device_num=1, is_training=False)
data_num = dataset.get_dataset_size()
net = PoseNet()
param_dict = load_checkpoint(args_opt.ckpt_url)
load_param_into_net(net, param_dict)
net.set_train(False)
print("Processing, please wait a moment.")
results = np.zeros((data_num, 2))
for step, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
image = item['image']
poses = item['image_pose']
pose_x = np.squeeze(poses[:, 0:3])
pose_q = np.squeeze(poses[:, 3:])
p1_x, p1_q, p2_x, p2_q, p3_x, p3_q = net(Tensor(image))
predicted_x = p3_x.asnumpy()
predicted_q = p3_q.asnumpy()
q1 = pose_q / np.linalg.norm(pose_q)
q2 = predicted_q / np.linalg.norm(predicted_q)
d = abs(np.sum(np.multiply(q1, q2)))
theta = 2 * np.arccos(d) * 180 / math.pi
error_x = np.linalg.norm(pose_x - predicted_x)
results[step, :] = [error_x, theta]
print('Iteration: ', step, ', Error XYZ (m): ', error_x, ', Error Q (degrees): ', theta)
median_result = np.median(results, axis=0)
print('Median error ', median_result[0], 'm and ', median_result[1], 'degrees.')

View File

@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np
import mindspore.common.dtype as ms
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.config import common_config, KingsCollege, StMarysChurch
from src.posenet import PoseNet
parser = argparse.ArgumentParser(description='PoseNet')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument('--dataset', type=str, default='KingsCollege',
choices=['KingsCollege', 'StMarysChurch'],
help='Name of dataset.')
parser.add_argument("--ckpt_url", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="posenet", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["MINDIR"], default='MINDIR', help='file format')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
cfg = common_config
if args.dataset == 'KingsCollege':
dataset_cfg = KingsCollege
elif args.dataset == 'StMarysChurch':
dataset_cfg = StMarysChurch
net = PoseNet()
assert cfg.checkpoint_dir is not None, "cfg.checkpoint_dir is None."
param_dict = load_checkpoint(args.ckpt_url)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.ones([dataset_cfg.batch_size, 3, 224, 224]), ms.float32)
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,76 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATASET_NAME RANK_SIZE"
echo "For example: bash run_distribute.sh dataset_name 8"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATASET_NAME=$1
RANK_SIZE=$2
export DATASET_NAME
export RANK_SIZE
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
test_dist_8pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
export RANK_SIZE=8
}
test_dist_2pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
export RANK_SIZE=2
}
test_dist_${RANK_SIZE}pcs
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cd ./device$i
mkdir src
cd ../
cp ./train.py ./device$i
cp ./src/*.py ./device$i/src
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
python train.py --run_distribute True --dataset $1 --device_num $2 --is_modelarts False > train$i.log 2>&1 &
echo "$i finish"
cd ../
done
if [ $? -eq 0 ];then
echo "training success"
else
echo "training failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train.sh DATA_PATH RANK_SIZE"
echo "For example: bash run_distribute_train.sh /path/dataset 8"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
export DEVICE_NUM=$1
export RANK_SIZE=$1
export DATASET_NAME=$2
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf ./train_parallel
mkdir ./train_parallel
cd ./train_parallel
mkdir src
cd ../
cp ../*.py ./train_parallel
cp ../src/*.py ./train_parallel/src
cd ./train_parallel
env > env.log
echo "start training"
mpirun -n $1 --allow-run-as-root \
python train.py --device_num $1 --dataset $2 --is_modelarts False --run_distribute True > train.log 2>&1 &

View File

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DEVICE_ID DATASET_NAME CKPT_PATH"
echo "For example: bash run_eval.sh device_id dataset ckpt_url"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DEVICE_ID=$1
DATASET_NAME=$2
CKPT_PATH=$3
export DEVICE_ID
export DATASET_NAME
export CKPT_PATH
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf eval/
mkdir eval
cd ./eval
mkdir src
cd ../
cp ./eval.py ./eval
cp ./src/*.py ./eval/src
cd ./eval
env > env0.log
echo "Eval begin."
python eval.py --device_id $1 --dataset $2 --ckpt_url $3 --is_modelarts False > ./eval.log 2>&1 &
if [ $? -eq 0 ];then
echo "evaling success"
else
echo "evaling failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DEVICE_ID DATASET_NAME CKPT_PATH"
echo "For example: bash run_eval.sh device_id dataset ckpt_url"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DEVICE_ID=$1
DATASET_NAME=$2
CKPT_PATH=$3
export DEVICE_ID
export DATASET_NAME
export CKPT_PATH
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf eval/
mkdir eval
cd ./eval
mkdir src
cd ../
cp ./eval.py ./eval
cp ./src/*.py ./eval/src
cd ./eval
env > env0.log
echo "Eval begin."
python eval.py --device_id $1 --dataset $2 --ckpt_url $3 --is_modelarts False > ./eval.log 2>&1 &
if [ $? -eq 0 ];then
echo "evaling success"
else
echo "evaling failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATASET_NAME DEVICE_ID"
echo "For example: bash run_standalone_train.sh dataset_name device_id"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATASET_NAME=$1
DEVICE_ID=$2
export DATASET_NAME
export DEVICE_ID
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf train
mkdir train
cd ./train
mkdir src
cd ../
cp ./train.py ./train
cp ./src/*.py ./train/src
cd ./train
env > env0.log
echo "Standalone train begin."
python train.py --run_distribute False --device_id $2 --dataset $1 --device_num 1 --is_modelarts False > ./train_alone.log 2>&1 &
if [ $? -eq 0 ];then
echo "training success"
else
echo "training failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATASET_NAME DEVICE_ID"
echo "For example: bash run_standalone_train.sh dataset_name device_id"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATASET_NAME=$1
DEVICE_ID=$2
export DATASET_NAME
export DEVICE_ID
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf train
mkdir train
cd ./train
mkdir src
cd ../
cp ./train.py ./train
cp ./src/*.py ./train/src
cd ./train
env > env0.log
echo "Standalone train begin."
python train.py --run_distribute False --device_id $2 --dataset $1 --device_num 1 --is_modelarts False > ./train_alone.log 2>&1 &
if [ $? -eq 0 ];then
echo "training success"
else
echo "training failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting
"""
from easydict import EasyDict as edict
common_config = edict({
'device_target': 'GPU',
'device_id': 0,
'pre_trained': True,
'max_steps': 30000,
'save_checkpoint': True,
# 'pre_trained_file': '/home/work/user-job-dir/posenet/pre_trained_googlenet_imagenet.ckpt',
'pre_trained_file': '../pre_trained_googlenet_imagenet.ckpt',
'checkpoint_dir': '../checkpoint',
'save_checkpoint_epochs': 5,
'keep_checkpoint_max': 10
})
KingsCollege = edict({
'batch_size': 75,
'lr_init': 0.001,
'weight_decay': 0.5,
'name': 'KingsCollege',
'dataset_path': '../KingsCollege/',
'mindrecord_dir': '../MindrecordKingsCollege'
})
StMarysChurch = edict({
'batch_size': 75,
'lr_init': 0.001,
'weight_decay': 0.5,
'name': 'StMarysChurch',
'dataset_path': '../StMarysChurch/',
'mindrecord_dir': '../MindrecordStMarysChurch'
})

View File

@ -0,0 +1,107 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py and eval.py"""
import os
import numpy as np
from mindspore.mindrecord import FileWriter
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
class Dataset:
"""dataset read"""
def __init__(self, root, is_training=True):
self.root = os.path.expanduser(root)
self.train = is_training
self.image_poses = []
self.image_paths = []
if self.train:
txt_file = self.root + 'dataset_train.txt'
else:
txt_file = self.root + 'dataset_test.txt'
with open(txt_file, 'r') as f:
next(f)
next(f)
next(f)
for line in f:
fname, p0, p1, p2, p3, p4, p5, p6 = line.split()
p0 = float(p0)
p1 = float(p1)
p2 = float(p2)
p3 = float(p3)
p4 = float(p4)
p5 = float(p5)
p6 = float(p6)
self.image_poses.append((p0, p1, p2, p3, p4, p5, p6))
self.image_paths.append(self.root + fname)
def __getdata__(self):
img_paths = self.image_paths
img_poses = self.image_poses
return img_paths, img_poses
def __len__(self):
return len(self.image_paths)
def data_to_mindrecord(data_path, is_training, mindrecord_file, file_num=1):
"""Create MindRecord file."""
writer = FileWriter(mindrecord_file, file_num)
data = Dataset(data_path, is_training)
image_paths, image_poses = data.__getdata__()
posenet_json = {
"image": {"type": "bytes"},
"image_pose": {"type": "float32", "shape": [-1]}
}
writer.add_schema(posenet_json, "posenet_json")
image_files_num = len(image_paths)
for ind, image_name in enumerate(image_paths):
with open(image_name, 'rb') as f:
image = f.read()
image_pose = np.array(image_poses[ind])
row = {"image": image, "image_pose": image_pose}
if (ind + 1) % 10 == 0:
print("writing {}/{} into mindrecord".format(ind + 1, image_files_num))
writer.write_raw_data([row])
writer.commit()
def create_posenet_dataset(mindrecord_file, batch_size=1, device_num=1, is_training=True, rank_id=0):
"""Create PoseNet dataset with MindDataset."""
dataset = ds.MindDataset(mindrecord_file, columns_list=["image", "image_pose"],
num_shards=device_num, shard_id=rank_id,
num_parallel_workers=8, shuffle=True)
decode = C.Decode()
dataset = dataset.map(operations=decode, input_columns=["image"])
transforms_list = []
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if is_training:
resized_op = C.Resize([455, 256])
random_crop_op = C.RandomCrop(224)
normlize_op = C.Normalize(mean=mean, std=std)
to_tensor_op = C.HWC2CHW()
transforms_list = [resized_op, random_crop_op, normlize_op, to_tensor_op]
else:
resized_op = C.Resize([455, 224])
center_crop_op = C.CenterCrop(224)
normlize_op = C.Normalize(mean=mean, std=std)
to_tensor_op = C.HWC2CHW()
transforms_list = [resized_op, center_crop_op, normlize_op, to_tensor_op]
dataset = dataset.map(operations=transforms_list, input_columns=['image'])
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset

View File

@ -0,0 +1,89 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define evaluation loss function for network."""
import mindspore.nn as nn
from mindspore.nn.loss.loss import _Loss
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import common_config as config
from src.posenet import PoseNet
class EuclideanDistance(nn.Cell):
"""calculate euclidean distance"""
def __init__(self):
super(EuclideanDistance, self).__init__()
self.sub = P.Sub()
self.mul = P.Mul()
self.reduce_sum = P.ReduceSum()
self.sqrt = P.Sqrt()
def construct(self, predicted, real):
res = self.sub(predicted, real)
res = self.mul(res, res)
res = self.reduce_sum(res, 0)
res = self.sqrt(res)
res = self.mul(res, res)
res = self.reduce_sum(res, 0)
res = self.sqrt(res)
return res
class PoseLoss(_Loss):
"""define loss function"""
def __init__(self, w1_x, w2_x, w3_x, w1_q, w2_q, w3_q):
super(PoseLoss, self).__init__()
self.w1_x = w1_x
self.w2_x = w2_x
self.w3_x = w3_x
self.w1_q = w1_q
self.w2_q = w2_q
self.w3_q = w3_q
self.ed = EuclideanDistance()
def construct(self, p1_x, p1_q, p2_x, p2_q, p3_x, p3_q, poseGT):
"""construct"""
pose_x = poseGT[:, 0:3]
pose_q = poseGT[:, 3:]
l1_x = self.ed(pose_x, p1_x) * self.w1_x
l1_q = self.ed(pose_q, p1_q) * self.w1_q
l2_x = self.ed(pose_x, p2_x) * self.w2_x
l2_q = self.ed(pose_q, p2_q) * self.w2_q
l3_x = self.ed(pose_x, p3_x) * self.w3_x
l3_q = self.ed(pose_q, p3_q) * self.w3_q
loss = l1_x + l1_q + l2_x + l2_q + l3_x + l3_q
return loss
class PosenetWithLoss(nn.Cell):
"""net with loss, and do pre_trained"""
def __init__(self, pre_trained=False):
super(PosenetWithLoss, self).__init__()
net = PoseNet()
if pre_trained:
param_dict = load_checkpoint(config.pre_trained_file)
load_param_into_net(net, param_dict)
self.network = net
self.loss = PoseLoss(3.0, 3.0, 10.0, 150, 150, 500)
self.cast = P.Cast()
def construct(self, data, poseGT):
p1_x, p1_q, p2_x, p2_q, p3_x, p3_q = self.network(data)
loss = self.loss(p1_x, p1_q, p2_x, p2_q, p3_x, p3_q, poseGT)
return self.cast(loss, mstype.float32)

View File

@ -0,0 +1,174 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PoseNet"""
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
def weight_variable():
"""Weight variable."""
return TruncatedNormal(0.02)
class Conv2dBlock(nn.Cell):
"""
Basic convolutional block
Args:
in_channles (int): Input channel.
out_channels (int): Output channel.
kernel_size (int): Input kernel size. Default: 1
stride (int): Stride size for the first convolutional layer. Default: 1.
padding (int): Implicit paddings on both sides of the input. Default: 0.
pad_mode (str): Padding mode. Optional values are "same", "valid", "pad". Default: "same".
Returns:
Tensor, output tensor.
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode="same"):
super(Conv2dBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_variable())
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Inception(nn.Cell):
"""
Inception Block
"""
def __init__(self, in_channels, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
super(Inception, self).__init__()
self.b1 = Conv2dBlock(in_channels, n1x1, kernel_size=1)
self.b2 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red, kernel_size=1),
Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)])
# kernel_size = 3: depend on googlenet
self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1),
Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)])
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode="same")
self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1)
self.concat = P.Concat(axis=1)
def construct(self, x):
"""construct"""
branch1 = self.b1(x)
branch2 = self.b2(x)
branch3 = self.b3(x)
cell = self.maxpool(x)
branch4 = self.b4(cell)
return self.concat((branch1, branch2, branch3, branch4))
class PoseNet(nn.Cell):
"""
PoseNet architecture
"""
def __init__(self):
super(PoseNet, self).__init__()
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.conv2 = Conv2dBlock(64, 64, kernel_size=1)
self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0)
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.block3a = Inception(192, 64, 96, 128, 16, 32, 32)
self.block3b = Inception(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.block4a = Inception(480, 192, 96, 208, 16, 48, 64)
self.block4b = Inception(512, 160, 112, 224, 24, 64, 64)
self.block4c = Inception(512, 128, 128, 256, 24, 64, 64)
self.block4d = Inception(512, 112, 144, 288, 32, 64, 64)
self.block4e = Inception(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
self.avgpool5x5 = nn.AvgPool2d(kernel_size=5, stride=3, pad_mode="valid")
self.conv1x1_1 = Conv2dBlock(512, 128, kernel_size=1, stride=1)
self.conv1x1_2 = Conv2dBlock(528, 128, kernel_size=1, stride=1)
self.fc2048 = nn.Dense(2048, 1024)
self.relu = nn.ReLU()
self.dropout7 = nn.Dropout(0.7)
self.cls_fc_pose_xyz_1024 = nn.Dense(1024, 512)
self.cls_fc_pose_xyz_512 = nn.Dense(512, 3)
self.cls_fc_pose_wpqr_1024 = nn.Dense(1024, 4)
self.avgpool7x7 = nn.AvgPool2d(kernel_size=7, stride=1, pad_mode="valid")
self.flatten = nn.Flatten()
self.fc = nn.Dense(1024, 2048)
self.dropout5 = nn.Dropout(0.5)
self.cls_fc_pose_xyz = nn.Dense(2048, 3)
self.cls_fc_pose_wpqr = nn.Dense(2048, 4)
self.print = P.Print()
def construct(self, x):
"""construct"""
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.maxpool2(x)
x = self.block3a(x)
x = self.block3b(x)
x = self.maxpool3(x)
x = self.block4a(x)
cls1 = self.avgpool5x5(x)
cls1 = self.conv1x1_1(cls1)
cls1 = self.relu(cls1)
cls1 = self.flatten(cls1)
cls1 = self.fc2048(cls1)
cls1 = self.relu(cls1)
cls1 = self.dropout7(cls1)
cls1_fc_pose_xyz = self.cls_fc_pose_xyz_1024(cls1)
cls1_fc_pose_xyz = self.cls_fc_pose_xyz_512(cls1_fc_pose_xyz)
cls1_fc_pose_wpqr = self.cls_fc_pose_wpqr_1024(cls1)
x = self.block4b(x)
x = self.block4c(x)
x = self.block4d(x)
cls2 = self.avgpool5x5(x)
cls2 = self.conv1x1_2(cls2)
cls2 = self.relu(cls2)
cls2 = self.flatten(cls2)
cls2 = self.fc2048(cls2)
cls2 = self.relu(cls2)
cls2 = self.dropout7(cls2)
cls2_fc_pose_xyz = self.cls_fc_pose_xyz_1024(cls2)
cls2_fc_pose_xyz = self.cls_fc_pose_xyz_512(cls2_fc_pose_xyz)
cls2_fc_pose_wpqr = self.cls_fc_pose_wpqr_1024(cls2)
x = self.block4e(x)
x = self.maxpool4(x)
x = self.block5a(x)
x = self.block5b(x)
cls3 = self.dropout5(x)
cls3 = self.avgpool7x7(cls3)
cls3 = self.flatten(cls3)
cls3 = self.fc(cls3)
cls3 = self.relu(cls3)
cls3 = self.dropout5(cls3)
cls3_fc_pose_xyz = self.cls_fc_pose_xyz(cls3)
cls3_fc_pose_wpqr = self.cls_fc_pose_wpqr(cls3)
return cls1_fc_pose_xyz, cls1_fc_pose_wpqr, \
cls2_fc_pose_xyz, cls2_fc_pose_wpqr, \
cls3_fc_pose_xyz, cls3_fc_pose_wpqr

View File

@ -0,0 +1,148 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train posenet"""
import ast
import argparse
import os
import time
from mindspore import context
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
from mindspore.nn import Adagrad
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from src.config import common_config, KingsCollege, StMarysChurch
from src.dataset import data_to_mindrecord, create_posenet_dataset
from src.loss import PosenetWithLoss
set_seed(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Posenet train.')
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
help="Run distribute, default is false.")
parser.add_argument('--device_id', type=int, default=None,
help='device id of GPU or Ascend. (Default: None)')
parser.add_argument('--dataset', type=str, default='KingsCollege',
choices=['KingsCollege', 'StMarysChurch'],
help='Name of dataset.')
parser.add_argument('--device_num', type=int, default=1, help='Number of device.')
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=False, help='Train in Modelarts.')
parser.add_argument('--data_url', default=None, help='Location of data.')
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
args_opt = parser.parse_args()
cfg = common_config
if args_opt.dataset == "KingsCollege":
dataset_cfg = KingsCollege
elif args_opt.dataset == "StMarysChurch":
dataset_cfg = StMarysChurch
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
if args_opt.run_distribute:
if device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
auto_parallel_search_mode="recursive_programming")
init()
elif device_target == "GPU":
init()
context.set_auto_parallel_context(device_num=args_opt.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
auto_parallel_search_mode="recursive_programming")
else:
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
else:
context.set_context(device_id=cfg.device_id)
train_dataset_path = dataset_cfg.dataset_path
if args_opt.is_modelarts:
import moxing as mox
mox.file.copy_parallel(src_url=args_opt.data_url,
dst_url='/cache/dataset_train/device_' + os.getenv('DEVICE_ID'))
train_dataset_path = '/cache/dataset_train/device_' + os.getenv('DEVICE_ID') + '/'
# It will generate train mindrecord file in cfg.mindrecord_dir,
# and the file name is "dataset_cfg.name + _posenet_train.mindrecord".
prefix = "_posenet_train.mindrecord"
mindrecord_dir = dataset_cfg.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, dataset_cfg.name + prefix)
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
print("Create mindrecord for train.")
data_to_mindrecord(train_dataset_path, True, mindrecord_file)
print("Create mindrecord done, at {}".format(mindrecord_dir))
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
dataset = create_posenet_dataset(mindrecord_file, batch_size=dataset_cfg.batch_size,
device_num=args_opt.device_num, is_training=True)
step_per_epoch = dataset.get_dataset_size()
net_with_loss = PosenetWithLoss(cfg.pre_trained)
opt = Adagrad(params=net_with_loss.trainable_params(),
learning_rate=dataset_cfg.lr_init,
weight_decay=dataset_cfg.weight_decay)
model = Model(net_with_loss, optimizer=opt)
time_cb = TimeMonitor(data_size=step_per_epoch)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if cfg.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_epochs * step_per_epoch,
keep_checkpoint_max=cfg.keep_checkpoint_max)
if args_opt.is_modelarts:
save_checkpoint_path = '/cache/train_output/checkpoint'
if args_opt.device_num == 1:
ckpt_cb = ModelCheckpoint(prefix='train_posenet_' + args_opt.dataset,
directory=save_checkpoint_path,
config=config_ck)
cb += [ckpt_cb]
if args_opt.device_num > 1 and get_rank() % 8 == 0:
ckpt_cb = ModelCheckpoint(prefix='train_posenet_' + args_opt.dataset,
directory=save_checkpoint_path,
config=config_ck)
cb += [ckpt_cb]
else:
save_checkpoint_path = cfg.checkpoint_dir
if not os.path.isdir(save_checkpoint_path):
os.makedirs(save_checkpoint_path)
if args_opt.device_num == 1:
ckpt_cb = ModelCheckpoint(prefix='train_posenet_' + args_opt.dataset,
directory=save_checkpoint_path,
config=config_ck)
cb += [ckpt_cb]
if args_opt.device_num > 1 and get_rank() % 8 == 0:
ckpt_cb = ModelCheckpoint(prefix='train_posenet_' + args_opt.dataset,
directory=save_checkpoint_path,
config=config_ck)
cb += [ckpt_cb]
epoch_size = cfg.max_steps // args_opt.device_num // step_per_epoch
model.train(epoch_size, dataset, callbacks=cb)
print("Train success!")
if args_opt.is_modelarts:
mox.file.copy_parallel(src_url='/cache/train_output', dst_url=args_opt.train_url)