commit
394dfc07bc
|
@ -0,0 +1,270 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [DeepSort描述](#DeepSort描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [导出mindir模型](#导出mindir模型)
|
||||
- [推理过程](#推理过程)
|
||||
- [用法](#用法)
|
||||
- [结果](#结果)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## DeepSort描述
|
||||
|
||||
DeepSort是2017年提出的多目标跟踪算方法。该网络在MOT16获得冠军,不仅提升了精度,而且速度比之前快20倍。
|
||||
|
||||
[论文](https://arxiv.org/abs/1602.00763): Nicolai Wojke, Alex Bewley, Dietrich Paulus. "SIMPLE ONLINE AND REALTIME TRACKING WITH A DEEP ASSOCIATION METRIC". *Presented at ICIP 2016*.
|
||||
|
||||
## 模型架构
|
||||
|
||||
DeepSort由一个特征提取器、一个卡尔曼滤波和一个匈牙利算法组成。特征提取器用于提取框中人物特征信息,卡尔曼滤波根据上一帧信息预测当前帧人物位置,匈牙利算法用于匹配预测信息与检测到的人物位置信息。
|
||||
|
||||
## 数据集
|
||||
|
||||
使用的数据集:[MOT16](<https://motchallenge.net/data/MOT16.zip>)、[Market-1501](<https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view>)
|
||||
|
||||
MOT16:
|
||||
|
||||
- 数据集大小:1.9G,共14个视频帧序列
|
||||
- test:7个视频序列帧
|
||||
- train:7个序列帧
|
||||
- 数据格式(一个train视频帧序列):
|
||||
- det:视频序列中人物坐标以及置信度等信息
|
||||
- gt:视频跟踪标签信息
|
||||
- img1:视频中所有帧序列
|
||||
- 注意:由于作者提供的视频帧序列检测到的坐标信息和置信度信息不一样,所以在跟踪时使用作者提供的信息,作者提供的[npy](https://drive.google.com/drive/folders/18fKzfqnqhqW3s9zwsCbnVJ5XF2JFeqMp)文件。
|
||||
|
||||
Market-1501:
|
||||
|
||||
- 使用:
|
||||
- 使用目的:训练DeepSort特征提取器
|
||||
- 使用方法: 先使用prepare.py处理数据
|
||||
|
||||
## 环境要求
|
||||
|
||||
- 硬件(Ascend/ModelArts)
|
||||
- 准备Ascend或ModelArts处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
## 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```python
|
||||
# 进入脚本目录,提取det信息(使用作者提供的检测框信息),在脚本中给出数据路径
|
||||
python process-npy.py
|
||||
# 进入脚本目录,预处理数据集(Market-1501),在脚本中给出数据集路径
|
||||
python prepare.py
|
||||
# 进入脚本目录,训练DeepSort特征提取器
|
||||
python src/deep/train.py --run_modelarts=False --run_distribute=True --data_url="" --train_url=""
|
||||
# 进入脚本目录,提取detections信息
|
||||
python generater_detection.py --run_modelarts=False --run_distribute=True --data_url="" --train_url="" --det_url="" --ckpt_url="" --model_name=""
|
||||
# 进入脚本目录,生成跟踪信息
|
||||
python evaluate_motchallenge.py --data_url="" --train_url="" --detection_url=""
|
||||
|
||||
#Ascend多卡训练
|
||||
bash scripts/run_distribute_train.sh train_code_path RANK_TABLE_FILE DATA_PATH
|
||||
```
|
||||
|
||||
Ascend训练:生成[RANK_TABLE_FILE](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── DeepSort
|
||||
├── scripts
|
||||
│ ├──run_distribute_train.sh // 在Ascend中多卡训练
|
||||
├── src //源码
|
||||
│ │ ├── application_util
|
||||
│ │ │ ├──image_viewer.py
|
||||
│ │ │ ├──preprocessing.py
|
||||
│ │ │ ├──visualization.py
|
||||
│ │ ├──deep
|
||||
│ │ │ ├──feature_extractor.py //提取目标框中人物特征信息
|
||||
│ │ │ ├──original_model.py //特征提取器模型
|
||||
│ │ │ ├──train.py //训练网络模型
|
||||
│ │ ├──sort
|
||||
│ │ │ ├──detection.py
|
||||
│ │ │ ├──iou_matching.py //预测信息与真实框匹配
|
||||
│ │ │ ├──kalman_filter.py //卡尔曼滤波,预测跟踪框信息
|
||||
│ │ │ ├──linear_assignment.py
|
||||
│ │ │ ├──nn_matching.py //框匹配
|
||||
│ │ │ ├──track.py //跟踪器
|
||||
│ │ │ ├──tracker.py //跟踪器
|
||||
├── deep_sort_app.py //目标跟踪
|
||||
├── evaluate_motchallenge.py //生成跟踪结果信息
|
||||
├── generate_videos.py //根据跟踪结果生成跟踪视频
|
||||
├── generater-detection.py //生成detection信息
|
||||
├── postprocess.py //生成Ascend310推理数据
|
||||
├── preprocess.py //处理Ascend310推理结果,生成精度
|
||||
├── prepare.py //处理训练数据集
|
||||
├── process-npy.py //提取帧序列人物坐标和置信度
|
||||
├── show_results.py //展示跟踪结果
|
||||
├── README.md // DeepSort相关说明
|
||||
```
|
||||
|
||||
### 脚本参数
|
||||
|
||||
```python
|
||||
train.py generater_detection.py evaluate_motchallenge.py 中主要参数如下:
|
||||
|
||||
--data_url: 到训练和提取信息数据集的绝对完整路径
|
||||
--train_url: 输出文件路径。
|
||||
--epoch: 总训练轮次
|
||||
--batch_size: 训练批次大小
|
||||
--device_targe: 实现代码的设备。值为'Ascend'
|
||||
--ckpt_url: 训练后保存的检查点文件的绝对完整路径
|
||||
--model_name: 模型文件名称
|
||||
--det_url: 视频帧序列人物信息文件路径
|
||||
--detection_url: 人物坐标信息、置信度以及特征信息文件路径
|
||||
--run_distribute: 多卡运行
|
||||
--run_modelarts: ModelArts上运行
|
||||
```
|
||||
|
||||
### 训练过程
|
||||
|
||||
#### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
python src/deep/train.py --run_modelarts=False --run_distribute=False --data_url="" --train_url=""
|
||||
# 或进入脚本目录,执行脚本
|
||||
bash scripts/run_distribute_train.sh train_code_path RANK_TABLE_FILE DATA_PATH
|
||||
```
|
||||
|
||||
经过训练后,损失值如下:
|
||||
|
||||
```bash
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 3984, loss is 6.4320717
|
||||
epoch: 1 step: 3984, loss is 6.414733
|
||||
epoch: 1 step: 3984, loss is 6.4306755
|
||||
epoch: 1 step: 3984, loss is 6.4387856
|
||||
epoch: 1 step: 3984, loss is 6.463995
|
||||
...
|
||||
epoch: 2 step: 3984, loss is 6.436552
|
||||
epoch: 2 step: 3984, loss is 6.408932
|
||||
epoch: 2 step: 3984, loss is 6.4517527
|
||||
epoch: 2 step: 3984, loss is 6.448922
|
||||
epoch: 2 step: 3984, loss is 6.4611588
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在当前目录下。
|
||||
|
||||
### 评估过程
|
||||
|
||||
#### 评估
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# 进入脚本目录,提取det信息(使用作者提供的检测框信息)
|
||||
python process-npy.py
|
||||
# 进入脚本目录,提取detections信息
|
||||
python generater_detection.py --run_modelarts False --run_distribute True --data_url "" --train_url "" --det_url "" --ckpt_url "" --model_name ""
|
||||
# 进入脚本目录,生成跟踪信息
|
||||
python evaluate_motchallenge.py --data_url="" --train_url="" --detection_url=""
|
||||
# 生成跟踪结果
|
||||
python eval_motchallenge.py ----run_modelarts=False --data_url="" --train_url="" --result_url=""
|
||||
```
|
||||
|
||||
- [测评工具](https://github.com/cheind/py-motmetrics)
|
||||
|
||||
说明:脚本中引用头文件可能存在一些问题,自行修改头文件路径即可
|
||||
|
||||
```bash
|
||||
#测量精度
|
||||
python motmetrics/apps/eval_motchallenge.py --groundtruths="" --tests=""
|
||||
```
|
||||
|
||||
-
|
||||
测试数据集的准确率如下:
|
||||
|
||||
| 数据 | MOTA | MOTP| MT | ML| IDs | FM | FP | FN |
|
||||
| -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -------------------------- | -----------------------------------------------------------
|
||||
| MOT16-02 | 29.0% | 0.207 | 11 | 11| 159 | 226 | 4151 | 8346 |
|
||||
| MOT16-04 | 58.6% | 0.167| 42 | 14| 62 | 242 | 6269 | 13374 |
|
||||
| MOT16-05 | 51.7% | 0.213| 31 | 27| 68 | 109 | 630 | 2595 |
|
||||
| MOT16-09 | 64.3% | 0.162| 12 | 1| 39 | 58 | 309 | 1537 |
|
||||
| MOT16-10 | 49.2% | 0.228| 25 | 1| 201 | 307 | 3089 | 2915 |
|
||||
| MOT16-11 | 65.9% | 0.152| 29 | 9| 54 | 99 | 907 | 2162 |
|
||||
| MOT16-13 | 45.0% | 0.237| 61 | 7| 269 | 335 | 3709 | 2251 |
|
||||
| overall | 51.9% | 0.189| 211 | 70| 852 | 1376 | 19094 | 33190 |
|
||||
|
||||
## [导出mindir模型](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --device_id [DEVICE_ID] --ckpt_file [CKPT_PATH]
|
||||
```
|
||||
|
||||
## [推理过程](#contents)
|
||||
|
||||
### 用法
|
||||
|
||||
执行推断之前,minirir文件必须由export.py导出。输入文件必须为bin格式
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
推理结果文件保存在当前路径中,将文件作为输入,输入到eval_motchallenge.py中,然后输出result文件,输入到测评工具中即可得到精度结果。
|
||||
|
||||
## 模型描述
|
||||
|
||||
### 性能
|
||||
|
||||
#### 评估性能
|
||||
|
||||
| 参数 | ModelArts
|
||||
| -------------------------- | -----------------------------------------------------------
|
||||
| 资源 | Ascend 910;CPU 2.60GHz, 192核;内存:755G
|
||||
| 上传日期 | 2021-08-12
|
||||
| MindSpore版本 | 1.2.0
|
||||
| 数据集 | MOT16 Market-1501
|
||||
| 训练参数 | epoch=100, step=191, batch_size=8, lr=0.1
|
||||
| 优化器 | SGD
|
||||
| 损失函数 | SoftmaxCrossEntropyWithLogits
|
||||
| 损失 | 0.03
|
||||
| 速度 | 9.804毫秒/步
|
||||
| 总时间 | 10分钟
|
||||
| 微调检查点 | 大约40M (.ckpt文件)
|
||||
| 脚本 | [DeepSort脚本]
|
||||
|
||||
## 随机情况说明
|
||||
|
||||
train.py中设置了随机种子。
|
||||
|
||||
## ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,14 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
#endif
|
|
@ -0,0 +1,134 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "/inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(input0_path, ".", "input0 path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||
|
||||
Model model;
|
||||
Status ret = model.Build(GraphCell(graph), context);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> model_inputs = model.GetInputs();
|
||||
if (model_inputs.empty()) {
|
||||
std::cout << "Invalid model, inputs is empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input0_files = GetAllFiles(FLAGS_input0_path);
|
||||
|
||||
if (input0_files.empty()) {
|
||||
std::cout << "ERROR: input data empty." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = input0_files.size();
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
|
||||
|
||||
auto input0 = ReadFileToTensor(input0_files[i]);
|
||||
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
input0.Data().get(), input0.DataSize());
|
||||
|
||||
for (auto shape : model_inputs[0].Shape()) {
|
||||
std::cout << "model input shape" << shape << std::endl;
|
||||
}
|
||||
gettimeofday(&start, nullptr);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
WriteResult(input0_files[i], outputs);
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
std::cout << "output size:" << outputSize << std::endl;
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -0,0 +1,275 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.application_util import preprocessing
|
||||
from src.application_util import visualization
|
||||
from src.sort import nn_matching
|
||||
from src.sort.detection import Detection
|
||||
from src.sort.tracker import Tracker
|
||||
|
||||
def gather_sequence_info(sequence_dir, detection_file):
|
||||
"""Gather sequence information, such as image filenames, detections,
|
||||
groundtruth (if available).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
detection_file : str
|
||||
Path to the detection file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict
|
||||
A dictionary of the following sequence information:
|
||||
|
||||
* sequence_name: Name of the sequence
|
||||
* image_filenames: A dictionary that maps frame indices to image
|
||||
filenames.
|
||||
* detections: A numpy array of detections in MOTChallenge format.
|
||||
* groundtruth: A numpy array of ground truth in MOTChallenge format.
|
||||
* image_size: Image size (height, width).
|
||||
* min_frame_idx: Index of the first frame.
|
||||
* max_frame_idx: Index of the last frame.
|
||||
|
||||
"""
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
groundtruth_file = os.path.join(sequence_dir, "gt/gt.txt")
|
||||
|
||||
detections = None
|
||||
if detection_file is not None:
|
||||
detections = np.load(detection_file)
|
||||
groundtruth = None
|
||||
if os.path.exists(groundtruth_file):
|
||||
groundtruth = np.loadtxt(groundtruth_file, delimiter=',')
|
||||
|
||||
if image_filenames:
|
||||
image = cv2.imread(next(iter(image_filenames.values())),
|
||||
cv2.IMREAD_GRAYSCALE)
|
||||
image_size = image.shape
|
||||
else:
|
||||
image_size = None
|
||||
|
||||
if image_filenames:
|
||||
min_frame_idx = min(image_filenames.keys())
|
||||
max_frame_idx = max(image_filenames.keys())
|
||||
else:
|
||||
min_frame_idx = int(detections[:, 0].min())
|
||||
max_frame_idx = int(detections[:, 0].max())
|
||||
|
||||
info_filename = os.path.join(sequence_dir, "seqinfo.ini")
|
||||
if os.path.exists(info_filename):
|
||||
with open(info_filename, "r") as f:
|
||||
line_splits = [l.split('=') for l in f.read().splitlines()[1:]]
|
||||
info_dict = dict(
|
||||
s for s in line_splits if isinstance(s, list) and len(s) == 2)
|
||||
|
||||
update_ms = 1000 / int(info_dict["frameRate"])
|
||||
else:
|
||||
update_ms = None
|
||||
|
||||
feature_dim = detections.shape[1] - 10 if detections is not None else 0
|
||||
seq_info = {
|
||||
"sequence_name": os.path.basename(sequence_dir),
|
||||
"image_filenames": image_filenames,
|
||||
"detections": detections,
|
||||
"groundtruth": groundtruth,
|
||||
"image_size": image_size,
|
||||
"min_frame_idx": min_frame_idx,
|
||||
"max_frame_idx": max_frame_idx,
|
||||
"feature_dim": feature_dim,
|
||||
"update_ms": update_ms
|
||||
}
|
||||
return seq_info
|
||||
|
||||
|
||||
def create_detections(detection_mat, frame_idx, min_height=0):
|
||||
"""Create detections for given frame index from the raw detection matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_mat : ndarray
|
||||
Matrix of detections. The first 10 columns of the detection matrix are
|
||||
in the standard MOTChallenge detection format. In the remaining columns
|
||||
store the feature vector associated with each detection.
|
||||
frame_idx : int
|
||||
The frame index.
|
||||
min_height : Optional[int]
|
||||
A minimum detection bounding box height. Detections that are smaller
|
||||
than this value are disregarded.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[tracker.Detection]
|
||||
Returns detection responses at given frame index.
|
||||
|
||||
"""
|
||||
frame_indices = detection_mat[:, 0].astype(np.int)
|
||||
mask = frame_indices == frame_idx
|
||||
|
||||
detection_list = []
|
||||
for row in detection_mat[mask]:
|
||||
bbox, confidence, feature = row[2:6], row[6], row[10:]
|
||||
|
||||
if bbox[3] < min_height:
|
||||
continue
|
||||
detection_list.append(Detection(bbox, confidence, feature))
|
||||
return detection_list
|
||||
|
||||
|
||||
def run(sequence_dir, detection_file, output_file, min_confidence,
|
||||
nms_max_overlap, min_detection_height, max_cosine_distance,
|
||||
nn_budget, display):
|
||||
"""Run multi-target tracker on a particular sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
detection_file : str
|
||||
Path to the detections file.
|
||||
output_file : str
|
||||
Path to the tracking output file. This file will contain the tracking
|
||||
results on completion.
|
||||
min_confidence : float
|
||||
Detection confidence threshold. Disregard all detections that have
|
||||
a confidence lower than this value.
|
||||
nms_max_overlap: float
|
||||
Maximum detection overlap (non-maxima suppression threshold).
|
||||
min_detection_height : int
|
||||
Detection height threshold. Disregard all detections that have
|
||||
a height lower than this value.
|
||||
max_cosine_distance : float
|
||||
Gating threshold for cosine distance metric (object appearance).
|
||||
nn_budget : Optional[int]
|
||||
Maximum size of the appearance descriptor gallery. If None, no budget
|
||||
is enforced.
|
||||
display : bool
|
||||
If True, show visualization of intermediate tracking results.
|
||||
|
||||
"""
|
||||
seq_info = gather_sequence_info(sequence_dir, detection_file)
|
||||
metric = nn_matching.NearestNeighborDistanceMetric(
|
||||
"cosine", max_cosine_distance, nn_budget)
|
||||
tracker = Tracker(metric)
|
||||
results = []
|
||||
|
||||
def frame_callback(vis, frame_idx):
|
||||
print("Processing frame %05d" % frame_idx)
|
||||
|
||||
# Load image and generate detections.
|
||||
detections = create_detections(
|
||||
seq_info["detections"], frame_idx, min_detection_height)
|
||||
detections = [d for d in detections if d.confidence >= min_confidence]
|
||||
|
||||
# Run non-maxima suppression.
|
||||
boxes = np.array([d.tlwh for d in detections])
|
||||
|
||||
scores = np.array([d.confidence for d in detections])
|
||||
indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
|
||||
detections = [detections[i] for i in indices]
|
||||
|
||||
# Update tracker.
|
||||
tracker.predict()
|
||||
tracker.update(detections)
|
||||
|
||||
# Update visualization.
|
||||
if display:
|
||||
image = cv2.imread(
|
||||
seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR)
|
||||
vis.set_image(image.copy())
|
||||
vis.draw_detections(detections)
|
||||
vis.draw_trackers(tracker.tracks)
|
||||
|
||||
# Store results.
|
||||
for track in tracker.tracks:
|
||||
if not track.is_confirmed() or track.time_since_update > 1:
|
||||
continue
|
||||
bbox = track.to_tlwh()
|
||||
results.append([
|
||||
frame_idx, track.track_id, bbox[0], bbox[1], bbox[2], bbox[3]])
|
||||
|
||||
# Run tracker.
|
||||
if display:
|
||||
visualizer = visualization.Visualization(seq_info, update_ms=5)
|
||||
else:
|
||||
visualizer = visualization.NoVisualization(seq_info)
|
||||
visualizer.run(frame_callback)
|
||||
|
||||
# Store results.
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for row in results:
|
||||
print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1' % (
|
||||
row[0], row[1], row[2], row[3], row[4], row[5]), file=f)
|
||||
|
||||
|
||||
def bool_string(input_string):
|
||||
if input_string not in {"True", "False"}:
|
||||
raise ValueError("Please Enter a valid Ture/False choice")
|
||||
return input_string == "True"
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Deep SORT")
|
||||
parser.add_argument(
|
||||
"--sequence_dir", help="Path to MOTChallenge sequence directory",
|
||||
default="../MOT16/train/MOT16-02")
|
||||
parser.add_argument(
|
||||
"--detection_file", help="Path to custom detections.", default="./detections/MOT16_POI_train/MOT16-02.npy")
|
||||
parser.add_argument(
|
||||
"--output_file", help="Path to the tracking output file. This file will"
|
||||
" contain the tracking results on completion.",
|
||||
default="./tmp/hypotheses-det.txt")
|
||||
parser.add_argument(
|
||||
"--min_confidence", help="Detection confidence threshold. Disregard "
|
||||
"all detections that have a confidence lower than this value.",
|
||||
default=0.8, type=float)
|
||||
parser.add_argument(
|
||||
"--min_detection_height", help="Threshold on the detection bounding "
|
||||
"box height. Detections with height smaller than this value are "
|
||||
"disregarded", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--nms_max_overlap", help="Non-maxima suppression threshold: Maximum "
|
||||
"detection overlap.", default=1.0, type=float)
|
||||
parser.add_argument(
|
||||
"--max_cosine_distance", help="Gating threshold for cosine distance "
|
||||
"metric (object appearance).", type=float, default=0.2)
|
||||
parser.add_argument(
|
||||
"--nn_budget", help="Maximum size of the appearance descriptors "
|
||||
"gallery. If None, no budget is enforced.", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--display", help="Show intermediate tracking results",
|
||||
default=False, type=bool_string)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
run(
|
||||
args.sequence_dir, args.detection_file, args.output_file,
|
||||
args.min_confidence, args.nms_max_overlap, args.min_detection_height,
|
||||
args.max_cosine_distance, args.nn_budget, args.display)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import deep_sort_app
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="MOTChallenge evaluation")
|
||||
parser.add_argument(
|
||||
"--detection_url", type=str, help="Path to detection files.")
|
||||
parser.add_argument(
|
||||
"--data_url", type=str, help="Path to image data.")
|
||||
parser.add_argument(
|
||||
"--train_url", type=str, help="Path to save result.")
|
||||
parser.add_argument(
|
||||
"--min_confidence", help="Detection confidence threshold. Disregard "
|
||||
"all detections that have a confidence lower than this value.",
|
||||
default=0.0, type=float)
|
||||
parser.add_argument(
|
||||
"--min_detection_height", help="Threshold on the detection bounding "
|
||||
"box height. Detections with height smaller than this value are "
|
||||
"disregarded", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--nms_max_overlap", help="Non-maxima suppression threshold: Maximum "
|
||||
"detection overlap.", default=1.0, type=float)
|
||||
parser.add_argument(
|
||||
"--max_cosine_distance", help="Gating threshold for cosine distance "
|
||||
"metric (object appearance).", type=float, default=0.2)
|
||||
parser.add_argument(
|
||||
"--nn_budget", help="Maximum size of the appearance descriptors "
|
||||
"gallery. If None, no budget is enforced.", type=int, default=100)
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
detection_dir = args.detection_url
|
||||
DATA_DIR = args.data_url + '/'
|
||||
local_result_url = args.train_url
|
||||
|
||||
if not os.path.exists(local_result_url):
|
||||
os.makedirs(local_result_url)
|
||||
sequences = os.listdir(DATA_DIR)
|
||||
for sequence in sequences:
|
||||
print("Running sequence %s" % sequence)
|
||||
sequence_dir = os.path.join(DATA_DIR, sequence)
|
||||
detection_file = os.path.join(detection_dir, "%s.npy" % sequence)
|
||||
output_file = os.path.join(local_result_url, "%s.txt" % sequence)
|
||||
deep_sort_app.run(
|
||||
sequence_dir, detection_file, output_file, args.min_confidence,
|
||||
args.nms_max_overlap, args.min_detection_height,
|
||||
args.max_cosine_distance, args.nn_budget, display=False)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.deep.original_model import Net
|
||||
|
||||
parser = argparse.ArgumentParser(description='Tracking')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--image_height", type=int, default=128, help="Image height.")
|
||||
parser.add_argument("--image_width", type=int, default=64, help="Image width.")
|
||||
parser.add_argument("--file_name", type=str, default="deepsort", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = Net(reid=True, ascend=True)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args_opt.batch_size, 3, args_opt.image_height, args_opt.image_width]), ms.float32)
|
||||
export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import show_results
|
||||
|
||||
def convert(filename_input, filename_output, ffmpeg_executable="ffmpeg"):
|
||||
import subprocess
|
||||
command = [ffmpeg_executable, "-i", filename_input, "-c:v", "libx264",
|
||||
"-preset", "slow", "-crf", "21", filename_output]
|
||||
subprocess.call(command)
|
||||
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Siamese Tracking")
|
||||
parser.add_argument('--data_url', type=str, default='', help='Det directory.')
|
||||
parser.add_argument('--train_url', type=str, help='Folder to store the videos in')
|
||||
parser.add_argument(
|
||||
"--result_dir", help="Path to the folder with tracking output.", default="")
|
||||
parser.add_argument(
|
||||
"--convert_h264", help="If true, convert videos to libx264 (requires "
|
||||
"FFMPEG", default=False)
|
||||
parser.add_argument(
|
||||
"--update_ms", help="Time between consecutive frames in milliseconds. "
|
||||
"Defaults to the frame_rate specified in seqinfo.ini, if available.",
|
||||
default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
data_dir = args.data_url
|
||||
local_train_url = args.train_url
|
||||
result_dir = args.result_dir
|
||||
|
||||
|
||||
os.makedirs(local_train_url, exist_ok=True)
|
||||
for sequence_txt in os.listdir(result_dir):
|
||||
sequence = os.path.splitext(sequence_txt)[0]
|
||||
sequence_dir = os.path.join(data_dir, sequence)
|
||||
if not os.path.exists(sequence_dir):
|
||||
continue
|
||||
result_file = os.path.join(result_dir, sequence_txt)
|
||||
update_ms = args.update_ms
|
||||
video_filename = os.path.join(local_train_url, "%s.avi" % sequence)
|
||||
|
||||
print("Saving %s to %s." % (sequence_txt, video_filename))
|
||||
show_results.run(
|
||||
sequence_dir, result_file, False, None, update_ms, video_filename)
|
||||
|
||||
if not args.convert_h264:
|
||||
import sys
|
||||
sys.exit()
|
||||
for sequence_txt in os.listdir(result_dir):
|
||||
sequence = os.path.splitext(sequence_txt)[0]
|
||||
sequence_dir = os.path.join(data_dir, sequence)
|
||||
if not os.path.exists(sequence_dir):
|
||||
continue
|
||||
filename_in = os.path.join(local_train_url, "%s.avi" % sequence)
|
||||
filename_out = os.path.join(local_train_url, "%s.mp4" % sequence)
|
||||
convert(filename_in, filename_out)
|
|
@ -0,0 +1,259 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import errno
|
||||
import argparse
|
||||
import ast
|
||||
import matplotlib
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from mindspore.train.model import ParallelMode
|
||||
from mindspore.communication.management import init
|
||||
from mindspore import context
|
||||
from src.deep.feature_extractor import Extractor
|
||||
|
||||
matplotlib.use("Agg")
|
||||
ASCEND_SLOG_PRINT_TO_STDOUT = 1
|
||||
|
||||
|
||||
def extract_image_patch(image, bbox, patch_shape=None):
|
||||
"""Extract image patch from bounding box.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : ndarray
|
||||
The full image.
|
||||
bbox : array_like
|
||||
The bounding box in format (x, y, width, height).
|
||||
patch_shape : Optional[array_like]
|
||||
This parameter can be used to enforce a desired patch shape
|
||||
(height, width). First, the `bbox` is adapted to the aspect ratio
|
||||
of the patch shape, then it is clipped at the image boundaries.
|
||||
If None, the shape is computed from :arg:`bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray | NoneType
|
||||
An image patch showing the :arg:`bbox`, optionally reshaped to
|
||||
:arg:`patch_shape`.
|
||||
Returns None if the bounding box is empty or fully outside of the image
|
||||
boundaries.
|
||||
|
||||
"""
|
||||
bbox = np.array(bbox)
|
||||
if patch_shape is not None:
|
||||
# correct aspect ratio to patch shape
|
||||
target_aspect = float(patch_shape[1]) / patch_shape[0]
|
||||
new_width = target_aspect * bbox[3]
|
||||
bbox[0] -= (new_width - bbox[2]) / 2
|
||||
bbox[2] = new_width
|
||||
|
||||
# convert to top left, bottom right
|
||||
bbox[2:] += bbox[:2]
|
||||
bbox = bbox.astype(np.int)
|
||||
|
||||
# clip at image boundaries
|
||||
bbox[:2] = np.maximum(0, bbox[:2])
|
||||
bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:])
|
||||
if np.any(bbox[:2] >= bbox[2:]):
|
||||
return None
|
||||
sx, sy, ex, ey = bbox
|
||||
image = image[sy:ey, sx:ex]
|
||||
return image
|
||||
|
||||
|
||||
class ImageEncoder:
|
||||
|
||||
def __init__(self, model_path, batch_size=32):
|
||||
|
||||
self.extractor = Extractor(model_path, batch_size)
|
||||
|
||||
def _get_features(self, bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
self.height, self.width = ori_img.shape[:2]
|
||||
for box in bbox_xywh:
|
||||
im = extract_image_patch(ori_img, box)
|
||||
if im is None:
|
||||
print("WARNING: Failed to extract image patch: %s." % str(box))
|
||||
im = np.random.uniform(
|
||||
0., 255., ori_img.shape).astype(np.uint8)
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = self.extractor(im_crops)
|
||||
else:
|
||||
features = np.array([])
|
||||
return features
|
||||
|
||||
|
||||
def __call__(self, image, boxes, batch_size=32):
|
||||
features = self._get_features(boxes, image)
|
||||
return features
|
||||
|
||||
|
||||
def create_box_encoder(model_filename, batch_size=32):
|
||||
image_encoder = ImageEncoder(model_filename, batch_size)
|
||||
|
||||
def encoder_box(image, boxes):
|
||||
return image_encoder(image, boxes)
|
||||
|
||||
return encoder_box
|
||||
|
||||
|
||||
def generate_detections(encoder_boxes, mot_dir, output_dir, det_path=None, detection_dir=None):
|
||||
"""Generate detections with features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoder : Callable[image, ndarray] -> ndarray
|
||||
The encoder function takes as input a BGR color image and a matrix of
|
||||
bounding boxes in format `(x, y, w, h)` and returns a matrix of
|
||||
corresponding feature vectors.
|
||||
mot_dir : str
|
||||
Path to the MOTChallenge directory (can be either train or test).
|
||||
output_dir
|
||||
Path to the output directory. Will be created if it does not exist.
|
||||
detection_dir
|
||||
Path to custom detections. The directory structure should be the default
|
||||
MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the
|
||||
standard MOTChallenge detections.
|
||||
|
||||
"""
|
||||
if detection_dir is None:
|
||||
detection_dir = mot_dir
|
||||
try:
|
||||
os.makedirs(output_dir)
|
||||
except OSError as exception:
|
||||
if exception.errno == errno.EEXIST and os.path.isdir(output_dir):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Failed to created output directory '%s'" % output_dir)
|
||||
for sequence in os.listdir(mot_dir):
|
||||
print("Processing %s" % sequence)
|
||||
sequence_dir = os.path.join(mot_dir, sequence)
|
||||
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
#image_dir = os.path.join(mot_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
if det_path:
|
||||
detection_dir = os.path.join(det_path, sequence)
|
||||
else:
|
||||
detection_dir = os.path.join(sequence_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
detections_out = []
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
print("Frame %05d/%05d" % (frame_idx, max_frame_idx))
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
if frame_idx not in image_filenames:
|
||||
print("WARNING could not find image for frame %d" % frame_idx)
|
||||
continue
|
||||
bgr_image = cv2.imread(
|
||||
image_filenames[frame_idx], cv2.IMREAD_COLOR)
|
||||
features = encoder_boxes(bgr_image, rows[:, 2:6].copy())
|
||||
detections_out += [np.r_[(row, feature)]
|
||||
for row, feature in zip(rows, features)]
|
||||
|
||||
output_filename = os.path.join(output_dir, "%s.npy" % sequence)
|
||||
print(output_filename)
|
||||
np.save(
|
||||
output_filename, np.asarray(detections_out), allow_pickle=False)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Re-ID feature extractor")
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval,
|
||||
default=False, help='Run distribute')
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval,
|
||||
default=False, help='Run distribute')
|
||||
parser.add_argument("--device_id", type=int, default=4,
|
||||
help="Use which device.")
|
||||
parser.add_argument('--data_url', type=str,
|
||||
default='', help='Det directory.')
|
||||
parser.add_argument('--train_url', type=str, default='',
|
||||
help='Train output directory.')
|
||||
parser.add_argument('--det_url', type=str, default='',
|
||||
help='Train output directory.')
|
||||
parser.add_argument('--batch_size', type=int,
|
||||
default=32, help='Batach size.')
|
||||
parser.add_argument("--ckpt_url", type=str, default='',
|
||||
help="Path to checkpoint.")
|
||||
parser.add_argument("--model_name", type=str,
|
||||
default="deepsort-30000_24.ckpt", help="Name of checkpoint.")
|
||||
parser.add_argument(
|
||||
"--detection_dir", help="Path to custom detections. Defaults to", default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
args = parse_args()
|
||||
if args.run_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data'
|
||||
local_ckpt_url = '/cache/ckpt'
|
||||
local_train_url = '/cache/train'
|
||||
local_det_url = '/cache/det'
|
||||
mox.file.copy_parallel(args.ckpt_url, local_ckpt_url)
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
mox.file.copy_parallel(args.det_url, local_det_url)
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
DATA_DIR = local_data_url + '/'
|
||||
ckpt_dir = local_ckpt_url + '/'
|
||||
det_dir = local_det_url + '/'
|
||||
else:
|
||||
if args.run_distribute:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
device_num = 1
|
||||
device_id = args.device_id
|
||||
DATA_DIR = args.data_url
|
||||
local_train_url = args.train_url
|
||||
ckpt_dir = args.ckpt_url
|
||||
det_dir = args.det_url
|
||||
|
||||
encoder = create_box_encoder(
|
||||
ckpt_dir+args.model_name, batch_size=args.batch_size)
|
||||
generate_detections(encoder, DATA_DIR, local_train_url, det_path=det_dir)
|
||||
if args.run_modelarts:
|
||||
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser('mindspore deepsort infer')
|
||||
# Path for data
|
||||
parser.add_argument('--det_dir', type=str, default='', help='det directory.')
|
||||
parser.add_argument('--result_dir', type=str, default="./result_Files", help='infer result dir.')
|
||||
parser.add_argument('--output_dir', type=str, default="./", help='output dir.')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
rst_path = args.result_dir
|
||||
start = end = 0
|
||||
|
||||
for sequence in os.listdir(args.det_dir):
|
||||
#sequence_dir = os.path.join(mot_dir, sequence)
|
||||
start = end
|
||||
detection_dir = os.path.join(args.det_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
detections_out = []
|
||||
raws = []
|
||||
features = []
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
for box in rows:
|
||||
raws.append(box)
|
||||
end += 1
|
||||
|
||||
raws = np.array(raws)
|
||||
for i in range(start, end):
|
||||
file_name = os.path.join(rst_path, "DeepSort_data_bs" + str(1) + '_' + str(i) + '_0.bin')
|
||||
output = np.fromfile(file_name, np.float32)
|
||||
features.append(output)
|
||||
features = np.array(features)
|
||||
detections_out += [np.r_[(row, feature)] for row, feature in zip(raws, features)]
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
output_filename = os.path.join(args.output_dir, "%s.npy" % sequence)
|
||||
print(output_filename)
|
||||
np.save(output_filename, np.asarray(detections_out), allow_pickle=False)
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
from shutil import copyfile
|
||||
|
||||
# You only need to change this line to your dataset download path
|
||||
download_path = '../data/Market-1501'
|
||||
|
||||
if not os.path.isdir(download_path):
|
||||
print('please change the download_path')
|
||||
|
||||
save_path = download_path + '/pytorch'
|
||||
if not os.path.isdir(save_path):
|
||||
os.mkdir(save_path)
|
||||
#-----------------------------------------
|
||||
#query
|
||||
query_path = download_path + '/query'
|
||||
query_save_path = download_path + '/pytorch/query'
|
||||
if not os.path.isdir(query_save_path):
|
||||
os.mkdir(query_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(query_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = query_path + '/' + name
|
||||
dst_path = query_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#-----------------------------------------
|
||||
#multi-query
|
||||
query_path = download_path + '/gt_bbox'
|
||||
# for dukemtmc-reid, we do not need multi-query
|
||||
if os.path.isdir(query_path):
|
||||
query_save_path = download_path + '/pytorch/multi-query'
|
||||
if not os.path.isdir(query_save_path):
|
||||
os.mkdir(query_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(query_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = query_path + '/' + name
|
||||
dst_path = query_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#-----------------------------------------
|
||||
#gallery
|
||||
gallery_path = download_path + '/bounding_box_test'
|
||||
gallery_save_path = download_path + '/pytorch/gallery'
|
||||
if not os.path.isdir(gallery_save_path):
|
||||
os.mkdir(gallery_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(gallery_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = gallery_path + '/' + name
|
||||
dst_path = gallery_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
#---------------------------------------
|
||||
#train_all
|
||||
train_path = download_path + '/bounding_box_train'
|
||||
train_save_path = download_path + '/pytorch/train_all'
|
||||
if not os.path.isdir(train_save_path):
|
||||
os.mkdir(train_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(train_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = train_path + '/' + name
|
||||
dst_path = train_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
||||
|
||||
|
||||
#---------------------------------------
|
||||
#train_val
|
||||
train_path = download_path + '/bounding_box_train'
|
||||
train_save_path = download_path + '/pytorch/train'
|
||||
val_save_path = download_path + '/pytorch/val'
|
||||
if not os.path.isdir(train_save_path):
|
||||
os.mkdir(train_save_path)
|
||||
os.mkdir(val_save_path)
|
||||
|
||||
for root, dirs, files in os.walk(train_path, topdown=True):
|
||||
for name in files:
|
||||
if not name[-3:] == 'jpg':
|
||||
continue
|
||||
ID = name.split('_')
|
||||
src_path = train_path + '/' + name
|
||||
dst_path = train_save_path + '/' + ID[0]
|
||||
if not os.path.isdir(dst_path):
|
||||
os.mkdir(dst_path)
|
||||
dst_path = val_save_path + '/' + ID[0] #first image is used as val image
|
||||
os.mkdir(dst_path)
|
||||
copyfile(src_path, dst_path + '/' + name)
|
|
@ -0,0 +1,192 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
matplotlib.use("Agg")
|
||||
ASCEND_SLOG_PRINT_TO_STDOUT = 1
|
||||
|
||||
def extract_image_patch(image, bbox, patch_shape=None):
|
||||
"""Extract image patch from bounding box.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : ndarray
|
||||
The full image.
|
||||
bbox : array_like
|
||||
The bounding box in format (x, y, width, height).
|
||||
patch_shape : Optional[array_like]
|
||||
This parameter can be used to enforce a desired patch shape
|
||||
(height, width). First, the `bbox` is adapted to the aspect ratio
|
||||
of the patch shape, then it is clipped at the image boundaries.
|
||||
If None, the shape is computed from :arg:`bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray | NoneType
|
||||
An image patch showing the :arg:`bbox`, optionally reshaped to
|
||||
:arg:`patch_shape`.
|
||||
Returns None if the bounding box is empty or fully outside of the image
|
||||
boundaries.
|
||||
|
||||
"""
|
||||
bbox = np.array(bbox)
|
||||
if patch_shape is not None:
|
||||
# correct aspect ratio to patch shape
|
||||
target_aspect = float(patch_shape[1]) / patch_shape[0]
|
||||
new_width = target_aspect * bbox[3]
|
||||
bbox[0] -= (new_width - bbox[2]) / 2
|
||||
bbox[2] = new_width
|
||||
|
||||
# convert to top left, bottom right
|
||||
bbox[2:] += bbox[:2]
|
||||
bbox = bbox.astype(np.int)
|
||||
|
||||
# clip at image boundaries
|
||||
bbox[:2] = np.maximum(0, bbox[:2])
|
||||
bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:])
|
||||
if np.any(bbox[:2] >= bbox[2:]):
|
||||
return None
|
||||
sx, sy, ex, ey = bbox
|
||||
|
||||
image = image[sy:ey, sx:ex]
|
||||
return image
|
||||
|
||||
def statistic_normalize_img(img, statistic_norm=True):
|
||||
"""Statistic normalize images."""
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
if statistic_norm:
|
||||
img = (img - mean) / std
|
||||
img = img.astype(np.float32)
|
||||
return img
|
||||
|
||||
def preprocess(im_crops):
|
||||
"""
|
||||
TODO:
|
||||
1. to float with scale from 0 to 1
|
||||
2. resize to (64, 128) as Market1501 dataset did
|
||||
3. concatenate to a numpy array
|
||||
3. to torch Tensor
|
||||
4. normalize
|
||||
"""
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32)/255., size)
|
||||
im_batch = []
|
||||
size = (64, 128)
|
||||
for im in im_crops:
|
||||
im = _resize(im, size)
|
||||
im = statistic_normalize_img(im)
|
||||
im = im.transpose(2, 0, 1).copy()
|
||||
im = np.expand_dims(im, 0)
|
||||
im_batch.append(im)
|
||||
|
||||
im_batch = np.array(im_batch)
|
||||
return im_batch
|
||||
|
||||
|
||||
def get_features(bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
for box in bbox_xywh:
|
||||
im = extract_image_patch(ori_img, box)
|
||||
if im is None:
|
||||
print("WARNING: Failed to extract image patch: %s." % str(box))
|
||||
im = np.random.uniform(
|
||||
0., 255., ori_img.shape).astype(np.uint8)
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = preprocess(im_crops)
|
||||
else:
|
||||
features = np.array([])
|
||||
return features
|
||||
|
||||
def generate_detections(mot_dir, img_path, det_path=None):
|
||||
"""Generate detections with features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoder : Callable[image, ndarray] -> ndarray
|
||||
The encoder function takes as input a BGR color image and a matrix of
|
||||
bounding boxes in format `(x, y, w, h)` and returns a matrix of
|
||||
corresponding feature vectors.
|
||||
mot_dir : str
|
||||
Path to the MOTChallenge directory (can be either train or test).
|
||||
output_dir
|
||||
Path to the output directory. Will be created if it does not exist.
|
||||
detection_dir
|
||||
Path to custom detections. The directory structure should be the default
|
||||
MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the
|
||||
standard MOTChallenge detections.
|
||||
|
||||
"""
|
||||
|
||||
count = 0
|
||||
for sequence in os.listdir(mot_dir):
|
||||
print("Processing %s" % sequence)
|
||||
sequence_dir = os.path.join(mot_dir, sequence)
|
||||
|
||||
image_dir = os.path.join(sequence_dir, "img1")
|
||||
#image_dir = os.path.join(mot_dir, "img1")
|
||||
image_filenames = {
|
||||
int(os.path.splitext(f)[0]): os.path.join(image_dir, f)
|
||||
for f in os.listdir(image_dir)}
|
||||
|
||||
if det_path is not None:
|
||||
detection_dir = os.path.join(det_path, sequence)
|
||||
else:
|
||||
detection_dir = os.path.join(sequence_dir, sequence)
|
||||
detection_file = os.path.join(detection_dir, "det/det.txt")
|
||||
detections_in = np.loadtxt(detection_file, delimiter=',')
|
||||
|
||||
frame_indices = detections_in[:, 0].astype(np.int)
|
||||
min_frame_idx = frame_indices.astype(np.int).min()
|
||||
max_frame_idx = frame_indices.astype(np.int).max()
|
||||
for frame_idx in range(min_frame_idx, max_frame_idx + 1):
|
||||
print("Frame %05d/%05d" % (frame_idx, max_frame_idx))
|
||||
mask = frame_indices == frame_idx
|
||||
rows = detections_in[mask]
|
||||
|
||||
if frame_idx not in image_filenames:
|
||||
print("WARNING could not find image for frame %d" % frame_idx)
|
||||
continue
|
||||
bgr_image = cv2.imread(image_filenames[frame_idx], cv2.IMREAD_COLOR)
|
||||
features = get_features(rows[:, 2:6].copy(), bgr_image)
|
||||
|
||||
for data in features:
|
||||
file_name = "DeepSort_data_bs" + str(1) + "_" + str(count) + ".bin"
|
||||
file_path = img_path + "/" + file_name
|
||||
data.tofile(file_path)
|
||||
count += 1
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Ascend 310 feature extractor")
|
||||
parser.add_argument('--data_path', type=str, default='', help='MOT directory.')
|
||||
parser.add_argument('--det_path', type=str, default='', help='Det directory.')
|
||||
parser.add_argument('--result_path', type=str, default='', help='Inference output directory.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
image_path = os.path.join(args.result_path, "00_data")
|
||||
os.mkdir(image_path)
|
||||
generate_detections(args.data_path, image_path, args.det_path)
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
npy_dir = "" #the npy files provided by author, and the directory name is MOT16_POI_train.
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(npy_dir):
|
||||
for filename in filenames:
|
||||
load_dir = os.path.join(dirpath, filename)
|
||||
loadData = np.load(load_dir)
|
||||
dirname = "./det/" + filename[ : 8] + "/" + "det/"
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
f = open(dirname+"det.txt", 'a')
|
||||
for info in loadData:
|
||||
s = ""
|
||||
for i, num in enumerate(info):
|
||||
if i in (0, 1, 7, 8, 9):
|
||||
s += str(int(num))
|
||||
if i != 9:
|
||||
s += ','
|
||||
elif i < 10:
|
||||
s += str(num)
|
||||
s += ','
|
||||
else:
|
||||
break
|
||||
#print(s)
|
||||
f.write(s)
|
||||
f.write('\n')
|
|
@ -0,0 +1,5 @@
|
|||
cv2
|
||||
mindspore
|
||||
numpy
|
||||
matplotlib
|
||||
shutil
|
|
@ -0,0 +1,86 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [train_code_path][RANK_TABLE_FILE][DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
train_code_path=$(get_real_path $1)
|
||||
echo $train_code_path
|
||||
|
||||
if [ ! -d $train_code_path ]
|
||||
then
|
||||
echo "error: train_code_path=$train_code_path is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RANK_TABLE_FILE=$(get_real_path $2)
|
||||
echo $RANK_TABLE_FILE
|
||||
|
||||
if [ ! -f $RANK_TABLE_FILE ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATA_PATH=$(get_real_path $3)
|
||||
echo $DATA_PATH
|
||||
|
||||
if [ ! -d $DATA_PATH ]
|
||||
then
|
||||
echo "error: DATA_PATH=$DATA_PATH is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -c unlimited
|
||||
export SLOG_PRINT_TO_STDOUT=0
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_FILE
|
||||
export RANK_SIZE=8
|
||||
export RANK_START_ID=0
|
||||
|
||||
|
||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||
do
|
||||
export RANK_ID=${i}
|
||||
export DEVICE_ID=$((i + RANK_START_ID))
|
||||
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
||||
if [ -d ${train_code_path}/device${DEVICE_ID} ]; then
|
||||
rm -rf ${train_code_path}/device${DEVICE_ID}
|
||||
fi
|
||||
mkdir ${train_code_path}/device${DEVICE_ID}
|
||||
cd ${train_code_path}/device${DEVICE_ID} || exit
|
||||
python ${train_code_path}/deep_sort/deep/train.py --data_url=${DATA_PATH} \
|
||||
--train_url=./checkpoint \
|
||||
--run_distribute=True \
|
||||
--run_modelarts=False > out.log 2>&1 &
|
||||
done
|
|
@ -0,0 +1,126 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
DEVICE_TARGET must choose from ['GPU', 'CPU', 'Ascend']
|
||||
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
model=$(get_real_path $1)
|
||||
dataset_path=$(get_real_path $2)
|
||||
det_path=$(get_real_path $3)
|
||||
|
||||
if [ "$4" == "y" ] || [ "$4" == "n" ];then
|
||||
need_preprocess=$4
|
||||
else
|
||||
echo "weather need preprocess or not, it's value must be in [y, n]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$dataset_path
|
||||
echo "det path: "$det_path
|
||||
echo "need preprocess: "$need_preprocess
|
||||
echo "device id: "$device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function preprocess_data()
|
||||
{
|
||||
if [ -d preprocess_Result ]; then
|
||||
rm -rf ./preprocess_Result
|
||||
fi
|
||||
mkdir preprocess_Result
|
||||
python3.7 ../preprocess.py --data_path=$dataset_path --det_path=$det_path --result_path=./preprocess_Result/ &>preprocess.log
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
cd ../ascend310_infer || exit
|
||||
bash build.sh &> build.log
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
cd - || exit
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
|
||||
../ascend310_infer/out/main --mindir_path=$model --input0_path=./preprocess_Result/00_data --device_id=$device_id &> infer.log
|
||||
|
||||
}
|
||||
|
||||
function generater_detection()
|
||||
{
|
||||
python3.7 ../postprocess.py --det_dir=$det_path --result_dir=./result_Files --output_dir=../detections/ &> detection.log
|
||||
}
|
||||
|
||||
if [ $need_preprocess == "y" ]; then
|
||||
preprocess_data
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "preprocess dataset failed"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
generater_detection
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "generator detection failed"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import deep_sort_app
|
||||
|
||||
from src.sort.iou_matching import iou
|
||||
from src.application_util import visualization
|
||||
|
||||
|
||||
DEFAULT_UPDATE_MS = 20
|
||||
|
||||
|
||||
def run(sequence_dir, result_file, show_false_alarms=False, detection_file=None,
|
||||
update_ms=None, video_filename=None):
|
||||
"""Run tracking result visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence_dir : str
|
||||
Path to the MOTChallenge sequence directory.
|
||||
result_file : str
|
||||
Path to the tracking output file in MOTChallenge ground truth format.
|
||||
show_false_alarms : Optional[bool]
|
||||
If True, false alarms are highlighted as red boxes.
|
||||
detection_file : Optional[str]
|
||||
Path to the detection file.
|
||||
update_ms : Optional[int]
|
||||
Number of milliseconds between cosecutive frames. Defaults to (a) the
|
||||
frame rate specified in the seqinfo.ini file or DEFAULT_UDPATE_MS ms if
|
||||
seqinfo.ini is not available.
|
||||
video_filename : Optional[Str]
|
||||
If not None, a video of the tracking results is written to this file.
|
||||
|
||||
"""
|
||||
seq_info = deep_sort_app.gather_sequence_info(sequence_dir, detection_file)
|
||||
results = np.loadtxt(result_file, delimiter=',')
|
||||
|
||||
if show_false_alarms and seq_info["groundtruth"] is None:
|
||||
raise ValueError("No groundtruth available. Cannot show false alarms.")
|
||||
|
||||
def frame_callback(vis, frame_idx):
|
||||
print("Frame idx", frame_idx)
|
||||
image = cv2.imread(
|
||||
seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR)
|
||||
|
||||
vis.set_image(image.copy())
|
||||
|
||||
if seq_info["detections"] is not None:
|
||||
detections = deep_sort_app.create_detections(
|
||||
seq_info["detections"], frame_idx)
|
||||
vis.draw_detections(detections)
|
||||
|
||||
mask = results[:, 0].astype(np.int) == frame_idx
|
||||
track_ids = results[mask, 1].astype(np.int)
|
||||
boxes = results[mask, 2:6]
|
||||
vis.draw_groundtruth(track_ids, boxes)
|
||||
|
||||
if show_false_alarms:
|
||||
groundtruth = seq_info["groundtruth"]
|
||||
mask = groundtruth[:, 0].astype(np.int) == frame_idx
|
||||
gt_boxes = groundtruth[mask, 2:6]
|
||||
for box in boxes:
|
||||
# NOTE(nwojke): This is not strictly correct, because we don't
|
||||
# solve the assignment problem here.
|
||||
min_iou_overlap = 0.5
|
||||
if iou(box, gt_boxes).max() < min_iou_overlap:
|
||||
vis.viewer.color = 0, 0, 255
|
||||
vis.viewer.thickness = 4
|
||||
vis.viewer.rectangle(*box.astype(np.int))
|
||||
|
||||
if update_ms is None:
|
||||
update_ms = seq_info["update_ms"]
|
||||
if update_ms is None:
|
||||
update_ms = DEFAULT_UPDATE_MS
|
||||
visualizer = visualization.Visualization(seq_info, update_ms)
|
||||
if video_filename is not None:
|
||||
visualizer.viewer.enable_videowriter(video_filename)
|
||||
visualizer.run(frame_callback)
|
||||
|
||||
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Siamese Tracking")
|
||||
parser.add_argument(
|
||||
"--sequence_dir", help="Path to the MOTChallenge sequence directory.",
|
||||
default="../MOT16/train")
|
||||
parser.add_argument(
|
||||
"--result_file", help="Tracking output in MOTChallenge file format.",
|
||||
default="./results/MOT16-01.txt")
|
||||
parser.add_argument(
|
||||
"--detection_file", help="Path to custom detections (optional).",
|
||||
default="../resources/detections/MOT16_POI_test/MOT16-01.npy")
|
||||
parser.add_argument(
|
||||
"--update_ms", help="Time between consecutive frames in milliseconds. "
|
||||
"Defaults to the frame_rate specified in seqinfo.ini, if available.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--output_file", help="Filename of the (optional) output video.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--show_false_alarms", help="Show false alarms as red bounding boxes.",
|
||||
type=bool, default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
run(
|
||||
args.sequence_dir, args.result_file, args.show_false_alarms,
|
||||
args.detection_file, args.update_ms, args.output_file)
|
|
@ -0,0 +1,356 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
This module contains an image viewer and drawing routines based on OpenCV.
|
||||
"""
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def is_in_bounds(mat, roi):
|
||||
"""Check if ROI is fully contained in the image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mat : ndarray
|
||||
An ndarray of ndim>=2.
|
||||
roi : (int, int, int, int)
|
||||
Region of interest (x, y, width, height) where (x, y) is the top-left
|
||||
corner.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Returns true if the ROI is contain in mat.
|
||||
|
||||
"""
|
||||
if roi[0] < 0 or roi[0] + roi[2] >= mat.shape[1]:
|
||||
return False
|
||||
if roi[1] < 0 or roi[1] + roi[3] >= mat.shape[0]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def view_roi(mat, roi):
|
||||
"""Get sub-array.
|
||||
|
||||
The ROI must be valid, i.e., fully contained in the image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mat : ndarray
|
||||
An ndarray of ndim=2 or ndim=3.
|
||||
roi : (int, int, int, int)
|
||||
Region of interest (x, y, width, height) where (x, y) is the top-left
|
||||
corner.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A view of the roi.
|
||||
|
||||
"""
|
||||
sx, ex = roi[0], roi[0] + roi[2]
|
||||
sy, ey = roi[1], roi[1] + roi[3]
|
||||
if mat.ndim == 2:
|
||||
return mat[sy:ey, sx:ex]
|
||||
return mat[sy:ey, sx:ex, :]
|
||||
|
||||
|
||||
class ImageViewer:
|
||||
"""An image viewer with drawing routines and video capture capabilities.
|
||||
|
||||
Key Bindings:
|
||||
|
||||
* 'SPACE' : pause
|
||||
* 'ESC' : quit
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update_ms : int
|
||||
Number of milliseconds between frames (1000 / frames per second).
|
||||
window_shape : (int, int)
|
||||
Shape of the window (width, height).
|
||||
caption : Optional[str]
|
||||
Title of the window.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
image : ndarray
|
||||
Color image of shape (height, width, 3). You may directly manipulate
|
||||
this image to change the view. Otherwise, you may call any of the
|
||||
drawing routines of this class. Internally, the image is treated as
|
||||
being in BGR color space.
|
||||
|
||||
Note that the image is resized to the the image viewers window_shape
|
||||
just prior to visualization. Therefore, you may pass differently sized
|
||||
images and call drawing routines with the appropriate, original point
|
||||
coordinates.
|
||||
color : (int, int, int)
|
||||
Current BGR color code that applies to all drawing routines.
|
||||
Values are in range [0-255].
|
||||
text_color : (int, int, int)
|
||||
Current BGR text color code that applies to all text rendering
|
||||
routines. Values are in range [0-255].
|
||||
thickness : int
|
||||
Stroke width in pixels that applies to all drawing routines.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, update_ms, window_shape=(640, 480), caption="Figure 1"):
|
||||
self._window_shape = window_shape
|
||||
self._caption = caption
|
||||
self._update_ms = update_ms
|
||||
self._video_writer = None
|
||||
self._user_fun = lambda: None
|
||||
self._terminate = False
|
||||
|
||||
self.image = np.zeros(self._window_shape+(3,), dtype=np.uint8)
|
||||
self._color = (0, 0, 0)
|
||||
self.text_color = (255, 255, 255)
|
||||
self.thickness = 1
|
||||
|
||||
@property
|
||||
def color(self):
|
||||
return self._color
|
||||
|
||||
@color.setter
|
||||
def color(self, value):
|
||||
if len(value) != 3:
|
||||
raise ValueError("color must be tuple of 3")
|
||||
self._color = tuple(int(c) for c in value)
|
||||
|
||||
def rectangle(self, x, y, w, h, label=None):
|
||||
"""Draw a rectangle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float | int
|
||||
Top left corner of the rectangle (x-axis).
|
||||
y : float | int
|
||||
Top let corner of the rectangle (y-axis).
|
||||
w : float | int
|
||||
Width of the rectangle.
|
||||
h : float | int
|
||||
Height of the rectangle.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the top left corner of the
|
||||
rectangle.
|
||||
|
||||
"""
|
||||
pt1 = int(x), int(y)
|
||||
pt2 = int(x + w), int(y + h)
|
||||
cv2.rectangle(self.image, pt1, pt2, self._color, self.thickness)
|
||||
if label is not None:
|
||||
text_size = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_PLAIN, 1, self.thickness)
|
||||
|
||||
center = pt1[0] + 5, pt1[1] + 5 + text_size[0][1]
|
||||
pt2 = pt1[0] + 10 + text_size[0][0], pt1[1] + 10 + \
|
||||
text_size[0][1]
|
||||
cv2.rectangle(self.image, pt1, pt2, self._color, -1)
|
||||
cv2.putText(self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
1, (255, 255, 255), self.thickness)
|
||||
|
||||
def circle(self, x, y, radius, label=None):
|
||||
"""Draw a circle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float | int
|
||||
Center of the circle (x-axis).
|
||||
y : float | int
|
||||
Center of the circle (y-axis).
|
||||
radius : float | int
|
||||
Radius of the circle in pixels.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the center of the circle.
|
||||
|
||||
"""
|
||||
image_size = int(radius + self.thickness + 1.5) # actually half size
|
||||
roi = int(x - image_size), int(y - image_size), \
|
||||
int(2 * image_size), int(2 * image_size)
|
||||
if not is_in_bounds(self.image, roi):
|
||||
return
|
||||
|
||||
image = view_roi(self.image, roi)
|
||||
center = image.shape[1] // 2, image.shape[0] // 2
|
||||
cv2.circle(
|
||||
image, center, int(radius + .5), self._color, self.thickness)
|
||||
if label is not None:
|
||||
cv2.putText(
|
||||
self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def gaussian(self, mean, covariance, label=None):
|
||||
"""Draw 95% confidence ellipse of a 2-D Gaussian distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : array_like
|
||||
The mean vector of the Gaussian distribution (ndim=1).
|
||||
covariance : array_like
|
||||
The 2x2 covariance matrix of the Gaussian distribution.
|
||||
label : Optional[str]
|
||||
A text label that is placed at the center of the ellipse.
|
||||
|
||||
"""
|
||||
# chi2inv(0.95, 2) = 5.9915
|
||||
vals, vecs = np.linalg.eigh(5.9915 * covariance)
|
||||
indices = vals.argsort()[::-1]
|
||||
vals, vecs = np.sqrt(vals[indices]), vecs[:, indices]
|
||||
|
||||
center = int(mean[0] + .5), int(mean[1] + .5)
|
||||
axes = int(vals[0] + .5), int(vals[1] + .5)
|
||||
angle = int(180. * np.arctan2(vecs[1, 0], vecs[0, 0]) / np.pi)
|
||||
cv2.ellipse(
|
||||
self.image, center, axes, angle, 0, 360, self._color, 2)
|
||||
if label is not None:
|
||||
cv2.putText(self.image, label, center, cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def annotate(self, x, y, text):
|
||||
"""Draws a text string at a given location.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : int | float
|
||||
Bottom-left corner of the text in the image (x-axis).
|
||||
y : int | float
|
||||
Bottom-left corner of the text in the image (y-axis).
|
||||
text : str
|
||||
The text to be drawn.
|
||||
|
||||
"""
|
||||
cv2.putText(self.image, text, (int(x), int(y)), cv2.FONT_HERSHEY_PLAIN,
|
||||
2, self.text_color, 2)
|
||||
|
||||
def colored_points(self, points, colors=None, skip_index_check=False):
|
||||
"""Draw a collection of points.
|
||||
|
||||
The point size is fixed to 1.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
points : ndarray
|
||||
The Nx2 array of image locations, where the first dimension is
|
||||
the x-coordinate and the second dimension is the y-coordinate.
|
||||
colors : Optional[ndarray]
|
||||
The Nx3 array of colors (dtype=np.uint8). If None, the current
|
||||
color attribute is used.
|
||||
skip_index_check : Optional[bool]
|
||||
If True, index range checks are skipped. This is faster, but
|
||||
requires all points to lie within the image dimensions.
|
||||
|
||||
"""
|
||||
if not skip_index_check:
|
||||
cond1, cond2 = points[:, 0] >= 0, points[:, 0] < 480
|
||||
cond3, cond4 = points[:, 1] >= 0, points[:, 1] < 640
|
||||
indices = np.logical_and.reduce((cond1, cond2, cond3, cond4))
|
||||
points = points[indices, :]
|
||||
if colors is None:
|
||||
colors = np.repeat(
|
||||
self._color, len(points)).reshape(3, len(points)).T
|
||||
indices = (points + .5).astype(np.int)
|
||||
self.image[indices[:, 1], indices[:, 0], :] = colors
|
||||
|
||||
def enable_videowriter(self, output_filename, fourcc_string="MJPG",
|
||||
fps=None):
|
||||
""" Write images to video file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_filename : str
|
||||
Output filename.
|
||||
fourcc_string : str
|
||||
The OpenCV FOURCC code that defines the video codec (check OpenCV
|
||||
documentation for more information).
|
||||
fps : Optional[float]
|
||||
Frames per second. If None, configured according to current
|
||||
parameters.
|
||||
|
||||
"""
|
||||
fourcc = cv2.VideoWriter_fourcc(*fourcc_string)
|
||||
if fps is None:
|
||||
fps = int(1000. / self._update_ms)
|
||||
self._video_writer = cv2.VideoWriter(
|
||||
output_filename, fourcc, fps, self._window_shape)
|
||||
|
||||
def disable_videowriter(self):
|
||||
""" Disable writing videos.
|
||||
"""
|
||||
self._video_writer = None
|
||||
|
||||
def run(self, update_fun=None):
|
||||
"""Start the image viewer.
|
||||
|
||||
This method blocks until the user requests to close the window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update_fun : Optional[Callable[] -> None]
|
||||
An optional callable that is invoked at each frame. May be used
|
||||
to play an animation/a video sequence.
|
||||
|
||||
"""
|
||||
if update_fun is not None:
|
||||
self._user_fun = update_fun
|
||||
|
||||
self._terminate, is_paused = False, False
|
||||
# print("ImageViewer is paused, press space to start.")
|
||||
while not self._terminate:
|
||||
t0 = time.time()
|
||||
if not is_paused:
|
||||
self._terminate = not self._user_fun()
|
||||
if self._video_writer is not None:
|
||||
self._video_writer.write(
|
||||
cv2.resize(self.image, self._window_shape))
|
||||
t1 = time.time()
|
||||
remaining_time = max(1, int(self._update_ms - 1e3*(t1-t0)))
|
||||
cv2.imshow(
|
||||
self._caption, cv2.resize(self.image, self._window_shape[:2]))
|
||||
key = cv2.waitKey(remaining_time)
|
||||
if key & 255 == 27: # ESC
|
||||
print("terminating")
|
||||
self._terminate = True
|
||||
elif key & 255 == 32: # ' '
|
||||
print("toggeling pause: " + str(not is_paused))
|
||||
is_paused = not is_paused
|
||||
elif key & 255 == 115: # 's'
|
||||
print("stepping")
|
||||
self._terminate = not self._user_fun()
|
||||
is_paused = True
|
||||
|
||||
# Due to a bug in OpenCV we must call imshow after destroying the
|
||||
# window. This will make the window appear again as soon as waitKey
|
||||
# is called.
|
||||
#
|
||||
# see https://github.com/Itseez/opencv/issues/4535
|
||||
self.image[:] = 0
|
||||
cv2.destroyWindow(self._caption)
|
||||
cv2.waitKey(1)
|
||||
cv2.imshow(self._caption, self.image)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the control loop.
|
||||
|
||||
After calling this method, the viewer will stop execution before the
|
||||
next frame and hand over control flow to the user.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
"""
|
||||
self._terminate = True
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
def non_max_suppression(boxes, max_bbox_overlap, scores=None):
|
||||
"""Suppress overlapping detections.
|
||||
|
||||
Original code from [1]_ has been adapted to include confidence score.
|
||||
|
||||
.. [1] http://www.pyimagesearch.com/2015/02/16/
|
||||
faster-non-maximum-suppression-python/
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> boxes = [d.roi for d in detections]
|
||||
>>> scores = [d.confidence for d in detections]
|
||||
>>> indices = non_max_suppression(boxes, max_bbox_overlap, scores)
|
||||
>>> detections = [detections[i] for i in indices]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
boxes : ndarray
|
||||
Array of ROIs (x, y, width, height).
|
||||
max_bbox_overlap : float
|
||||
ROIs that overlap more than this values are suppressed.
|
||||
scores : Optional[array_like]
|
||||
Detector confidence score.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[int]
|
||||
Returns indices of detections that have survived non-maxima suppression.
|
||||
|
||||
"""
|
||||
if np.size(boxes) == 0:
|
||||
return []
|
||||
|
||||
boxes = boxes.astype(np.float)
|
||||
pick = []
|
||||
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2] + boxes[:, 0]
|
||||
y2 = boxes[:, 3] + boxes[:, 1]
|
||||
|
||||
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
if scores is not None:
|
||||
idxs = np.argsort(scores)
|
||||
else:
|
||||
idxs = np.argsort(y2)
|
||||
|
||||
while np.size(idxs) > 0:
|
||||
last = len(idxs) - 1
|
||||
i = idxs[last]
|
||||
pick.append(i)
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[idxs[:last]])
|
||||
yy1 = np.maximum(y1[i], y1[idxs[:last]])
|
||||
xx2 = np.minimum(x2[i], x2[idxs[:last]])
|
||||
yy2 = np.minimum(y2[i], y2[idxs[:last]])
|
||||
|
||||
w = np.maximum(0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0, yy2 - yy1 + 1)
|
||||
|
||||
overlap = (w * h) / area[idxs[:last]]
|
||||
|
||||
idxs = np.delete(
|
||||
idxs, np.concatenate(
|
||||
([last], np.where(overlap > max_bbox_overlap)[0])))
|
||||
|
||||
return pick
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import colorsys
|
||||
import numpy as np
|
||||
from .image_viewer import ImageViewer
|
||||
|
||||
|
||||
def create_unique_color_float(tag, hue_step=0.41):
|
||||
"""Create a unique RGB color code for a given track id (tag).
|
||||
|
||||
The color code is generated in HSV color space by moving along the
|
||||
hue angle and gradually changing the saturation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag : int
|
||||
The unique target identifying tag.
|
||||
hue_step : float
|
||||
Difference between two neighboring color codes in HSV space (more
|
||||
specifically, the distance in hue channel).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(float, float, float)
|
||||
RGB color code in range [0, 1]
|
||||
|
||||
"""
|
||||
h, v = (tag * hue_step) % 1, 1. - (int(tag * hue_step) % 4) / 5.
|
||||
r, g, b = colorsys.hsv_to_rgb(h, 1., v)
|
||||
return r, g, b
|
||||
|
||||
|
||||
def create_unique_color_uchar(tag, hue_step=0.41):
|
||||
"""Create a unique RGB color code for a given track id (tag).
|
||||
|
||||
The color code is generated in HSV color space by moving along the
|
||||
hue angle and gradually changing the saturation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag : int
|
||||
The unique target identifying tag.
|
||||
hue_step : float
|
||||
Difference between two neighboring color codes in HSV space (more
|
||||
specifically, the distance in hue channel).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(int, int, int)
|
||||
RGB color code in range [0, 255]
|
||||
|
||||
"""
|
||||
r, g, b = create_unique_color_float(tag, hue_step)
|
||||
return int(255*r), int(255*g), int(255*b)
|
||||
|
||||
|
||||
class NoVisualization:
|
||||
"""
|
||||
A dummy visualization object that loops through all frames in a given
|
||||
sequence to update the tracker without performing any visualization.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_info):
|
||||
self.frame_idx = seq_info["min_frame_idx"]
|
||||
self.last_idx = seq_info["max_frame_idx"]
|
||||
|
||||
def set_image(self, image):
|
||||
pass
|
||||
|
||||
def draw_groundtruth(self, track_ids, boxes):
|
||||
pass
|
||||
|
||||
def draw_detections(self, detections):
|
||||
pass
|
||||
|
||||
def draw_trackers(self, trackers):
|
||||
pass
|
||||
|
||||
def run(self, frame_callback):
|
||||
while self.frame_idx <= self.last_idx:
|
||||
frame_callback(self, self.frame_idx)
|
||||
self.frame_idx += 1
|
||||
|
||||
|
||||
class Visualization:
|
||||
"""
|
||||
This class shows tracking output in an OpenCV image viewer.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_info, update_ms):
|
||||
image_shape = seq_info["image_size"][::-1]
|
||||
aspect_ratio = float(image_shape[1]) / image_shape[0]
|
||||
image_shape = 1024, int(aspect_ratio * 1024)
|
||||
self.viewer = ImageViewer(
|
||||
update_ms, image_shape, "Figure %s" % seq_info["sequence_name"])
|
||||
self.viewer.thickness = 2
|
||||
self.frame_idx = seq_info["min_frame_idx"]
|
||||
self.last_idx = seq_info["max_frame_idx"]
|
||||
|
||||
def run(self, frame_callback):
|
||||
self.viewer.run(lambda: self._update_fun(frame_callback))
|
||||
|
||||
def _update_fun(self, frame_callback):
|
||||
if self.frame_idx > self.last_idx:
|
||||
return False # Terminate
|
||||
frame_callback(self, self.frame_idx)
|
||||
self.frame_idx += 1
|
||||
return True
|
||||
|
||||
def set_image(self, image):
|
||||
self.viewer.image = image
|
||||
|
||||
def draw_groundtruth(self, track_ids, boxes):
|
||||
self.viewer.thickness = 2
|
||||
for track_id, box in zip(track_ids, boxes):
|
||||
self.viewer.color = create_unique_color_uchar(track_id)
|
||||
self.viewer.rectangle(*box.astype(np.int), label=str(track_id))
|
||||
|
||||
def draw_detections(self, detections):
|
||||
self.viewer.thickness = 2
|
||||
self.viewer.color = 0, 0, 255
|
||||
for detection in detections:
|
||||
self.viewer.rectangle(*detection.tlwh)
|
||||
|
||||
def draw_trackers(self, tracks):
|
||||
self.viewer.thickness = 2
|
||||
for track in tracks:
|
||||
if not track.is_confirmed() or track.time_since_update > 0:
|
||||
continue
|
||||
self.viewer.color = create_unique_color_uchar(track.track_id)
|
||||
self.viewer.rectangle(
|
||||
*track.to_tlwh().astype(np.int), label=str(track.track_id))
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import cv2
|
||||
import mindspore
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from .original_model import Net
|
||||
|
||||
class Extractor:
|
||||
def __init__(self, model_path, batch_size=32):
|
||||
self.net = Net(reid=True)
|
||||
self.batch_size = batch_size
|
||||
param_dict = load_checkpoint(model_path)
|
||||
load_param_into_net(self.net, param_dict)
|
||||
self.size = (64, 128)
|
||||
|
||||
def statistic_normalize_img(self, img, statistic_norm=True):
|
||||
"""Statistic normalize images."""
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
if statistic_norm:
|
||||
img = (img - mean) / std
|
||||
img = img.astype(np.float32)
|
||||
return img
|
||||
|
||||
def _preprocess(self, im_crops):
|
||||
"""
|
||||
TODO:
|
||||
1. to float with scale from 0 to 1
|
||||
2. resize to (64, 128) as Market1501 dataset did
|
||||
3. concatenate to a numpy array
|
||||
3. to torch Tensor
|
||||
4. normalize
|
||||
"""
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32)/255., size)
|
||||
im_batch = []
|
||||
for im in im_crops:
|
||||
im = _resize(im, self.size)
|
||||
im = self.statistic_normalize_img(im)
|
||||
im = mindspore.Tensor.from_numpy(im.transpose(2, 0, 1).copy())
|
||||
im = mindspore.ops.ExpandDims()(im, 0)
|
||||
im_batch.append(im)
|
||||
|
||||
im_batch = mindspore.ops.Concat(axis=0)(tuple(im_batch))
|
||||
return im_batch
|
||||
|
||||
|
||||
def __call__(self, im_crops):
|
||||
out = np.zeros((len(im_crops), 128), np.float32)
|
||||
num_batches = int(len(im_crops)/self.batch_size)
|
||||
s, e = 0, 0
|
||||
for i in range(num_batches):
|
||||
s, e = i * self.batch_size, (i + 1) * self.batch_size
|
||||
im_batch = self._preprocess(im_crops[s:e])
|
||||
feature = self.net(im_batch)
|
||||
out[s:e] = feature.asnumpy()
|
||||
if e < len(out):
|
||||
im_batch = self._preprocess(im_crops[e:])
|
||||
feature = self.net(im_batch)
|
||||
out[e:] = feature.asnumpy()
|
||||
return out
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as F
|
||||
|
||||
class BasicBlock(nn.Cell):
|
||||
def __init__(self, c_in, c_out, is_downsample=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.add = mindspore.ops.Add()
|
||||
self.ReLU = F.ReLU()
|
||||
self.is_downsample = is_downsample
|
||||
if is_downsample:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, pad_mode='pad', padding=1,\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, pad_mode='same',\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
self.bn1 = nn.BatchNorm2d(c_out, momentum=0.9)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1, pad_mode='pad', padding=1,\
|
||||
has_bias=False, weight_init='HeUniform')
|
||||
self.bn2 = nn.BatchNorm2d(c_out, momentum=0.9)
|
||||
if is_downsample:
|
||||
self.downsample = nn.SequentialCell(
|
||||
[nn.Conv2d(c_in, c_out, 1, stride=2, pad_mode='same', has_bias=False, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(c_out, momentum=0.9)]
|
||||
)
|
||||
elif c_in != c_out:
|
||||
self.downsample = nn.SequentialCell(
|
||||
[nn.Conv2d(c_in, c_out, 1, stride=1, pad_mode='pad', has_bias=False, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(c_out, momentum=0.9)]
|
||||
)
|
||||
self.is_downsample = True
|
||||
def construct(self, x):
|
||||
y = self.conv1(x)
|
||||
y = self.bn1(y)
|
||||
y = self.relu(y)
|
||||
y = self.conv2(y)
|
||||
y = self.bn2(y)
|
||||
if self.is_downsample:
|
||||
x = self.downsample(x)
|
||||
y = self.add(x, y)
|
||||
y = self.ReLU(y)
|
||||
return y
|
||||
|
||||
def make_layers(c_in, c_out, repeat_times, is_downsample=False):
|
||||
blocks = []
|
||||
for i in range(repeat_times):
|
||||
if i == 0:
|
||||
blocks.append(BasicBlock(c_in, c_out, is_downsample=is_downsample))
|
||||
else:
|
||||
blocks.append(BasicBlock(c_out, c_out))
|
||||
return nn.SequentialCell(blocks)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, num_classes=751, reid=False, ascend=False):
|
||||
super(Net, self).__init__()
|
||||
# 3 128 64
|
||||
self.conv = nn.SequentialCell(
|
||||
[nn.Conv2d(3, 32, 3, stride=1, pad_mode='same', has_bias=True, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(32, momentum=0.9),
|
||||
nn.ELU(),
|
||||
nn.Conv2d(32, 32, 3, stride=1, pad_mode='same', has_bias=True, weight_init='HeUniform'),
|
||||
nn.BatchNorm2d(32, momentum=0.9),
|
||||
nn.ELU(),
|
||||
nn.MaxPool2d(3, 2, pad_mode='same')]
|
||||
)
|
||||
#]
|
||||
# 32 64 32
|
||||
self.layer1 = make_layers(32, 32, 2, False)
|
||||
# 32 64 32
|
||||
self.layer2 = make_layers(32, 64, 2, True)
|
||||
# 64 32 16
|
||||
self.layer3 = make_layers(64, 128, 2, True)
|
||||
# 128 16 8
|
||||
self.dp = nn.Dropout(keep_prob=0.6)
|
||||
self.dense = nn.Dense(128*16*8, 128)
|
||||
self.bn1 = nn.BatchNorm1d(128, momentum=0.9)
|
||||
self.elu = nn.ELU()
|
||||
# 256 1 1
|
||||
self.reid = reid
|
||||
self.ascend = ascend
|
||||
#self.flatten = nn.Flatten()
|
||||
self.div = F.Div()
|
||||
self.batch_norm = nn.BatchNorm1d(128, momentum=0.9)
|
||||
self.classifier = nn.Dense(128, num_classes)
|
||||
self.Norm = nn.Norm(axis=0, keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
#x = self.flatten(x)
|
||||
x = x.view((x.shape[0], -1))
|
||||
if self.reid:
|
||||
x = self.dp(x)
|
||||
x = self.dense(x)
|
||||
if self.ascend:
|
||||
x = self.bn1(x)
|
||||
else:
|
||||
f = self.Norm(x)
|
||||
x = self.div(x, f)
|
||||
return x
|
||||
x = self.dp(x)
|
||||
x = self.dense(x)
|
||||
x = self.bn1(x)
|
||||
x = self.elu(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.nn as nn
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.common import set_seed
|
||||
from original_model import Net
|
||||
set_seed(1234)
|
||||
def parse_args():
|
||||
""" Parse command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument("--epoch", help="Path to custom detections.", type=int, default=100)
|
||||
parser.add_argument("--batch_size", help="Batch size for Training.", type=int, default=8)
|
||||
parser.add_argument("--num_parallel_workers", help="The number of parallel workers.", type=int, default=16)
|
||||
parser.add_argument("--pre_train", help='The ckpt file of model.', type=str, default=None)
|
||||
parser.add_argument("--save_check_point", help="Whether save the training resulting.", type=bool, default=True)
|
||||
|
||||
#learning rate
|
||||
parser.add_argument("--learning_rate", help="Learning rate.", type=float, default=0.1)
|
||||
parser.add_argument("--decay_epoch", help="decay epochs.", type=int, default=20)
|
||||
parser.add_argument('--gamma', type=float, default=0.10, help='learning rate decay.')
|
||||
parser.add_argument("--momentum", help="", type=float, default=0.9)
|
||||
|
||||
#run on where
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: 0)')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=True, help='Run distribute')
|
||||
|
||||
return parser.parse_args()
|
||||
def get_lr(base_lr, total_epochs, steps_per_epoch, step_size, gamma):
|
||||
lr_each_step = []
|
||||
for i in range(1, total_epochs+1):
|
||||
if i % step_size == 0:
|
||||
base_lr *= gamma
|
||||
for _ in range(steps_per_epoch):
|
||||
lr_each_step.append(base_lr)
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
if args.run_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data'
|
||||
local_train_url = '/cache/train'
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num,\
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
DATA_DIR = local_data_url + '/'
|
||||
else:
|
||||
if args.run_distribute:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,\
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
device_num = 1
|
||||
args.batch_size = args.batch_size*int(8/device_num)
|
||||
device_id = args.device_id
|
||||
DATA_DIR = args.data_url + '/'
|
||||
|
||||
data = ds.ImageFolderDataset(DATA_DIR, decode=True, shuffle=True,\
|
||||
num_parallel_workers=args.num_parallel_workers, num_shards=device_num, shard_id=device_id)
|
||||
|
||||
transform_img = [
|
||||
C.RandomCrop((128, 64), padding=4),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize([0.485*255, 0.456*255, 0.406*255], [0.229*255, 0.224*255, 0.225*255]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
num_classes = max(data.num_classes(), 0)
|
||||
|
||||
data = data.map(input_columns="image", operations=transform_img, num_parallel_workers=args.num_parallel_workers)
|
||||
data = data.batch(batch_size=args.batch_size)
|
||||
|
||||
data_size = data.get_dataset_size()
|
||||
|
||||
loss_cb = LossMonitor(data_size)
|
||||
time_cb = TimeMonitor(data_size=data_size)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
|
||||
#save training results
|
||||
if args.save_check_point and (device_num == 1 or device_id == 0):
|
||||
|
||||
model_save_path = './ckpt_' + str(6) + '/'
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=data_size*args.epoch, keep_checkpoint_max=args.epoch)
|
||||
|
||||
if args.run_modelarts:
|
||||
ckpoint_cb = ModelCheckpoint(prefix='deepsort', directory=local_train_url, config=config_ck)
|
||||
else:
|
||||
ckpoint_cb = ModelCheckpoint(prefix='deepsort', directory=model_save_path, config=config_ck)
|
||||
callbacks += [ckpoint_cb]
|
||||
|
||||
#design learning rate
|
||||
lr = Tensor(get_lr(args.learning_rate, args.epoch, data_size, args.decay_epoch, args.gamma))
|
||||
# net definition
|
||||
net = Net(num_classes=num_classes)
|
||||
|
||||
# loss and optimizer
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=args.momentum)
|
||||
#optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=args.momentum, weight_decay=5e-4)
|
||||
#optimizer = mindspore.nn.Momentum(params = net.trainable_params(), learning_rate=lr, momentum=args.momentum)
|
||||
|
||||
#train
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer)
|
||||
|
||||
model.train(args.epoch, data, callbacks=callbacks, dataset_sink_mode=True)
|
||||
if args.run_modelarts:
|
||||
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Detection:
|
||||
"""
|
||||
This class represents a bounding box detection in a single image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tlwh : array_like
|
||||
Bounding box in format `(x, y, w, h)`.
|
||||
confidence : float
|
||||
Detector confidence score.
|
||||
feature : array_like
|
||||
A feature vector that describes the object contained in this image.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
tlwh : ndarray
|
||||
Bounding box in format `(top left x, top left y, width, height)`.
|
||||
confidence : ndarray
|
||||
Detector confidence score.
|
||||
feature : ndarray | NoneType
|
||||
A feature vector that describes the object contained in this image.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tlwh, confidence, feature):
|
||||
self.tlwh = np.asarray(tlwh, dtype=np.float)
|
||||
self.confidence = float(confidence)
|
||||
self.feature = np.asarray(feature, dtype=np.float32)
|
||||
|
||||
def to_tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def to_xyah(self):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import linear_assignment
|
||||
|
||||
|
||||
def iou(bbox, candidates):
|
||||
"""Computer intersection over union.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox : ndarray
|
||||
A bounding box in format `(top left x, top left y, width, height)`.
|
||||
candidates : ndarray
|
||||
A matrix of candidate bounding boxes (one per row) in the same format
|
||||
as `bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The intersection over union in [0, 1] between the `bbox` and each
|
||||
candidate. A higher score means a larger fraction of the `bbox` is
|
||||
occluded by the candidate.
|
||||
|
||||
"""
|
||||
bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
|
||||
candidates_tl = candidates[:, :2]
|
||||
candidates_br = candidates[:, :2] + candidates[:, 2:]
|
||||
|
||||
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
|
||||
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
|
||||
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
|
||||
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
|
||||
wh = np.maximum(0., br - tl)
|
||||
|
||||
area_intersection = wh.prod(axis=1)
|
||||
area_bbox = bbox[2:].prod()
|
||||
area_candidates = candidates[:, 2:].prod(axis=1)
|
||||
return area_intersection / (area_bbox + area_candidates - area_intersection)
|
||||
|
||||
|
||||
def iou_cost(tracks, detections, track_indices=None,
|
||||
detection_indices=None):
|
||||
"""An intersection over union distance metric.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tracks : List[deep_sort.track.Track]
|
||||
A list of tracks.
|
||||
detections : List[deep_sort.detection.Detection]
|
||||
A list of detections.
|
||||
track_indices : Optional[List[int]]
|
||||
A list of indices to tracks that should be matched. Defaults to
|
||||
all `tracks`.
|
||||
detection_indices : Optional[List[int]]
|
||||
A list of indices to detections that should be matched. Defaults
|
||||
to all `detections`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a cost matrix of shape
|
||||
len(track_indices), len(detection_indices) where entry (i, j) is
|
||||
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if tracks[track_idx].time_since_update > 1:
|
||||
cost_matrix[row, :] = linear_assignment.INFTY_COST
|
||||
continue
|
||||
|
||||
bbox = tracks[track_idx].to_tlwh()
|
||||
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
|
||||
cost_matrix[row, :] = 1. - iou(bbox, candidates)
|
||||
return cost_matrix
|
|
@ -0,0 +1,237 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
||||
chi2inv95 = {
|
||||
1: 3.8415,
|
||||
2: 5.9915,
|
||||
3: 7.8147,
|
||||
4: 9.4877,
|
||||
5: 11.070,
|
||||
6: 12.592,
|
||||
7: 14.067,
|
||||
8: 15.507,
|
||||
9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilter:
|
||||
"""
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(self._motion_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(
|
||||
projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((
|
||||
kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements,
|
||||
only_position=False):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
d = measurements - mean
|
||||
z = scipy.linalg.solve_triangular(
|
||||
cholesky_factor, d.T, lower=True, check_finite=False,
|
||||
overwrite_b=True)
|
||||
squared_maha = np.sum(z * z, axis=0)
|
||||
return squared_maha
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
# from sklearn.utils.linear_assignment_ import linear_assignment
|
||||
from scipy.optimize import linear_sum_assignment as linear_assignment
|
||||
from . import kalman_filter
|
||||
|
||||
|
||||
INFTY_COST = 1e+5
|
||||
|
||||
|
||||
def min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections, track_indices=None,
|
||||
detection_indices=None):
|
||||
"""Solve linear assignment problem.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as well as
|
||||
a list of N track indices and M detection indices. The metric should
|
||||
return the NxM dimensional cost matrix, where element (i, j) is the
|
||||
association cost between the i-th track in the given track indices and
|
||||
the j-th detection in the given detection_indices.
|
||||
max_distance : float
|
||||
Gating threshold. Associations with cost larger than this value are
|
||||
disregarded.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : List[int]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above).
|
||||
detection_indices : List[int]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(List[(int, int)], List[int], List[int])
|
||||
Returns a tuple with the following three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
if not detection_indices or not track_indices:
|
||||
return [], track_indices, detection_indices # Nothing to match.
|
||||
|
||||
cost_matrix = distance_metric(
|
||||
tracks, detections, track_indices, detection_indices)
|
||||
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
|
||||
|
||||
row_indices, col_indices = linear_assignment(cost_matrix)
|
||||
|
||||
matches, unmatched_tracks, unmatched_detections = [], [], []
|
||||
for col, detection_idx in enumerate(detection_indices):
|
||||
if col not in col_indices:
|
||||
unmatched_detections.append(detection_idx)
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if row not in row_indices:
|
||||
unmatched_tracks.append(track_idx)
|
||||
for row, col in zip(row_indices, col_indices):
|
||||
track_idx = track_indices[row]
|
||||
detection_idx = detection_indices[col]
|
||||
if cost_matrix[row, col] > max_distance:
|
||||
unmatched_tracks.append(track_idx)
|
||||
unmatched_detections.append(detection_idx)
|
||||
else:
|
||||
matches.append((track_idx, detection_idx))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def matching_cascade(
|
||||
distance_metric, max_distance, cascade_depth, tracks, detections,
|
||||
track_indices=None, detection_indices=None):
|
||||
"""Run matching cascade.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as well as
|
||||
a list of N track indices and M detection indices. The metric should
|
||||
return the NxM dimensional cost matrix, where element (i, j) is the
|
||||
association cost between the i-th track in the given track indices and
|
||||
the j-th detection in the given detection indices.
|
||||
max_distance : float
|
||||
Gating threshold. Associations with cost larger than this value are
|
||||
disregarded.
|
||||
cascade_depth: int
|
||||
The cascade depth, should be se to the maximum track age.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : Optional[List[int]]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above). Defaults to all tracks.
|
||||
detection_indices : Optional[List[int]]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above). Defaults to all
|
||||
detections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(List[(int, int)], List[int], List[int])
|
||||
Returns a tuple with the following three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = list(range(len(tracks)))
|
||||
if detection_indices is None:
|
||||
detection_indices = list(range(len(detections)))
|
||||
|
||||
unmatched_detections = detection_indices
|
||||
matches = []
|
||||
for level in range(cascade_depth):
|
||||
if not unmatched_detections: # No detections left
|
||||
break
|
||||
|
||||
track_indices_l = [
|
||||
k for k in track_indices
|
||||
if tracks[k].time_since_update == 1 + level
|
||||
]
|
||||
if not track_indices_l: # Nothing to match at this level
|
||||
continue
|
||||
|
||||
matches_l, _, unmatched_detections = \
|
||||
min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections,
|
||||
track_indices_l, unmatched_detections)
|
||||
matches += matches_l
|
||||
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def gate_cost_matrix(
|
||||
kf, cost_matrix, tracks, detections, track_indices, detection_indices,
|
||||
gated_cost=INFTY_COST, only_position=False):
|
||||
"""Invalidate infeasible entries in cost matrix based on the state
|
||||
distributions obtained by Kalman filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : The Kalman filter.
|
||||
cost_matrix : ndarray
|
||||
The NxM dimensional cost matrix, where N is the number of track indices
|
||||
and M is the number of detection indices, such that entry (i, j) is the
|
||||
association cost between `tracks[track_indices[i]]` and
|
||||
`detections[detection_indices[j]]`.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : List[int]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above).
|
||||
detection_indices : List[int]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above).
|
||||
gated_cost : Optional[float]
|
||||
Entries in the cost matrix corresponding to infeasible associations are
|
||||
set this value. Defaults to a very large value.
|
||||
only_position : Optional[bool]
|
||||
If True, only the x, y position of the state distribution is considered
|
||||
during gating. Defaults to False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns the modified cost matrix.
|
||||
|
||||
"""
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray(
|
||||
[detections[i].to_xyah() for i in detection_indices])
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
track = tracks[track_idx]
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
|
||||
return cost_matrix
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _pdist(a, b):
|
||||
"""Compute pair-wise squared distance between points in `a` and `b`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
An NxM matrix of N samples of dimensionality M.
|
||||
b : array_like
|
||||
An LxM matrix of L samples of dimensionality M.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||
contains the squared distance between `a[i]` and `b[j]`.
|
||||
|
||||
"""
|
||||
a, b = np.asarray(a), np.asarray(b)
|
||||
if np.size(a) == 0 or np.size(b) == 0:
|
||||
return np.zeros((len(a), len(b)))
|
||||
a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
|
||||
r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
|
||||
r2 = np.clip(r2, 0., float(np.inf))
|
||||
return r2
|
||||
|
||||
|
||||
def _cosine_distance(a, b, data_is_normalized=False):
|
||||
"""Compute pair-wise cosine distance between points in `a` and `b`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
An NxM matrix of N samples of dimensionality M.
|
||||
b : array_like
|
||||
An LxM matrix of L samples of dimensionality M.
|
||||
data_is_normalized : Optional[bool]
|
||||
If True, assumes rows in a and b are unit length vectors.
|
||||
Otherwise, a and b are explicitly normalized to length 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||
contains the squared distance between `a[i]` and `b[j]`.
|
||||
|
||||
"""
|
||||
if not data_is_normalized:
|
||||
a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
|
||||
b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
|
||||
return 1. - np.dot(a, b.T)
|
||||
|
||||
|
||||
def _nn_euclidean_distance(x, y):
|
||||
""" Helper function for nearest neighbor distance metric (Euclidean).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
A matrix of N row-vectors (sample points).
|
||||
y : ndarray
|
||||
A matrix of M row-vectors (query points).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A vector of length M that contains for each entry in `y` the
|
||||
smallest Euclidean distance to a sample in `x`.
|
||||
|
||||
"""
|
||||
distances = _pdist(x, y)
|
||||
return np.maximum(0.0, distances.min(axis=0))
|
||||
|
||||
|
||||
def _nn_cosine_distance(x, y):
|
||||
""" Helper function for nearest neighbor distance metric (cosine).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
A matrix of N row-vectors (sample points).
|
||||
y : ndarray
|
||||
A matrix of M row-vectors (query points).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A vector of length M that contains for each entry in `y` the
|
||||
smallest cosine distance to a sample in `x`.
|
||||
|
||||
"""
|
||||
distances = _cosine_distance(x, y)
|
||||
return distances.min(axis=0)
|
||||
|
||||
|
||||
class NearestNeighborDistanceMetric:
|
||||
"""
|
||||
A nearest neighbor distance metric that, for each target, returns
|
||||
the closest distance to any sample that has been observed so far.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
Either "euclidean" or "cosine".
|
||||
matching_threshold: float
|
||||
The matching threshold. Samples with larger distance are considered an
|
||||
invalid match.
|
||||
budget : Optional[int]
|
||||
If not None, fix samples per class to at most this number. Removes
|
||||
the oldest samples when the budget is reached.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samples : Dict[int -> List[ndarray]]
|
||||
A dictionary that maps from target identities to the list of samples
|
||||
that have been observed so far.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metric, matching_threshold, budget=None):
|
||||
|
||||
|
||||
if metric == "euclidean":
|
||||
self._metric = _nn_euclidean_distance
|
||||
elif metric == "cosine":
|
||||
self._metric = _nn_cosine_distance
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid metric; must be either 'euclidean' or 'cosine'")
|
||||
self.matching_threshold = matching_threshold
|
||||
self.budget = budget
|
||||
self.samples = {}
|
||||
|
||||
def partial_fit(self, features, targets, active_targets):
|
||||
"""Update the distance metric with new data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : ndarray
|
||||
An NxM matrix of N features of dimensionality M.
|
||||
targets : ndarray
|
||||
An integer array of associated target identities.
|
||||
active_targets : List[int]
|
||||
A list of targets that are currently present in the scene.
|
||||
|
||||
"""
|
||||
for feature, target in zip(features, targets):
|
||||
self.samples.setdefault(target, []).append(feature)
|
||||
if self.budget is not None:
|
||||
self.samples[target] = self.samples[target][-int(self.budget):]
|
||||
self.samples = {k: self.samples[k] for k in active_targets}
|
||||
|
||||
def distance(self, features, targets):
|
||||
"""Compute distance between features and targets.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : ndarray
|
||||
An NxM matrix of N features of dimensionality M.
|
||||
targets : List[int]
|
||||
A list of targets to match the given `features` against.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a cost matrix of shape len(targets), len(features), where
|
||||
element (i, j) contains the closest squared distance between
|
||||
`targets[i]` and `features[j]`.
|
||||
|
||||
"""
|
||||
cost_matrix = np.zeros((len(targets), len(features)))
|
||||
for i, target in enumerate(targets):
|
||||
cost_matrix[i, :] = self._metric(self.samples[target], features)
|
||||
return cost_matrix
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
class TrackState:
|
||||
"""
|
||||
Enumeration type for the single target track state. Newly created tracks are
|
||||
classified as `tentative` until enough evidence has been collected. Then,
|
||||
the track state is changed to `confirmed`. Tracks that are no longer alive
|
||||
are classified as `deleted` to mark them for removal from the set of active
|
||||
tracks.
|
||||
|
||||
"""
|
||||
|
||||
Tentative = 1
|
||||
Confirmed = 2
|
||||
Deleted = 3
|
||||
|
||||
|
||||
class Track:
|
||||
"""
|
||||
A single target track with state space `(x, y, a, h)` and associated
|
||||
velocities, where `(x, y)` is the center of the bounding box, `a` is the
|
||||
aspect ratio and `h` is the height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector of the initial state distribution.
|
||||
covariance : ndarray
|
||||
Covariance matrix of the initial state distribution.
|
||||
track_id : int
|
||||
A unique track identifier.
|
||||
n_init : int
|
||||
Number of consecutive detections before the track is confirmed. The
|
||||
track state is set to `Deleted` if a miss occurs within the first
|
||||
`n_init` frames.
|
||||
max_age : int
|
||||
The maximum number of consecutive misses before the track state is
|
||||
set to `Deleted`.
|
||||
feature : Optional[ndarray]
|
||||
Feature vector of the detection this track originates from. If not None,
|
||||
this feature is added to the `features` cache.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector of the initial state distribution.
|
||||
covariance : ndarray
|
||||
Covariance matrix of the initial state distribution.
|
||||
track_id : int
|
||||
A unique track identifier.
|
||||
hits : int
|
||||
Total number of measurement updates.
|
||||
age : int
|
||||
Total number of frames since first occurrence.
|
||||
time_since_update : int
|
||||
Total number of frames since last measurement update.
|
||||
state : TrackState
|
||||
The current track state.
|
||||
features : List[ndarray]
|
||||
A cache of features. On each measurement update, the associated feature
|
||||
vector is added to this list.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mean, covariance, track_id, n_init, max_age,
|
||||
feature=None):
|
||||
self.mean = mean
|
||||
self.covariance = covariance
|
||||
self.track_id = track_id
|
||||
self.hits = 1
|
||||
self.age = 1
|
||||
self.time_since_update = 0
|
||||
|
||||
self.state = TrackState.Tentative
|
||||
self.features = []
|
||||
if feature is not None:
|
||||
self.features.append(feature)
|
||||
|
||||
self._n_init = n_init
|
||||
self._max_age = max_age
|
||||
|
||||
def to_tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The bounding box.
|
||||
|
||||
"""
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
def to_tlbr(self):
|
||||
"""Get current position in bounding box format `(min x, miny, max x,
|
||||
max y)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The bounding box.
|
||||
|
||||
"""
|
||||
ret = self.to_tlwh()
|
||||
ret[2:] = ret[:2] + ret[2:]
|
||||
return ret
|
||||
|
||||
def predict(self, kf):
|
||||
"""Propagate the state distribution to the current time step using a
|
||||
Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : kalman_filter.KalmanFilter
|
||||
The Kalman filter.
|
||||
|
||||
"""
|
||||
self.mean, self.covariance = kf.predict(self.mean, self.covariance)
|
||||
self.age += 1
|
||||
self.time_since_update += 1
|
||||
|
||||
def update(self, kf, detection):
|
||||
"""Perform Kalman filter measurement update step and update the feature
|
||||
cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : kalman_filter.KalmanFilter
|
||||
The Kalman filter.
|
||||
detection : Detection
|
||||
The associated detection.
|
||||
|
||||
"""
|
||||
self.mean, self.covariance = kf.update(
|
||||
self.mean, self.covariance, detection.to_xyah())
|
||||
self.features.append(detection.feature)
|
||||
|
||||
self.hits += 1
|
||||
self.time_since_update = 0
|
||||
if self.state == TrackState.Tentative and self.hits >= self._n_init:
|
||||
self.state = TrackState.Confirmed
|
||||
|
||||
def mark_missed(self):
|
||||
"""Mark this track as missed (no association at the current time step).
|
||||
"""
|
||||
if self.state == TrackState.Tentative:
|
||||
self.state = TrackState.Deleted
|
||||
elif self.time_since_update > self._max_age:
|
||||
self.state = TrackState.Deleted
|
||||
|
||||
def is_tentative(self):
|
||||
"""Returns True if this track is tentative (unconfirmed).
|
||||
"""
|
||||
return self.state == TrackState.Tentative
|
||||
|
||||
def is_confirmed(self):
|
||||
"""Returns True if this track is confirmed."""
|
||||
return self.state == TrackState.Confirmed
|
||||
|
||||
def is_deleted(self):
|
||||
"""Returns True if this track is dead and should be deleted."""
|
||||
return self.state == TrackState.Deleted
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import kalman_filter
|
||||
from . import linear_assignment
|
||||
from . import iou_matching
|
||||
from .track import Track
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""
|
||||
This is the multi-target tracker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : nn_matching.NearestNeighborDistanceMetric
|
||||
A distance metric for measurement-to-track association.
|
||||
max_age : int
|
||||
Maximum number of missed misses before a track is deleted.
|
||||
n_init : int
|
||||
Number of consecutive detections before the track is confirmed. The
|
||||
track state is set to `Deleted` if a miss occurs within the first
|
||||
`n_init` frames.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
metric : nn_matching.NearestNeighborDistanceMetric
|
||||
The distance metric used for measurement to track association.
|
||||
max_age : int
|
||||
Maximum number of missed misses before a track is deleted.
|
||||
n_init : int
|
||||
Number of frames that a track remains in initialization phase.
|
||||
kf : kalman_filter.KalmanFilter
|
||||
A Kalman filter to filter target trajectories in image space.
|
||||
tracks : List[Track]
|
||||
The list of active tracks at the current time step.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3):
|
||||
self.metric = metric
|
||||
self.max_iou_distance = max_iou_distance
|
||||
self.max_age = max_age
|
||||
self.n_init = n_init
|
||||
|
||||
self.kf = kalman_filter.KalmanFilter()
|
||||
self.tracks = []
|
||||
self._next_id = 1
|
||||
|
||||
def predict(self):
|
||||
"""Propagate track state distributions one time step forward.
|
||||
|
||||
This function should be called once every time step, before `update`.
|
||||
"""
|
||||
for track in self.tracks:
|
||||
track.predict(self.kf)
|
||||
|
||||
def update(self, detections):
|
||||
"""Perform measurement update and track management.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detections : List[deep_sort.detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
|
||||
"""
|
||||
# Run matching cascade.
|
||||
matches, unmatched_tracks, unmatched_detections = \
|
||||
self._match(detections)
|
||||
|
||||
# Update track set.
|
||||
for track_idx, detection_idx in matches:
|
||||
self.tracks[track_idx].update(
|
||||
self.kf, detections[detection_idx])
|
||||
for track_idx in unmatched_tracks:
|
||||
self.tracks[track_idx].mark_missed()
|
||||
for detection_idx in unmatched_detections:
|
||||
self._initiate_track(detections[detection_idx])
|
||||
self.tracks = [t for t in self.tracks if not t.is_deleted()]
|
||||
|
||||
# Update distance metric.
|
||||
active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
|
||||
features, targets = [], []
|
||||
for track in self.tracks:
|
||||
if not track.is_confirmed():
|
||||
continue
|
||||
features += track.features
|
||||
targets += [track.track_id for _ in track.features]
|
||||
track.features = []
|
||||
self.metric.partial_fit(
|
||||
np.asarray(features), np.asarray(targets), active_targets)
|
||||
|
||||
def _match(self, detections):
|
||||
|
||||
def gated_metric(tracks, dets, track_indices, detection_indices):
|
||||
features = np.array([dets[i].feature for i in detection_indices])
|
||||
targets = np.array([tracks[i].track_id for i in track_indices])
|
||||
cost_matrix = self.metric.distance(features, targets)
|
||||
cost_matrix = linear_assignment.gate_cost_matrix(
|
||||
self.kf, cost_matrix, tracks, dets, track_indices,
|
||||
detection_indices)
|
||||
|
||||
return cost_matrix
|
||||
|
||||
# Split track set into confirmed and unconfirmed tracks.
|
||||
confirmed_tracks = [
|
||||
i for i, t in enumerate(self.tracks) if t.is_confirmed()]
|
||||
unconfirmed_tracks = [
|
||||
i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
|
||||
|
||||
# Associate confirmed tracks using appearance features.
|
||||
matches_a, unmatched_tracks_a, unmatched_detections = \
|
||||
linear_assignment.matching_cascade(
|
||||
gated_metric, self.metric.matching_threshold, self.max_age,
|
||||
self.tracks, detections, confirmed_tracks)
|
||||
|
||||
# Associate remaining tracks together with unconfirmed tracks using IOU.
|
||||
iou_track_candidates = unconfirmed_tracks + [
|
||||
k for k in unmatched_tracks_a if
|
||||
self.tracks[k].time_since_update == 1]
|
||||
unmatched_tracks_a = [
|
||||
k for k in unmatched_tracks_a if
|
||||
self.tracks[k].time_since_update != 1]
|
||||
matches_b, unmatched_tracks_b, unmatched_detections = \
|
||||
linear_assignment.min_cost_matching(
|
||||
iou_matching.iou_cost, self.max_iou_distance, self.tracks,
|
||||
detections, iou_track_candidates, unmatched_detections)
|
||||
|
||||
matches = matches_a + matches_b
|
||||
unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
def _initiate_track(self, detection):
|
||||
mean, covariance = self.kf.initiate(detection.to_xyah())
|
||||
self.tracks.append(Track(
|
||||
mean, covariance, self._next_id, self.n_init, self.max_age,
|
||||
detection.feature))
|
||||
self._next_id += 1
|
Loading…
Reference in New Issue