deepsort
run on Ascend with 8p deepsort description deepsort run on Ascend with 8p deepsort description de deepsort de deepsort deepsort description deepsort remove deepsort run on Ascend with 8p deepsort description de deepsort de deepsort deepsort description deepsort dele Deepsort stgcn rm
This commit is contained in:
parent
196d65c0bd
commit
87533c180d
|
@ -0,0 +1,270 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [DeepSort描述](#DeepSort描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [导出mindir模型](#导出mindir模型)
|
||||
- [推理过程](#推理过程)
|
||||
- [用法](#用法)
|
||||
- [结果](#结果)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## DeepSort描述
|
||||
|
||||
DeepSort是2017年提出的多目标跟踪算方法。该网络在MOT16获得冠军,不仅提升了精度,而且速度比之前快20倍。
|
||||
|
||||
[论文](https://arxiv.org/abs/1602.00763): Nicolai Wojke, Alex Bewley, Dietrich Paulus. "SIMPLE ONLINE AND REALTIME TRACKING WITH A DEEP ASSOCIATION METRIC". *Presented at ICIP 2016*.
|
||||
|
||||
## 模型架构
|
||||
|
||||
DeepSort由一个特征提取器、一个卡尔曼滤波和一个匈牙利算法组成。特征提取器用于提取框中人物特征信息,卡尔曼滤波根据上一帧信息预测当前帧人物位置,匈牙利算法用于匹配预测信息与检测到的人物位置信息。
|
||||
|
||||
## 数据集
|
||||
|
||||
使用的数据集:[MOT16](<https://motchallenge.net/data/MOT16.zip>)、[Market-1501](<https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view>)
|
||||
|
||||
MOT16:
|
||||
|
||||
- 数据集大小:1.9G,共14个视频帧序列
|
||||
- test:7个视频序列帧
|
||||
- train:7个序列帧
|
||||
- 数据格式(一个train视频帧序列):
|
||||
- det:视频序列中人物坐标以及置信度等信息
|
||||
- gt:视频跟踪标签信息
|
||||
- img1:视频中所有帧序列
|
||||
- 注意:由于作者提供的视频帧序列检测到的坐标信息和置信度信息不一样,所以在跟踪时使用作者提供的信息,作者提供的[npy](https://drive.google.com/drive/folders/18fKzfqnqhqW3s9zwsCbnVJ5XF2JFeqMp)文件。
|
||||
|
||||
Market-1501:
|
||||
|
||||
- 使用:
|
||||
- 使用目的:训练DeepSort特征提取器
|
||||
- 使用方法: 先使用prepare.py处理数据
|
||||
|
||||
## 环境要求
|
||||
|
||||
- 硬件(Ascend/ModelArts)
|
||||
- 准备Ascend或ModelArts处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
## 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```python
|
||||
# 进入脚本目录,提取det信息(使用作者提供的检测框信息),在脚本中给出数据路径
|
||||
python process-npy.py
|
||||
# 进入脚本目录,预处理数据集(Market-1501),在脚本中给出数据集路径
|
||||
python prepare.py
|
||||
# 进入脚本目录,训练DeepSort特征提取器
|
||||
python src/deep/train.py --run_modelarts=False --run_distribute=True --data_url="" --train_url=""
|
||||
# 进入脚本目录,提取detections信息
|
||||
python generater_detection.py --run_modelarts=False --run_distribute=True --data_url="" --train_url="" --det_url="" --ckpt_url="" --model_name=""
|
||||
# 进入脚本目录,生成跟踪信息
|
||||
python evaluate_motchallenge.py --data_url="" --train_url="" --detection_url=""
|
||||
|
||||
#Ascend多卡训练
|
||||
bash scripts/run_distribute_train.sh train_code_path RANK_TABLE_FILE DATA_PATH
|
||||
```
|
||||
|
||||
Ascend训练:生成[RANK_TABLE_FILE](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── DeepSort
|
||||
├── scripts
|
||||
│ ├──run_distribute_train.sh // 在Ascend中多卡训练
|
||||
├── src //源码
|
||||
│ │ ├── application_util
|
||||
│ │ │ ├──image_viewer.py
|
||||
│ │ │ ├──preprocessing.py
|
||||
│ │ │ ├──visualization.py
|
||||
│ │ ├──deep
|
||||
│ │ │ ├──feature_extractor.py //提取目标框中人物特征信息
|
||||
│ │ │ ├──original_model.py //特征提取器模型
|
||||
│ │ │ ├──train.py //训练网络模型
|
||||
│ │ ├──sort
|
||||
│ │ │ ├──detection.py
|
||||
│ │ │ ├──iou_matching.py //预测信息与真实框匹配
|
||||
│ │ │ ├──kalman_filter.py //卡尔曼滤波,预测跟踪框信息
|
||||
│ │ │ ├──linear_assignment.py
|
||||
│ │ │ ├──nn_matching.py //框匹配
|
||||
│ │ │ ├──track.py //跟踪器
|
||||
│ │ │ ├──tracker.py //跟踪器
|
||||
├── deep_sort_app.py //目标跟踪
|
||||
├── evaluate_motchallenge.py //生成跟踪结果信息
|
||||
├── generate_videos.py //根据跟踪结果生成跟踪视频
|
||||
├── generater-detection.py //生成detection信息
|
||||
├── postprocess.py //生成Ascend310推理数据
|
||||
├── preprocess.py //处理Ascend310推理结果,生成精度
|
||||
├── prepare.py //处理训练数据集
|
||||
├── process-npy.py //提取帧序列人物坐标和置信度
|
||||
├── show_results.py //展示跟踪结果
|
||||
├── README.md // DeepSort相关说明
|
||||
```
|
||||
|
||||
### 脚本参数
|
||||
|
||||
```python
|
||||
train.py generater_detection.py evaluate_motchallenge.py 中主要参数如下:
|
||||
|
||||
--data_url: 到训练和提取信息数据集的绝对完整路径
|
||||
--train_url: 输出文件路径。
|
||||
--epoch: 总训练轮次
|
||||
--batch_size: 训练批次大小
|
||||
--device_targe: 实现代码的设备。值为'Ascend'
|
||||
--ckpt_url: 训练后保存的检查点文件的绝对完整路径
|
||||
--model_name: 模型文件名称
|
||||
--det_url: 视频帧序列人物信息文件路径
|
||||
--detection_url: 人物坐标信息、置信度以及特征信息文件路径
|
||||
--run_distribute: 多卡运行
|
||||
--run_modelarts: ModelArts上运行
|
||||
```
|
||||
|
||||
### 训练过程
|
||||
|
||||
#### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
python src/deep/train.py --run_modelarts=False --run_distribute=False --data_url="" --train_url=""
|
||||
# 或进入脚本目录,执行脚本
|
||||
bash scripts/run_distribute_train.sh train_code_path RANK_TABLE_FILE DATA_PATH
|
||||
```
|
||||
|
||||
经过训练后,损失值如下:
|
||||
|
||||
```bash
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 3984, loss is 6.4320717
|
||||
epoch: 1 step: 3984, loss is 6.414733
|
||||
epoch: 1 step: 3984, loss is 6.4306755
|
||||
epoch: 1 step: 3984, loss is 6.4387856
|
||||
epoch: 1 step: 3984, loss is 6.463995
|
||||
...
|
||||
epoch: 2 step: 3984, loss is 6.436552
|
||||
epoch: 2 step: 3984, loss is 6.408932
|
||||
epoch: 2 step: 3984, loss is 6.4517527
|
||||
epoch: 2 step: 3984, loss is 6.448922
|
||||
epoch: 2 step: 3984, loss is 6.4611588
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在当前目录下。
|
||||
|
||||
### 评估过程
|
||||
|
||||
#### 评估
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# 进入脚本目录,提取det信息(使用作者提供的检测框信息)
|
||||
python process-npy.py
|
||||
# 进入脚本目录,提取detections信息
|
||||
python generater_detection.py --run_modelarts False --run_distribute True --data_url "" --train_url "" --det_url "" --ckpt_url "" --model_name ""
|
||||
# 进入脚本目录,生成跟踪信息
|
||||
python evaluate_motchallenge.py --data_url="" --train_url="" --detection_url=""
|
||||
# 生成跟踪结果
|
||||
python eval_motchallenge.py ----run_modelarts=False --data_url="" --train_url="" --result_url=""
|
||||
```
|
||||
|
||||
- [测评工具](https://github.com/cheind/py-motmetrics)
|
||||
|
||||
说明:脚本中引用头文件可能存在一些问题,自行修改头文件路径即可
|
||||
|
||||
```bash
|
||||
#测量精度
|
||||
python motmetrics/apps/eval_motchallenge.py --groundtruths="" --tests=""
|
||||
```
|
||||
|
||||
-
|
||||
测试数据集的准确率如下:
|
||||
|
||||
| 数据 | MOTA | MOTP| MT | ML| IDs | FM | FP | FN |
|
||||
| -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -----------------------------------------------------------
|
||||
| MOT16-02 | 29.0% | 0.207 | 11 | 11| 159 | 226 | 4151 | 8346 |
|
||||
| MOT16-04 | 58.6% | 0.167| 42 | 14| 62 | 242 | 6269 | 13374 |
|
||||
| MOT16-05 | 51.7% | 0.213| 31 | 27| 68 | 109 | 630 | 2595 |
|
||||
| MOT16-09 | 64.3% | 0.162| 12 | 1| 39 | 58 | 309 | 1537 |
|
||||
| MOT16-10 | 49.2% | 0.228| 25 | 1| 201 | 307 | 3089 | 2915 |
|
||||
| MOT16-11 | 65.9% | 0.152| 29 | 9| 54 | 99 | 907 | 2162 |
|
||||
| MOT16-13 | 45.0% | 0.237| 61 | 7| 269 | 335 | 3709 | 2251 |
|
||||
| overall | 51.9% | 0.189| 211 | 70| 852 | 1376 | 19094 | 33190 |
|
||||
|
||||
## [导出mindir模型](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --device_id [DEVICE_ID] --ckpt_file [CKPT_PATH]
|
||||
```
|
||||
|
||||
## [推理过程](#contents)
|
||||
|
||||
### 用法
|
||||
|
||||
执行推断之前,minirir文件必须由export.py导出。输入文件必须为bin格式
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
推理结果文件保存在当前路径中,将文件作为输入,输入到eval_motchallenge.py中,然后输出result文件,输入到测评工具中即可得到精度结果。
|
||||
|
||||
## 模型描述
|
||||
|
||||
### 性能
|
||||
|
||||
#### 评估性能
|
||||
|
||||
| 参数 | ModelArts
|
||||
| -------------------------- | -----------------------------------------------------------
|
||||
| 资源 | Ascend 910;CPU 2.60GHz, 192核;内存:755G
|
||||
| 上传日期 | 2021-08-12
|
||||
| MindSpore版本 | 1.2.0
|
||||
| 数据集 | MOT16 Market-1501
|
||||
| 训练参数 | epoch=100, step=191, batch_size=8, lr=0.1
|
||||
| 优化器 | SGD
|
||||
| 损失函数 | SoftmaxCrossEntropyWithLogits
|
||||
| 损失 | 0.03
|
||||
| 速度 | 9.804毫秒/步
|
||||
| 总时间 | 10分钟
|
||||
| 微调检查点 | 大约40M (.ckpt文件)
|
||||
| 脚本 | [DeepSort脚本]
|
||||
|
||||
## 随机情况说明
|
||||
|
||||
train.py中设置了随机种子。
|
||||
|
||||
## ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,14 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
#endif
|
|
@ -0,0 +1,134 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "/inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(input0_path, ".", "input0 path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||
|
||||
Model model;
|
||||
Status ret = model.Build(GraphCell(graph), context);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> model_inputs = model.GetInputs();
|
||||
if (model_inputs.empty()) {
|
||||
std::cout << "Invalid model, inputs is empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input0_files = GetAllFiles(FLAGS_input0_path);
|
||||
|
||||
if (input0_files.empty()) {
|
||||
std::cout << "ERROR: input data empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = input0_files.size();
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
|
||||
|
||||
auto input0 = ReadFileToTensor(input0_files[i]);
|
||||
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
input0.Data().get(), input0.DataSize());
|
||||
|
||||
for (auto shape : model_inputs[0].Shape()) {
|
||||
std::cout << "model input shape" << shape << std::endl;
|
||||
}
|
||||
gettimeofday(&start, nullptr);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
WriteResult(input0_files[i], outputs);
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
std::cout << "output size:" << outputSize << std::endl;
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -0,0 +1,275 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.application_util import preprocessing
|
||||
from src.application_util import visualization
|
||||
from src.sort import nn_matching
|
||||
from src.sort.detection import Detection
|
||||
from src.sort.tracker import Tracker
|
||||
|
||||
def gather_sequence_info(sequence_dir, detection_file):
|
||||
"""Gather sequence information, such as image filenames, detections,
|
||||
groundtruth (if available).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
detection_file : str
|
||||
Path to the detection file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict
|
||||
A dictionary of the following sequence information:
|
||||
|
||||
* sequence_name: Name of the sequence
|
||||
* image_filenames: A dictionary that maps frame indices to image
|
||||
filenames.
|
||||
* detections: A numpy array of detections in MOTChallenge format.
|
||||
* groundtruth: A numpy array of ground truth in MOTChallenge format.
|
||||
* image_size: Image size (height, width).
|
||||
* min_frame_idx: Index of the first frame.
|
||||
* max_frame_idx: Index of the last frame.
|
||||
|
||||
"""
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
groundtruth_file = os.path.join(sequence_dir, "gt/gt.txt")
|
||||
|
||||
detections = None
|
||||
if detection_file is not None:
|
||||
detections = np.load(detection_file)
|
||||
groundtruth = None
|
||||
if os.path.exists(groundtruth_file):
|
||||
groundtruth = np.loadtxt(groundtruth_file, delimiter=',')
|
||||
|
||||
if image_filenames:
|
||||
image = cv2.imread(next(iter(image_filenames.values())),
|
||||
cv2.IMREAD_GRAYSCALE)
|
||||
image_size = image.shape
|
||||
else:
|
||||
image_size = None
|
||||
|
||||
if image_filenames:
|
||||
min_frame_idx = min(image_filenames.keys())
|
||||
max_frame_idx = max(image_filenames.keys())
|
||||
else:
|
||||
min_frame_idx = int(detections[:, 0].min())
|
||||
max_frame_idx = int(detections[:, 0].max())
|
||||
|
||||
info_filename = os.path.join(sequence_dir, "seqinfo.ini")
|
||||
if os.path.exists(info_filename):
|
||||
with open(info_filename, "r") as f:
|
||||
line_splits = [l.split('=') for l in f.read().splitlines()[1:]]
|
||||
info_dict = dict(
|
||||
s for s in line_splits if isinstance(s, list) and len(s) == 2)
|
||||
|
||||
update_ms = 1000 / int(info_dict["frameRate"])
|
||||
else:
|
||||
update_ms = None
|
||||
|
||||
feature_dim = detections.shape[1] - 10 if detections is not None else 0
|
||||
seq_info = {
|
||||
"sequence_name": os.path.basename(sequence_dir),
|
||||
"image_filenames": image_filenames,
|
||||
"detections": detections,
|
||||
"groundtruth": groundtruth,
|
||||
"image_size": image_size,
|
||||
"min_frame_idx": min_frame_idx,
|
||||
"max_frame_idx": max_frame_idx,
|
||||
"feature_dim": feature_dim,
|
||||
"update_ms": update_ms
|
||||
}
|
||||
return seq_info
|
||||
|
||||
|
||||
def create_detections(detection_mat, frame_idx, min_height=0):
|
||||
"""Create detections for given frame index from the raw detection matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_mat : ndarray
|
||||
Matrix of detections. The first 10 columns of the detection matrix are
|
||||
in the standard MOTChallenge detection format. In the remaining columns
|
||||
store the feature vector associated with each detection.
|
||||
frame_idx : int
|
||||
The frame index.
|
||||
min_height : Optional[int]
|
||||
A minimum detection bounding box height. Detections that are smaller
|
||||
than this value are disregarded.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[tracker.Detection]
|
||||
Returns detection responses at given frame index.
|
||||
|
||||
"""
|
||||
frame_indices = detection_mat[:, 0].astype(np.int)
|
||||
mask = frame_indices == frame_idx
|
||||
|
||||
detection_list = []
|
||||
for row in detection_mat[mask]:
|
||||
bbox, confidence, feature = row[2:6], row[6], row[10:]
|
||||
|
||||
if bbox[3] < min_height:
|
||||
continue
|
||||
detection_list.append(Detection(bbox, confidence, feature))
|
||||
return detection_list
|
||||
|
||||
|
||||
def run(sequence_dir, detection_file, output_file, min_confidence,
|
||||
nms_max_overlap, min_detection_height, max_cosine_distance,
|
||||
nn_budget, display):
|
||||
"""Run multi-target tracker on a particular sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
detection_file : str
|
||||
Path to the detections file.
|
||||
output_file : str
|
||||
Path to the tracking output file. This file will contain the tracking
|
||||
results on completion.
|
||||
min_confidence : float
|
||||
Detection confidence threshold. Disregard all detections that have
|
||||
a confidence lower than this value.
|
||||
nms_max_overlap: float
|
||||
Maximum detection overlap (non-maxima suppression threshold).
|
||||
min_detection_height : int
|
||||
Detection height threshold. Disregard all detections that have
|
||||
a height lower than this value.
|
||||
max_cosine_distance : float
|
||||
Gating threshold for cosine distance metric (object appearance).
|
||||
nn_budget : Optional[int]
|
||||
Maximum size of the appearance descriptor gallery. If None, no budget
|
||||
is enforced.
|
||||
display : bool
|
||||
If True, show visualization of intermediate tracking results.
|
||||
|
||||
"""
|
||||
seq_info = gather_sequence_info(sequence_dir, detection_file)
|
||||
metric = nn_matching.NearestNeighborDistanceMetric(
|
||||
"cosine", max_cosine_distance, nn_budget)
|
||||
tracker = Tracker(metric)
|
||||
results = []
|
||||
|
||||
def frame_callback(vis, frame_idx):
|
||||
print("Processing frame %05d" % frame_idx)
|
||||
|
||||
# Load image and generate detections.
|
||||
detections = create_detections(
|
||||
seq_info["detections"], frame_idx, min_detection_height)
|
||||
detections = [d for d in detections if d.confidence >= min_confidence]
|
||||
|
||||
# Run non-maxima suppression.
|
||||
boxes = np.array([d.tlwh for d in detections])
|
||||
|
||||
scores = np.array([d.confidence for d in detections])
|
||||
indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
|
||||
detections = [detections[i] for i in indices]
|
||||
|
||||
# Update tracker.
|
||||
tracker.predict()
|
||||
tracker.update(detections)
|
||||
|
||||
# Update visualization.
|
||||
if display:
|
||||
image = cv2.imread(
|
||||
seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR)
|
||||
vis.set_image(image.copy())
|
||||
vis.draw_detections(detections)
|
||||
vis.draw_trackers(tracker.tracks)
|
||||
|
||||
# Store results.
|
||||
for track in tracker.tracks:
|
||||
if not track.is_confirmed() or track.time_since_update > 1:
|
||||
continue
|
||||
bbox = track.to_tlwh()
|
||||
results.append([
|
||||
frame_idx, track.track_id, bbox[0], bbox[1], bbox[2], bbox[3]])
|
||||
|
||||
# Run tracker.
|
||||
if display:
|
||||
visualizer = visualization.Visualization(seq_info, update_ms=5)
|
||||
else:
|
||||
visualizer = visualization.NoVisualization(seq_info)
|
||||
visualizer.run(frame_callback)
|
||||
|
||||
# Store results.
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for row in results:
|
||||
print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1' % (
|
||||
row[0], row[1], row[2], row[3], row[4], row[5]), file=f)
|
||||
|
||||
|
||||
def bool_string(input_string):
|
||||
if input_string not in {"True", "False"}:
|
||||
raise ValueError("Please Enter a valid Ture/False choice")
|
||||
return input_string == "True"
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Deep SORT")
|
||||
parser.add_argument(
|
||||
"--sequence_dir", help="Path to MOTChallenge sequence directory",
|
||||
default="../MOT16/train/MOT16-02")
|
||||
parser.add_argument(
|
||||
"--detection_file", help="Path to custom detections.", default="./detections/MOT16_POI_train/MOT16-02.npy")
|
||||
parser.add_argument(
|
||||
"--output_file", help="Path to the tracking output file. This file will"
|
||||
" contain the tracking results on completion.",
|
||||
default="./tmp/hypotheses-det.txt")
|
||||
parser.add_argument(
|
||||
"--min_confidence", help="Detection confidence threshold. Disregard "
|
||||
"all detections that have a confidence lower than this value.",
|
||||
default=0.8, type=float)
|
||||
parser.add_argument(
|
||||
"--min_detection_height", help="Threshold on the detection bounding "
|
||||
"box height. Detections with height smaller than this value are "
|
||||
"disregarded", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--nms_max_overlap", help="Non-maxima suppression threshold: Maximum "
|
||||
"detection overlap.", default=1.0, type=float)
|
||||
parser.add_argument(
|
||||
"--max_cosine_distance", help="Gating threshold for cosine distance "
|
||||
"metric (object appearance).", type=float, default=0.2)
|
||||
parser.add_argument(
|
||||
"--nn_budget", help="Maximum size of the appearance descriptors "
|
||||
"gallery. If None, no budget is enforced.", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--display", help="Show intermediate tracking results",
|
||||
default=False, type=bool_string)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
run(
|
||||
args.sequence_dir, args.detection_file, args.output_file,
|
||||
args.min_confidence, args.nms_max_overlap, args.min_detection_height,
|
||||
args.max_cosine_distance, args.nn_budget, args.display)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import deep_sort_app
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="MOTChallenge evaluation")
|
||||
parser.add_argument(
|
||||
"--detection_url", type=str, help="Path to detection files.")
|
||||
parser.add_argument(
|
||||
"--data_url", type=str, help="Path to image data.")
|
||||
parser.add_argument(
|
||||
"--train_url", type=str, help="Path to save result.")
|
||||
parser.add_argument(
|
||||
"--min_confidence", help="Detection confidence threshold. Disregard "
|
||||
"all detections that have a confidence lower than this value.",
|
||||
default=0.0, type=float)
|
||||
parser.add_argument(
|
||||
"--min_detection_height", help="Threshold on the detection bounding "
|
||||
"box height. Detections with height smaller than this value are "
|
||||
"disregarded", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--nms_max_overlap", help="Non-maxima suppression threshold: Maximum "
|
||||
"detection overlap.", default=1.0, type=float)
|
||||
parser.add_argument(
|
||||
"--max_cosine_distance", help="Gating threshold for cosine distance "
|
||||
"metric (object appearance).", type=float, default=0.2)
|
||||
parser.add_argument(
|
||||
"--nn_budget", help="Maximum size of the appearance descriptors "
|
||||
"gallery. If None, no budget is enforced.", type=int, default=100)
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
detection_dir = args.detection_url
|
||||
DATA_DIR = args.data_url + '/'
|
||||
local_result_url = args.train_url
|
||||
|
||||
if not os.path.exists(local_result_url):
|
||||
os.makedirs(local_result_url)
|
||||
sequences = os.listdir(DATA_DIR)
|
||||
for sequence in sequences:
|
||||
print("Running sequence %s" % sequence)
|
||||
sequence_dir = os.path.join(DATA_DIR, sequence)
|
||||
detection_file = os.path.join(detection_dir, "%s.npy" % sequence)
|
||||
output_file = os.path.join(local_result_url, "%s.txt" % sequence)
|
||||
deep_sort_app.run(
|
||||
sequence_dir, detection_file, output_file, args.min_confidence,
|
||||
args.nms_max_overlap, args.min_detection_height,
|
||||
args.max_cosine_distance, args.nn_budget, display=False)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.deep.original_model import Net
|
||||
|
||||
parser = argparse.ArgumentParser(description='Tracking')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--image_height", type=int, default=128, help="Image height.")
|
||||
parser.add_argument("--image_width", type=int, default=64, help="Image width.")
|
||||
parser.add_argument("--file_name", type=str, default="deepsort", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = Net(reid=True, ascend=True)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args_opt.batch_size, 3, args_opt.image_height, args_opt.image_width]), ms.float32)
|
||||
export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import show_results
|
||||
|
||||
def convert(filename_input, filename_output, ffmpeg_executable="ffmpeg"):
|
||||
import subprocess
|
||||
command = [ffmpeg_executable, "-i", filename_input, "-c:v", "libx264",
|
||||
"-preset", "slow", "-crf", "21", filename_output]
|
||||
subprocess.call(command)
|
||||
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Siamese Tracking")
|
||||
parser.add_argument('--data_url', type=str, default='', help='Det directory.')
|
||||
parser.add_argument('--train_url', type=str, help='Folder to store the videos in')
|
||||
parser.add_argument(
|
||||
"--result_dir", help="Path to the folder with tracking output.", default="")
|
||||
parser.add_argument(
|
||||
"--convert_h264", help="If true, convert videos to libx264 (requires "
|
||||
"FFMPEG", default=False)
|
||||
parser.add_argument(
|
||||
"--update_ms", help="Time between consecutive frames in milliseconds. "
|
||||
"Defaults to the frame_rate specified in seqinfo.ini, if available.",
|
||||
default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
data_dir = args.data_url
|
||||
local_train_url = args.train_url
|
||||
result_dir = args.result_dir
|
||||
|
||||
|
||||
os.makedirs(local_train_url, exist_ok=True)
|
||||
for sequence_txt in os.listdir(result_dir):
|
||||
sequence = os.path.splitext(sequence_txt)[0]
|
||||
sequence_dir = os.path.join(data_dir, sequence)
|
||||
if not os.path.exists(sequence_dir):
|
||||
continue
|
||||
result_file = os.path.join(result_dir, sequence_txt)
|
||||
update_ms = args.update_ms
|
||||
video_filename = os.path.join(local_train_url, "%s.avi" % sequence)
|
||||
|
||||
print("Saving %s to %s." % (sequence_txt, video_filename))
|
||||
show_results.run(
|
||||
sequence_dir, result_file, False, None, update_ms, video_filename)
|
||||
|
||||
if not args.convert_h264:
|
||||
import sys
|
||||
sys.exit()
|
||||
for sequence_txt in os.listdir(result_dir):
|
||||
sequence = os.path.splitext(sequence_txt)[0]
|
||||
sequence_dir = os.path.join(data_dir, sequence)
|
||||
if not os.path.exists(sequence_dir):
|
||||
continue
|
||||
filename_in = os.path.join(local_train_url, "%s.avi" % sequence)
|
||||
filename_out = os.path.join(local_train_url, "%s.mp4" % sequence)
|
||||
convert(filename_in, filename_out)
|
|
@ -0,0 +1,259 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import errno
|
||||
import argparse
|
||||
import ast
|
||||
import matplotlib
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from mindspore.train.model import ParallelMode
|
||||
from mindspore.communication.management import init
|
||||
from mindspore import context
|
||||
from src.deep.feature_extractor import Extractor
|
||||
|
||||
matplotlib.use("Agg")
|
||||
ASCEND_SLOG_PRINT_TO_STDOUT = 1
|
||||
|
||||
|
||||
def extract_image_patch(image, bbox, patch_shape=None):
|
||||
"""Extract image patch from bounding box.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : ndarray
|
||||
The full image.
|
||||
bbox : array_like
|
||||
The bounding box in format (x, y, width, height).
|
||||
patch_shape : Optional[array_like]
|
||||
This parameter can be used to enforce a desired patch shape
|
||||
(height, width). First, the `bbox` is adapted to the aspect ratio
|
||||
of the patch shape, then it is clipped at the image boundaries.
|
||||
If None, the shape is computed from :arg:`bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray | NoneType
|
||||
An image patch showing the :arg:`bbox`, optionally reshaped to
|
||||
:arg:`patch_shape`.
|
||||
Returns None if the bounding box is empty or fully outside of the image
|
||||
boundaries.
|
||||
|
||||
"""
|
||||
bbox = np.array(bbox)
|
||||
if patch_shape is not None:
|
||||
# correct aspect ratio to patch shape
|
||||
target_aspect = float(patch_shape[1]) / patch_shape[0]
|
||||
new_width = target_aspect * bbox[3]
|
||||
bbox[0] -= (new_width - bbox[2]) / 2
|
||||
bbox[2] = new_width
|
||||
|
||||
# convert to top left, bottom right
|
||||
bbox[2:] += bbox[:2]
|
||||
bbox = bbox.astype(np.int)
|
||||
|
||||
# clip at image boundaries
|
||||
bbox[:2] = np.maximum(0, bbox[:2])
|
||||
bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:])
|
||||
if np.any(bbox[:2] >= bbox[2:]):
|
||||
return None
|
||||
sx, sy, ex, ey = bbox
|
||||
image = image[sy:ey, sx:ex]
|
||||
return image
|
||||
|
||||
|
||||
class ImageEncoder:
|
||||
|
||||
def __init__(self, model_path, batch_size=32):
|
||||
|
||||
self.extractor = Extractor(model_path, batch_size)
|
||||
|
||||
def _get_features(self, bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
self.height, self.width = ori_img.shape[:2]
|
||||
for box in bbox_xywh:
|
||||
im = extract_image_patch(ori_img, box)
|
||||
if im is None:
|
||||
print("WARNING: Failed to extract image patch: %s." % str(box))
|
||||
im = np.random.uniform(
|
||||
0., 255., ori_img.shape).astype(np.uint8)
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = self.extractor(im_crops)
|
||||
else:
|
||||
features = np.array([])
|
||||
return features
|
||||
|
||||
|
||||
def __call__(self, image, boxes, batch_size=32):
|
||||
features = self._get_features(boxes, image)
|
||||
return features
|
||||
|
||||
|
||||
def create_box_encoder(model_filename, batch_size=32):
|
||||
image_encoder = ImageEncoder(model_filename, batch_size)
|
||||
|
||||
def encoder_box(image, boxes):
|
||||
return image_encoder(image, boxes)
|
||||
|
||||
return encoder_box
|
||||
|
||||
|
||||
def generate_detections(encoder_boxes, mot_dir, output_dir, det_path=None, detection_dir=None):
|
||||
"""Generate detections with features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoder : Callable[image, ndarray] -> ndarray
|
||||
The encoder function takes as input a BGR color image and a matrix of
|
||||
bounding boxes in format `(x, y, w, h)` and returns a matrix of
|
||||
corresponding feature vectors.
|
||||
mot_dir : str
|
||||
Path to the MOTChallenge directory (can be either train or test).
|
||||
output_dir
|
||||
Path to the output directory. Will be created if it does not exist.
|
||||
detection_dir
|
||||
Path to custom detections. The directory structure should be the default
|
||||
MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the
|
||||
standard MOTChallenge detections.
|
||||
|
||||
"""
|
||||
if detection_dir is None:
|
||||
detection_dir = mot_dir
|
||||
try:
|
||||
os.makedirs(output_dir)
|
||||
except OSError as exception:
|
||||
if exception.errno == errno.EEXIST and os.path.isdir(output_dir):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Failed to created output directory '%s'" % output_dir)
|
||||
for sequence in os.listdir(mot_dir):
|
||||
print("Processing %s" % sequence)
|
||||
sequence_dir = os.path.join(mot_dir, sequence)
|
||||
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
#image_dir = os.path.join(mot_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
if det_path:
|
||||
detection_dir = os.path.join(det_path, sequence)
|
||||
else:
|
||||
detection_dir = os.path.join(sequence_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
detections_out = []
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
print("Frame %05d/%05d" % (frame_idx, max_frame_idx))
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
if frame_idx not in image_filenames:
|
||||
print("WARNING could not find image for frame %d" % frame_idx)
|
||||
continue
|
||||
bgr_image = cv2.imread(
|
||||
image_filenames[frame_idx], cv2.IMREAD_COLOR)
|
||||
features = encoder_boxes(bgr_image, rows[:, 2:6].copy())
|
||||
detections_out += [np.r_[(row, feature)]
|
||||
for row, feature in zip(rows, features)]
|
||||
|
||||
output_filename = os.path.join(output_dir, "%s.npy" % sequence)
|
||||
print(output_filename)
|
||||
np.save(
|
||||
output_filename, np.asarray(detections_out), allow_pickle=False)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Re-ID feature extractor")
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval,
|
||||
default=False, help='Run distribute')
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval,
|
||||
default=False, help='Run distribute')
|
||||
parser.add_argument("--device_id", type=int, default=4,
|
||||
help="Use which device.")
|
||||
parser.add_argument('--data_url', type=str,
|
||||
default='', help='Det directory.')
|
||||
parser.add_argument('--train_url', type=str, default='',
|
||||
help='Train output directory.')
|
||||
parser.add_argument('--det_url', type=str, default='',
|
||||
help='Train output directory.')
|
||||
parser.add_argument('--batch_size', type=int,
|
||||
default=32, help='Batach size.')
|
||||
parser.add_argument("--ckpt_url", type=str, default='',
|
||||
help="Path to checkpoint.")
|
||||
parser.add_argument("--model_name", type=str,
|
||||
default="deepsort-30000_24.ckpt", help="Name of checkpoint.")
|
||||
parser.add_argument(
|
||||
"--detection_dir", help="Path to custom detections. Defaults to", default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
args = parse_args()
|
||||
if args.run_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data'
|
||||
local_ckpt_url = '/cache/ckpt'
|
||||
local_train_url = '/cache/train'
|
||||
local_det_url = '/cache/det'
|
||||
mox.file.copy_parallel(args.ckpt_url, local_ckpt_url)
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
mox.file.copy_parallel(args.det_url, local_det_url)
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
DATA_DIR = local_data_url + '/'
|
||||
ckpt_dir = local_ckpt_url + '/'
|
||||
det_dir = local_det_url + '/'
|
||||
else:
|
||||
if args.run_distribute:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
device_num = 1
|
||||
device_id = args.device_id
|
||||
DATA_DIR = args.data_url
|
||||
local_train_url = args.train_url
|
||||
ckpt_dir = args.ckpt_url
|
||||
det_dir = args.det_url
|
||||
|
||||
encoder = create_box_encoder(
|
||||
ckpt_dir+args.model_name, batch_size=args.batch_size)
|
||||
generate_detections(encoder, DATA_DIR, local_train_url, det_path=det_dir)
|
||||
if args.run_modelarts:
|
||||
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore deepsort infer')
|
||||
# Path for data
|
||||
parser.add_argument('--det_dir', type=str, default='', help='det directory.')
|
||||
parser.add_argument('--result_dir', type=str, default="./result_Files", help='infer result dir.')
|
||||
parser.add_argument('--output_dir', type=str, default="./", help='output dir.')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
rst_path = args.result_dir
|
||||
start = end = 0
|
||||
|
||||
for sequence in os.listdir(args.det_dir):
|
||||
#sequence_dir = os.path.join(mot_dir, sequence)
|
||||
start = end
|
||||
detection_dir = os.path.join(args.det_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
detections_out = []
|
||||
raws = []
|
||||
features = []
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
for box in rows:
|
||||
raws.append(box)
|
||||
end += 1
|
||||
|
||||
raws = np.array(raws)
|
||||
for i in range(start, end):
|
||||
file_name = os.path.join(rst_path, "DeepSort_data_bs" + str(1) + '_' + str(i) + '_0.bin')
|
||||
output = np.fromfile(file_name, np.float32)
|
||||
features.append(output)
|
||||
features = np.array(features)
|
||||
detections_out += [np.r_[(row, feature)] for row, feature in zip(raws, features)]
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
output_filename = os.path.join(args.output_dir, "%s.npy" % sequence)
|
||||
print(output_filename)
|
||||
np.save(output_filename, np.asarray(detections_out), allow_pickle=False)
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
from shutil import copyfile
|
||||
|
||||
# You only need to change this line to your dataset download path
|
||||
download_path = '../data/Market-1501'
|
||||
|
||||
if not os.path.isdir(download_path):
|
||||
print('please change the download_path')
|
||||
|
||||
save_path = download_path + '/pytorch'
|
||||
if not os.path.isdir(save_path):
|
||||
os.mkdir(save_path)
|
||||
#-----------------------------------------
|
||||
#query
|
||||
query_path = download_path + '/query'
|
||||
query_save_path = download_path + '/pytorch/query'
|
||||
if not os.path.isdir(query_save_path):
|
||||
os.mkdir(query_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(query_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = query_path + '/' + name
|
||||
dst_path = query_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#-----------------------------------------
|
||||
#multi-query
|
||||
query_path = download_path + '/gt_bbox'
|
||||
# for dukemtmc-reid, we do not need multi-query
|
||||
if os.path.isdir(query_path):
|
||||
query_save_path = download_path + '/pytorch/multi-query'
|
||||
if not os.path.isdir(query_save_path):
|
||||
os.mkdir(query_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(query_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = query_path + '/' + name
|
||||
dst_path = query_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#-----------------------------------------
|
||||
#gallery
|
||||
gallery_path = download_path + '/bounding_box_test'
|
||||
gallery_save_path = download_path + '/pytorch/gallery'
|
||||
if not os.path.isdir(gallery_save_path):
|
||||
os.mkdir(gallery_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(gallery_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = gallery_path + '/' + name
|
||||
dst_path = gallery_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#---------------------------------------
|
||||
#train_all
|
||||
train_path = download_path + '/bounding_box_train'
|
||||
train_save_path = download_path + '/pytorch/train_all'
|
||||
if not os.path.isdir(train_save_path):
|
||||
os.mkdir(train_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(train_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = train_path + '/' + name
|
||||
dst_path = train_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
|
||||
#---------------------------------------
|
||||
#train_val
|
||||
train_path = download_path + '/bounding_box_train'
|
||||
train_save_path = download_path + '/pytorch/train'
|
||||
val_save_path = download_path + '/pytorch/val'
|
||||
if not os.path.isdir(train_save_path):
|
||||
os.mkdir(train_save_path)
|
||||
os.mkdir(val_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(train_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = train_path + '/' + name
|
||||
dst_path = train_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
dst_path = val_save_path + '/' + ID[0] #first image is used as val image
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
|
@ -0,0 +1,192 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
matplotlib.use("Agg")
|
||||
ASCEND_SLOG_PRINT_TO_STDOUT = 1
|
||||
|
||||
def extract_image_patch(image, bbox, patch_shape=None):
|
||||
"""Extract image patch from bounding box.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : ndarray
|
||||
The full image.
|
||||
bbox : array_like
|
||||
The bounding box in format (x, y, width, height).
|
||||
patch_shape : Optional[array_like]
|
||||
This parameter can be used to enforce a desired patch shape
|
||||
(height, width). First, the `bbox` is adapted to the aspect ratio
|
||||
of the patch shape, then it is clipped at the image boundaries.
|
||||
If None, the shape is computed from :arg:`bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray | NoneType
|
||||
An image patch showing the :arg:`bbox`, optionally reshaped to
|
||||
:arg:`patch_shape`.
|
||||
Returns None if the bounding box is empty or fully outside of the image
|
||||
boundaries.
|
||||
|
||||
"""
|
||||
bbox = np.array(bbox)
|
||||
if patch_shape is not None:
|
||||
# correct aspect ratio to patch shape
|
||||
target_aspect = float(patch_shape[1]) / patch_shape[0]
|
||||
new_width = target_aspect * bbox[3]
|
||||
bbox[0] -= (new_width - bbox[2]) / 2
|
||||
bbox[2] = new_width
|
||||
|
||||
# convert to top left, bottom right
|
||||
bbox[2:] += bbox[:2]
|
||||
bbox = bbox.astype(np.int)
|
||||
|
||||
# clip at image boundaries
|
||||
bbox[:2] = np.maximum(0, bbox[:2])
|
||||
bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:])
|
||||
if np.any(bbox[:2] >= bbox[2:]):
|
||||
return None
|
||||
sx, sy, ex, ey = bbox
|
||||
|
||||
image = image[sy:ey, sx:ex]
|
||||
return image
|
||||
|
||||
def statistic_normalize_img(img, statistic_norm=True):
|
||||
"""Statistic normalize images."""
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
if statistic_norm:
|
||||
img = (img - mean) / std
|
||||
img = img.astype(np.float32)
|
||||
return img
|
||||
|
||||
def preprocess(im_crops):
|
||||
"""
|
||||
TODO:
|
||||
1. to float with scale from 0 to 1
|
||||
2. resize to (64, 128) as Market1501 dataset did
|
||||
3. concatenate to a numpy array
|
||||
3. to torch Tensor
|
||||
4. normalize
|
||||
"""
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32)/255., size)
|
||||
im_batch = []
|
||||
size = (64, 128)
|
||||
for im in im_crops:
|
||||
im = _resize(im, size)
|
||||
im = statistic_normalize_img(im)
|
||||
im = im.transpose(2, 0, 1).copy()
|
||||
im = np.expand_dims(im, 0)
|
||||
im_batch.append(im)
|
||||
|
||||
im_batch = np.array(im_batch)
|
||||
return im_batch
|
||||
|
||||
|
||||
def get_features(bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
for box in bbox_xywh:
|
||||
im = extract_image_patch(ori_img, box)
|
||||
if im is None:
|
||||
print("WARNING: Failed to extract image patch: %s." % str(box))
|
||||
im = np.random.uniform(
|
||||
0., 255., ori_img.shape).astype(np.uint8)
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = preprocess(im_crops)
|
||||
else:
|
||||
features = np.array([])
|
||||
return features
|
||||
|
||||
def generate_detections(mot_dir, img_path, det_path=None):
|
||||
"""Generate detections with features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoder : Callable[image, ndarray] -> ndarray
|
||||
The encoder function takes as input a BGR color image and a matrix of
|
||||
bounding boxes in format `(x, y, w, h)` and returns a matrix of
|
||||
corresponding feature vectors.
|
||||
mot_dir : str
|
||||
Path to the MOTChallenge directory (can be either train or test).
|
||||
output_dir
|
||||
Path to the output directory. Will be created if it does not exist.
|
||||
detection_dir
|
||||
Path to custom detections. The directory structure should be the default
|
||||
MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the
|
||||
standard MOTChallenge detections.
|
||||
|
||||
"""
|
||||
|
||||
count = 0
|
||||
for sequence in os.listdir(mot_dir):
|
||||
print("Processing %s" % sequence)
|
||||
sequence_dir = os.path.join(mot_dir, sequence)
|
||||
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
#image_dir = os.path.join(mot_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
|
||||
if det_path is not None:
|
||||
detection_dir = os.path.join(det_path, sequence)
|
||||
else:
|
||||
detection_dir = os.path.join(sequence_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
print("Frame %05d/%05d" % (frame_idx, max_frame_idx))
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
if frame_idx not in image_filenames:
|
||||
print("WARNING could not find image for frame %d" % frame_idx)
|
||||
continue
|
||||
bgr_image = cv2.imread(image_filenames[frame_idx], cv2.IMREAD_COLOR)
|
||||
features = get_features(rows[:, 2:6].copy(), bgr_image)
|
||||
|
||||
for data in features:
|
||||
file_name = "DeepSort_data_bs" + str(1) + "_" + str(count) + ".bin"
|
||||
file_path = img_path + "/" + file_name
|
||||
data.tofile(file_path)
|
||||
count += 1
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Ascend 310 feature extractor")
|
||||
parser.add_argument('--data_path', type=str, default='', help='MOT directory.')
|
||||
parser.add_argument('--det_path', type=str, default='', help='Det directory.')
|
||||
parser.add_argument('--result_path', type=str, default='', help='Inference output directory.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
image_path = os.path.join(args.result_path, "00_data")
|
||||
os.mkdir(image_path)
|
||||
generate_detections(args.data_path, image_path, args.det_path)
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
npy_dir = "" #the npy files provided by author, and the directory name is MOT16_POI_train.
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(npy_dir):
|
||||
for filename in filenames:
|
||||
load_dir = os.path.join(dirpath, filename)
|
||||
loadData = np.load(load_dir)
|
||||
dirname = "./det/" + filename[ : 8] + "/" + "det/"
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
f = open(dirname+"det.txt", 'a')
|
||||
for info in loadData:
|
||||
s = ""
|
||||
for i, num in enumerate(info):
|
||||
if i in (0, 1, 7, 8, 9):
|
||||
s += str(int(num))
|
||||
if i != 9:
|
||||
s += ','
|
||||
elif i < 10:
|
||||
s += str(num)
|
||||
s += ','
|
||||
else:
|
||||
break
|
||||
#print(s)
|
||||
f.write(s)
|
||||
f.write('\n')
|
|
@ -0,0 +1,5 @@
|
|||
cv2
|
||||
mindspore
|
||||
numpy
|
||||
matplotlib
|
||||
shutil
|
|
@ -0,0 +1,86 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [train_code_path][RANK_TABLE_FILE][DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
train_code_path=$(get_real_path $1)
|
||||
echo $train_code_path
|
||||
|
||||
if [ ! -d $train_code_path ]
|
||||
then
|
||||
echo "error: train_code_path=$train_code_path is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RANK_TABLE_FILE=$(get_real_path $2)
|
||||
echo $RANK_TABLE_FILE
|
||||
|
||||
if [ ! -f $RANK_TABLE_FILE ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATA_PATH=$(get_real_path $3)
|
||||
echo $DATA_PATH
|
||||
|
||||
if [ ! -d $DATA_PATH ]
|
||||
then
|
||||
echo "error: DATA_PATH=$DATA_PATH is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -c unlimited
|
||||
export SLOG_PRINT_TO_STDOUT=0
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_FILE
|
||||
export RANK_SIZE=8
|
||||
export RANK_START_ID=0
|
||||
|
||||
|
||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||
do
|
||||
export RANK_ID=${i}
|
||||
export DEVICE_ID=$((i + RANK_START_ID))
|
||||
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
||||
if [ -d ${train_code_path}/device${DEVICE_ID} ]; then
|
||||
rm -rf ${train_code_path}/device${DEVICE_ID}
|
||||
fi
|
||||
mkdir ${train_code_path}/device${DEVICE_ID}
|
||||
cd ${train_code_path}/device${DEVICE_ID} || exit
|
||||
python ${train_code_path}/deep_sort/deep/train.py --data_url=${DATA_PATH} \
|
||||
--train_url=./checkpoint \
|
||||
--run_distribute=True \
|
||||
--run_modelarts=False > out.log 2>&1 &
|
||||
done
|
|
@ -0,0 +1,126 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
DEVICE_TARGET must choose from ['GPU', 'CPU', 'Ascend']
|
||||
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
model=$(get_real_path $1)
|
||||
dataset_path=$(get_real_path $2)
|
||||
det_path=$(get_real_path $3)
|
||||
|
||||
if [ "$4" == "y" ] || [ "$4" == "n" ];then
|
||||
need_preprocess=$4
|
||||
else
|
||||
echo "weather need preprocess or not, it's value must be in [y, n]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$dataset_path
|
||||
echo "det path: "$det_path
|
||||
echo "need preprocess: "$need_preprocess
|
||||
echo "device id: "$device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function preprocess_data()
|
||||
{
|
||||
if [ -d preprocess_Result ]; then
|
||||
rm -rf ./preprocess_Result
|
||||
fi
|
||||
mkdir preprocess_Result
|
||||
python3.7 ../preprocess.py --data_path=$dataset_path --det_path=$det_path --result_path=./preprocess_Result/ &>preprocess.log
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
cd ../ascend310_infer || exit
|
||||
bash build.sh &> build.log
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
cd - || exit
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
|
||||
../ascend310_infer/out/main --mindir_path=$model --input0_path=./preprocess_Result/00_data --device_id=$device_id &> infer.log
|
||||
|
||||
}
|
||||
|
||||
function generater_detection()
|
||||
{
|
||||
python3.7 ../postprocess.py --det_dir=$det_path --result_dir=./result_Files --output_dir=../detections/ &> detection.log
|
||||
}
|
||||
|
||||
if [ $need_preprocess == "y" ]; then
|
||||
preprocess_data
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "preprocess dataset failed"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
generater_detection
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "generator detection failed"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import deep_sort_app
|
||||
|
||||
from src.sort.iou_matching import iou
|
||||
from src.application_util import visualization
|
||||
|
||||
|
||||
DEFAULT_UPDATE_MS = 20
|
||||
|
||||
|
||||
def run(sequence_dir, result_file, show_false_alarms=False, detection_file=None,
|
||||
update_ms=None, video_filename=None):
|
||||
"""Run tracking result visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
result_file : str
|
||||
Path to the tracking output file in MOTChallenge ground truth format.
|
||||
show_false_alarms : Optional[bool]
|
||||
If True, false alarms are highlighted as red boxes.
|
||||
detection_file : Optional[str]
|
||||
Path to the detection file.
|
||||
update_ms : Optional[int]
|
||||
Number of milliseconds between cosecutive frames. Defaults to (a) the
|
||||
frame rate specified in the seqinfo.ini file or DEFAULT_UDPATE_MS ms if
|
||||
seqinfo.ini is not available.
|
||||
video_filename : Optional[Str]
|
||||
If not None, a video of the tracking results is written to this file.
|
||||
|
||||
"""
|
||||
seq_info = deep_sort_app.gather_sequence_info(sequence_dir, detection_file)
|
||||
results = np.loadtxt(result_file, delimiter=',')
|
||||
|
||||
if show_false_alarms and seq_info["groundtruth"] is None:
|
||||
raise ValueError("No groundtruth available. Cannot show false alarms.")
|
||||
|
||||
def frame_callback(vis, frame_idx):
|
||||
print("Frame idx", frame_idx)
|
||||
image = cv2.imread(
|
||||
seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR)
|
||||
|
||||
vis.set_image(image.copy())
|
||||
|
||||
if seq_info["detections"] is not None:
|
||||
detections = deep_sort_app.create_detections(
|
||||
seq_info["detections"], frame_idx)
|
||||
vis.draw_detections(detections)
|
||||
|
||||
mask = results[:, 0].astype(np.int) == frame_idx
|
||||
track_ids = results[mask, 1].astype(np.int)
|
||||
boxes = results[mask, 2:6]
|
||||
vis.draw_groundtruth(track_ids, boxes)
|
||||
|
||||
if show_false_alarms:
|
||||
groundtruth = seq_info["groundtruth"]
|
||||
mask = groundtruth[:, 0].astype(np.int) == frame_idx
|
||||
gt_boxes = groundtruth[mask, 2:6]
|
||||
for box in boxes:
|
||||
# NOTE(nwojke): This is not strictly correct, because we don't
|
||||
# solve the assignment problem here.
|
||||
min_iou_overlap = 0.5
|
||||
if iou(box, gt_boxes).max() < min_iou_overlap:
|
||||
vis.viewer.color = 0, 0, 255
|
||||
vis.viewer.thickness = 4
|
||||
vis.viewer.rectangle(*box.astype(np.int))
|
||||
|
||||
if update_ms is None:
|
||||
update_ms = seq_info["update_ms"]
|
||||
if update_ms is None:
|
||||
update_ms = DEFAULT_UPDATE_MS
|
||||
visualizer = visualization.Visualization(seq_info, update_ms)
|
||||
if video_filename is not None:
|
||||
visualizer.viewer.enable_videowriter(video_filename)
|
||||
visualizer.run(frame_callback)
|
||||
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Siamese Tracking")
|
||||
parser.add_argument(
|
||||
"--sequence_dir", help="Path to the MOTChallenge sequence directory.",
|
||||
default="../MOT16/train")
|
||||
parser.add_argument(
|
||||
"--result_file", help="Tracking output in MOTChallenge file format.",
|
||||
default="./results/MOT16-01.txt")
|
||||
parser.add_argument(
|
||||
"--detection_file", help="Path to custom detections (optional).",
|
||||
default="../resources/detections/MOT16_POI_test/MOT16-01.npy")
|
||||
parser.add_argument(
|
||||
"--update_ms", help="Time between consecutive frames in milliseconds. "
|
||||
"Defaults to the frame_rate specified in seqinfo.ini, if available.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--output_file", help="Filename of the (optional) output video.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--show_false_alarms", help="Show false alarms as red bounding boxes.",
|
||||
type=bool, default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
run(
|
||||
args.sequence_dir, args.result_file, args.show_false_alarms,
|
||||
args.detection_file, args.update_ms, args.output_file)
|
|
@ -0,0 +1,356 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
This module contains an image viewer and drawing routines based on OpenCV.
|
||||
"""
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def is_in_bounds(mat, roi):
|
||||
"""Check if ROI is fully contained in the image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mat : ndarray
|
||||
An ndarray of ndim>=2.
|
||||
roi : (int, int, int, int)
|
||||
Region of interest (x, y, width, height) where (x, y) is the top-left
|
||||
corner.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Returns true if the ROI is contain in mat.
|
||||
|
||||
"""
|
||||
if roi[0] < 0 or roi[0] + roi[2] >= mat.shape[1]:
|
||||
return False
|
||||
if roi[1] < 0 or roi[1] + roi[3] >= mat.shape[0]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def view_roi(mat, roi):
|
||||
"""Get sub-array.
|
||||
|
||||
The ROI must be valid, i.e., fully contained in the image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mat : ndarray
|
||||
An ndarray of ndim=2 or ndim=3.
|
||||
roi : (int, int, int, int)
|
||||
Region of interest (x, y, width, height) where (x, y) is the top-left
|
||||
corner.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A view of the roi.
|
||||
|
||||
"""
|
||||
sx, ex = roi[0], roi[0] + roi[2]
|
||||
sy, ey = roi[1], roi[1] + roi[3]
|
||||
if mat.ndim == 2:
|
||||
return mat[sy:ey, sx:ex]
|
||||
return mat[sy:ey, sx:ex, :]
|
||||
|
||||
|
||||
class ImageViewer:
|
||||
"""An image viewer with drawing routines and video capture capabilities.
|
||||
|
||||
Key Bindings:
|
||||
|
||||
* 'SPACE' : pause
|
||||
* 'ESC' : quit
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update_ms : int
|
||||
Number of milliseconds between frames (1000 / frames per second).
|
||||
window_shape : (int, int)
|
||||
Shape of the window (width, height).
|
||||
caption : Optional[str]
|
||||
Title of the window.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
image : ndarray
|
||||
Color image of shape (height, width, 3). You may directly manipulate
|
||||
this image to change the view. Otherwise, you may call any of the
|
||||
drawing routines of this class. Internally, the image is treated as
|
||||
being in BGR color space.
|
||||
|
||||
Note that the image is resized to the the image viewers window_shape
|
||||
just prior to visualization. Therefore, you may pass differently sized
|
||||
images and call drawing routines with the appropriate, original point
|
||||
coordinates.
|
||||
color : (int, int, int)
|
||||
Current BGR color code that applies to all drawing routines.
|
||||
Values are in range [0-255].
|
||||
text_color : (int, int, int)
|
||||
Current BGR text color code that applies to all text rendering
|
||||
routines. Values are in range [0-255].
|
||||
thickness : int
|
||||
Stroke width in pixels that applies to all drawing routines.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, update_ms, window_shape=(640, 480), caption="Figure 1"):
|
||||
self._window_shape = window_shape
|
||||
self._caption = caption
|
||||
self._update_ms = update_ms
|
||||
self._video_writer = None
|
||||
self._user_fun = lambda: None
|
||||
self._terminate = False
|
||||
|
||||
self.image = np.zeros(self._window_shape+(3,), dtype=np.uint8)
|
||||
self._color = (0, 0, 0)
|
||||
self.text_color = (255, 255, 255)
|
||||
self.thickness = 1
|
||||
|
||||
@property
|
||||
def color(self):
|
||||
return self._color
|
||||
|
||||
@color.setter
|
||||
def color(self, value):
|
||||
if len(value) != 3:
|
||||
raise ValueError("color must be tuple of 3")
|
||||
self._color = tuple(int(c) for c in value)
|
||||
|
||||
def rectangle(self, x, y, w, h, label=None):
|
||||
"""Draw a rectangle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float | int
|
||||
Top left corner of the rectangle (x-axis).
|
||||
y : float | int
|
||||
Top let corner of the rectangle (y-axis).
|
||||
w : float | int
|
||||
Width of the rectangle.
|
||||
h : float | int
|
||||
Height of the rectangle.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the top left corner of the
|
||||
rectangle.
|
||||
|
||||
"""
|
||||
pt1 = int(x), int(y)
|
||||
pt2 = int(x + w), int(y + h)
|
||||
cv2.rectangle(self.image, pt1, pt2, self._color, self.thickness)
|
||||
if label is not None:
|
||||
text_size = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_PLAIN, 1, self.thickness)
|
||||
|
||||
center = pt1[0] + 5, pt1[1] + 5 + text_size[0][1]
|
||||
pt2 = pt1[0] + 10 + text_size[0][0], pt1[1] + 10 + \
|
||||
text_size[0][1]
|
||||
cv2.rectangle(self.image, pt1, pt2, self._color, -1)
|
||||
cv2.putText(self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
1, (255, 255, 255), self.thickness)
|
||||
|
||||
def circle(self, x, y, radius, label=None):
|
||||
"""Draw a circle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float | int
|
||||
Center of the circle (x-axis).
|
||||
y : float | int
|
||||
Center of the circle (y-axis).
|
||||
radius : float | int
|
||||
Radius of the circle in pixels.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the center of the circle.
|
||||
|
||||
"""
|
||||
image_size = int(radius + self.thickness + 1.5) # actually half size
|
||||
roi = int(x - image_size), int(y - image_size), \
|
||||
int(2 * image_size), int(2 * image_size)
|
||||
if not is_in_bounds(self.image, roi):
|
||||
return
|
||||
|
||||
image = view_roi(self.image, roi)
|
||||
center = image.shape[1] // 2, image.shape[0] // 2
|
||||
cv2.circle(
|
||||
image, center, int(radius + .5), self._color, self.thickness)
|
||||
if label is not None:
|
||||
cv2.putText(
|
||||
self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def gaussian(self, mean, covariance, label=None):
|
||||
"""Draw 95% confidence ellipse of a 2-D Gaussian distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : array_like
|
||||
The mean vector of the Gaussian distribution (ndim=1).
|
||||
covariance : array_like
|
||||
The 2x2 covariance matrix of the Gaussian distribution.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the center of the ellipse.
|
||||
|
||||
"""
|
||||
# chi2inv(0.95, 2) = 5.9915
|
||||
vals, vecs = np.linalg.eigh(5.9915 * covariance)
|
||||
indices = vals.argsort()[::-1]
|
||||
vals, vecs = np.sqrt(vals[indices]), vecs[:, indices]
|
||||
|
||||
center = int(mean[0] + .5), int(mean[1] + .5)
|
||||
axes = int(vals[0] + .5), int(vals[1] + .5)
|
||||
angle = int(180. * np.arctan2(vecs[1, 0], vecs[0, 0]) / np.pi)
|
||||
cv2.ellipse(
|
||||
self.image, center, axes, angle, 0, 360, self._color, 2)
|
||||
if label is not None:
|
||||
cv2.putText(self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def annotate(self, x, y, text):
|
||||
"""Draws a text string at a given location.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : int | float
|
||||
Bottom-left corner of the text in the image (x-axis).
|
||||
y : int | float
|
||||
Bottom-left corner of the text in the image (y-axis).
|
||||
text : str
|
||||
The text to be drawn.
|
||||
|
||||
"""
|
||||
cv2.putText(self.image, text, (int(x), int(y)), cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def colored_points(self, points, colors=None, skip_index_check=False):
|
||||
"""Draw a collection of points.
|
||||
|
||||
The point size is fixed to 1.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
points : ndarray
|
||||
The Nx2 array of image locations, where the first dimension is
|
||||
the x-coordinate and the second dimension is the y-coordinate.
|
||||
colors : Optional[ndarray]
|
||||
The Nx3 array of colors (dtype=np.uint8). If None, the current
|
||||
color attribute is used.
|
||||
skip_index_check : Optional[bool]
|
||||
If True, index range checks are skipped. This is faster, but
|
||||
requires all points to lie within the image dimensions.
|
||||
|
||||
"""
|
||||
if not skip_index_check:
|
||||
cond1, cond2 = points[:, 0] >= 0, points[:, 0] < 480
|
||||
cond3, cond4 = points[:, 1] >= 0, points[:, 1] < 640
|
||||
indices = np.logical_and.reduce((cond1, cond2, cond3, cond4))
|
||||
points = points[indices, :]
|
||||
if colors is None:
|
||||
colors = np.repeat(
|
||||
self._color, len(points)).reshape(3, len(points)).T
|
||||
indices = (points + .5).astype(np.int)
|
||||
self.image[indices[:, 1], indices[:, 0], :] = colors
|
||||
|
||||
def enable_videowriter(self, output_filename, fourcc_string="MJPG",
|
||||
fps=None):
|
||||
""" Write images to video file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_filename : str
|
||||
Output filename.
|
||||
fourcc_string : str
|
||||
The OpenCV FOURCC code that defines the video codec (check OpenCV
|
||||
documentation for more information).
|
||||
fps : Optional[float]
|
||||
Frames per second. If None, configured according to current
|
||||
parameters.
|
||||
|
||||
"""
|
||||
fourcc = cv2.VideoWriter_fourcc(*fourcc_string)
|
||||
if fps is None:
|
||||
fps = int(1000. / self._update_ms)
|
||||
self._video_writer = cv2.VideoWriter(
|
||||
output_filename, fourcc, fps, self._window_shape)
|
||||
|
||||
def disable_videowriter(self):
|
||||
""" Disable writing videos.
|
||||
"""
|
||||
self._video_writer = None
|
||||
|
||||
def run(self, update_fun=None):
|
||||
"""Start the image viewer.
|
||||
|
||||
This method blocks until the user requests to close the window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update_fun : Optional[Callable[] -> None]
|
||||
An optional callable that is invoked at each frame. May be used
|
||||
to play an animation/a video sequence.
|
||||
|
||||
"""
|
||||
if update_fun is not None:
|
||||
self._user_fun = update_fun
|
||||
|
||||
self._terminate, is_paused = False, False
|
||||
# print("ImageViewer is paused, press space to start.")
|
||||
while not self._terminate:
|
||||
t0 = time.time()
|
||||
if not is_paused:
|
||||
self._terminate = not self._user_fun()
|
||||
if self._video_writer is not None:
|
||||
self._video_writer.write(
|
||||
cv2.resize(self.image, self._window_shape))
|
||||
t1 = time.time()
|
||||
remaining_time = max(1, int(self._update_ms - 1e3*(t1-t0)))
|
||||
cv2.imshow(
|
||||
self._caption, cv2.resize(self.image, self._window_shape[:2]))
|
||||
key = cv2.waitKey(remaining_time)
|
||||
if key & 255 == 27: # ESC
|
||||
print("terminating")
|
||||
self._terminate = True
|
||||
elif key & 255 == 32: # ' '
|
||||
print("toggeling pause: " + str(not is_paused))
|
||||
is_paused = not is_paused
|
||||
elif key & 255 == 115: # 's'
|
||||
print("stepping")
|
||||
self._terminate = not self._user_fun()
|
||||
is_paused = True
|
||||
|
||||
# Due to a bug in OpenCV we must call imshow after destroying the
|
||||
# window. This will make the window appear again as soon as waitKey
|
||||
# is called.
|
||||
#
|
||||
# see https://github.com/Itseez/opencv/issues/4535
|
||||
self.image[:] = 0
|
||||
cv2.destroyWindow(self._caption)
|
||||
cv2.waitKey(1)
|
||||
cv2.imshow(self._caption, self.image)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the control loop.
|
||||
|
||||
After calling this method, the viewer will stop execution before the
|
||||
next frame and hand over control flow to the user.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
"""
|
||||
self._terminate = True
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
def non_max_suppression(boxes, max_bbox_overlap, scores=None):
|
||||
"""Suppress overlapping detections.
|
||||
|
||||
Original code from [1]_ has been adapted to include confidence score.
|
||||
|
||||
.. [1] http://www.pyimagesearch.com/2015/02/16/
|
||||
faster-non-maximum-suppression-python/
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> boxes = [d.roi for d in detections]
|
||||
>>> scores = [d.confidence for d in detections]
|
||||
>>> indices = non_max_suppression(boxes, max_bbox_overlap, scores)
|
||||
>>> detections = [detections[i] for i in indices]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
boxes : ndarray
|
||||
Array of ROIs (x, y, width, height).
|
||||
max_bbox_overlap : float
|
||||
ROIs that overlap more than this values are suppressed.
|
||||
scores : Optional[array_like]
|
||||
Detector confidence score.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[int]
|
||||
Returns indices of detections that have survived non-maxima suppression.
|
||||
|
||||
"""
|
||||
if np.size(boxes) == 0:
|
||||
return []
|
||||
|
||||
boxes = boxes.astype(np.float)
|
||||
pick = []
|
||||
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2] + boxes[:, 0]
|
||||
y2 = boxes[:, 3] + boxes[:, 1]
|
||||
|
||||
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
if scores is not None:
|
||||
idxs = np.argsort(scores)
|
||||
else:
|
||||
idxs = np.argsort(y2)
|
||||
|
||||
while np.size(idxs) > 0:
|
||||
last = len(idxs) - 1
|
||||
i = idxs[last]
|
||||
pick.append(i)
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[idxs[:last]])
|
||||
yy1 = np.maximum(y1[i], y1[idxs[:last]])
|
||||
xx2 = np.minimum(x2[i], x2[idxs[:last]])
|
||||
yy2 = np.minimum(y2[i], y2[idxs[:last]])
|
||||
|
||||
w = np.maximum(0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0, yy2 - yy1 + 1)
|
||||
|
||||
overlap = (w * h) / area[idxs[:last]]
|
||||
|
||||
idxs = np.delete(
|
||||
idxs, np.concatenate(
|
||||
([last], np.where(overlap > max_bbox_overlap)[0])))
|
||||
|
||||
return pick
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import colorsys
|
||||
import numpy as np
|
||||
from .image_viewer import ImageViewer
|
||||
|
||||
|
||||
def create_unique_color_float(tag, hue_step=0.41):
|
||||
"""Create a unique RGB color code for a given track id (tag).
|
||||
|
||||
The color code is generated in HSV color space by moving along the
|
||||
hue angle and gradually changing the saturation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag : int
|
||||
The unique target identifying tag.
|
||||
hue_step : float
|
||||
Difference between two neighboring color codes in HSV space (more
|
||||
specifically, the distance in hue channel).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(float, float, float)
|
||||
RGB color code in range [0, 1]
|
||||
|
||||
"""
|
||||
h, v = (tag * hue_step) % 1, 1. - (int(tag * hue_step) % 4) / 5.
|
||||
r, g, b = colorsys.hsv_to_rgb(h, 1., v)
|
||||
return r, g, b
|
||||
|
||||
|
||||
def create_unique_color_uchar(tag, hue_step=0.41):
|
||||
"""Create a unique RGB color code for a given track id (tag).
|
||||
|
||||
The color code is generated in HSV color space by moving along the
|
||||
hue angle and gradually changing the saturation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag : int
|
||||
The unique target identifying tag.
|
||||
hue_step : float
|
||||
Difference between two neighboring color codes in HSV space (more
|
||||
specifically, the distance in hue channel).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(int, int, int)
|
||||
RGB color code in range [0, 255]
|
||||
|
||||
"""
|
||||
r, g, b = create_unique_color_float(tag, hue_step)
|
||||
return int(255*r), int(255*g), int(255*b)
|
||||
|
||||
|
||||
class NoVisualization:
|
||||
"""
|
||||
A dummy visualization object that loops through all frames in a given
|
||||
sequence to update the tracker without performing any visualization.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_info):
|
||||
self.frame_idx = seq_info["min_frame_idx"]
|
||||
self.last_idx = seq_info["max_frame_idx"]
|
||||
|
||||
def set_image(self, image):
|
||||
pass
|
||||
|
||||
def draw_groundtruth(self, track_ids, boxes):
|
||||
pass
|
||||
|
||||
def draw_detections(self, detections):
|
||||
pass
|
||||
|
||||
def draw_trackers(self, trackers):
|
||||
pass
|
||||
|
||||
def run(self, frame_callback):
|
||||
while self.frame_idx <= self.last_idx:
|
||||
frame_callback(self, self.frame_idx)
|
||||
self.frame_idx += 1
|
||||
|
||||
|
||||
class Visualization:
|
||||
"""
|
||||
This class shows tracking output in an OpenCV image viewer.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_info, update_ms):
|
||||
image_shape = seq_info["image_size"][::-1]
|
||||
aspect_ratio = float(image_shape[1]) / image_shape[0]
|
||||
image_shape = 1024, int(aspect_ratio * 1024)
|
||||
self.viewer = ImageViewer(
|
||||
update_ms, image_shape, "Figure %s" % seq_info["sequence_name"])
|
||||
self.viewer.thickness = 2
|
||||
self.frame_idx = seq_info["min_frame_idx"]
|
||||
self.last_idx = seq_info["max_frame_idx"]
|
||||
|
||||
def run(self, frame_callback):
|
||||
self.viewer.run(lambda: self._update_fun(frame_callback))
|
||||
|
||||
def _update_fun(self, frame_callback):
|
||||
if self.frame_idx > self.last_idx:
|
||||
return False # Terminate
|
||||
frame_callback(self, self.frame_idx)
|
||||
self.frame_idx += 1
|
||||
return True
|
||||
|
||||
def set_image(self, image):
|
||||
self.viewer.image = image
|
||||
|
||||
def draw_groundtruth(self, track_ids, boxes):
|
||||
self.viewer.thickness = 2
|
||||
for track_id, box in zip(track_ids, boxes):
|
||||
self.viewer.color = create_unique_color_uchar(track_id)
|
||||
self.viewer.rectangle(*box.astype(np.int), label=str(track_id))
|
||||
|
||||
def draw_detections(self, detections):
|
||||
self.viewer.thickness = 2
|
||||
self.viewer.color = 0, 0, 255
|
||||
for detection in detections:
|
||||
self.viewer.rectangle(*detection.tlwh)
|
||||
|
||||
def draw_trackers(self, tracks):
|
||||
self.viewer.thickness = 2
|
||||
for track in tracks:
|
||||
if not track.is_confirmed() or track.time_since_update > 0:
|
||||
continue
|
||||
self.viewer.color = create_unique_color_uchar(track.track_id)
|
||||
self.viewer.rectangle(
|
||||
*track.to_tlwh().astype(np.int), label=str(track.track_id))
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import cv2
|
||||
import mindspore
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from .original_model import Net
|
||||
|
||||
class Extractor:
|
||||
def __init__(self, model_path, batch_size=32):
|
||||
self.net = Net(reid=True)
|
||||
self.batch_size = batch_size
|
||||
param_dict = load_checkpoint(model_path)
|
||||
load_param_into_net(self.net, param_dict)
|
||||
self.size = (64, 128)
|
||||
|
||||
def statistic_normalize_img(self, img, statistic_norm=True):
|
||||
"""Statistic normalize images."""
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
if statistic_norm:
|
||||
img = (img - mean) / std
|
||||
img = img.astype(np.float32)
|
||||
return img
|
||||
|
||||
def _preprocess(self, im_crops):
|
||||
"""
|
||||
TODO:
|
||||
1. to float with scale from 0 to 1
|
||||
2. resize to (64, 128) as Market1501 dataset did
|
||||
3. concatenate to a numpy array
|
||||
3. to torch Tensor
|
||||
4. normalize
|
||||
"""
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32)/255., size)
|
||||
im_batch = []
|
||||
for im in im_crops:
|
||||
im = _resize(im, self.size)
|
||||
im = self.statistic_normalize_img(im)
|
||||
im = mindspore.Tensor.from_numpy(im.transpose(2, 0, 1).copy())
|
||||
im = mindspore.ops.ExpandDims()(im, 0)
|
||||
im_batch.append(im)
|
||||
|
||||
im_batch = mindspore.ops.Concat(axis=0)(tuple(im_batch))
|
||||
return im_batch
|
||||
|
||||
|
||||
def __call__(self, im_crops):
|
||||
out = np.zeros((len(im_crops), 128), np.float32)
|
||||
num_batches = int(len(im_crops)/self.batch_size)
|
||||
s, e = 0, 0
|
||||
for i in range(num_batches):
|
||||
s, e = i * self.batch_size, (i + 1) * self.batch_size
|
||||
im_batch = self._preprocess(im_crops[s:e])
|
||||
feature = self.net(im_batch)
|
||||
out[s:e] = feature.asnumpy()
|
||||
if e < len(out):
|
||||
im_batch = self._preprocess(im_crops[e:])
|
||||
feature = self.net(im_batch)
|
||||
out[e:] = feature.asnumpy()
|
||||
return out
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as F
|
||||
|
||||
class BasicBlock(nn.Cell):
|
||||
def __init__(self, c_in, c_out, is_downsample=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.add = mindspore.ops.Add()
|
||||
self.ReLU = F.ReLU()
|
||||
self.is_downsample = is_downsample
|
||||
if is_downsample:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, pad_mode='pad', padding=1,\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, pad_mode='same',\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
self.bn1 = nn.BatchNorm2d(c_out, momentum=0.9)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1, pad_mode='pad', padding=1,\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
self.bn2 = nn.BatchNorm2d(c_out, momentum=0.9)
|
||||
if is_downsample:
|
||||
self.downsample = nn.SequentialCell(
|
||||
[nn.Conv2d(c_in, c_out, 1, stride=2, pad_mode='same', has_bias=False, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(c_out, momentum=0.9)]
|
||||
)
|
||||
elif c_in != c_out:
|
||||
self.downsample = nn.SequentialCell(
|
||||
[nn.Conv2d(c_in, c_out, 1, stride=1, pad_mode='pad', has_bias=False, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(c_out, momentum=0.9)]
|
||||
)
|
||||
self.is_downsample = True
|
||||
def construct(self, x):
|
||||
y = self.conv1(x)
|
||||
y = self.bn1(y)
|
||||
y = self.relu(y)
|
||||
y = self.conv2(y)
|
||||
y = self.bn2(y)
|
||||
if self.is_downsample:
|
||||
x = self.downsample(x)
|
||||
y = self.add(x, y)
|
||||
y = self.ReLU(y)
|
||||
return y
|
||||
|
||||
def make_layers(c_in, c_out, repeat_times, is_downsample=False):
|
||||
blocks = []
|
||||
for i in range(repeat_times):
|
||||
if i == 0:
|
||||
blocks.append(BasicBlock(c_in, c_out, is_downsample=is_downsample))
|
||||
else:
|
||||
blocks.append(BasicBlock(c_out, c_out))
|
||||
return nn.SequentialCell(blocks)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, num_classes=751, reid=False, ascend=False):
|
||||
super(Net, self).__init__()
|
||||
# 3 128 64
|
||||
self.conv = nn.SequentialCell(
|
||||
[nn.Conv2d(3, 32, 3, stride=1, pad_mode='same', has_bias=True, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(32, momentum=0.9),
|
||||
nn.ELU(),
|
||||
nn.Conv2d(32, 32, 3, stride=1, pad_mode='same', has_bias=True, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(32, momentum=0.9),
|
||||
nn.ELU(),
|
||||
nn.MaxPool2d(3, 2, pad_mode='same')]
|
||||
)
|
||||
#]
|
||||
# 32 64 32
|
||||
self.layer1 = make_layers(32, 32, 2, False)
|
||||
# 32 64 32
|
||||
self.layer2 = make_layers(32, 64, 2, True)
|
||||
# 64 32 16
|
||||
self.layer3 = make_layers(64, 128, 2, True)
|
||||
# 128 16 8
|
||||
self.dp = nn.Dropout(keep_prob=0.6)
|
||||
self.dense = nn.Dense(128*16*8, 128)
|
||||
self.bn1 = nn.BatchNorm1d(128, momentum=0.9)
|
||||
self.elu = nn.ELU()
|
||||
# 256 1 1
|
||||
self.reid = reid
|
||||
self.ascend = ascend
|
||||
#self.flatten = nn.Flatten()
|
||||
self.div = F.Div()
|
||||
self.batch_norm = nn.BatchNorm1d(128, momentum=0.9)
|
||||
self.classifier = nn.Dense(128, num_classes)
|
||||
self.Norm = nn.Norm(axis=0, keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
#x = self.flatten(x)
|
||||
x = x.view((x.shape[0], -1))
|
||||
if self.reid:
|
||||
x = self.dp(x)
|
||||
x = self.dense(x)
|
||||
if self.ascend:
|
||||
x = self.bn1(x)
|
||||
else:
|
||||
f = self.Norm(x)
|
||||
x = self.div(x, f)
|
||||
return x
|
||||
x = self.dp(x)
|
||||
x = self.dense(x)
|
||||
x = self.bn1(x)
|
||||
x = self.elu(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.nn as nn
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.common import set_seed
|
||||
from original_model import Net
|
||||
set_seed(1234)
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument("--epoch", help="Path to custom detections.", type=int, default=100)
|
||||
parser.add_argument("--batch_size", help="Batch size for Training.", type=int, default=8)
|
||||
parser.add_argument("--num_parallel_workers", help="The number of parallel workers.", type=int, default=16)
|
||||
parser.add_argument("--pre_train", help='The ckpt file of model.', type=str, default=None)
|
||||
parser.add_argument("--save_check_point", help="Whether save the training resulting.", type=bool, default=True)
|
||||
|
||||
#learning rate
|
||||
parser.add_argument("--learning_rate", help="Learning rate.", type=float, default=0.1)
|
||||
parser.add_argument("--decay_epoch", help="decay epochs.", type=int, default=20)
|
||||
parser.add_argument('--gamma', type=float, default=0.10, help='learning rate decay.')
|
||||
parser.add_argument("--momentum", help="", type=float, default=0.9)
|
||||
|
||||
#run on where
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: 0)')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=True, help='Run distribute')
|
||||
|
||||
return parser.parse_args()
|
||||
def get_lr(base_lr, total_epochs, steps_per_epoch, step_size, gamma):
|
||||
lr_each_step = []
|
||||
for i in range(1, total_epochs+1):
|
||||
if i % step_size == 0:
|
||||
base_lr *= gamma
|
||||
for _ in range(steps_per_epoch):
|
||||
lr_each_step.append(base_lr)
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
if args.run_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data'
|
||||
local_train_url = '/cache/train'
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num,\
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
DATA_DIR = local_data_url + '/'
|
||||
else:
|
||||
if args.run_distribute:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,\
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
device_num = 1
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
device_id = args.device_id
|
||||
DATA_DIR = args.data_url + '/'
|
||||
|
||||
data = ds.ImageFolderDataset(DATA_DIR, decode=True, shuffle=True,\
|
||||
num_parallel_workers=args.num_parallel_workers, num_shards=device_num, shard_id=device_id)
|
||||
|
||||
transform_img = [
|
||||
C.RandomCrop((128, 64), padding=4),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize([0.485*255, 0.456*255, 0.406*255], [0.229*255, 0.224*255, 0.225*255]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
num_classes = max(data.num_classes(), 0)
|
||||
|
||||
data = data.map(input_columns="image", operations=transform_img, num_parallel_workers=args.num_parallel_workers)
|
||||
data = data.batch(batch_size=args.batch_size)
|
||||
|
||||
data_size = data.get_dataset_size()
|
||||
|
||||
loss_cb = LossMonitor(data_size)
|
||||
time_cb = TimeMonitor(data_size=data_size)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
|
||||
#save training results
|
||||
if args.save_check_point and (device_num == 1 or device_id == 0):
|
||||
|
||||
model_save_path = './ckpt_' + str(6) + '/'
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=data_size*args.epoch, keep_checkpoint_max=args.epoch)
|
||||
|
||||
if args.run_modelarts:
|
||||
ckpoint_cb = ModelCheckpoint(prefix='deepsort', directory=local_train_url, config=config_ck)
|
||||
else:
|
||||
ckpoint_cb = ModelCheckpoint(prefix='deepsort', directory=model_save_path, config=config_ck)
|
||||
callbacks += [ckpoint_cb]
|
||||
|
||||
#design learning rate
|
||||
lr = Tensor(get_lr(args.learning_rate, args.epoch, data_size, args.decay_epoch, args.gamma))
|
||||
# net definition
|
||||
net = Net(num_classes=num_classes)
|
||||
|
||||
# loss and optimizer
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=args.momentum)
|
||||
#optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=args.momentum, weight_decay=5e-4)
|
||||
#optimizer = mindspore.nn.Momentum(params = net.trainable_params(), learning_rate=lr, momentum=args.momentum)
|
||||
|
||||
#train
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer)
|
||||
|
||||
model.train(args.epoch, data, callbacks=callbacks, dataset_sink_mode=True)
|
||||
if args.run_modelarts:
|
||||
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Detection:
|
||||
"""
|
||||
This class represents a bounding box detection in a single image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tlwh : array_like
|
||||
Bounding box in format `(x, y, w, h)`.
|
||||
confidence : float
|
||||
Detector confidence score.
|
||||
feature : array_like
|
||||
A feature vector that describes the object contained in this image.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
tlwh : ndarray
|
||||
Bounding box in format `(top left x, top left y, width, height)`.
|
||||
confidence : ndarray
|
||||
Detector confidence score.
|
||||
feature : ndarray | NoneType
|
||||
A feature vector that describes the object contained in this image.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tlwh, confidence, feature):
|
||||
self.tlwh = np.asarray(tlwh, dtype=np.float)
|
||||
self.confidence = float(confidence)
|
||||
self.feature = np.asarray(feature, dtype=np.float32)
|
||||
|
||||
def to_tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def to_xyah(self):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import linear_assignment
|
||||
|
||||
|
||||
def iou(bbox, candidates):
|
||||
"""Computer intersection over union.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox : ndarray
|
||||
A bounding box in format `(top left x, top left y, width, height)`.
|
||||
candidates : ndarray
|
||||
A matrix of candidate bounding boxes (one per row) in the same format
|
||||
as `bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The intersection over union in [0, 1] between the `bbox` and each
|
||||
candidate. A higher score means a larger fraction of the `bbox` is
|
||||
occluded by the candidate.
|
||||
|
||||
"""
|
||||
bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
|
||||
candidates_tl = candidates[:, :2]
|
||||
candidates_br = candidates[:, :2] + candidates[:, 2:]
|
||||
|
||||
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
|
||||
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
|
||||
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
|
||||
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
|
||||
wh = np.maximum(0., br - tl)
|
||||
|
||||
area_intersection = wh.prod(axis=1)
|
||||
area_bbox = bbox[2:].prod()
|
||||
area_candidates = candidates[:, 2:].prod(axis=1)
|
||||
return area_intersection / (area_bbox + area_candidates - area_intersection)
|
||||
|
||||
|
||||
def iou_cost(tracks, detections, track_indices=None,
|
||||
detection_indices=None):
|
||||
"""An intersection over union distance metric.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tracks : List[deep_sort.track.Track]
|
||||
A list of tracks.
|
||||
detections : List[deep_sort.detection.Detection]
|
||||
A list of detections.
|
||||
track_indices : Optional[List[int]]
|
||||
A list of indices to tracks that should be matched. Defaults to
|
||||
all `tracks`.
|
||||
detection_indices : Optional[List[int]]
|
||||
A list of indices to detections that should be matched. Defaults
|
||||
to all `detections`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a cost matrix of shape
|
||||
len(track_indices), len(detection_indices) where entry (i, j) is
|
||||
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if tracks[track_idx].time_since_update > 1:
|
||||
cost_matrix[row, :] = linear_assignment.INFTY_COST
|
||||
continue
|
||||
|
||||
bbox = tracks[track_idx].to_tlwh()
|
||||
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
|
||||
cost_matrix[row, :] = 1. - iou(bbox, candidates)
|
||||
return cost_matrix
|
|
@ -0,0 +1,237 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
||||
chi2inv95 = {
|
||||
1: 3.8415,
|
||||
2: 5.9915,
|
||||
3: 7.8147,
|
||||
4: 9.4877,
|
||||
5: 11.070,
|
||||
6: 12.592,
|
||||
7: 14.067,
|
||||
8: 15.507,
|
||||
9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilter:
|
||||
"""
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(self._motion_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(
|
||||
projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((
|
||||
kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements,
|
||||
only_position=False):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
d = measurements - mean
|
||||
z = scipy.linalg.solve_triangular(
|
||||
cholesky_factor, d.T, lower=True, check_finite=False,
|
||||
overwrite_b=True)
|
||||
squared_maha = np.sum(z * z, axis=0)
|
||||
return squared_maha
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
# from sklearn.utils.linear_assignment_ import linear_assignment
|
||||
from scipy.optimize import linear_sum_assignment as linear_assignment
|
||||
from . import kalman_filter
|
||||
|
||||
|
||||
INFTY_COST = 1e+5
|
||||
|
||||
|
||||
def min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections, track_indices=None,
|
||||
detection_indices=None):
|
||||
"""Solve linear assignment problem.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as well as
|
||||
a list of N track indices and M detection indices. The metric should
|
||||
return the NxM dimensional cost matrix, where element (i, j) is the
|
||||
association cost between the i-th track in the given track indices and
|
||||
the j-th detection in the given detection_indices.
|
||||
max_distance : float
|
||||
Gating threshold. Associations with cost larger than this value are
|
||||
disregarded.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : List[int]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above).
|
||||
detection_indices : List[int]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(List[(int, int)], List[int], List[int])
|
||||
Returns a tuple with the following three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
if not detection_indices or not track_indices:
|
||||
return [], track_indices, detection_indices # Nothing to match.
|
||||
|
||||
cost_matrix = distance_metric(
|
||||
tracks, detections, track_indices, detection_indices)
|
||||
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
|
||||
|
||||
row_indices, col_indices = linear_assignment(cost_matrix)
|
||||
|
||||
matches, unmatched_tracks, unmatched_detections = [], [], []
|
||||
for col, detection_idx in enumerate(detection_indices):
|
||||
if col not in col_indices:
|
||||
unmatched_detections.append(detection_idx)
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if row not in row_indices:
|
||||
unmatched_tracks.append(track_idx)
|
||||
for row, col in zip(row_indices, col_indices):
|
||||
track_idx = track_indices[row]
|
||||
detection_idx = detection_indices[col]
|
||||
if cost_matrix[row, col] > max_distance:
|
||||
unmatched_tracks.append(track_idx)
|
||||
unmatched_detections.append(detection_idx)
|
||||
else:
|
||||
matches.append((track_idx, detection_idx))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def matching_cascade(
|
||||
distance_metric, max_distance, cascade_depth, tracks, detections,
|
||||
track_indices=None, detection_indices=None):
|
||||
"""Run matching cascade.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as well as
|
||||
a list of N track indices and M detection indices. The metric should
|
||||
return the NxM dimensional cost matrix, where element (i, j) is the
|
||||
association cost between the i-th track in the given track indices and
|
||||
the j-th detection in the given detection indices.
|
||||
max_distance : float
|
||||
Gating threshold. Associations with cost larger than this value are
|
||||
disregarded.
|
||||
cascade_depth: int
|
||||
The cascade depth, should be se to the maximum track age.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : Optional[List[int]]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above). Defaults to all tracks.
|
||||
detection_indices : Optional[List[int]]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above). Defaults to all
|
||||
detections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(List[(int, int)], List[int], List[int])
|
||||
Returns a tuple with the following three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = list(range(len(tracks)))
|
||||
if detection_indices is None:
|
||||
detection_indices = list(range(len(detections)))
|
||||
|
||||
unmatched_detections = detection_indices
|
||||
matches = []
|
||||
for level in range(cascade_depth):
|
||||
if not unmatched_detections: # No detections left
|
||||
break
|
||||
|
||||
track_indices_l = [
|
||||
k for k in track_indices
|
||||
if tracks[k].time_since_update == 1 + level
|
||||
]
|
||||
if not track_indices_l: # Nothing to match at this level
|
||||
continue
|
||||
|
||||
matches_l, _, unmatched_detections = \
|
||||
min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections,
|
||||
track_indices_l, unmatched_detections)
|
||||
matches += matches_l
|
||||
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def gate_cost_matrix(
|
||||
kf, cost_matrix, tracks, detections, track_indices, detection_indices,
|
||||
gated_cost=INFTY_COST, only_position=False):
|
||||
"""Invalidate infeasible entries in cost matrix based on the state
|
||||
distributions obtained by Kalman filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : The Kalman filter.
|
||||
cost_matrix : ndarray
|
||||
The NxM dimensional cost matrix, where N is the number of track indices
|
||||
and M is the number of detection indices, such that entry (i, j) is the
|
||||
association cost between `tracks[track_indices[i]]` and
|
||||
`detections[detection_indices[j]]`.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : List[int]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above).
|
||||
detection_indices : List[int]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above).
|
||||
gated_cost : Optional[float]
|
||||
Entries in the cost matrix corresponding to infeasible associations are
|
||||
set this value. Defaults to a very large value.
|
||||
only_position : Optional[bool]
|
||||
If True, only the x, y position of the state distribution is considered
|
||||
during gating. Defaults to False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns the modified cost matrix.
|
||||
|
||||
"""
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray(
|
||||
[detections[i].to_xyah() for i in detection_indices])
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
track = tracks[track_idx]
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
|
||||
return cost_matrix
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _pdist(a, b):
|
||||
"""Compute pair-wise squared distance between points in `a` and `b`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
An NxM matrix of N samples of dimensionality M.
|
||||
b : array_like
|
||||
An LxM matrix of L samples of dimensionality M.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||
contains the squared distance between `a[i]` and `b[j]`.
|
||||
|
||||
"""
|
||||
a, b = np.asarray(a), np.asarray(b)
|
||||
if np.size(a) == 0 or np.size(b) == 0:
|
||||
return np.zeros((len(a), len(b)))
|
||||
a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
|
||||
r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
|
||||
r2 = np.clip(r2, 0., float(np.inf))
|
||||
return r2
|
||||
|
||||
|
||||
def _cosine_distance(a, b, data_is_normalized=False):
|
||||
"""Compute pair-wise cosine distance between points in `a` and `b`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
An NxM matrix of N samples of dimensionality M.
|
||||
b : array_like
|
||||
An LxM matrix of L samples of dimensionality M.
|
||||
data_is_normalized : Optional[bool]
|
||||
If True, assumes rows in a and b are unit length vectors.
|
||||
Otherwise, a and b are explicitly normalized to length 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||
contains the squared distance between `a[i]` and `b[j]`.
|
||||
|
||||
"""
|
||||
if not data_is_normalized:
|
||||
a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
|
||||
b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
|
||||
return 1. - np.dot(a, b.T)
|
||||
|
||||
|
||||
def _nn_euclidean_distance(x, y):
|
||||
""" Helper function for nearest neighbor distance metric (Euclidean).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
A matrix of N row-vectors (sample points).
|
||||
y : ndarray
|
||||
A matrix of M row-vectors (query points).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A vector of length M that contains for each entry in `y` the
|
||||
smallest Euclidean distance to a sample in `x`.
|
||||
|
||||
"""
|
||||
distances = _pdist(x, y)
|
||||
return np.maximum(0.0, distances.min(axis=0))
|
||||
|
||||
|
||||
def _nn_cosine_distance(x, y):
|
||||
""" Helper function for nearest neighbor distance metric (cosine).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
A matrix of N row-vectors (sample points).
|
||||
y : ndarray
|
||||
A matrix of M row-vectors (query points).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A vector of length M that contains for each entry in `y` the
|
||||
smallest cosine distance to a sample in `x`.
|
||||
|
||||
"""
|
||||
distances = _cosine_distance(x, y)
|
||||
return distances.min(axis=0)
|
||||
|
||||
|
||||
class NearestNeighborDistanceMetric:
|
||||
"""
|
||||
A nearest neighbor distance metric that, for each target, returns
|
||||
the closest distance to any sample that has been observed so far.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
Either "euclidean" or "cosine".
|
||||
matching_threshold: float
|
||||
The matching threshold. Samples with larger distance are considered an
|
||||
invalid match.
|
||||
budget : Optional[int]
|
||||
If not None, fix samples per class to at most this number. Removes
|
||||
the oldest samples when the budget is reached.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samples : Dict[int -> List[ndarray]]
|
||||
A dictionary that maps from target identities to the list of samples
|
||||
that have been observed so far.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metric, matching_threshold, budget=None):
|
||||
|
||||
|
||||
if metric == "euclidean":
|
||||
self._metric = _nn_euclidean_distance
|
||||
elif metric == "cosine":
|
||||
self._metric = _nn_cosine_distance
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid metric; must be either 'euclidean' or 'cosine'")
|
||||
self.matching_threshold = matching_threshold
|
||||
self.budget = budget
|
||||
self.samples = {}
|
||||
|
||||
def partial_fit(self, features, targets, active_targets):
|
||||
"""Update the distance metric with new data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : ndarray
|
||||
An NxM matrix of N features of dimensionality M.
|
||||
targets : ndarray
|
||||
An integer array of associated target identities.
|
||||
active_targets : List[int]
|
||||
A list of targets that are currently present in the scene.
|
||||
|
||||
"""
|
||||
for feature, target in zip(features, targets):
|
||||
self.samples.setdefault(target, []).append(feature)
|
||||
if self.budget is not None:
|
||||
self.samples[target] = self.samples[target][-int(self.budget):]
|
||||
self.samples = {k: self.samples[k] for k in active_targets}
|
||||
|
||||
def distance(self, features, targets):
|
||||
"""Compute distance between features and targets.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : ndarray
|
||||
An NxM matrix of N features of dimensionality M.
|
||||
targets : List[int]
|
||||
A list of targets to match the given `features` against.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a cost matrix of shape len(targets), len(features), where
|
||||
element (i, j) contains the closest squared distance between
|
||||
`targets[i]` and `features[j]`.
|
||||
|
||||
"""
|
||||
cost_matrix = np.zeros((len(targets), len(features)))
|
||||
for i, target in enumerate(targets):
|
||||
cost_matrix[i, :] = self._metric(self.samples[target], features)
|
||||
return cost_matrix
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
class TrackState:
|
||||
"""
|
||||
Enumeration type for the single target track state. Newly created tracks are
|
||||
classified as `tentative` until enough evidence has been collected. Then,
|
||||
the track state is changed to `confirmed`. Tracks that are no longer alive
|
||||
are classified as `deleted` to mark them for removal from the set of active
|
||||
tracks.
|
||||
|
||||
"""
|
||||
|
||||
Tentative = 1
|
||||
Confirmed = 2
|
||||
Deleted = 3
|
||||
|
||||
|
||||
class Track:
|
||||
"""
|
||||
A single target track with state space `(x, y, a, h)` and associated
|
||||
velocities, where `(x, y)` is the center of the bounding box, `a` is the
|
||||
aspect ratio and `h` is the height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector of the initial state distribution.
|
||||
covariance : ndarray
|
||||
Covariance matrix of the initial state distribution.
|
||||
track_id : int
|
||||
A unique track identifier.
|
||||
n_init : int
|
||||
Number of consecutive detections before the track is confirmed. The
|
||||
track state is set to `Deleted` if a miss occurs within the first
|
||||
`n_init` frames.
|
||||
max_age : int
|
||||
The maximum number of consecutive misses before the track state is
|
||||
set to `Deleted`.
|
||||
feature : Optional[ndarray]
|
||||
Feature vector of the detection this track originates from. If not None,
|
||||
this feature is added to the `features` cache.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector of the initial state distribution.
|
||||
covariance : ndarray
|
||||
Covariance matrix of the initial state distribution.
|
||||
track_id : int
|
||||
A unique track identifier.
|
||||
hits : int
|
||||
Total number of measurement updates.
|
||||
age : int
|
||||
Total number of frames since first occurrence.
|
||||
time_since_update : int
|
||||
Total number of frames since last measurement update.
|
||||
state : TrackState
|
||||
The current track state.
|
||||
features : List[ndarray]
|
||||
A cache of features. On each measurement update, the associated feature
|
||||
vector is added to this list.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mean, covariance, track_id, n_init, max_age,
|
||||
feature=None):
|
||||
self.mean = mean
|
||||
self.covariance = covariance
|
||||
self.track_id = track_id
|
||||
self.hits = 1
|
||||
self.age = 1
|
||||
self.time_since_update = 0
|
||||
|
||||
self.state = TrackState.Tentative
|
||||
self.features = []
|
||||
if feature is not None:
|
||||
self.features.append(feature)
|
||||
|
||||
self._n_init = n_init
|
||||
self._max_age = max_age
|
||||
|
||||
def to_tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The bounding box.
|
||||
|
||||
"""
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
def to_tlbr(self):
|
||||
"""Get current position in bounding box format `(min x, miny, max x,
|
||||
max y)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The bounding box.
|
||||
|
||||
"""
|
||||
ret = self.to_tlwh()
|
||||
ret[2:] = ret[:2] + ret[2:]
|
||||
return ret
|
||||
|
||||
def predict(self, kf):
|
||||
"""Propagate the state distribution to the current time step using a
|
||||
Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : kalman_filter.KalmanFilter
|
||||
The Kalman filter.
|
||||
|
||||
"""
|
||||
self.mean, self.covariance = kf.predict(self.mean, self.covariance)
|
||||
self.age += 1
|
||||
self.time_since_update += 1
|
||||
|
||||
def update(self, kf, detection):
|
||||
"""Perform Kalman filter measurement update step and update the feature
|
||||
cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : kalman_filter.KalmanFilter
|
||||
The Kalman filter.
|
||||
detection : Detection
|
||||
The associated detection.
|
||||
|
||||
"""
|
||||
self.mean, self.covariance = kf.update(
|
||||
self.mean, self.covariance, detection.to_xyah())
|
||||
self.features.append(detection.feature)
|
||||
|
||||
self.hits += 1
|
||||
self.time_since_update = 0
|
||||
if self.state == TrackState.Tentative and self.hits >= self._n_init:
|
||||
self.state = TrackState.Confirmed
|
||||
|
||||
def mark_missed(self):
|
||||
"""Mark this track as missed (no association at the current time step).
|
||||
"""
|
||||
if self.state == TrackState.Tentative:
|
||||
self.state = TrackState.Deleted
|
||||
elif self.time_since_update > self._max_age:
|
||||
self.state = TrackState.Deleted
|
||||
|
||||
def is_tentative(self):
|
||||
"""Returns True if this track is tentative (unconfirmed).
|
||||
"""
|
||||
return self.state == TrackState.Tentative
|
||||
|
||||
def is_confirmed(self):
|
||||
"""Returns True if this track is confirmed."""
|
||||
return self.state == TrackState.Confirmed
|
||||
|
||||
def is_deleted(self):
|
||||
"""Returns True if this track is dead and should be deleted."""
|
||||
return self.state == TrackState.Deleted
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import kalman_filter
|
||||
from . import linear_assignment
|
||||
from . import iou_matching
|
||||
from .track import Track
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""
|
||||
This is the multi-target tracker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : nn_matching.NearestNeighborDistanceMetric
|
||||
A distance metric for measurement-to-track association.
|
||||
max_age : int
|
||||
Maximum number of missed misses before a track is deleted.
|
||||
n_init : int
|
||||
Number of consecutive detections before the track is confirmed. The
|
||||
track state is set to `Deleted` if a miss occurs within the first
|
||||
`n_init` frames.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
metric : nn_matching.NearestNeighborDistanceMetric
|
||||
The distance metric used for measurement to track association.
|
||||
max_age : int
|
||||
Maximum number of missed misses before a track is deleted.
|
||||
n_init : int
|
||||
Number of frames that a track remains in initialization phase.
|
||||
kf : kalman_filter.KalmanFilter
|
||||
A Kalman filter to filter target trajectories in image space.
|
||||
tracks : List[Track]
|
||||
The list of active tracks at the current time step.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3):
|
||||
self.metric = metric
|
||||
self.max_iou_distance = max_iou_distance
|
||||
self.max_age = max_age
|
||||
self.n_init = n_init
|
||||
|
||||
self.kf = kalman_filter.KalmanFilter()
|
||||
self.tracks = []
|
||||
self._next_id = 1
|
||||
|
||||
def predict(self):
|
||||
"""Propagate track state distributions one time step forward.
|
||||
|
||||
This function should be called once every time step, before `update`.
|
||||
"""
|
||||
for track in self.tracks:
|
||||
track.predict(self.kf)
|
||||
|
||||
def update(self, detections):
|
||||
"""Perform measurement update and track management.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detections : List[deep_sort.detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
|
||||
"""
|
||||
# Run matching cascade.
|
||||
matches, unmatched_tracks, unmatched_detections = \
|
||||
self._match(detections)
|
||||
|
||||
# Update track set.
|
||||
for track_idx, detection_idx in matches:
|
||||
self.tracks[track_idx].update(
|
||||
self.kf, detections[detection_idx])
|
||||
for track_idx in unmatched_tracks:
|
||||
self.tracks[track_idx].mark_missed()
|
||||
for detection_idx in unmatched_detections:
|
||||
self._initiate_track(detections[detection_idx])
|
||||
self.tracks = [t for t in self.tracks if not t.is_deleted()]
|
||||
|
||||
# Update distance metric.
|
||||
active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
|
||||
features, targets = [], []
|
||||
for track in self.tracks:
|
||||
if not track.is_confirmed():
|
||||
continue
|
||||
features += track.features
|
||||
targets += [track.track_id for _ in track.features]
|
||||
track.features = []
|
||||
self.metric.partial_fit(
|
||||
np.asarray(features), np.asarray(targets), active_targets)
|
||||
|
||||
def _match(self, detections):
|
||||
|
||||
def gated_metric(tracks, dets, track_indices, detection_indices):
|
||||
features = np.array([dets[i].feature for i in detection_indices])
|
||||
targets = np.array([tracks[i].track_id for i in track_indices])
|
||||
cost_matrix = self.metric.distance(features, targets)
|
||||
cost_matrix = linear_assignment.gate_cost_matrix(
|
||||
self.kf, cost_matrix, tracks, dets, track_indices,
|
||||
detection_indices)
|
||||
|
||||
return cost_matrix
|
||||
|
||||
# Split track set into confirmed and unconfirmed tracks.
|
||||
confirmed_tracks = [
|
||||
i for i, t in enumerate(self.tracks) if t.is_confirmed()]
|
||||
unconfirmed_tracks = [
|
||||
i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
|
||||
|
||||
# Associate confirmed tracks using appearance features.
|
||||
matches_a, unmatched_tracks_a, unmatched_detections = \
|
||||
linear_assignment.matching_cascade(
|
||||
gated_metric, self.metric.matching_threshold, self.max_age,
|
||||
self.tracks, detections, confirmed_tracks)
|
||||
|
||||
# Associate remaining tracks together with unconfirmed tracks using IOU.
|
||||
iou_track_candidates = unconfirmed_tracks + [
|
||||
k for k in unmatched_tracks_a if
|
||||
self.tracks[k].time_since_update == 1]
|
||||
unmatched_tracks_a = [
|
||||
k for k in unmatched_tracks_a if
|
||||
self.tracks[k].time_since_update != 1]
|
||||
matches_b, unmatched_tracks_b, unmatched_detections = \
|
||||
linear_assignment.min_cost_matching(
|
||||
iou_matching.iou_cost, self.max_iou_distance, self.tracks,
|
||||
detections, iou_track_candidates, unmatched_detections)
|
||||
|
||||
matches = matches_a + matches_b
|
||||
unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
def _initiate_track(self, detection):
|
||||
mean, covariance = self.kf.initiate(detection.to_xyah())
|
||||
self.tracks.append(Track(
|
||||
mean, covariance, self._next_id, self.n_init, self.max_age,
|
||||
detection.feature))
|
||||
self._next_id += 1
|
Loading…
Reference in New Issue