forked from mindspore-Ecosystem/mindspore
!16674 [crowdfunding] Add new model: posenet(GPU)
Merge pull request !16674 from dlliu123/posenet_gpu-master
This commit is contained in:
commit
94443f8e8c
|
@ -0,0 +1,366 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [PoseNet描述](#posenet描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [单机训练](#单机训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [KingsCollege上的PoseNet](#KingsCollege上的PoseNet)
|
||||
- [StMarysChurch上的PoseNet](#StMarysChurch上的PoseNet)
|
||||
- [推理性能](#推理性能)
|
||||
- [KingsCollege上的PoseNet](#KingsCollege上的PoseNet)
|
||||
- [StMarysChurch上的PoseNet](#StMarysChurch上的PoseNet)
|
||||
- [使用流程](#使用流程)
|
||||
- [推理](#推理)
|
||||
- [继续训练预训练模型](#继续训练预训练模型)
|
||||
- [迁移学习](#迁移学习)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# PoseNet描述
|
||||
|
||||
PoseNet是剑桥大学提出的一种鲁棒、实时的6DOF(单目六自由度)重定位系统。该系统训练一个卷积神经网络,可以端到端的从RGB图像中回归出6DOF姿态,而不需要其它额外的处理。该算法以每帧5ms的处理速度实时进行室内外场景的位姿估计。该网络模型包含23个卷积层,可以用来解决复杂的图像平面回归问题。这可以通过利用从大规模分类数据中进行的迁移学习来实现。PoseNet利用高层特征进行图像定位,作者证明对于光照变化、运动模糊以及传统SIFT注册失败的案例具有较好的鲁棒性。此外,作者展示了模型推广到其他场景的扩展性以及小样本上进行姿态回归的能力。
|
||||
|
||||
[论文](https://arxiv.org/abs/1505.07427):Kendall A, Grimes M, Cipolla R. "PoseNet: A convolutional network for real-time 6-dof camera relocalization."*In IEEE International Conference on Computer Vision (pp. 2938–2946), 2015.
|
||||
|
||||
# 模型架构
|
||||
|
||||
基本骨架模型采用GoogLeNet,该模型包括22个卷积层和3个分类分支(其中2个分类分支在测试时将进行丢弃)。改进包括3个小点:移除softmax层并新增具有7个神经元的全连接回归层(用于回归位姿);在全连接回归层前插入神经元数为2048的特征向量层;测试时,回归出的四元数需进行单位化。输入数据均resize到224x224。
|
||||
|
||||
# 数据集
|
||||
|
||||
[KingsCollege](<http://mi.eng.cam.ac.uk/projects/relocalisation/#dataset>)
|
||||
|
||||
- 数据集大小:5.73G,含视频videos
|
||||
- 训练集:2.9G,共1220张图像(seq1, seq4, seq5, seq6, seq8)
|
||||
- 测试集:852M,共342张图像(seq2, seq3, seq7)
|
||||
- 数据格式:txt文件(image_url + label)
|
||||
- 注:数据将在src/dataset.py中处理。
|
||||
|
||||
[StMarysChurch](<http://mi.eng.cam.ac.uk/projects/relocalisation/#dataset>)
|
||||
|
||||
- 数据集大小:5.04G,含视频videos
|
||||
- 训练集:3.5G,共1487张图像(seq1, seq2, seq4, seq5, seq7, seq8, seq9, seq10, seq11, seq12, seq14)
|
||||
- 测试集:1.34G,共530张图像(seq3, seq5, seq13)
|
||||
- 数据格式:txt文件(image_url + label)
|
||||
- 注:数据将在src/dataset.py中处理。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend/GPU)
|
||||
- 使用Ascend来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```python
|
||||
# 运行单机训练示例
|
||||
sh run_standalone_train.sh [DATASET_NAME] [DEVICE_ID]
|
||||
|
||||
# 运行分布式训练示例
|
||||
sh run_distribute_train.sh [DATASET_NAME] [RANK_SIZE]
|
||||
|
||||
# 运行评估示例
|
||||
sh run_eval.sh [DEVICE_ID] [DATASET_NAME] [CKPT_PATH]
|
||||
```
|
||||
|
||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
|
||||
|
||||
请遵循以下链接中的说明:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
为了在GPU处理器环境运行,请将配置文件src/config.py中的`device_target`从`Ascend`改为`GPU`
|
||||
|
||||
```python
|
||||
# 运行单机训练示例
|
||||
sh run_standalone_train_gpu.sh [DATASET_NAME] [DEVICE_ID]
|
||||
|
||||
# 运行分布式训练示例
|
||||
sh run_distribute_train_gpu.sh [DATASET_NAME] [RANK_SIZE]
|
||||
|
||||
# 运行评估示例
|
||||
sh run_eval_gpu.sh [DEVICE_ID] [DATASET_NAME] [CKPT_PATH]
|
||||
```
|
||||
|
||||
默认使用KingsCollege数据集。您也可以将`$dataset_name`传入脚本,以便选择其他数据集。如需查看更多详情,请参考指定脚本。
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型相关说明
|
||||
├── posenet
|
||||
├── README.md // posenet相关说明
|
||||
├── scripts
|
||||
│ ├──run_standalone_train.sh // 单机到Ascend处理器的shell脚本
|
||||
│ ├──run_distribute_train.sh // 分布式到Ascend处理器的shell脚本
|
||||
│ ├──run_eval.sh // Ascend评估的shell脚本
|
||||
│ ├──run_standalone_train_gpu.sh // 单机到GPU的shell脚本
|
||||
│ ├──run_distribute_train_gpu.sh // 分布式到GPU的shell脚本
|
||||
│ ├──run_eval_gpu.sh // GPU评估的shell脚本
|
||||
├── src
|
||||
│ ├──dataset.py // 数据集转换成mindrecord格式,创建数据集及数据预处理
|
||||
│ ├──posenet.py // posenet架构
|
||||
│ ├──loss.py // posenet的损失函数定义
|
||||
│ ├──config.py // 参数配置
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 评估脚本
|
||||
├── export.py // 将checkpoint文件导出到mindir下
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 配置PoseNet和2种数据集。
|
||||
|
||||
```python
|
||||
# common_config
|
||||
'device_target': 'GPU', # 运行设备Ascend/GPU
|
||||
'device_id': 0, # 用于训练或评估数据集的设备ID使用run_distribute_train.sh进行分布式训练时可以忽略
|
||||
'pre_trained': True, # 是否基于预训练模型训练
|
||||
'max_steps': 30000, # 最大迭代次数
|
||||
'save_checkpoint': True, # 是否保存检查点文件
|
||||
'pre_trained_file': '../pre_trained_googlenet_imagenet.ckpt', # checkpoint文件保存的路径
|
||||
'checkpoint_dir': '../checkpoint', # checkpoint文件夹路径
|
||||
'save_checkpoint_epochs': 5, # 保存检查点间隔epoch数
|
||||
'keep_checkpoint_max': 10 # 保存的最大checkpoint文件数
|
||||
|
||||
# dataset_config
|
||||
'batch_size': 75, # 批处理大小
|
||||
'lr_init': 0.001, # 初始学习率
|
||||
'weight_decay': 0.5, # 权重衰减率
|
||||
'name': 'KingsCollege', # 数据集名字
|
||||
'dataset_path': '../KingsCollege/', # 数据集路径
|
||||
'mindrecord_dir': '../MindrecordKingsCollege' # 数据集mindrecord文件路径
|
||||
```
|
||||
|
||||
注:预训练checkpoint文件'pre_trained_file'在ModelArts环境下需调整为对应的绝对路径
|
||||
比如"/home/work/user-job-dir/posenet/pre_trained_googlenet_imagenet.ckpt"
|
||||
|
||||
更多配置细节请参考脚本`config.py`。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 单机训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
sh run_standalone_train.sh [DATASET_NAME] [DEVICE_ID]
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过train.log文件查看结果。
|
||||
|
||||
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式得到损失值:
|
||||
|
||||
```bash
|
||||
epoch:1 step:38, loss is 1722.1506
|
||||
epcoh:2 step:38, loss is 1671.5763
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在checkpoint文件夹下。
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
```bash
|
||||
sh run_standalone_train_gpu.sh [DATASET_NAME] [DEVICE_ID]
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过train.log文件查看结果。
|
||||
|
||||
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式得到损失值:
|
||||
|
||||
```bash
|
||||
epoch:1 step:38, loss is 1722.1506
|
||||
epcoh:2 step:38, loss is 1671.5763
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在checkpoint文件夹下。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
sh run_distribute_train.sh [DATASET_NAME] [RANK_SIZE]
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。您可以通过device[X]/log文件查看结果。采用以下方式达到损失值:
|
||||
|
||||
```bash
|
||||
device0/log:epoch:1 step:38, loss is 1722.1506
|
||||
device0/log:epcoh:2 step:38, loss is 1671.5763
|
||||
...
|
||||
device1/log:epoch:1 step:38, loss is 1722.1506
|
||||
device1/log:epcoh:2 step:38, loss is 1671.5763
|
||||
...
|
||||
```
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
```bash
|
||||
sh run_distribute_train_gpu.sh [DATASET_NAME] [RANK_SIZE]
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。您可以通过device[X]/log文件查看结果。采用以下方式达到损失值:
|
||||
|
||||
```bash
|
||||
device0/log:epoch:1 step:38, loss is 1722.1506
|
||||
device0/log:epcoh:2 step:38, loss is 1671.5763
|
||||
...
|
||||
device1/log:epoch:1 step:38, loss is 1722.1506
|
||||
device1/log:epcoh:2 step:38, loss is 1671.5763
|
||||
...
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend环境运行时评估KingsCollege数据集
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
请将检查点路径设置为相对路径,例如“../checkpoint/train_posenet_KingsCollege-790_38.ckpt”。
|
||||
|
||||
```bash
|
||||
sh run_eval.sh [DEVICE_ID] [DATASET_NAME] [CKPT_PATH]
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval/eval.log文件查看结果。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
Median error 3.56644630432129 m and 3.07089155413442 degrees
|
||||
```
|
||||
|
||||
- 在GPU环境运行时评估KingsCollege数据集
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
请将检查点路径设置为相对路径,例如“../checkpoint/train_posenet_KingsCollege-1875_2.ckpt”。
|
||||
|
||||
```bash
|
||||
sh run_eval_gpu.sh [DEVICE_ID] [DATASET_NAME] [CKPT_PATH]
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval/eval.log文件查看结果。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
Median error 3.56644630432129 m and 3.07089155413442 degrees
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### KingsCollege上的PoseNet
|
||||
|
||||
| 参数 | Ascend | GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- | ---------------------- |
|
||||
| 资源 | Ascend 910 ;CPU 2.60GHz,192核;内存:755G | NV SMX2 V100-32G |
|
||||
| 上传日期 | 2021-03-26 | 2021-05-20 |
|
||||
| MindSpore版本 | 1.1.1-alpha | 1.2.1-alpha |
|
||||
| 数据集 | KingsCollege | KingsCollege |
|
||||
| 训练参数 | max_steps=30000, batch_size=75, lr_init=0.001 | max_steps=30000, batch_size=75, lr_init=0.001 |
|
||||
| 优化器 | Adagrad | Adagrad |
|
||||
| 损失函数 | 自定义损失函数 | 自定义损失函数 |
|
||||
| 输出 | 距离、角度 | 距离、角度 |
|
||||
| 损失 | 1110.86 | 1110.86 |
|
||||
| 速度 | 单卡:750毫秒/步; 8卡:856毫秒/步 | 8卡:675毫秒/步(不稳定) |
|
||||
| 总时长 | 8卡:75分钟 | 8卡:60分钟 |
|
||||
| 参数(M) | 10.7 | 10.7 |
|
||||
| 微调检查点 | 82.91M (.ckpt文件) | 82.91M (.ckpt文件) |
|
||||
| 推理模型 | 41.66M (.mindir文件) | 41.66M (.mindir文件) |
|
||||
| 脚本 | [posenet脚本](https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/research/cv/posenet) | [posenet脚本](https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/master/cv/posenet) |
|
||||
|
||||
#### StMarysChurch上的PoseNet
|
||||
|
||||
| 参数 | Ascend | GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- | ---------------------- |
|
||||
| 资源 | Ascend 910 ;CPU 2.60GHz,192核;内存:755G | NV SMX2 V100-32G |
|
||||
| 上传日期 | 2021-03-26 | 2021-05-20 |
|
||||
| MindSpore版本 | 1.1.1-alpha | 1.2.1-alpha |
|
||||
| 数据集 | StMarysChurch | StMarysChurch |
|
||||
| 训练参数 | max_steps=30000, batch_size=75, lr_init=0.001 | max_steps=30000, batch_size=75, lr_init=0.001 |
|
||||
| 优化器 | Adagrad | Adagrad |
|
||||
| 损失函数 | 自定义损失函数 | 自定义损失函数 |
|
||||
| 输出 | 距离、角度 | 距离、角度 |
|
||||
| 损失 | 1077.86 | 1023.67 |
|
||||
| 速度 | 单卡:800毫秒/步; 8卡:1122毫秒/步 | 8卡:850毫秒/步(不稳定) |
|
||||
| 总时长 | 单卡:6小时40分钟; 8卡:85分钟 | 8卡:80分钟 |
|
||||
| 参数(M) | 10.7 | 10.7 |
|
||||
| 微调检查点 | 82.91M (.ckpt文件) | 82.91M (.ckpt文件) |
|
||||
| 推理模型 | 41.66M (.mindir文件) | 41.66M (.mindir文件) |
|
||||
| 脚本 | [posenet脚本](https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/research/cv/posenet) | [posenet脚本](https://gitee.com/mindspore/mindspore/tree/r1.1/model_zoo/master/cv/posenet) |
|
||||
|
||||
### 推理性能
|
||||
|
||||
#### KingsCollege上的PoseNet
|
||||
|
||||
| 参数 | Ascend | GPU |
|
||||
| ------------------- | --------------------------- | --------------------------- |
|
||||
| 资源 | Ascend 910 | GPU |
|
||||
| 上传日期 | 2021-03-26 | 2021-05-20 |
|
||||
| MindSpore 版本 | 1.1.1-alpha | 1.2.1-alpha |
|
||||
| 数据集 | KingsCollege | KingsCollege |
|
||||
| batch_size | 1 | 1 |
|
||||
| 输出 | 距离、角度 |距离、角度 |
|
||||
| 准确性 | 单卡: 1.928米 4.24度; 8卡:1.89米 4.31度 | 8卡:1.80米 3.68度 |
|
||||
| 推理模型 | 41.66M (.mindir文件) | 41.66M (.mindir文件) |
|
||||
|
||||
#### StMarysChurch上的PoseNet
|
||||
|
||||
| 参数 | Ascend | GPU |
|
||||
| ------------------- | --------------------------- | --------------------------- |
|
||||
| 资源 | Ascend 910 | GPU |
|
||||
| 上传日期 | 2021-03-26 | 2021-05-20 |
|
||||
| MindSpore 版本 | 1.1.1-alpha | 1.2.1-alpha |
|
||||
| 数据集 | StMarysChurch | StMarysChurch |
|
||||
| batch_size | 1 | 1 |
|
||||
| 输出 | 距离、角度 | 距离、角度 |
|
||||
| 准确性 | 单卡: 1.884米 7.20度; 8卡:1.90米 6.23度 | 8卡:1.89米 6.24度 |
|
||||
| 推理模型 | 41.66M (.mindir文件) | 41.66M (.mindir文件) |
|
||||
|
||||
## 迁移学习
|
||||
|
||||
在Imagenet数据集上预训练GoogLeNet,迁移至PoseNet。
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
在train.py中,我们设置了随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test posenet"""
|
||||
import ast
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import common_config, KingsCollege, StMarysChurch
|
||||
from src.posenet import PoseNet
|
||||
from src.dataset import data_to_mindrecord, create_posenet_dataset
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='posenet eval')
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
parser.add_argument('--dataset', type=str, default='KingsCollege',
|
||||
choices=['KingsCollege', 'StMarysChurch'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--ckpt_url', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=False, help='Train in Modelarts.')
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
cfg = common_config
|
||||
if args_opt.dataset == "KingsCollege":
|
||||
dataset_cfg = KingsCollege
|
||||
elif args_opt.dataset == "StMarysChurch":
|
||||
dataset_cfg = StMarysChurch
|
||||
|
||||
device_target = cfg.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
|
||||
if args_opt.device_id is not None:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
else:
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
eval_dataset_path = dataset_cfg.dataset_path
|
||||
if args_opt.is_modelarts:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=args_opt.data_url,
|
||||
dst_url='/cache/dataset_eval/device_' + os.getenv('DEVICE_ID'))
|
||||
eval_dataset_path = '/cache/dataset_eval/device_' + os.getenv('DEVICE_ID') + '/'
|
||||
|
||||
# It will generate eval mindrecord file in cfg.mindrecord_dir,
|
||||
# and the file name is "dataset_cfg.name + _posenet_eval.mindrecord".
|
||||
prefix = "_posenet_eval.mindrecord"
|
||||
mindrecord_dir = dataset_cfg.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, dataset_cfg.name + prefix)
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
print("Create mindrecord for eval.")
|
||||
data_to_mindrecord(eval_dataset_path, False, mindrecord_file)
|
||||
print("Create mindrecord done, at {}".format(mindrecord_dir))
|
||||
while not os.path.exists(mindrecord_file + ".db"):
|
||||
time.sleep(5)
|
||||
|
||||
dataset = create_posenet_dataset(mindrecord_file, batch_size=1, device_num=1, is_training=False)
|
||||
data_num = dataset.get_dataset_size()
|
||||
|
||||
net = PoseNet()
|
||||
param_dict = load_checkpoint(args_opt.ckpt_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
print("Processing, please wait a moment.")
|
||||
results = np.zeros((data_num, 2))
|
||||
for step, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
image = item['image']
|
||||
poses = item['image_pose']
|
||||
|
||||
pose_x = np.squeeze(poses[:, 0:3])
|
||||
pose_q = np.squeeze(poses[:, 3:])
|
||||
p1_x, p1_q, p2_x, p2_q, p3_x, p3_q = net(Tensor(image))
|
||||
predicted_x = p3_x.asnumpy()
|
||||
predicted_q = p3_q.asnumpy()
|
||||
|
||||
q1 = pose_q / np.linalg.norm(pose_q)
|
||||
q2 = predicted_q / np.linalg.norm(predicted_q)
|
||||
d = abs(np.sum(np.multiply(q1, q2)))
|
||||
theta = 2 * np.arccos(d) * 180 / math.pi
|
||||
error_x = np.linalg.norm(pose_x - predicted_x)
|
||||
|
||||
results[step, :] = [error_x, theta]
|
||||
print('Iteration: ', step, ', Error XYZ (m): ', error_x, ', Error Q (degrees): ', theta)
|
||||
|
||||
median_result = np.median(results, axis=0)
|
||||
print('Median error ', median_result[0], 'm and ', median_result[1], 'degrees.')
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as ms
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.config import common_config, KingsCollege, StMarysChurch
|
||||
from src.posenet import PoseNet
|
||||
|
||||
parser = argparse.ArgumentParser(description='PoseNet')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument('--dataset', type=str, default='KingsCollege',
|
||||
choices=['KingsCollege', 'StMarysChurch'],
|
||||
help='Name of dataset.')
|
||||
parser.add_argument("--ckpt_url", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="posenet", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["MINDIR"], default='MINDIR', help='file format')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
cfg = common_config
|
||||
if args.dataset == 'KingsCollege':
|
||||
dataset_cfg = KingsCollege
|
||||
elif args.dataset == 'StMarysChurch':
|
||||
dataset_cfg = StMarysChurch
|
||||
|
||||
net = PoseNet()
|
||||
|
||||
assert cfg.checkpoint_dir is not None, "cfg.checkpoint_dir is None."
|
||||
param_dict = load_checkpoint(args.ckpt_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.ones([dataset_cfg.batch_size, 3, 224, 224]), ms.float32)
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,76 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATASET_NAME RANK_SIZE"
|
||||
echo "For example: bash run_distribute.sh dataset_name 8"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
|
||||
DATASET_NAME=$1
|
||||
RANK_SIZE=$2
|
||||
export DATASET_NAME
|
||||
export RANK_SIZE
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
for((i=0;i<${RANK_SIZE};i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cd ./device$i
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ./train.py ./device$i
|
||||
cp ./src/*.py ./device$i/src
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
python train.py --run_distribute True --dataset $1 --device_num $2 --is_modelarts False > train$i.log 2>&1 &
|
||||
echo "$i finish"
|
||||
cd ../
|
||||
done
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "training success"
|
||||
else
|
||||
echo "training failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distribute_train.sh DATA_PATH RANK_SIZE"
|
||||
echo "For example: bash run_distribute_train.sh /path/dataset 8"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
|
||||
export DEVICE_NUM=$1
|
||||
export RANK_SIZE=$1
|
||||
export DATASET_NAME=$2
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf ./train_parallel
|
||||
mkdir ./train_parallel
|
||||
cd ./train_parallel
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ../*.py ./train_parallel
|
||||
cp ../src/*.py ./train_parallel/src
|
||||
cd ./train_parallel
|
||||
env > env.log
|
||||
echo "start training"
|
||||
mpirun -n $1 --allow-run-as-root \
|
||||
python train.py --device_num $1 --dataset $2 --is_modelarts False --run_distribute True > train.log 2>&1 &
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DEVICE_ID DATASET_NAME CKPT_PATH"
|
||||
echo "For example: bash run_eval.sh device_id dataset ckpt_url"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
|
||||
DEVICE_ID=$1
|
||||
DATASET_NAME=$2
|
||||
CKPT_PATH=$3
|
||||
export DEVICE_ID
|
||||
export DATASET_NAME
|
||||
export CKPT_PATH
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf eval/
|
||||
mkdir eval
|
||||
cd ./eval
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ./eval.py ./eval
|
||||
cp ./src/*.py ./eval/src
|
||||
cd ./eval
|
||||
|
||||
env > env0.log
|
||||
echo "Eval begin."
|
||||
python eval.py --device_id $1 --dataset $2 --ckpt_url $3 --is_modelarts False > ./eval.log 2>&1 &
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "evaling success"
|
||||
else
|
||||
echo "evaling failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,58 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DEVICE_ID DATASET_NAME CKPT_PATH"
|
||||
echo "For example: bash run_eval.sh device_id dataset ckpt_url"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
|
||||
DEVICE_ID=$1
|
||||
DATASET_NAME=$2
|
||||
CKPT_PATH=$3
|
||||
export DEVICE_ID
|
||||
export DATASET_NAME
|
||||
export CKPT_PATH
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf eval/
|
||||
mkdir eval
|
||||
cd ./eval
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ./eval.py ./eval
|
||||
cp ./src/*.py ./eval/src
|
||||
cd ./eval
|
||||
|
||||
env > env0.log
|
||||
echo "Eval begin."
|
||||
python eval.py --device_id $1 --dataset $2 --ckpt_url $3 --is_modelarts False > ./eval.log 2>&1 &
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "evaling success"
|
||||
else
|
||||
echo "evaling failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,57 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATASET_NAME DEVICE_ID"
|
||||
echo "For example: bash run_standalone_train.sh dataset_name device_id"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
set -e
|
||||
|
||||
DATASET_NAME=$1
|
||||
DEVICE_ID=$2
|
||||
export DATASET_NAME
|
||||
export DEVICE_ID
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf train
|
||||
mkdir train
|
||||
cd ./train
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ./train.py ./train
|
||||
cp ./src/*.py ./train/src
|
||||
cd ./train
|
||||
|
||||
env > env0.log
|
||||
|
||||
echo "Standalone train begin."
|
||||
python train.py --run_distribute False --device_id $2 --dataset $1 --device_num 1 --is_modelarts False > ./train_alone.log 2>&1 &
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "training success"
|
||||
else
|
||||
echo "training failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,57 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATASET_NAME DEVICE_ID"
|
||||
echo "For example: bash run_standalone_train.sh dataset_name device_id"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
set -e
|
||||
|
||||
DATASET_NAME=$1
|
||||
DEVICE_ID=$2
|
||||
export DATASET_NAME
|
||||
export DEVICE_ID
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf train
|
||||
mkdir train
|
||||
cd ./train
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ./train.py ./train
|
||||
cp ./src/*.py ./train/src
|
||||
cd ./train
|
||||
|
||||
env > env0.log
|
||||
|
||||
echo "Standalone train begin."
|
||||
python train.py --run_distribute False --device_id $2 --dataset $1 --device_num 1 --is_modelarts False > ./train_alone.log 2>&1 &
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "training success"
|
||||
else
|
||||
echo "training failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
common_config = edict({
|
||||
'device_target': 'GPU',
|
||||
'device_id': 0,
|
||||
'pre_trained': True,
|
||||
'max_steps': 30000,
|
||||
'save_checkpoint': True,
|
||||
# 'pre_trained_file': '/home/work/user-job-dir/posenet/pre_trained_googlenet_imagenet.ckpt',
|
||||
'pre_trained_file': '../pre_trained_googlenet_imagenet.ckpt',
|
||||
'checkpoint_dir': '../checkpoint',
|
||||
'save_checkpoint_epochs': 5,
|
||||
'keep_checkpoint_max': 10
|
||||
})
|
||||
|
||||
KingsCollege = edict({
|
||||
'batch_size': 75,
|
||||
'lr_init': 0.001,
|
||||
'weight_decay': 0.5,
|
||||
'name': 'KingsCollege',
|
||||
'dataset_path': '../KingsCollege/',
|
||||
'mindrecord_dir': '../MindrecordKingsCollege'
|
||||
})
|
||||
|
||||
StMarysChurch = edict({
|
||||
'batch_size': 75,
|
||||
'lr_init': 0.001,
|
||||
'weight_decay': 0.5,
|
||||
'name': 'StMarysChurch',
|
||||
'dataset_path': '../StMarysChurch/',
|
||||
'mindrecord_dir': '../MindrecordStMarysChurch'
|
||||
})
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data operations, will be used in train.py and eval.py"""
|
||||
import os
|
||||
import numpy as np
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
|
||||
class Dataset:
|
||||
"""dataset read"""
|
||||
def __init__(self, root, is_training=True):
|
||||
self.root = os.path.expanduser(root)
|
||||
self.train = is_training
|
||||
self.image_poses = []
|
||||
self.image_paths = []
|
||||
if self.train:
|
||||
txt_file = self.root + 'dataset_train.txt'
|
||||
else:
|
||||
txt_file = self.root + 'dataset_test.txt'
|
||||
|
||||
with open(txt_file, 'r') as f:
|
||||
next(f)
|
||||
next(f)
|
||||
next(f)
|
||||
for line in f:
|
||||
fname, p0, p1, p2, p3, p4, p5, p6 = line.split()
|
||||
p0 = float(p0)
|
||||
p1 = float(p1)
|
||||
p2 = float(p2)
|
||||
p3 = float(p3)
|
||||
p4 = float(p4)
|
||||
p5 = float(p5)
|
||||
p6 = float(p6)
|
||||
self.image_poses.append((p0, p1, p2, p3, p4, p5, p6))
|
||||
self.image_paths.append(self.root + fname)
|
||||
|
||||
def __getdata__(self):
|
||||
img_paths = self.image_paths
|
||||
img_poses = self.image_poses
|
||||
return img_paths, img_poses
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def data_to_mindrecord(data_path, is_training, mindrecord_file, file_num=1):
|
||||
"""Create MindRecord file."""
|
||||
writer = FileWriter(mindrecord_file, file_num)
|
||||
data = Dataset(data_path, is_training)
|
||||
image_paths, image_poses = data.__getdata__()
|
||||
posenet_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"image_pose": {"type": "float32", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(posenet_json, "posenet_json")
|
||||
|
||||
image_files_num = len(image_paths)
|
||||
for ind, image_name in enumerate(image_paths):
|
||||
with open(image_name, 'rb') as f:
|
||||
image = f.read()
|
||||
image_pose = np.array(image_poses[ind])
|
||||
row = {"image": image, "image_pose": image_pose}
|
||||
if (ind + 1) % 10 == 0:
|
||||
print("writing {}/{} into mindrecord".format(ind + 1, image_files_num))
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
def create_posenet_dataset(mindrecord_file, batch_size=1, device_num=1, is_training=True, rank_id=0):
|
||||
"""Create PoseNet dataset with MindDataset."""
|
||||
dataset = ds.MindDataset(mindrecord_file, columns_list=["image", "image_pose"],
|
||||
num_shards=device_num, shard_id=rank_id,
|
||||
num_parallel_workers=8, shuffle=True)
|
||||
|
||||
decode = C.Decode()
|
||||
dataset = dataset.map(operations=decode, input_columns=["image"])
|
||||
transforms_list = []
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
if is_training:
|
||||
resized_op = C.Resize([455, 256])
|
||||
random_crop_op = C.RandomCrop(224)
|
||||
normlize_op = C.Normalize(mean=mean, std=std)
|
||||
to_tensor_op = C.HWC2CHW()
|
||||
transforms_list = [resized_op, random_crop_op, normlize_op, to_tensor_op]
|
||||
else:
|
||||
resized_op = C.Resize([455, 224])
|
||||
center_crop_op = C.CenterCrop(224)
|
||||
normlize_op = C.Normalize(mean=mean, std=std)
|
||||
to_tensor_op = C.HWC2CHW()
|
||||
transforms_list = [resized_op, center_crop_op, normlize_op, to_tensor_op]
|
||||
|
||||
dataset = dataset.map(operations=transforms_list, input_columns=['image'])
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
|
||||
return dataset
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define evaluation loss function for network."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import common_config as config
|
||||
from src.posenet import PoseNet
|
||||
|
||||
class EuclideanDistance(nn.Cell):
|
||||
"""calculate euclidean distance"""
|
||||
def __init__(self):
|
||||
super(EuclideanDistance, self).__init__()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.sqrt = P.Sqrt()
|
||||
|
||||
def construct(self, predicted, real):
|
||||
res = self.sub(predicted, real)
|
||||
res = self.mul(res, res)
|
||||
res = self.reduce_sum(res, 0)
|
||||
res = self.sqrt(res)
|
||||
res = self.mul(res, res)
|
||||
res = self.reduce_sum(res, 0)
|
||||
res = self.sqrt(res)
|
||||
|
||||
return res
|
||||
|
||||
class PoseLoss(_Loss):
|
||||
"""define loss function"""
|
||||
def __init__(self, w1_x, w2_x, w3_x, w1_q, w2_q, w3_q):
|
||||
super(PoseLoss, self).__init__()
|
||||
self.w1_x = w1_x
|
||||
self.w2_x = w2_x
|
||||
self.w3_x = w3_x
|
||||
self.w1_q = w1_q
|
||||
self.w2_q = w2_q
|
||||
self.w3_q = w3_q
|
||||
self.ed = EuclideanDistance()
|
||||
|
||||
def construct(self, p1_x, p1_q, p2_x, p2_q, p3_x, p3_q, poseGT):
|
||||
"""construct"""
|
||||
pose_x = poseGT[:, 0:3]
|
||||
pose_q = poseGT[:, 3:]
|
||||
|
||||
l1_x = self.ed(pose_x, p1_x) * self.w1_x
|
||||
l1_q = self.ed(pose_q, p1_q) * self.w1_q
|
||||
l2_x = self.ed(pose_x, p2_x) * self.w2_x
|
||||
l2_q = self.ed(pose_q, p2_q) * self.w2_q
|
||||
l3_x = self.ed(pose_x, p3_x) * self.w3_x
|
||||
l3_q = self.ed(pose_q, p3_q) * self.w3_q
|
||||
loss = l1_x + l1_q + l2_x + l2_q + l3_x + l3_q
|
||||
|
||||
return loss
|
||||
|
||||
class PosenetWithLoss(nn.Cell):
|
||||
"""net with loss, and do pre_trained"""
|
||||
def __init__(self, pre_trained=False):
|
||||
super(PosenetWithLoss, self).__init__()
|
||||
|
||||
net = PoseNet()
|
||||
if pre_trained:
|
||||
param_dict = load_checkpoint(config.pre_trained_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
self.network = net
|
||||
self.loss = PoseLoss(3.0, 3.0, 10.0, 150, 150, 500)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, data, poseGT):
|
||||
p1_x, p1_q, p2_x, p2_q, p3_x, p3_q = self.network(data)
|
||||
loss = self.loss(p1_x, p1_q, p2_x, p2_q, p3_x, p3_q, poseGT)
|
||||
return self.cast(loss, mstype.float32)
|
|
@ -0,0 +1,174 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""PoseNet"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
def weight_variable():
|
||||
"""Weight variable."""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
class Conv2dBlock(nn.Cell):
|
||||
"""
|
||||
Basic convolutional block
|
||||
Args:
|
||||
in_channles (int): Input channel.
|
||||
out_channels (int): Output channel.
|
||||
kernel_size (int): Input kernel size. Default: 1
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
pad_mode (str): Padding mode. Optional values are "same", "valid", "pad". Default: "same".
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode="same"):
|
||||
super(Conv2dBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, pad_mode=pad_mode, weight_init=weight_variable())
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class Inception(nn.Cell):
|
||||
"""
|
||||
Inception Block
|
||||
"""
|
||||
def __init__(self, in_channels, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
|
||||
super(Inception, self).__init__()
|
||||
self.b1 = Conv2dBlock(in_channels, n1x1, kernel_size=1)
|
||||
self.b2 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red, kernel_size=1),
|
||||
Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)])
|
||||
# kernel_size = 3: depend on googlenet
|
||||
self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1),
|
||||
Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)])
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode="same")
|
||||
self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1)
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
branch1 = self.b1(x)
|
||||
branch2 = self.b2(x)
|
||||
branch3 = self.b3(x)
|
||||
cell = self.maxpool(x)
|
||||
branch4 = self.b4(cell)
|
||||
return self.concat((branch1, branch2, branch3, branch4))
|
||||
|
||||
class PoseNet(nn.Cell):
|
||||
"""
|
||||
PoseNet architecture
|
||||
"""
|
||||
def __init__(self):
|
||||
super(PoseNet, self).__init__()
|
||||
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
self.conv2 = Conv2dBlock(64, 64, kernel_size=1)
|
||||
self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0)
|
||||
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.block3a = Inception(192, 64, 96, 128, 16, 32, 32)
|
||||
self.block3b = Inception(256, 128, 128, 192, 32, 96, 64)
|
||||
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.block4a = Inception(480, 192, 96, 208, 16, 48, 64)
|
||||
self.block4b = Inception(512, 160, 112, 224, 24, 64, 64)
|
||||
self.block4c = Inception(512, 128, 128, 256, 24, 64, 64)
|
||||
self.block4d = Inception(512, 112, 144, 288, 32, 64, 64)
|
||||
self.block4e = Inception(528, 256, 160, 320, 32, 128, 128)
|
||||
self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
|
||||
|
||||
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
|
||||
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
|
||||
|
||||
self.avgpool5x5 = nn.AvgPool2d(kernel_size=5, stride=3, pad_mode="valid")
|
||||
self.conv1x1_1 = Conv2dBlock(512, 128, kernel_size=1, stride=1)
|
||||
self.conv1x1_2 = Conv2dBlock(528, 128, kernel_size=1, stride=1)
|
||||
self.fc2048 = nn.Dense(2048, 1024)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout7 = nn.Dropout(0.7)
|
||||
self.cls_fc_pose_xyz_1024 = nn.Dense(1024, 512)
|
||||
self.cls_fc_pose_xyz_512 = nn.Dense(512, 3)
|
||||
self.cls_fc_pose_wpqr_1024 = nn.Dense(1024, 4)
|
||||
|
||||
self.avgpool7x7 = nn.AvgPool2d(kernel_size=7, stride=1, pad_mode="valid")
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc = nn.Dense(1024, 2048)
|
||||
self.dropout5 = nn.Dropout(0.5)
|
||||
self.cls_fc_pose_xyz = nn.Dense(2048, 3)
|
||||
self.cls_fc_pose_wpqr = nn.Dense(2048, 4)
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.maxpool2(x)
|
||||
|
||||
x = self.block3a(x)
|
||||
x = self.block3b(x)
|
||||
x = self.maxpool3(x)
|
||||
x = self.block4a(x)
|
||||
|
||||
cls1 = self.avgpool5x5(x)
|
||||
cls1 = self.conv1x1_1(cls1)
|
||||
cls1 = self.relu(cls1)
|
||||
cls1 = self.flatten(cls1)
|
||||
cls1 = self.fc2048(cls1)
|
||||
cls1 = self.relu(cls1)
|
||||
cls1 = self.dropout7(cls1)
|
||||
cls1_fc_pose_xyz = self.cls_fc_pose_xyz_1024(cls1)
|
||||
cls1_fc_pose_xyz = self.cls_fc_pose_xyz_512(cls1_fc_pose_xyz)
|
||||
cls1_fc_pose_wpqr = self.cls_fc_pose_wpqr_1024(cls1)
|
||||
|
||||
x = self.block4b(x)
|
||||
x = self.block4c(x)
|
||||
x = self.block4d(x)
|
||||
|
||||
cls2 = self.avgpool5x5(x)
|
||||
cls2 = self.conv1x1_2(cls2)
|
||||
cls2 = self.relu(cls2)
|
||||
cls2 = self.flatten(cls2)
|
||||
cls2 = self.fc2048(cls2)
|
||||
cls2 = self.relu(cls2)
|
||||
cls2 = self.dropout7(cls2)
|
||||
cls2_fc_pose_xyz = self.cls_fc_pose_xyz_1024(cls2)
|
||||
cls2_fc_pose_xyz = self.cls_fc_pose_xyz_512(cls2_fc_pose_xyz)
|
||||
cls2_fc_pose_wpqr = self.cls_fc_pose_wpqr_1024(cls2)
|
||||
|
||||
x = self.block4e(x)
|
||||
x = self.maxpool4(x)
|
||||
x = self.block5a(x)
|
||||
x = self.block5b(x)
|
||||
|
||||
cls3 = self.dropout5(x)
|
||||
cls3 = self.avgpool7x7(cls3)
|
||||
cls3 = self.flatten(cls3)
|
||||
cls3 = self.fc(cls3)
|
||||
cls3 = self.relu(cls3)
|
||||
cls3 = self.dropout5(cls3)
|
||||
cls3_fc_pose_xyz = self.cls_fc_pose_xyz(cls3)
|
||||
cls3_fc_pose_wpqr = self.cls_fc_pose_wpqr(cls3)
|
||||
|
||||
return cls1_fc_pose_xyz, cls1_fc_pose_wpqr, \
|
||||
cls2_fc_pose_xyz, cls2_fc_pose_wpqr, \
|
||||
cls3_fc_pose_xyz, cls3_fc_pose_wpqr
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train posenet"""
|
||||
import ast
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.nn import Adagrad
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
|
||||
from src.config import common_config, KingsCollege, StMarysChurch
|
||||
from src.dataset import data_to_mindrecord, create_posenet_dataset
|
||||
from src.loss import PosenetWithLoss
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Posenet train.')
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is false.")
|
||||
parser.add_argument('--device_id', type=int, default=None,
|
||||
help='device id of GPU or Ascend. (Default: None)')
|
||||
parser.add_argument('--dataset', type=str, default='KingsCollege',
|
||||
choices=['KingsCollege', 'StMarysChurch'],
|
||||
help='Name of dataset.')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Number of device.')
|
||||
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=False, help='Train in Modelarts.')
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
cfg = common_config
|
||||
if args_opt.dataset == "KingsCollege":
|
||||
dataset_cfg = KingsCollege
|
||||
elif args_opt.dataset == "StMarysChurch":
|
||||
dataset_cfg = StMarysChurch
|
||||
|
||||
device_target = cfg.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
if args_opt.run_distribute:
|
||||
if device_target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
auto_parallel_search_mode="recursive_programming")
|
||||
init()
|
||||
elif device_target == "GPU":
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
auto_parallel_search_mode="recursive_programming")
|
||||
else:
|
||||
if args_opt.device_id is not None:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
else:
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
train_dataset_path = dataset_cfg.dataset_path
|
||||
if args_opt.is_modelarts:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=args_opt.data_url,
|
||||
dst_url='/cache/dataset_train/device_' + os.getenv('DEVICE_ID'))
|
||||
train_dataset_path = '/cache/dataset_train/device_' + os.getenv('DEVICE_ID') + '/'
|
||||
|
||||
# It will generate train mindrecord file in cfg.mindrecord_dir,
|
||||
# and the file name is "dataset_cfg.name + _posenet_train.mindrecord".
|
||||
prefix = "_posenet_train.mindrecord"
|
||||
mindrecord_dir = dataset_cfg.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, dataset_cfg.name + prefix)
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
print("Create mindrecord for train.")
|
||||
data_to_mindrecord(train_dataset_path, True, mindrecord_file)
|
||||
print("Create mindrecord done, at {}".format(mindrecord_dir))
|
||||
while not os.path.exists(mindrecord_file + ".db"):
|
||||
time.sleep(5)
|
||||
|
||||
dataset = create_posenet_dataset(mindrecord_file, batch_size=dataset_cfg.batch_size,
|
||||
device_num=args_opt.device_num, is_training=True)
|
||||
step_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
net_with_loss = PosenetWithLoss(cfg.pre_trained)
|
||||
opt = Adagrad(params=net_with_loss.trainable_params(),
|
||||
learning_rate=dataset_cfg.lr_init,
|
||||
weight_decay=dataset_cfg.weight_decay)
|
||||
model = Model(net_with_loss, optimizer=opt)
|
||||
|
||||
time_cb = TimeMonitor(data_size=step_per_epoch)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if cfg.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_epochs * step_per_epoch,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
if args_opt.is_modelarts:
|
||||
save_checkpoint_path = '/cache/train_output/checkpoint'
|
||||
if args_opt.device_num == 1:
|
||||
ckpt_cb = ModelCheckpoint(prefix='train_posenet_' + args_opt.dataset,
|
||||
directory=save_checkpoint_path,
|
||||
config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
if args_opt.device_num > 1 and get_rank() % 8 == 0:
|
||||
ckpt_cb = ModelCheckpoint(prefix='train_posenet_' + args_opt.dataset,
|
||||
directory=save_checkpoint_path,
|
||||
config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
else:
|
||||
save_checkpoint_path = cfg.checkpoint_dir
|
||||
if not os.path.isdir(save_checkpoint_path):
|
||||
os.makedirs(save_checkpoint_path)
|
||||
|
||||
if args_opt.device_num == 1:
|
||||
ckpt_cb = ModelCheckpoint(prefix='train_posenet_' + args_opt.dataset,
|
||||
directory=save_checkpoint_path,
|
||||
config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
if args_opt.device_num > 1 and get_rank() % 8 == 0:
|
||||
ckpt_cb = ModelCheckpoint(prefix='train_posenet_' + args_opt.dataset,
|
||||
directory=save_checkpoint_path,
|
||||
config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
|
||||
epoch_size = cfg.max_steps // args_opt.device_num // step_per_epoch
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
print("Train success!")
|
||||
|
||||
if args_opt.is_modelarts:
|
||||
mox.file.copy_parallel(src_url='/cache/train_output', dst_url=args_opt.train_url)
|
Loading…
Reference in New Issue