forked from mindspore-Ecosystem/mindspore
add SiamRPN
This commit is contained in:
parent
90f37a98c4
commit
31e6936bee
|
@ -0,0 +1,242 @@
|
|||
# 目录
|
||||
|
||||
- [目录](#目录)
|
||||
- [SiamRPN描述](#概述)
|
||||
- [模型架构](#s模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [特性](#特性)
|
||||
- [混合精度](#混合精度)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [训练性能](#训练性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# 概述
|
||||
|
||||
Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN网络组成,它抛弃了传统的多尺度测试和在线跟踪,从而使得跟踪速度非常快。在VOT2015, VOT2016和VOT2017上取得了领先的性能,并且速度能都达到160fps。
|
||||
|
||||
[论文](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf):Bo Li,Junjie Yan,Wei Wu,Zheng Zhu,Xiaolin Hu, High Performance Visual Tracking with Siamese Region Proposal Network[C]// 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2018.
|
||||
|
||||
# 模型架构
|
||||
|
||||
此网络由Siamese Network和Region Proposal Network两部分组成。前者用来提取特征,后者用来产生候选区域。其中,RPN子网络由两个分支组成,一个是用来区分目标和背景的分类分支,另外一个是微调候选区域的回归分支。整个网络实现了端到端的训练。
|
||||
|
||||
# 数据集
|
||||
|
||||
:[VID-youtube-bb](https://pan.baidu.com/s/1QnQEM_jtc3alX8RyZ3i4-g) [VOT2015](https://www.votchallenge.net/vot2015/dataset.html) [VOT2016](https://www.votchallenge.net/vot2016/dataset.html)
|
||||
|
||||
- 百度网盘密码:myq4
|
||||
- 数据集大小:143.8G,共600多万图像
|
||||
- 训练集:141G,共600多万图像
|
||||
- 测试集:2.8G,共60个视频
|
||||
- 数据格式:RGB
|
||||
|
||||
# 特性
|
||||
|
||||
## 混合精度
|
||||
|
||||
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend/GPU)
|
||||
- 使用Ascend/GPU处理器来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/doc/programming_guide/zh-CN/r1.2/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/programming_guide/zh-CN/r1.2/index.html#operator_api)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```python
|
||||
# 运行训练示例
|
||||
bash scripts/run.sh 0 1
|
||||
|
||||
# 运行分布式训练示例
|
||||
bash scripts/run_distribute_train.sh /path/dataset /path/rank_table
|
||||
|
||||
# 运行评估示例
|
||||
bash scripts/run_eval.sh 0 /path/dataset /path/ckpt/siamRPN-50_1417.ckpt eval.json
|
||||
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型相关说明
|
||||
├── research
|
||||
├── cv
|
||||
├── siamRPN
|
||||
├── README_CN.md // googlenet相关说明
|
||||
├── ascend310_infer // 实现310推理源代码
|
||||
├── scripts
|
||||
│ ├──run.sh // 训练脚本
|
||||
├── src
|
||||
│ ├──data_loader.py // 数据集加载处理脚本
|
||||
│ ├──net.py // siamRPN架构
|
||||
│ ├──loss.py // 损失函数
|
||||
│ ├──util.py // 工具脚本
|
||||
│ ├──tracker.py
|
||||
│ ├──generate_anchors.py
|
||||
│ ├──tracker.py
|
||||
│ ├──evaluation.py
|
||||
│ ├──config.py // 参数配置
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 评估脚本
|
||||
├── export_mindir.py // 将checkpoint文件导出到air/mindir
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 训练相关参数。
|
||||
|
||||
```python
|
||||
checkpoint_path = r'/home/data/ckpt' # 模型检查点保存目录
|
||||
pretrain_model = 'alexnet.ckpt' # 预训练模型名称
|
||||
train_path = r'/home/data/ytb_vid_filter' # 训练数据集路径
|
||||
cur_epoch = 0 #当前训练周期
|
||||
max_epoches = 50 #训练最大周期
|
||||
batch_size = 32 #每次训练样本数
|
||||
|
||||
start_lr = 3e-2 #初始训练学习率
|
||||
end_lr = 1e-7 #结束学习率
|
||||
momentum = 0.9 #动量
|
||||
weight_decay = 0.0005 # 权重衰减值
|
||||
check = True #是否加载模型
|
||||
```
|
||||
|
||||
更多配置细节请参考脚本`config.py`。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
python train.py --device_id=0 > train.log 2>&1 &
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过train.log文件查看结果。
|
||||
|
||||
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
|
||||
|
||||
```bash
|
||||
# grep "loss is " train.log
|
||||
epoch:1 step:390, loss is 1.4842823
|
||||
epcoh:2 step:390, loss is 1.0897788
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在当前目录下。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
|
||||
|
||||
请遵循以下链接中的说明:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
- 在 ModelArts 进行训练 (如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||
|
||||
- 在 ModelArts 上使用8卡训练 VID-youtube-bb 数据集
|
||||
|
||||
```python
|
||||
# (1) 在网页上设置 "is_cloudtrain=True"
|
||||
# 在网页上设置 "is_parallel=True"
|
||||
# 在网页上设置 "unzip_mode=1"(原始的数据集设置为0,tar压缩文件设置为1)
|
||||
# 在网页上设置 "data_url=/cache/data/ytb_vid_filter/"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||
# (3) 在网页上设置你的代码路径为 "/path/siamRPN"
|
||||
# (4) 在网页上设置启动文件为 "train.py"
|
||||
# (5) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (6) 创建训练作业
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 评估过程如下,需要vot数据集对应video的图片放于对应文件夹的color文件夹下,标签groundtruth.txt放于该目录下。
|
||||
|
||||
```bash
|
||||
# 使用数据集
|
||||
python eval.py --device_id=0 --dataset_path=/path/dataset --checkpoint_path=/path/ckpt/siamRPN-50_1417.ckpt --filename=eval.json &> evallog &
|
||||
```
|
||||
|
||||
- 上述python命令在后台运行,可通过`evallog`文件查看评估进程,结束后可通过`eval.json`文件查看评估结果。评估结果如下:
|
||||
|
||||
```bash
|
||||
{... "all_videos": {"accuracy": 0.5809545709441025, "robustness": 0.33422978326730364, "eao": 0.3102655908013835}}
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 训练性能
|
||||
|
||||
| 参数 | siamRPN(Ascend) |
|
||||
| -------------------------- | ---------------------------------------------- |
|
||||
| 模型版本 | siamRPN |
|
||||
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755 GB |
|
||||
| 上传日期 | 2021-07-22 |
|
||||
| MindSpore版本 | 1.2.0-alpha |
|
||||
| 数据集 |VID-youtube-bb |
|
||||
| 训练参数 |epoch=50, steps=1147, batch_size = 32 |
|
||||
| 优化器 | SGD |
|
||||
| 损失函数 | 自定义损失函数 |
|
||||
| 输出 | 目标框 |
|
||||
| 损失 |100~0.05 |
|
||||
| 速度 | 8卡:120毫秒/步 |
|
||||
| 总时长 | 8卡:12.3小时 |
|
||||
| 调优检查点 | 247.58MB(.ckpt 文件) |
|
||||
| 脚本 | [siamRPN脚本](https://gitee.com/mindspore/mindspore/tree/r1.2/model_zoo/research/cv/siamRPN) |
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | siamRPN(Ascend) | siamRPN(Ascend) |
|
||||
| ------------------- | --------------------------- | --------------------------- |
|
||||
| 模型版本 | simaRPN | simaRPN |
|
||||
| 资源 | Ascend 910 | Ascend 910 |
|
||||
| 上传日期 | 2021-07-22 | 2021-07-22 |
|
||||
| MindSpore版本 | 1.2.0-alpha | 1.2.0-alpha |
|
||||
| 数据集 | vot2015,60个video | vot2016,60个video |
|
||||
| batch_size | 1 | 1 |
|
||||
| 输出 | 目标框 | 目标框 |
|
||||
| 准确率 | 单卡:accuracy:0.58,robustness:0.33,eao:0.31; | 单卡:accuracy:0.56,robustness:0.39,eao:0.28;|
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
在train.py中,我们设置了随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval vot"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src import evaluation as eval_
|
||||
from src.net import SiameseRPN
|
||||
from src.tracker import SiamRPNTracker
|
||||
|
||||
import cv2
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
|
||||
def get_axis_aligned_bbox(region):
|
||||
""" convert region to (cx, cy, w, h) that represent by axis aligned box
|
||||
"""
|
||||
nv = len(region)
|
||||
region = np.array(region)
|
||||
if nv == 8:
|
||||
x1 = min(region[0::2])
|
||||
x2 = max(region[0::2])
|
||||
y1 = min(region[1::2])
|
||||
y2 = max(region[1::2])
|
||||
A1 = np.linalg.norm(region[0:2] - region[2:4]) * \
|
||||
np.linalg.norm(region[2:4] - region[4:6])
|
||||
A2 = (x2 - x1) * (y2 - y1)
|
||||
s = np.sqrt(A1 / A2)
|
||||
w = s * (x2 - x1) + 1
|
||||
h = s * (y2 - y1) + 1
|
||||
x = x1
|
||||
y = y1
|
||||
else:
|
||||
x = region[0]
|
||||
y = region[1]
|
||||
w = region[2]
|
||||
h = region[3]
|
||||
|
||||
return x, y, w, h
|
||||
|
||||
|
||||
def test(model_path, data_path, save_name):
|
||||
""" using tracking """
|
||||
# ------------ prepare data -----------
|
||||
direct_file = os.path.join(data_path, 'list.txt')
|
||||
with open(direct_file, 'r') as f:
|
||||
direct_lines = f.readlines()
|
||||
video_names = np.sort([x.split('\n')[0] for x in direct_lines])
|
||||
video_paths = [os.path.join(data_path, x) for x in video_names]
|
||||
# ------------ prepare models -----------
|
||||
model = SiameseRPN()
|
||||
param_dict = load_checkpoint(model_path)
|
||||
param_not_load = load_param_into_net(model, param_dict)
|
||||
print(param_not_load)
|
||||
# ------------ starting validation -----------
|
||||
results = {}
|
||||
accuracy = 0
|
||||
all_overlaps = []
|
||||
all_failures = []
|
||||
gt_lenth = []
|
||||
for video_path in tqdm(video_paths, total=len(video_paths)):
|
||||
# ------------ prepare groundtruth -----------
|
||||
groundtruth_path = os.path.join(video_path, 'groundtruth.txt')
|
||||
with open(groundtruth_path, 'r') as f:
|
||||
boxes = f.readlines()
|
||||
if ',' in boxes[0]:
|
||||
boxes = [list(map(float, box.split(','))) for box in boxes]
|
||||
else:
|
||||
boxes = [list(map(int, box.split())) for box in boxes]
|
||||
gt = boxes.copy()
|
||||
gt[:][2] = gt[:][0] + gt[:][2]
|
||||
gt[:][3] = gt[:][1] + gt[:][3]
|
||||
frames = [os.path.join(video_path, 'color', x) for x in np.sort(os.listdir(os.path.join(video_path, '/color')))]
|
||||
frames = [x for x in frames if '.jpg' in x]
|
||||
tic = time.perf_counter()
|
||||
template_idx = 0
|
||||
tracker = SiamRPNTracker(model)
|
||||
res = []
|
||||
for idx, frame in tqdm(enumerate(frames), total=len(frames)):
|
||||
frame = cv2.imdecode(np.fromfile(frame, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
h, w = frame.shape[0], frame.shape[1]
|
||||
if idx == template_idx:
|
||||
box = get_axis_aligned_bbox(boxes[idx])
|
||||
tracker.init(frame, box)
|
||||
res.append([1])
|
||||
elif idx < template_idx:
|
||||
res.append([0])
|
||||
else:
|
||||
bbox, _ = tracker.update(frame)
|
||||
bbox = np.array(bbox)
|
||||
bbox = list((bbox[0] - bbox[2] / 2 + 1 / 2, bbox[1] - bbox[3] / 2 + 1 / 2, \
|
||||
bbox[0] + bbox[2] / 2 - 1 / 2, bbox[1] + bbox[3] / 2 - 1 / 2))
|
||||
if eval_.judge_failures(bbox, boxes[idx], 0):
|
||||
res.append([2])
|
||||
print('fail')
|
||||
template_idx = min(idx + 5, len(frames) - 1)
|
||||
else:
|
||||
res.append(bbox)
|
||||
duration = time.perf_counter() - tic
|
||||
acc, overlaps, failures, num_failures = eval_.calculate_accuracy_failures(res, gt, [w, h])
|
||||
accuracy += acc
|
||||
result1 = {}
|
||||
result1['acc'] = acc
|
||||
result1['num_failures'] = num_failures
|
||||
result1['fps'] = round(len(frames) / duration, 3)
|
||||
results[video_path.split('/')[-1]] = result1
|
||||
all_overlaps.append(overlaps)
|
||||
all_failures.append(failures)
|
||||
gt_lenth.append(len(frames))
|
||||
all_length = sum([len(x) for x in all_overlaps])
|
||||
robustness = sum([len(x) for x in all_failures]) / all_length * 100
|
||||
eao = eval_.calculate_eao("VOT2015", all_failures, all_overlaps, gt_lenth)
|
||||
result1 = {}
|
||||
result1['accuracy'] = accuracy / float(len(video_paths))
|
||||
result1['robustness'] = robustness
|
||||
result1['eao'] = eao
|
||||
results['all_videos'] = result1
|
||||
print('accuracy is ', accuracy / float(len(video_paths)))
|
||||
print('robustness is ', robustness)
|
||||
print('eao is ', eao)
|
||||
json.dump(results, open(save_name, 'w'))
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
parser = argparse.ArgumentParser(description='Mindspore SiameseRPN Infering')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend'), help='run platform')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='DEVICE_ID')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of siamRPN')
|
||||
parser.add_argument('--filename', type=str, default='', help='save result file')
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
if args.platform == 'Ascend':
|
||||
device_id = args.device_id
|
||||
context.set_context(device_id=device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
model_file_path = args.checkpoint_path
|
||||
data_file_path = args.dataset_path
|
||||
save_file_name = args.filename
|
||||
test(model_path=model_file_path, data_path=data_file_path, save_name=save_file_name)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" export script """
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import context, Tensor, export
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from src.net import SiameseRPN
|
||||
|
||||
|
||||
def siamrpn_export():
|
||||
""" export function """
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=args.device_id)
|
||||
net = SiameseRPN(groups=1)
|
||||
load_checkpoint(args.ckpt_file, net=net)
|
||||
net.set_train(False)
|
||||
input_data1 = Tensor(np.zeros([1, 3, 127, 127]), mindspore.float32)
|
||||
input_data2 = Tensor(np.zeros([1, 3, 255, 255]), mindspore.float32)
|
||||
input_data = [input_data1, input_data2]
|
||||
export(net, *input_data, file_name='siamrpn3', file_format="MINDIR")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument('--ckpt_file', type=str, required=True, help='siamRPN ckpt file.')
|
||||
args = parser.parse_args()
|
||||
siamrpn_export()
|
|
@ -0,0 +1,5 @@
|
|||
lmdb
|
||||
fire
|
||||
opencv-python
|
||||
tqdm
|
||||
Shaply
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH RANK_TABLE"
|
||||
echo "For example: bash run.sh 0"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
|
||||
export DEVICE_ID=$DEVICE_ID
|
||||
python3 train.py --device_id=$DEVICE_ID > train.log 2>&1 &
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH RANK_TABLE"
|
||||
echo "For example: bash run.sh /path/dataset /path/rank_table"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
RANK_SIZE=8
|
||||
RANK_TABLE=$(get_real_path $1)
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
export RANK_TABLE_FILE=$RANK_TABLE
|
||||
|
||||
start_divice=0
|
||||
for((i=$start_divice;i<$[$start_divice+$RANK_SIZE];i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
mkdir device$i/src
|
||||
cp ./train.py ./device$i
|
||||
cp ./src/net.py ./src/loss.py ./src/config.py ./src/util.py ./src/data_loader.py ./src/generate_anchors.py ./device$i/src
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
python3 train.py --is_parallel=True &> log &
|
||||
cd ../
|
||||
done
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_ID=$1
|
||||
export DATA_NAME=$2
|
||||
export MODEL_PATH=$3
|
||||
export FILENAME=$4
|
||||
python eval.py --device_id=$DEVICE_ID --dataset_path=$DATA_NAME --checkpoint_path=$MODEL_PATH --filename=$FILENAME &> evallog &
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" unique configs """
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Config setup
|
||||
"""
|
||||
exemplar_size = 127 # exemplar size
|
||||
instance_size = 255 # instance size 271
|
||||
context_amount = 0.5 # context amount
|
||||
sample_type = 'uniform'
|
||||
exem_stretch = False
|
||||
scale_range = (0.001, 0.7)
|
||||
ratio_range = (0.1, 10)
|
||||
# pairs per video
|
||||
pairs_per_video_per_epoch = 2
|
||||
frame_range_vid = 100 # frame range of choosing the instance
|
||||
frame_range_ytb = 1
|
||||
|
||||
# training related
|
||||
checkpoint_path = r'./ckpt'
|
||||
pretrain_model = 'mindspore_alexnet.ckpt'
|
||||
train_path = r'./ytb_vid_filter'
|
||||
cur_epoch = 0
|
||||
max_epoches = 50
|
||||
batch_size = 32
|
||||
|
||||
start_lr = 3e-2
|
||||
end_lr = 1e-7
|
||||
momentum = 0.9
|
||||
weight_decay = 0.0005
|
||||
check = True
|
||||
|
||||
|
||||
max_translate = 12 # max translation of random shift
|
||||
max_stretch = 0.15 # scale step of instance image
|
||||
total_stride = 8 # total stride of backbone
|
||||
valid_scope = int((instance_size - exemplar_size) / total_stride / 2)
|
||||
anchor_scales = np.array([8,])
|
||||
anchor_ratios = np.array([0.33, 0.5, 1, 2, 3])
|
||||
anchor_num = len(anchor_scales) * len(anchor_ratios)
|
||||
anchor_base_size = 8
|
||||
pos_threshold = 0.6
|
||||
neg_threshold = 0.3
|
||||
pos_num = 16
|
||||
neg_num = 48
|
||||
|
||||
# tracking related
|
||||
gray_ratio = 0.25
|
||||
score_size = int((instance_size - exemplar_size) / 8 + 1)
|
||||
penalty_k = 0.22
|
||||
window_influence = 0.40
|
||||
lr_box = 0.30
|
||||
min_scale = 0.1
|
||||
max_scale = 10
|
||||
|
||||
#cloud train
|
||||
cloud_data_path = '/cache/data'
|
||||
|
||||
|
||||
config = Config()
|
|
@ -0,0 +1,208 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" data loader class """
|
||||
import pickle
|
||||
import glob
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from src.generate_anchors import generate_anchors
|
||||
from src.util import box_transform, compute_iou, crop_and_pad
|
||||
from src.config import config
|
||||
|
||||
class TrainDataLoader:
|
||||
""" dataloader """
|
||||
def __init__(self, data_dir):
|
||||
self.ret = {}
|
||||
self.data_dir = data_dir
|
||||
meta_data_path = os.path.join(data_dir, 'meta_data.pkl')
|
||||
config.pairs_per_video_per_epoch = 2
|
||||
self.meta_data = pickle.load(open(meta_data_path, 'rb'))
|
||||
self.video_names = [x[0] for x in self.meta_data]
|
||||
self.meta_data = {x[0]: x[1] for x in self.meta_data}
|
||||
self.training = True
|
||||
for key in self.meta_data.keys():
|
||||
trajs = self.meta_data[key]
|
||||
for trkid in list(trajs.keys()):
|
||||
if len(trajs[trkid]) < 2:
|
||||
del trajs[trkid]
|
||||
#dataset config
|
||||
self.num = len(self.video_names) if config.pairs_per_video_per_epoch is None or not self.training \
|
||||
else config.pairs_per_video_per_epoch * len(self.video_names)
|
||||
self.valid_scope = int((config.instance_size - config.exemplar_size) / 8 / 2)*2+1
|
||||
self.anchors = generate_anchors(total_stride=config.total_stride, base_size=config.anchor_base_size,
|
||||
scales=config.anchor_scales, \
|
||||
ratios=config.anchor_ratios, score_size=self.valid_scope)
|
||||
|
||||
def imread(self, image_name):
|
||||
img = cv2.imread(image_name)
|
||||
return img
|
||||
|
||||
def RandomStretch(self, sample, gt_w, gt_h):
|
||||
scale_h = 1.0 + np.random.uniform(-config.max_stretch, config.max_stretch)
|
||||
scale_w = 1.0 + np.random.uniform(-config.max_stretch, config.max_stretch)
|
||||
h, w = sample.shape[:2]
|
||||
shape = int(w * scale_w), int(h * scale_h)
|
||||
scale_w = int(w * scale_w) / w
|
||||
scale_h = int(h * scale_h) / h
|
||||
gt_w = gt_w * scale_w
|
||||
gt_h = gt_h * scale_h
|
||||
return cv2.resize(sample, shape, cv2.INTER_LINEAR), gt_w, gt_h
|
||||
|
||||
def compute_target(self, anchors, box, pos_threshold=0.6, neg_threshold=0.3, pos_num=16, num_neg=48):
|
||||
""" compute iou to label """
|
||||
total_num = pos_num + num_neg
|
||||
regression_target = box_transform(anchors, box)
|
||||
iou = compute_iou(anchors, box).flatten()
|
||||
pos_cand = np.where(iou > pos_threshold)[0]
|
||||
if len(pos_cand) > pos_num:
|
||||
pos_index = np.random.choice(pos_cand, pos_num, replace=False)
|
||||
|
||||
else:
|
||||
pos_index = pos_cand
|
||||
pos_num = len(pos_index)
|
||||
neg_cand = np.where(iou < neg_threshold)[0]
|
||||
neg_num = total_num - pos_num
|
||||
neg_index = np.random.choice(neg_cand, neg_num, replace=False)
|
||||
label = np.ones_like(iou) * -100
|
||||
label[pos_index] = 1
|
||||
label[neg_index] = 0
|
||||
pos_neg_diff = np.hstack((label.reshape(-1, 1), regression_target))
|
||||
return pos_neg_diff
|
||||
|
||||
def __getitem__(self, idx):
|
||||
all_idx = np.arange(self.num)
|
||||
np.random.shuffle(all_idx)
|
||||
all_idx = np.insert(all_idx, 0, idx, 0)
|
||||
for vedio_idx in all_idx:
|
||||
vedio_idx = vedio_idx % len(self.video_names)
|
||||
video = self.video_names[vedio_idx]
|
||||
trajs = self.meta_data[video]
|
||||
# sample one trajs
|
||||
if not trajs.keys():
|
||||
continue
|
||||
|
||||
trkid = np.random.choice(list(trajs.keys()))
|
||||
traj = trajs[trkid]
|
||||
assert len(traj) > 1, "video_name: {}".format(video)
|
||||
# sample exemplar
|
||||
exemplar_idx = np.random.choice(list(range(len(traj))))
|
||||
if 'ILSVRC2015' in video:
|
||||
exemplar_name = \
|
||||
glob.glob(os.path.join(self.data_dir, video, traj[exemplar_idx] + ".{:02d}.x*.jpg".format(trkid)))[
|
||||
0]
|
||||
else:
|
||||
exemplar_name = \
|
||||
glob.glob(os.path.join(self.data_dir, video, traj[exemplar_idx] + ".{}.x*.jpg".format(trkid)))[0]
|
||||
exemplar_gt_w, exemplar_gt_h, exemplar_w_image, exemplar_h_image = \
|
||||
float(exemplar_name.split('_')[-4]), float(exemplar_name.split('_')[-3]), \
|
||||
float(exemplar_name.split('_')[-2]), float(exemplar_name.split('_')[-1][:-4])
|
||||
exemplar_ratio = min(exemplar_gt_w / exemplar_gt_h, exemplar_gt_h / exemplar_gt_w)
|
||||
exemplar_scale = exemplar_gt_w * exemplar_gt_h / (exemplar_w_image * exemplar_h_image)
|
||||
if not config.scale_range[0] <= exemplar_scale < config.scale_range[1]:
|
||||
continue
|
||||
if not config.ratio_range[0] <= exemplar_ratio < config.ratio_range[1]:
|
||||
continue
|
||||
|
||||
exemplar_img = self.imread(exemplar_name)
|
||||
# sample instance
|
||||
if 'ILSVRC2015' in exemplar_name:
|
||||
frame_range = config.frame_range_vid
|
||||
else:
|
||||
frame_range = config.frame_range_ytb
|
||||
low_idx = max(0, exemplar_idx - frame_range)
|
||||
up_idx = min(len(traj), exemplar_idx + frame_range + 1)
|
||||
weights = self._sample_weights(exemplar_idx, low_idx, up_idx, config.sample_type)
|
||||
instance = np.random.choice(traj[low_idx:exemplar_idx] + traj[exemplar_idx + 1:up_idx], p=weights)
|
||||
|
||||
if 'ILSVRC2015' in video:
|
||||
instance_name = \
|
||||
glob.glob(os.path.join(self.data_dir, video, instance + ".{:02d}.x*.jpg".format(trkid)))[0]
|
||||
else:
|
||||
instance_name = glob.glob(os.path.join(self.data_dir, video, instance + ".{}.x*.jpg".format(trkid)))[0]
|
||||
|
||||
instance_gt_w, instance_gt_h, instance_w_image, instance_h_image = \
|
||||
float(instance_name.split('_')[-4]), float(instance_name.split('_')[-3]), \
|
||||
float(instance_name.split('_')[-2]), float(instance_name.split('_')[-1][:-4])
|
||||
instance_ratio = min(instance_gt_w / instance_gt_h, instance_gt_h / instance_gt_w)
|
||||
instance_scale = instance_gt_w * instance_gt_h / (instance_w_image * instance_h_image)
|
||||
if not config.scale_range[0] <= instance_scale < config.scale_range[1]:
|
||||
continue
|
||||
if not config.ratio_range[0] <= instance_ratio < config.ratio_range[1]:
|
||||
continue
|
||||
|
||||
instance_img = self.imread(instance_name)
|
||||
|
||||
if np.random.rand(1) < config.gray_ratio:
|
||||
exemplar_img = cv2.cvtColor(exemplar_img, cv2.COLOR_RGB2GRAY)
|
||||
exemplar_img = cv2.cvtColor(exemplar_img, cv2.COLOR_GRAY2RGB)
|
||||
instance_img = cv2.cvtColor(instance_img, cv2.COLOR_RGB2GRAY)
|
||||
instance_img = cv2.cvtColor(instance_img, cv2.COLOR_GRAY2RGB)
|
||||
if config.exem_stretch:
|
||||
exemplar_img, exemplar_gt_w, exemplar_gt_h = self.RandomStretch(exemplar_img, exemplar_gt_w,
|
||||
exemplar_gt_h)
|
||||
exemplar_img, _ = crop_and_pad(exemplar_img, (exemplar_img.shape[1] - 1) / 2,
|
||||
(exemplar_img.shape[0] - 1) / 2, config.exemplar_size,
|
||||
config.exemplar_size)
|
||||
|
||||
instance_img, gt_w, gt_h = self.RandomStretch(instance_img, instance_gt_w, instance_gt_h)
|
||||
im_h, im_w, _ = instance_img.shape
|
||||
cy_o = (im_h - 1) / 2
|
||||
cx_o = (im_w - 1) / 2
|
||||
cy = cy_o + np.random.randint(- config.max_translate, config.max_translate + 1)
|
||||
cx = cx_o + np.random.randint(- config.max_translate, config.max_translate + 1)
|
||||
gt_cx = cx_o - cx
|
||||
gt_cy = cy_o - cy
|
||||
|
||||
instance_img_1, _ = crop_and_pad(instance_img, cx, cy, config.instance_size, config.instance_size)
|
||||
|
||||
|
||||
pos_neg_diff = self.compute_target(self.anchors, np.array(list(map(round, [gt_cx, gt_cy, gt_w, gt_h]))),
|
||||
pos_threshold=config.pos_threshold, neg_threshold=config.neg_threshold,
|
||||
pos_num=config.pos_num, num_neg=config.neg_num)
|
||||
self.ret['template_cropped_resized'] = exemplar_img
|
||||
self.ret['detection_cropped_resized'] = instance_img_1
|
||||
self.ret['pos_neg_diff'] = pos_neg_diff
|
||||
self._tranform()
|
||||
return (self.ret['template_tensor'], self.ret['detection_tensor'], self.ret['pos_neg_diff_tensor'])
|
||||
|
||||
def _tranform(self):
|
||||
"""PIL to Tensor"""
|
||||
template_pil = self.ret['template_cropped_resized'].copy()
|
||||
detection_pil = self.ret['detection_cropped_resized'].copy()
|
||||
pos_neg_diff = self.ret['pos_neg_diff'].copy()
|
||||
|
||||
template_tensor = (np.transpose(np.array(template_pil), (2, 0, 1))).astype(np.float32)
|
||||
detection_tensor = (np.transpose(np.array(detection_pil), (2, 0, 1))).astype(np.float32)
|
||||
self.ret['template_tensor'] = template_tensor
|
||||
self.ret['detection_tensor'] = detection_tensor
|
||||
|
||||
self.ret['pos_neg_diff_tensor'] = pos_neg_diff
|
||||
|
||||
def _sample_weights(self, center, low_idx, high_idx, s_type='uniform'):
|
||||
""" sample weights"""
|
||||
weights = list(range(low_idx, high_idx))
|
||||
weights.remove(center)
|
||||
weights = np.array(weights)
|
||||
if s_type == 'linear':
|
||||
weights = abs(weights - center)
|
||||
elif s_type == 'sqrt':
|
||||
weights = np.sqrt(abs(weights - center))
|
||||
elif s_type == 'uniform':
|
||||
weights = np.ones_like(weights)
|
||||
return weights / sum(weights)
|
||||
|
||||
def __len__(self):
|
||||
return self.num
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" evaluation """
|
||||
|
||||
|
||||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
|
||||
|
||||
|
||||
def calculate_eao(dataset_name, all_failures, all_overlaps, gt_traj_length, skipping=5):
|
||||
'''
|
||||
input:dataset name
|
||||
all_failures: type is list , index of failure
|
||||
all_overlaps: type is list , length of list is the length of all_failures
|
||||
gt_traj_length: type is list , length of list is the length of all_failures
|
||||
skipping:number of skipping per failing
|
||||
'''
|
||||
if dataset_name == "VOT2016":
|
||||
|
||||
low = 108
|
||||
high = 371
|
||||
|
||||
elif dataset_name == "VOT2015":
|
||||
low = 108
|
||||
high = 371
|
||||
|
||||
fragment_num = sum([len(x)+1 for x in all_failures])
|
||||
max_len = max([len(x) for x in all_overlaps])
|
||||
tags = [1] * max_len
|
||||
seq_weight = 1 / (1 + 1e-10) # division by zero
|
||||
|
||||
eao = {}
|
||||
|
||||
# prepare segments
|
||||
fweights = np.ones(fragment_num, dtype=np.float32) * np.nan
|
||||
fragments = np.ones((fragment_num, max_len), dtype=np.float32) * np.nan
|
||||
seg_counter = 0
|
||||
for traj_len, failures, overlaps in zip(gt_traj_length, all_failures, all_overlaps):
|
||||
if failures:
|
||||
points = [x+skipping for x in failures if
|
||||
x+skipping <= len(overlaps)]
|
||||
points.insert(0, 0)
|
||||
for i, _ in enumerate(points):
|
||||
if i != len(points) - 1:
|
||||
fragment = np.array(overlaps[points[i]:points[i+1]+1], dtype=np.float32)
|
||||
fragments[seg_counter, :] = 0
|
||||
else:
|
||||
fragment = np.array(overlaps[points[i]:], dtype=np.float32)
|
||||
fragment[np.isnan(fragment)] = 0
|
||||
fragments[seg_counter, :len(fragment)] = fragment
|
||||
if i != len(points) - 1:
|
||||
tag_value = tags[points[i]:points[i+1]+1]
|
||||
w = sum(tag_value) / (points[i+1] - points[i]+1)
|
||||
fweights[seg_counter] = seq_weight * w
|
||||
else:
|
||||
tag_value = tags[points[i]:len(overlaps)]
|
||||
w = sum(tag_value) / (traj_len - points[i]+1e-16)
|
||||
fweights[seg_counter] = seq_weight * w
|
||||
seg_counter += 1
|
||||
else:
|
||||
# no failure
|
||||
max_idx = min(len(overlaps), max_len)
|
||||
fragments[seg_counter, :max_idx] = overlaps[:max_idx]
|
||||
tag_value = tags[0: max_idx]
|
||||
w = sum(tag_value) / max_idx
|
||||
fweights[seg_counter] = seq_weight * w
|
||||
seg_counter += 1
|
||||
|
||||
expected_overlaps = calculate_expected_overlap(fragments, fweights)
|
||||
print(len(expected_overlaps))
|
||||
# calculate eao
|
||||
weight = np.zeros((len(expected_overlaps)))
|
||||
weight[low-1:high-1+1] = 1
|
||||
expected_overlaps = np.array(expected_overlaps, dtype=np.float32)
|
||||
is_valid = np.logical_not(np.isnan(expected_overlaps))
|
||||
eao_ = np.sum(expected_overlaps[is_valid] * weight[is_valid]) / np.sum(weight[is_valid])
|
||||
eao = eao_
|
||||
return eao
|
||||
|
||||
def calculate_expected_overlap(fragments, fweights):
|
||||
""" compute expected iou """
|
||||
max_len = fragments.shape[1]
|
||||
expected_overlaps = np.zeros((max_len), np.float32)
|
||||
expected_overlaps[0] = 1
|
||||
|
||||
# TODO Speed Up
|
||||
for i in range(1, max_len):
|
||||
mask = np.logical_not(np.isnan(fragments[:, i]))
|
||||
if np.any(mask):
|
||||
fragment = fragments[mask, 1:i+1]
|
||||
seq_mean = np.sum(fragment, 1) / fragment.shape[1]
|
||||
expected_overlaps[i] = np.sum(seq_mean *
|
||||
fweights[mask]) / np.sum(fweights[mask])
|
||||
return expected_overlaps
|
||||
|
||||
def calculate_accuracy_failures(pred_trajectory, gt_trajectory, \
|
||||
bound=None):
|
||||
'''
|
||||
args:
|
||||
pred_trajectory:list of bbox
|
||||
gt_trajectory: list of bbox ,shape == pred_trajectory
|
||||
bound :w and h of img
|
||||
return :
|
||||
overlaps:list ,iou value in pred_trajectory
|
||||
acc : mean iou value
|
||||
failures: failures point in pred_trajectory
|
||||
num_failures: number of failres
|
||||
'''
|
||||
|
||||
|
||||
overlaps = []
|
||||
failures = []
|
||||
|
||||
for i, pred_traj in enumerate(pred_trajectory):
|
||||
if len(pred_traj) == 1:
|
||||
|
||||
if pred_trajectory[i][0] == 2:
|
||||
failures.append(i)
|
||||
overlaps.append(float("nan"))
|
||||
|
||||
else:
|
||||
if bound is not None:
|
||||
poly_img = Polygon(np.array([[0, 0],\
|
||||
[0, bound[1]],\
|
||||
[bound[0], bound[1]],\
|
||||
[bound[0], 0]])).convex_hull
|
||||
|
||||
|
||||
if len(gt_trajectory[i]) == 8:
|
||||
|
||||
poly_pred = Polygon(np.array([[pred_trajectory[i][0], pred_trajectory[i][1]], \
|
||||
[pred_trajectory[i][2], pred_trajectory[i][1]], \
|
||||
[pred_trajectory[i][2], pred_trajectory[i][3]], \
|
||||
[pred_trajectory[i][0], pred_trajectory[i][3]] \
|
||||
])).convex_hull
|
||||
poly_gt = Polygon(np.array(gt_trajectory[i]).reshape(4, 2)).convex_hull
|
||||
if bound is not None:
|
||||
gt_inter_img = poly_gt.intersection(poly_img)
|
||||
pred_inter_img = poly_pred.intersection(poly_img)
|
||||
inter_area = gt_inter_img.intersection(pred_inter_img).area
|
||||
overlap = inter_area /(gt_inter_img.area + pred_inter_img.area - inter_area)
|
||||
else:
|
||||
inter_area = poly_gt.intersection(poly_pred).area
|
||||
overlap = inter_area / (poly_gt.area + poly_pred.area - inter_area)
|
||||
elif len(gt_trajectory[i]) == 4:
|
||||
|
||||
overlap = iou(np.array(pred_trajectory[i]).reshape(-1, 4), np.array(gt_trajectory[i]).reshape(-1, 4))
|
||||
overlaps.append(overlap)
|
||||
acc = 0
|
||||
num_failures = len(failures)
|
||||
if overlaps:
|
||||
acc = np.nanmean(overlaps)
|
||||
return acc, overlaps, failures, num_failures
|
||||
|
||||
def judge_failures(pred_bbox, gt_bbox, threshold=0):
|
||||
"""" judge whether to fail or not """
|
||||
if len(gt_bbox) == 4:
|
||||
if iou(np.array(pred_bbox).reshape(-1, 4), np.array(gt_bbox).reshape(-1, 4)) > threshold:
|
||||
return False
|
||||
else:
|
||||
poly_pred = Polygon(np.array([[pred_bbox[0], pred_bbox[1]], \
|
||||
[pred_bbox[2], pred_bbox[1]], \
|
||||
[pred_bbox[2], pred_bbox[3]], \
|
||||
[pred_bbox[0], pred_bbox[3]] \
|
||||
])).convex_hull
|
||||
poly_gt = Polygon(np.array(gt_bbox).reshape(4, 2)).convex_hull
|
||||
inter_area = poly_gt.intersection(poly_pred).area
|
||||
overlap = inter_area / (poly_gt.area + poly_pred.area - inter_area)
|
||||
if overlap > threshold:
|
||||
return False
|
||||
return True
|
||||
|
||||
def iou(box1, box2):
|
||||
""" compute iou """
|
||||
box1, box2 = box1.copy(), box2.copy()
|
||||
N = box1.shape[0]
|
||||
K = box2.shape[0]
|
||||
box1 = np.array(box1.reshape((N, 1, 4)))+np.zeros((1, K, 4))#box1=[N,K,4]
|
||||
box2 = np.array(box2.reshape((1, K, 4)))+np.zeros((N, 1, 4))#box1=[N,K,4]
|
||||
x_max = np.max(np.stack((box1[:, :, 0], box2[:, :, 0]), axis=-1), axis=2)
|
||||
x_min = np.min(np.stack((box1[:, :, 2], box2[:, :, 2]), axis=-1), axis=2)
|
||||
y_max = np.max(np.stack((box1[:, :, 1], box2[:, :, 1]), axis=-1), axis=2)
|
||||
y_min = np.min(np.stack((box1[:, :, 3], box2[:, :, 3]), axis=-1), axis=2)
|
||||
tb = x_min-x_max
|
||||
lr = y_min-y_max
|
||||
tb[np.where(tb < 0)] = 0
|
||||
lr[np.where(lr < 0)] = 0
|
||||
over_square = tb*lr
|
||||
all_square = (box1[:, :, 2] - box1[:, :, 0]) * (box1[:, :, 3] - box1[:, :, 1]) + (box2[:, :, 2] - \
|
||||
box2[:, :, 0]) * (box2[:, :, 3] - box2[:, :, 1]) - over_square
|
||||
return over_square / all_square
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" anchor generator"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_anchors(total_stride, base_size, scales, ratios, score_size):
|
||||
""" anchor generator class"""
|
||||
anchor_num = len(ratios) * len(scales)
|
||||
anchor = np.zeros((anchor_num, 4), dtype=np.float32)
|
||||
size = base_size * base_size
|
||||
count = 0
|
||||
for ratio in ratios:
|
||||
ws = int(np.sqrt(size / ratio))
|
||||
hs = int(ws * ratio)
|
||||
for scale in scales:
|
||||
wws = ws * scale
|
||||
hhs = hs * scale
|
||||
anchor[count, 0] = 0
|
||||
anchor[count, 1] = 0
|
||||
anchor[count, 2] = wws
|
||||
anchor[count, 3] = hhs
|
||||
count += 1
|
||||
|
||||
anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
|
||||
ori = - (score_size // 2) * total_stride
|
||||
# the left displacement
|
||||
xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
|
||||
[ori + total_stride * dy for dy in range(score_size)])
|
||||
xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
|
||||
np.tile(yy.flatten(), (anchor_num, 1)).flatten()
|
||||
anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
|
||||
return anchor
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" define loss function"""
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MultiBoxLoss(nn.Cell):
|
||||
""" loss class """
|
||||
def __init__(self, batch_size=1):
|
||||
super(MultiBoxLoss, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.cast = ops.Cast()
|
||||
self.realsum_false = ops.ReduceSum(keep_dims=False)
|
||||
self.realdiv = ops.RealDiv()
|
||||
self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
self.onehot = ops.OneHot()
|
||||
self.SmooL1loss = nn.SmoothL1Loss()
|
||||
self.realsum_true = ops.ReduceSum(keep_dims=True)
|
||||
self.div = ops.Div()
|
||||
|
||||
self.equal = ops.Equal()
|
||||
self.select = ops.Select()
|
||||
self.tile = ops.Tile()
|
||||
self.reshape = ops.Reshape()
|
||||
self.transpose = ops.Transpose()
|
||||
self.split = ops.Split(axis=0, output_num=self.batch_size)
|
||||
|
||||
self.ones = Tensor(np.ones((1, 1445)), mindspore.float32)
|
||||
self.zeros_class_pred = Tensor(np.zeros((1, 1445, 2)), mindspore.float32)
|
||||
self.nage_class_pred = Tensor(np.ones((1, 1445, 2)), mindspore.float32)
|
||||
self.zeros_class_target = Tensor(np.zeros((1, 1445)), mindspore.float32)
|
||||
|
||||
self.depth, self.on_value, self.off_value = 2, Tensor(1.0, mindspore.float32), Tensor(0.0, mindspore.float32)
|
||||
|
||||
self.c_all_loss = Tensor(0, mindspore.float32)
|
||||
self.r_all_loss = Tensor(0, mindspore.float32)
|
||||
|
||||
def construct(self, predictions1, predictions2, targets):
|
||||
""" class """
|
||||
cout = self.transpose(self.reshape(predictions1, (-1, 2, 5 * 17 * 17)), (0, 2, 1))
|
||||
rout = self.transpose(self.reshape(predictions2, (-1, 4, 5 * 17 * 17)), (0, 2, 1))
|
||||
cout = self.split(cout)
|
||||
rout = self.split(rout)
|
||||
targets = self.cast(targets, mindspore.float32)
|
||||
ctargets = targets[:, :, 0:1]
|
||||
rtargets = targets[:, :, 1:]
|
||||
ctargets = self.split(ctargets)
|
||||
rtargets = self.split(rtargets)
|
||||
c_all_loss = self.c_all_loss
|
||||
r_all_loss = self.r_all_loss
|
||||
for batch in range(self.batch_size):
|
||||
class_pred, class_target = cout[batch], ctargets[batch]
|
||||
class_pred = self.reshape(class_pred, (1, -1, 2))
|
||||
class_target = self.reshape(class_target, (1, -1))
|
||||
class_target = self.cast(class_target, mindspore.float32)
|
||||
|
||||
pos_mask = self.equal(class_target, self.ones)
|
||||
neg_mask = self.equal(class_target, self.zeros_class_target)
|
||||
class_target_pos = self.select(pos_mask, class_target, self.zeros_class_target - self.ones)
|
||||
class_target_pos = self.cast(class_target_pos, mindspore.int32)
|
||||
class_target_pos = self.reshape(class_target_pos, (-1,))
|
||||
class_target_neg = self.select(neg_mask, class_target, self.zeros_class_target - self.ones)
|
||||
class_target_neg = self.cast(class_target_neg, mindspore.int32)
|
||||
class_target_neg = self.reshape(class_target_neg, (-1,))
|
||||
pos_mask1 = self.cast(pos_mask, mindspore.int32)#
|
||||
pos_mask2 = self.cast(pos_mask, mindspore.float32)
|
||||
pos_num = self.realsum_false(pos_mask2)
|
||||
pos_mask1 = self.reshape(pos_mask1, (1, -1, 1))
|
||||
pos_mask1 = self.tile(pos_mask1, (1, 1, 2))
|
||||
neg_mask1 = self.cast(neg_mask, mindspore.int32)
|
||||
neg_mask2 = self.cast(neg_mask, mindspore.float32)
|
||||
neg_num = self.realsum_false(neg_mask2)
|
||||
neg_mask1 = self.reshape(neg_mask1, (1, -1, 1))
|
||||
neg_mask1 = self.tile(neg_mask1, (1, 1, 2))
|
||||
pos_mask1 = self.cast(pos_mask1, mindspore.bool_)
|
||||
neg_mask1 = self.cast(neg_mask1, mindspore.bool_)
|
||||
class_pred_pos = self.select(pos_mask1, class_pred, self.zeros_class_pred)
|
||||
class_pred_neg = self.select(neg_mask1, class_pred, self.zeros_class_pred)
|
||||
class_pos = self.reshape(class_pred_pos, (-1, 2))
|
||||
class_neg = self.reshape(class_pred_neg, (-1, 2))
|
||||
closs_pos = self.cross_entropy(class_pos, class_target_pos)
|
||||
closs_neg = self.cross_entropy(class_neg, class_target_neg)
|
||||
c_all_loss += (self.realdiv(self.realsum_false(closs_pos), (pos_num + 1e-6)) + \
|
||||
self.realdiv(self.realsum_false(closs_neg), (neg_num + 1e-6)))/2
|
||||
reg_pred = rout[batch]
|
||||
reg_pred = self.reshape(reg_pred, (-1, 4))
|
||||
reg_target = rtargets[batch]
|
||||
reg_target = self.reshape(reg_target, (-1, 4))
|
||||
rloss = self.SmooL1loss(reg_pred, reg_target) # 1445, 4
|
||||
rloss = self.realsum_false(rloss, -1)
|
||||
rloss = rloss / 4
|
||||
rloss = self.reshape(rloss, (1, -1))
|
||||
rloss = self.select(pos_mask, rloss, self.zeros_class_target)
|
||||
rloss = self.realsum_false(rloss) / (pos_num + 1e-6)
|
||||
r_all_loss += rloss
|
||||
loss = (c_all_loss + 5 * r_all_loss) / self.batch_size
|
||||
return loss
|
|
@ -0,0 +1,230 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""net structure"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
class SiameseRPN(nn.Cell):
|
||||
"""
|
||||
SiameseRPN Network.
|
||||
|
||||
Args:
|
||||
groups (int): Size of one batch.
|
||||
k (int): Numbers of one point‘s anchors.
|
||||
s (int): Numbers of one anchor‘s parameters.
|
||||
|
||||
Returns:
|
||||
coutputs tensor, routputs tensor.
|
||||
"""
|
||||
def __init__(self, groups=1, k=5, s=4, is_train=False, is_trackinit=False, is_track=False):
|
||||
super(SiameseRPN, self).__init__()
|
||||
self.groups = groups
|
||||
self.k = k
|
||||
self.s = s
|
||||
self.is_train = is_train
|
||||
self.is_trackinit = is_trackinit
|
||||
self.is_track = is_track
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
self.featureExtract = nn.SequentialCell(
|
||||
[nn.Conv2d(3, 96, kernel_size=11, stride=2, pad_mode='valid', has_bias=True),
|
||||
nn.BatchNorm2d(96, use_batch_statistics=False),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid'),
|
||||
nn.Conv2d(96, 256, kernel_size=5, pad_mode='valid', has_bias=True),
|
||||
nn.BatchNorm2d(256, use_batch_statistics=False),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid'),
|
||||
nn.Conv2d(256, 384, kernel_size=3, pad_mode='valid', has_bias=True),
|
||||
nn.BatchNorm2d(384, use_batch_statistics=False),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(384, 384, kernel_size=3, pad_mode='valid', has_bias=True),
|
||||
nn.BatchNorm2d(384),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(384, 256, kernel_size=3, pad_mode='valid', has_bias=True),
|
||||
nn.BatchNorm2d(256)])
|
||||
self.conv1 = nn.Conv2d(256, 2 * self.k * 256, kernel_size=3, pad_mode='valid', has_bias=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(256, 4 * self.k * 256, kernel_size=3, pad_mode='valid', has_bias=True)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.conv3 = nn.Conv2d(256, 256, kernel_size=3, pad_mode='valid', has_bias=True)
|
||||
self.relu3 = nn.ReLU()
|
||||
self.conv4 = nn.Conv2d(256, 256, kernel_size=3, pad_mode='valid', has_bias=True)
|
||||
self.relu4 = nn.ReLU()
|
||||
|
||||
self.op_split_input = ops.Split(axis=1, output_num=self.groups)
|
||||
self.op_split_krenal = ops.Split(axis=0, output_num=self.groups)
|
||||
self.op_concat = ops.Concat(axis=1)
|
||||
self.conv2d_cout = ops.Conv2D(out_channel=10, kernel_size=4)
|
||||
self.conv2d_rout = ops.Conv2D(out_channel=20, kernel_size=4)
|
||||
self.regress_adjust = nn.Conv2d(4 * self.k, 4 * self.k, 1, pad_mode='valid', has_bias=True)
|
||||
self.reshape = ops.Reshape()
|
||||
self.transpose = ops.Transpose()
|
||||
self.softmax = ops.Softmax(axis=2)
|
||||
self.print = ops.Print()
|
||||
|
||||
def construct(self, template=None, detection=None, ckernal=None, rkernal=None):
|
||||
""" forward function """
|
||||
if self.is_train is True and template is not None and detection is not None:
|
||||
template_feature = self.featureExtract(template)
|
||||
detection_feature = self.featureExtract(detection)
|
||||
|
||||
ckernal = self.conv1(template_feature)
|
||||
ckernal = self.reshape(ckernal.view(self.groups, 2 * self.k, 256, 4, 4), (-1, 256, 4, 4))
|
||||
cinput = self.reshape(self.conv3(detection_feature), (1, -1, 20, 20))
|
||||
|
||||
rkernal = self.conv2(template_feature)
|
||||
rkernal = self.reshape(rkernal.view(self.groups, 4 * self.k, 256, 4, 4), (-1, 256, 4, 4))
|
||||
rinput = self.reshape(self.conv4(detection_feature), (1, -1, 20, 20))
|
||||
c_features = self.op_split_input(cinput)
|
||||
c_weights = self.op_split_krenal(ckernal)
|
||||
r_features = self.op_split_input(rinput)
|
||||
r_weights = self.op_split_krenal(rkernal)
|
||||
coutputs = ()
|
||||
routputs = ()
|
||||
for i in range(self.groups):
|
||||
coutputs = coutputs + (self.conv2d_cout(c_features[i], c_weights[i]),)
|
||||
routputs = routputs + (self.conv2d_rout(r_features[i], r_weights[i]),)
|
||||
coutputs = self.op_concat(coutputs)
|
||||
routputs = self.op_concat(routputs)
|
||||
coutputs = self.reshape(coutputs, (self.groups, 10, 17, 17))
|
||||
routputs = self.reshape(routputs, (self.groups, 20, 17, 17))
|
||||
routputs = self.regress_adjust(routputs)
|
||||
out1, out2 = coutputs, routputs
|
||||
|
||||
elif self.is_trackinit is True and template is not None:
|
||||
|
||||
template = self.transpose(template, (2, 0, 1))
|
||||
template = self.expand_dims(template, 0)
|
||||
template_feature = self.featureExtract(template)
|
||||
|
||||
ckernal = self.conv1(template_feature)
|
||||
ckernal = self.reshape(ckernal.view(self.groups, 2 * self.k, 256, 4, 4), (-1, 256, 4, 4))
|
||||
|
||||
rkernal = self.conv2(template_feature)
|
||||
rkernal = self.reshape(rkernal.view(self.groups, 4 * self.k, 256, 4, 4), (-1, 256, 4, 4))
|
||||
out1, out2 = ckernal, rkernal
|
||||
elif self.is_track is True and detection is not None:
|
||||
detection = self.transpose(detection, (2, 0, 1))
|
||||
detection = self.expand_dims(detection, 0)
|
||||
detection_feature = self.featureExtract(detection)
|
||||
cinput = self.reshape(self.conv3(detection_feature), (1, -1, 20, 20))
|
||||
rinput = self.reshape(self.conv4(detection_feature), (1, -1, 20, 20))
|
||||
|
||||
c_features = self.op_split_input(cinput)
|
||||
c_weights = self.op_split_krenal(ckernal)
|
||||
r_features = self.op_split_input(rinput)
|
||||
r_weights = self.op_split_krenal(rkernal)
|
||||
coutputs = ()
|
||||
routputs = ()
|
||||
for i in range(self.groups):
|
||||
coutputs = coutputs + (self.conv2d_cout(c_features[i], c_weights[i]),)
|
||||
routputs = routputs + (self.conv2d_rout(r_features[i], r_weights[i]),)
|
||||
coutputs = self.op_concat(coutputs)
|
||||
routputs = self.op_concat(routputs)
|
||||
coutputs = self.reshape(coutputs, (self.groups, 10, 17, 17))
|
||||
routputs = self.reshape(routputs, (self.groups, 20, 17, 17))
|
||||
routputs = self.regress_adjust(routputs)
|
||||
pred_score = self.transpose(
|
||||
self.reshape(coutputs, (-1, 2, 1445)), (0, 2, 1))
|
||||
pred_regression = self.transpose(
|
||||
self.reshape(routputs, (-1, 4, 1445)), (0, 2, 1))
|
||||
pred_score = self.softmax(pred_score)[0, :, 1]
|
||||
out1, out2 = pred_score, pred_regression
|
||||
else:
|
||||
out1, out2 = template, detection
|
||||
return out1, out2
|
||||
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 5.0
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
class MyTrainOneStepCell(nn.Cell):
|
||||
"""MyTrainOneStepCell"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(MyTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
|
||||
self.network.set_grad()
|
||||
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.cast = ops.Cast()
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
||||
|
||||
def construct(self, template, detection, target):
|
||||
weights = self.weights
|
||||
|
||||
loss = self.network(template, detection, target)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(template, detection, target, sens)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
class BuildTrainNet(nn.Cell):
|
||||
def __init__(self, network, criterion):
|
||||
super(BuildTrainNet, self).__init__()
|
||||
self.network = network
|
||||
self.criterion = criterion
|
||||
def construct(self, template, detection, target):
|
||||
cout, rout = self.network(template=template, detection=detection, ckernal=detection, rkernal=detection)
|
||||
total_loss = self.criterion(cout, rout, target)
|
||||
return total_loss
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" eval tracker"""
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from src.config import config
|
||||
from src.util import get_exemplar_image, get_instance_image, box_transform_inv
|
||||
from src.generate_anchors import generate_anchors
|
||||
|
||||
|
||||
class SiamRPNTracker:
|
||||
""" Tracker for SiamRPN"""
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
valid_scope = 2 * config.valid_scope + 1
|
||||
self.anchors = generate_anchors(config.total_stride, config.anchor_base_size, config.anchor_scales,
|
||||
config.anchor_ratios,
|
||||
valid_scope)
|
||||
self.window = np.tile(np.outer(np.hanning(config.score_size), np.hanning(config.score_size))[None, :],
|
||||
[config.anchor_num, 1, 1]).flatten()
|
||||
|
||||
def _cosine_window(self, size):
|
||||
"""
|
||||
get the cosine window
|
||||
"""
|
||||
cos_window = np.hanning(int(size[0]))[:, np.newaxis].dot(np.hanning(int(size[1]))[np.newaxis, :])
|
||||
cos_window = cos_window.astype(np.float32)
|
||||
cos_window /= np.sum(cos_window)
|
||||
return cos_window
|
||||
|
||||
def init(self, frame, bbox):
|
||||
""" initialize siamfc tracker
|
||||
Args:
|
||||
frame: an RGB image
|
||||
bbox: one-based bounding box [x, y, width, height]
|
||||
"""
|
||||
self.pos = np.array(
|
||||
[bbox[0] + bbox[2] / 2 - 1 / 2, bbox[1] + bbox[3] / 2 - 1 / 2]) # center x, center y, zero based
|
||||
# same to original code
|
||||
self.target_sz = np.array([bbox[2], bbox[3]]) # width, height
|
||||
self.bbox = np.array([bbox[0] + bbox[2] / 2 - 1 / 2, bbox[1] + bbox[3] / 2 - 1 / 2, bbox[2], bbox[3]])
|
||||
# same to original code
|
||||
self.origin_target_sz = np.array([bbox[2], bbox[3]])
|
||||
# get exemplar img
|
||||
self.img_mean = np.mean(frame, axis=(0, 1))
|
||||
|
||||
exemplar_img, _, _ = get_exemplar_image(frame, self.bbox,
|
||||
config.exemplar_size, config.context_amount, self.img_mean)
|
||||
exemplar_img = Tensor(exemplar_img, ms.float32)
|
||||
self.model.is_train = False
|
||||
self.model.is_trackinit = True
|
||||
self.model.is_track = False
|
||||
self.ckernal, self.rkernal = self.model(template=exemplar_img, detection=exemplar_img,
|
||||
ckernal=exemplar_img, rkernal=exemplar_img)
|
||||
|
||||
|
||||
|
||||
|
||||
def update(self, frame):
|
||||
"""track object based on the previous frame
|
||||
Args:
|
||||
frame: an RGB image
|
||||
|
||||
Returns:
|
||||
bbox: tuple of 1-based bounding box(xmin, ymin, xmax, ymax)
|
||||
"""
|
||||
instance_img_np, _, _, scale_x = get_instance_image(frame, self.bbox, config.exemplar_size,
|
||||
config.instance_size,
|
||||
config.context_amount, self.img_mean)
|
||||
|
||||
self.model.is_train = False
|
||||
self.model.is_trackinit = False
|
||||
self.model.is_track = True
|
||||
instance_img_np = Tensor(instance_img_np, ms.float32)
|
||||
pred_score, pred_regression = self.model(template=instance_img_np, detection=instance_img_np,
|
||||
ckernal=self.ckernal, rkernal=self.rkernal)
|
||||
delta = pred_regression[0].asnumpy()
|
||||
box_pred = box_transform_inv(self.anchors, delta)
|
||||
pred_score = pred_score.asnumpy()
|
||||
|
||||
def change(r):
|
||||
return np.maximum(r, 1. / r)
|
||||
|
||||
def sz(w, h):
|
||||
pad = (w + h) * 0.5
|
||||
sz2 = (w + pad) * (h + pad)
|
||||
return np.sqrt(sz2)
|
||||
|
||||
def sz_wh(wh):
|
||||
pad = (wh[0] + wh[1]) * 0.5
|
||||
sz2 = (wh[0] + pad) * (wh[1] + pad)
|
||||
return np.sqrt(sz2)
|
||||
|
||||
s_c = change(sz(box_pred[:, 2], box_pred[:, 3]) / (sz_wh(self.target_sz * scale_x))) # scale penalty
|
||||
r_c = change((self.target_sz[0] / self.target_sz[1]) / (box_pred[:, 2] / box_pred[:, 3])) # ratio penalty
|
||||
penalty = np.exp(-(r_c * s_c - 1.) * config.penalty_k)
|
||||
pscore = penalty * pred_score
|
||||
pscore = pscore * (1 - config.window_influence) + self.window * config.window_influence
|
||||
best_pscore_id = np.argmax(pscore)
|
||||
|
||||
target = box_pred[best_pscore_id, :] / scale_x
|
||||
|
||||
lr = penalty[best_pscore_id] * pred_score[best_pscore_id] * config.lr_box
|
||||
|
||||
res_x = np.clip(target[0] + self.pos[0], 0, frame.shape[1])
|
||||
res_y = np.clip(target[1] + self.pos[1], 0, frame.shape[0])
|
||||
|
||||
res_w = np.clip(self.target_sz[0] * (1 - lr) + target[2] * lr, config.min_scale * self.origin_target_sz[0],
|
||||
config.max_scale * self.origin_target_sz[0])
|
||||
res_h = np.clip(self.target_sz[1] * (1 - lr) + target[3] * lr, config.min_scale * self.origin_target_sz[1],
|
||||
config.max_scale * self.origin_target_sz[1])
|
||||
|
||||
self.pos = np.array([res_x, res_y])
|
||||
self.target_sz = np.array([res_w, res_h])
|
||||
bbox = np.array([res_x, res_y, res_w, res_h])
|
||||
self.bbox = (
|
||||
np.clip(bbox[0], 0, frame.shape[1]).astype(np.float64),
|
||||
np.clip(bbox[1], 0, frame.shape[0]).astype(np.float64),
|
||||
np.clip(bbox[2], 10, frame.shape[1]).astype(np.float64),
|
||||
np.clip(bbox[3], 10, frame.shape[0]).astype(np.float64))
|
||||
return self.bbox, pred_score[best_pscore_id]
|
|
@ -0,0 +1,200 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" function for data preprocessing"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
def round_up(value):
|
||||
"""Replace the built-in round function to achieve accurate rounding with 2 decimal places
|
||||
|
||||
:param value:object
|
||||
:return:object
|
||||
"""
|
||||
|
||||
return round(value + 1e-6 + 1000) - 1000
|
||||
|
||||
|
||||
def crop_and_pad(img, cx, cy, model_sz, original_sz, img_mean=None):
|
||||
"""change img size
|
||||
|
||||
:param img:rgb
|
||||
:param cx: center x
|
||||
:param cy: center y
|
||||
:param model_sz: changed size
|
||||
:param original_sz: origin size
|
||||
:param img_mean: mean of img
|
||||
:return: changed img ,scale for origin to changed
|
||||
"""
|
||||
im_h, im_w, _ = img.shape
|
||||
|
||||
xmin = cx - (original_sz - 1) / 2
|
||||
xmax = xmin + original_sz - 1
|
||||
ymin = cy - (original_sz - 1) / 2
|
||||
ymax = ymin + original_sz - 1
|
||||
|
||||
left = int(round_up(max(0., -xmin)))
|
||||
top = int(round_up(max(0., -ymin)))
|
||||
right = int(round_up(max(0., xmax - im_w + 1)))
|
||||
bottom = int(round_up(max(0., ymax - im_h + 1)))
|
||||
|
||||
xmin = int(round_up(xmin + left))
|
||||
xmax = int(round_up(xmax + left))
|
||||
ymin = int(round_up(ymin + top))
|
||||
ymax = int(round_up(ymax + top))
|
||||
r, c, k = img.shape
|
||||
if any([top, bottom, left, right]):
|
||||
te_im = np.zeros((r + top + bottom, c + left + right, k), np.uint8) # 0 is better than 1 initialization
|
||||
te_im[top:top + r, left:left + c, :] = img
|
||||
if top:
|
||||
te_im[0:top, left:left + c, :] = img_mean
|
||||
if bottom:
|
||||
te_im[r + top:, left:left + c, :] = img_mean
|
||||
if left:
|
||||
te_im[:, 0:left, :] = img_mean
|
||||
if right:
|
||||
te_im[:, c + left:, :] = img_mean
|
||||
im_patch_original = te_im[int(ymin):int(ymax + 1), int(xmin):int(xmax + 1), :]
|
||||
else:
|
||||
im_patch_original = img[int(ymin):int(ymax + 1), int(xmin):int(xmax + 1), :]
|
||||
if not np.array_equal(model_sz, original_sz):
|
||||
im_patch = cv2.resize(im_patch_original, (model_sz, model_sz)) # zzp: use cv to get a better speed
|
||||
else:
|
||||
im_patch = im_patch_original
|
||||
scale = model_sz / im_patch_original.shape[0]
|
||||
return im_patch, scale
|
||||
|
||||
|
||||
def get_exemplar_image(img, bbox, size_z, context_amount, img_mean=None):
|
||||
""" preprocessing exemplar
|
||||
|
||||
:param img: exemplar img
|
||||
:param bbox: init bbox
|
||||
:param size_z: changed size
|
||||
:param context_amount: context amount
|
||||
:param img_mean: mean of img
|
||||
:return: img ,scale
|
||||
"""
|
||||
cx, cy, w, h = bbox
|
||||
wc_z = w + context_amount * (w + h)
|
||||
hc_z = h + context_amount * (w + h)
|
||||
s_z = np.sqrt(wc_z * hc_z)
|
||||
scale_z = size_z / s_z
|
||||
exemplar_img, _ = crop_and_pad(img, cx, cy, size_z, s_z, img_mean)
|
||||
return exemplar_img, scale_z, s_z
|
||||
|
||||
|
||||
def get_instance_image(img, bbox, size_z, size_x, context_amount, img_mean=None):
|
||||
""" preprocessing instance
|
||||
|
||||
:param img: instance img
|
||||
:param bbox: init bbox
|
||||
:param size_z: changed size
|
||||
:param context_amount: context amount
|
||||
:param img_mean: mean of img
|
||||
:return: img ,scale
|
||||
"""
|
||||
cx, cy, w, h = bbox # float type
|
||||
wc_z = w + context_amount * (w + h)
|
||||
hc_z = h + context_amount * (w + h)
|
||||
s_z = np.sqrt(wc_z * hc_z) # the width of the crop box
|
||||
|
||||
s_x = s_z * size_x / size_z
|
||||
instance_img, scale_x = crop_and_pad(img, cx, cy, size_x, s_x, img_mean)
|
||||
w_x = w * scale_x
|
||||
h_x = h * scale_x
|
||||
return instance_img, w_x, h_x, scale_x
|
||||
|
||||
|
||||
def box_transform(anchors, gt_box):
|
||||
"""transform box
|
||||
|
||||
:param anchors: object
|
||||
:param gt_box: object
|
||||
:return: object
|
||||
"""
|
||||
anchor_xctr = anchors[:, :1]
|
||||
anchor_yctr = anchors[:, 1:2]
|
||||
anchor_w = anchors[:, 2:3]
|
||||
anchor_h = anchors[:, 3:]
|
||||
gt_cx, gt_cy, gt_w, gt_h = gt_box
|
||||
|
||||
target_x = (gt_cx - anchor_xctr) / anchor_w
|
||||
target_y = (gt_cy - anchor_yctr) / anchor_h
|
||||
target_w = np.log(gt_w / anchor_w)
|
||||
target_h = np.log(gt_h / anchor_h)
|
||||
regression_target = np.hstack((target_x, target_y, target_w, target_h))
|
||||
return regression_target
|
||||
|
||||
|
||||
def box_transform_inv(anchors, offset):
|
||||
"""invert transform box
|
||||
|
||||
:param anchors: object
|
||||
:param offset: object
|
||||
:return: object
|
||||
"""
|
||||
anchor_xctr = anchors[:, :1]
|
||||
anchor_yctr = anchors[:, 1:2]
|
||||
anchor_w = anchors[:, 2:3]
|
||||
anchor_h = anchors[:, 3:]
|
||||
offset_x, offset_y, offset_w, offset_h = offset[:, :1], offset[:, 1:2], offset[:, 2:3], offset[:, 3:]
|
||||
|
||||
box_cx = anchor_w * offset_x + anchor_xctr
|
||||
box_cy = anchor_h * offset_y + anchor_yctr
|
||||
box_w = anchor_w * np.exp(offset_w)
|
||||
box_h = anchor_h * np.exp(offset_h)
|
||||
box = np.hstack([box_cx, box_cy, box_w, box_h])
|
||||
return box
|
||||
|
||||
|
||||
def compute_iou(anchors, box):
|
||||
"""compute iou
|
||||
|
||||
:param anchors: object
|
||||
:param box:object
|
||||
:return: iou value
|
||||
"""
|
||||
if np.array(anchors).ndim == 1:
|
||||
anchors = np.array(anchors)[None, :]
|
||||
else:
|
||||
anchors = np.array(anchors)
|
||||
if np.array(box).ndim == 1:
|
||||
box = np.array(box)[None, :]
|
||||
else:
|
||||
box = np.array(box)
|
||||
gt_box = np.tile(box.reshape(1, -1), (anchors.shape[0], 1))
|
||||
|
||||
anchor_x1 = anchors[:, :1] - anchors[:, 2:3] / 2 + 0.5
|
||||
anchor_x2 = anchors[:, :1] + anchors[:, 2:3] / 2 - 0.5
|
||||
anchor_y1 = anchors[:, 1:2] - anchors[:, 3:] / 2 + 0.5
|
||||
anchor_y2 = anchors[:, 1:2] + anchors[:, 3:] / 2 - 0.5
|
||||
|
||||
gt_x1 = gt_box[:, :1] - gt_box[:, 2:3] / 2 + 0.5
|
||||
gt_x2 = gt_box[:, :1] + gt_box[:, 2:3] / 2 - 0.5
|
||||
gt_y1 = gt_box[:, 1:2] - gt_box[:, 3:] / 2 + 0.5
|
||||
gt_y2 = gt_box[:, 1:2] + gt_box[:, 3:] / 2 - 0.5
|
||||
|
||||
xx1 = np.max([anchor_x1, gt_x1], axis=0)
|
||||
xx2 = np.min([anchor_x2, gt_x2], axis=0)
|
||||
yy1 = np.max([anchor_y1, gt_y1], axis=0)
|
||||
yy2 = np.min([anchor_y2, gt_y2], axis=0)
|
||||
|
||||
inter_area = np.max([xx2 - xx1, np.zeros(xx1.shape)], axis=0) * np.max([yy2 - yy1, np.zeros(xx1.shape)],
|
||||
axis=0)
|
||||
area_anchor = (anchor_x2 - anchor_x1) * (anchor_y2 - anchor_y1)
|
||||
area_gt = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
|
||||
iou = inter_area / (area_anchor + area_gt - inter_area + 1e-6)
|
||||
return iou
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" train """
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import sys
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train.model import Model
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import numpy as np
|
||||
from src.data_loader import TrainDataLoader
|
||||
from src.net import SiameseRPN, BuildTrainNet, MyTrainOneStepCell
|
||||
from src.config import config
|
||||
from src.loss import MultiBoxLoss
|
||||
|
||||
sys.path.append('../')
|
||||
|
||||
parser = argparse.ArgumentParser(description='Mindspore SiameseRPN Training')
|
||||
|
||||
parser.add_argument('--is_parallel', default=False, type=bool, help='whether parallel or not parallel')
|
||||
|
||||
parser.add_argument('--is_cloudtrain', default=False, type=bool, help='whether cloud or not')
|
||||
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
|
||||
parser.add_argument('--unzip_mode', default=0, type=int, metavar='N', help='unzip mode:0:no unzip,1:tar,2:unzip')
|
||||
|
||||
parser.add_argument('--device_id', default=2, type=int, metavar='N', help='number of total epochs to run')
|
||||
|
||||
|
||||
#add random seed
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
ds.config.set_seed(1)
|
||||
|
||||
|
||||
def main(args):
|
||||
""" Model"""
|
||||
net = SiameseRPN(groups=config.batch_size, is_train=True)
|
||||
criterion = MultiBoxLoss(config.batch_size)
|
||||
|
||||
if config.check:
|
||||
checkpoint_path = os.path.join(config.checkpoint_path, config.pretrain_model)
|
||||
print("Load checkpoint Done ")
|
||||
print(config.checkpoint_path)
|
||||
if not checkpoint_path is None:
|
||||
param_dict = load_checkpoint(checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
cur_epoch = config.cur_epoch
|
||||
total_epoch = config.max_epoches - cur_epoch
|
||||
# dataloader
|
||||
data_loader = TrainDataLoader(config.train_path)
|
||||
if args.is_parallel:
|
||||
# get rank_id and rank_size
|
||||
rank_id = get_rank()
|
||||
rank_size = int(os.getenv('RANK_SIZE'))
|
||||
# create dataset
|
||||
dataset = ds.GeneratorDataset(data_loader, ["template", "detection", "label"], shuffle=True,
|
||||
num_parallel_workers=rank_size, num_shards=rank_size, shard_id=rank_id)
|
||||
else:
|
||||
dataset = ds.GeneratorDataset(data_loader, ["template", "detection", "label"], shuffle=True)
|
||||
dataset = dataset.batch(config.batch_size, drop_remainder=True)
|
||||
|
||||
# set training
|
||||
net.set_train()
|
||||
|
||||
|
||||
conv_params = list(filter(lambda layer: 'featureExtract.0.' not in layer.name and 'featureExtract.1.'
|
||||
not in layer.name and 'featureExtract.4.' not in layer.name and 'featureExtract.5.'
|
||||
not in layer.name and 'featureExtract.8.' not in layer.name and 'featureExtract.9.'
|
||||
not in layer.name, net.trainable_params()))
|
||||
|
||||
lr = adjust_learning_rate(config.start_lr, config.end_lr, config.max_epoches, dataset.get_dataset_size())
|
||||
#start fixed epoch,fix layers,select optimizer
|
||||
del lr[0:dataset.get_dataset_size() * cur_epoch]
|
||||
|
||||
optimizer = nn.optim.SGD(learning_rate=lr, params=conv_params, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay)
|
||||
|
||||
train_net = BuildTrainNet(net, criterion)
|
||||
train_network = MyTrainOneStepCell(train_net, optimizer)
|
||||
|
||||
#define Model
|
||||
model = Model(train_network)
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
|
||||
class Print_info(Callback):
|
||||
""" print callback function """
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
self.tlosses = AverageMeter()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
epoch_seconds = (time.time() - self.epoch_time) * 1000
|
||||
print("epoch time: %s, per step time: %s"%(epoch_seconds, epoch_seconds/cb_params.batch_num))
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
self.tlosses.update(loss)
|
||||
print("epoch: %s step: %s, loss is %s, "
|
||||
"avg_loss is %s, step time is %s" % (cb_params.cur_epoch_num, cb_params.cur_step_num, loss,
|
||||
self.tlosses.avg, step_mseconds), flush=True)
|
||||
print_cb = Print_info()
|
||||
#save checkpoint
|
||||
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=dataset.get_dataset_size(), keep_checkpoint_max=51)
|
||||
if args.is_cloudtrain:
|
||||
ckpt_cb = ModelCheckpoint(prefix='siamRPN', directory=config.train_path+'/ckpt', config=ckpt_cfg)
|
||||
else:
|
||||
ckpt_cb = ModelCheckpoint(prefix='siamRPN', directory='./ckpt', config=ckpt_cfg)
|
||||
|
||||
if config.checkpoint_path is not None and os.path.exists(config.checkpoint_path):
|
||||
model.train(total_epoch, dataset, callbacks=[loss_cb, ckpt_cb, print_cb], dataset_sink_mode=False)
|
||||
else:
|
||||
model.train(epoch=total_epoch, train_dataset=dataset, callbacks=[loss_cb, ckpt_cb, print_cb],
|
||||
dataset_sink_mode=False)
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
|
||||
def adjust_learning_rate(start_lr, end_lr, total_epochs, steps_pre_epoch):
|
||||
""" adjust lr """
|
||||
lr = np.logspace(np.log10(start_lr), np.log10(end_lr), num=total_epochs)[0]
|
||||
gamma = np.logspace(np.log10(start_lr), np.log10(end_lr), num=total_epochs)[1] / \
|
||||
np.logspace(np.log10(start_lr), np.log10(end_lr), num=total_epochs)[0]
|
||||
lr_each_step = []
|
||||
for _ in range(steps_pre_epoch):
|
||||
lr_each_step.append(lr)
|
||||
for _ in range(2, total_epochs + 1):
|
||||
lr = lr * gamma
|
||||
for _ in range(steps_pre_epoch):
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
if __name__ == '__main__':
|
||||
Args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if not Args.is_parallel:
|
||||
device_id = Args.device_id
|
||||
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if Args.is_cloudtrain:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID') if os.getenv('DEVICE_ID') is not None else 0)
|
||||
local_data_path = config.cloud_data_path
|
||||
# adapt to cloud: define distributed local data path
|
||||
local_data_path = os.path.join(local_data_path, str(device_id))
|
||||
# adapt to cloud: download data from obs to local location
|
||||
mox.file.copy_parallel(src_url=Args.data_url, dst_url=local_data_path)
|
||||
tar_command1 = "tar -zxf " + local_data_path + "/ytb_vid_filter.tar.gz -C " + local_data_path + '/train/'
|
||||
zip_command1 = "unzip -o -q " + local_data_path + "/train.zip -d " + local_data_path + '/train/'
|
||||
config.checkpoint_path = local_data_path
|
||||
if Args.unzip_mode == 2:
|
||||
os.system(zip_command1)
|
||||
local_data_path = local_data_path + '/train'
|
||||
elif Args.unzip_mode == 1:
|
||||
os.system("mkdir " + local_data_path + '/train')
|
||||
os.system(tar_command1)
|
||||
local_data_path = local_data_path + '/train/ytb_vid_filter'
|
||||
config.train_path = local_data_path
|
||||
elif Args.is_parallel:
|
||||
config.train_path = os.getenv('DATA_PATH')
|
||||
if Args.is_parallel:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
if device_num > 1:
|
||||
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
init()
|
||||
main(Args)
|
||||
if Args.is_cloudtrain:
|
||||
mox.file.copy_parallel(src_url=local_data_path + '/ckpt', dst_url=Args.train_url + '/ckpt')
|
Loading…
Reference in New Issue