forked from mindspore-Ecosystem/mindspore
fix
This commit is contained in:
parent
5ee636c519
commit
791745a9b4
|
@ -0,0 +1,247 @@
|
|||
# Contents
|
||||
|
||||
- [Description](#description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Dataset Preparation](#dataset-preparation)
|
||||
- [Model Checkpoints](#model-checkpoints)
|
||||
- [Running](#running)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [Description](#contents)
|
||||
|
||||
There has been remarkable progress on object detection and re-identification in recent years which are the core components for multi-object tracking. However, little attention has been focused on accomplishing the two tasks in a single network to improve the inference speed. The initial attempts along this path ended up with degraded results mainly because the re-identification branch is not appropriately learned. In this work, we study the essential reasons behind the failure, and accordingly present a simple baseline to addresses the problems. It remarkably outperforms the state-of-the-arts on the MOT challenge datasets at 30 FPS. This baseline could inspire and help evaluate new ideas in this field. More detail about this model can be found in:
|
||||
|
||||
Zhang Y , Wang C , Wang X , et al. FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking[J]. 2020.
|
||||
|
||||
This repository contains a Mindspore implementation of FairMot based upon original Pytorch implementation (<https://github.com/ifzhang/FairMOT>). The training and validating scripts are also included, and the evaluation results are shown in the [Performance](#performance) section.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The overall network architecture of FairMOT is shown below:
|
||||
|
||||
[Link](https://arxiv.org/pdf/1804.06208.pdf)
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: ETH, CalTech, MOT17, CUHK-SYSU, PRW, CityPerson
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## [Mixed Precision](#contents)
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
To run the python scripts in the repository, you need to prepare the environment as follow:
|
||||
|
||||
- Python and dependencies
|
||||
- opencv-python 4.5.1.48
|
||||
- Cython 0.29.23
|
||||
- cython-bbox 0.1.3
|
||||
- sympy 1.7.1
|
||||
- yacs
|
||||
- numba
|
||||
- progress
|
||||
- motmetrics 1.2.0
|
||||
- matplotlib 3.4.1
|
||||
- lap 0.4.0
|
||||
- openpyxl 3.0.7
|
||||
- Pillow 8.1.0
|
||||
- tensorboardX 2.2
|
||||
- python 3.7
|
||||
- mindspore 1.2.0
|
||||
- pycocotools 2.0
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
## [Dataset Preparation](#contents)
|
||||
|
||||
FairMot model uses mix dataset to train and validate in this repository. We use the training data as [JDE](https://github.com/Zhongdao/Towards-Realtime-MOT) in this part and we call it "MIX". Please refer to their [DATA ZOO](https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/DATASET_ZOO.md) to download and prepare all the training data including Caltech Pedestrian, CityPersons, CUHK-SYSU, PRW, ETHZ, MOT17 and MOT16.
|
||||
|
||||
## [Model Checkpoints](#contents)
|
||||
|
||||
Before you start your training process, you need to obtain mindspore pretrained models.
|
||||
The FairMOT model (DLA-34 backbone_conv) can be downloaded here:
|
||||
[dla34-ba72cf86.pth](http://dl.yf.io/dla/models/imagenet/dla34-ba72cf86.pth)
|
||||
|
||||
## [Running](#contents)
|
||||
|
||||
To train the model, run the shell script `scripts/train_standalone.sh` with the format below:
|
||||
|
||||
```shell
|
||||
# standalone training
|
||||
sh scripts/run_standalone_train.sh [device_id]
|
||||
|
||||
# distributed training
|
||||
sh scripts/run_distribute_train.sh [device_num]
|
||||
```
|
||||
|
||||
To validate the model, change the settings in `src/opts.py` to the path of the model you want to validate. For example:
|
||||
|
||||
```python
|
||||
self.parser.add_argument('--load_model', default='XXX.ckpt',
|
||||
help='path to pretrained model')
|
||||
```
|
||||
|
||||
Then, run the shell script `scripts/run_eval.sh` with the format below:
|
||||
|
||||
```shell
|
||||
sh scripts/run_eval.sh [device_id]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
The structure of the files in this repository is shown below.
|
||||
|
||||
```text
|
||||
└─mindspore-fairmot
|
||||
├─scripts
|
||||
│ ├─run_eval.sh // launch ascend standalone evaluation
|
||||
│ ├─run_distribute_train.sh // launch ascend distributed training
|
||||
│ └─run_standalone_train.sh // launch ascend standalone training
|
||||
├─src
|
||||
│ ├─tracker
|
||||
│ │ ├─basetrack.py // basic tracker
|
||||
│ │ ├─matching.py // calculating box distance
|
||||
│ │ └─multitracker.py // JDETracker
|
||||
│ ├─tracking_utils
|
||||
│ │ ├─evaluation.py // evaluate tracking results
|
||||
│ │ ├─kalman_filter.py // Kalman filter for tracking bounding boxes
|
||||
│ │ ├─log.py // logging tools
|
||||
│ │ ├─io.py //I/o tool
|
||||
│ │ ├─timer.py // evaluation of time consuming
|
||||
│ │ ├─utils.py // check that the folder exists
|
||||
│ │ └─visualization.py // display image tool
|
||||
│ ├─utils
|
||||
│ │ ├─callback.py // custom callback functions
|
||||
│ │ ├─image.py // image processing
|
||||
│ │ ├─jde.py // LoadImage
|
||||
│ │ ├─logger.py // a summary writer logging
|
||||
│ │ ├─lr_schedule.py // learning ratio generator
|
||||
│ │ ├─pth2ckpt.py // pth transformer
|
||||
│ │ └─tools.py // image processing tool
|
||||
│ ├─fairmot_poase.py // WithLossCell
|
||||
│ ├─losses.py // loss
|
||||
│ ├─opts.py // total config
|
||||
│ ├─util.py // routine operation
|
||||
│ ├─infer_net.py // infer net
|
||||
│ └─backbone_dla_conv.py // dla34_conv net
|
||||
├─fairmot_eval.py // eval fairmot
|
||||
├─fairmot_run.py // run fairmot
|
||||
├─fairmot_train.py // train fairmot
|
||||
├─fairmot_export.py // export fairmot
|
||||
└─README.md // descriptions about this repository
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### [Training](#contents)
|
||||
|
||||
#### Running on Ascend
|
||||
|
||||
Run `scripts/run_standalone_train.sh` to train the model standalone. The usage of the script is:
|
||||
|
||||
```shell
|
||||
sh scripts/run_standalone_train.sh DEVICE_ID DATA_CFG LOAD_PRE_MODEL
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the training procedure.
|
||||
|
||||
```shell
|
||||
sh run_standalone_train.sh 0 ./dataset/ ./dla34.ckpt
|
||||
```
|
||||
|
||||
The model checkpoint will be saved into `./ckpt`.
|
||||
|
||||
### [Distributed Training](#contents)
|
||||
|
||||
#### Running on Ascend
|
||||
|
||||
Run `scripts/run_distribute_train.sh` to train the model distributed. The usage of the script is:
|
||||
|
||||
```shell
|
||||
sh run_distribute.sh RANK_SIZE DATA_CFG LOAD_PRE_MODEL
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the distributed training procedure.
|
||||
|
||||
```shell
|
||||
sh run_distribute.sh 8 ./data.json ./dla34.ckpt
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/tran[X].log` as follows:
|
||||
|
||||
The model checkpoint will be saved into `train_parallel[X]/ckpt`.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
The evaluation data set was [MOT20](https://motchallenge.net/data/MOT20/)
|
||||
|
||||
### Running on Ascend
|
||||
|
||||
Run `scripts/run_eval.sh` to evaluate the model with one Ascend processor. The usage of the script is:
|
||||
|
||||
```shell
|
||||
sh run_eval.sh DEVICE_ID LOAD_MODEL
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the validation procedure.
|
||||
|
||||
```shell
|
||||
sh run_eval.sh 0 ./dla34.ckpt
|
||||
```
|
||||
|
||||
The tracing results can be viewed in `/MOT20/distribute_dla34_conv`.
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### FairMot on MIX dataset with detector
|
||||
|
||||
#### Performance parameters
|
||||
|
||||
| Parameters | Standalone | Distributed |
|
||||
| ------------------- | --------------------------- | --------------------------- |
|
||||
| Model Version | FairMotNet | FairMotNet |
|
||||
| Resource | Ascend 910 | 8 Ascend 910 cards |
|
||||
| Uploaded Date | 25/06/2021 (month/day/year) | 25/06/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 | 1.2.0 |
|
||||
| Training Dataset | MIX | MIX |
|
||||
| Evaluation Dataset | MOT20 | MOT20 |
|
||||
| Training Parameters | epoch=30, batch_size=4 | epoch=30, batch_size=4 |
|
||||
| Optimizer | Adam | Adam |
|
||||
| Loss Function | FocalLoss,RegLoss | FocalLoss,RegLoss |
|
||||
| Train Performance | MOTA:43.8% Prcn:90.9% | MOTA:42.5% Prcn:91.9%% |
|
||||
| Speed | 1pc: 380.528 ms/step | 8pc: 700.371 ms/step |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We also use random seed in `src/utils/backbone_dla_conv.py` to initial network weights.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"root":"/opt_data/xidian_wks/kasor/Fairmot/dataset",
|
||||
"train":
|
||||
{
|
||||
"cuhksysu":"/data/cuhksysu.train",
|
||||
"caltech":"/data/caltech.all",
|
||||
"citypersons":"/data/citypersons.train",
|
||||
"mot17":"/data/mot17.half",
|
||||
"prw":"/data/prw.train",
|
||||
"eth":"/data/eth.train"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval fairmot."""
|
||||
import os
|
||||
import os.path as osp
|
||||
import logging
|
||||
from src.backbone_dla_conv import DLASegConv
|
||||
from src.infer_net import InferNet
|
||||
from src.opts import Opts
|
||||
from src.fairmot_pose import WithNetCell
|
||||
from src.tracking_utils import visualization as vis
|
||||
from src.tracker.multitracker import JDETracker
|
||||
from src.tracking_utils.log import logger
|
||||
from src.tracking_utils.utils import mkdir_if_missing
|
||||
from src.tracking_utils.evaluation import Evaluator
|
||||
from src.tracking_utils.timer import Timer
|
||||
import src.utils.jde as datasets
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
import cv2
|
||||
import motmetrics as mm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def write_results(filename, results, data_type):
|
||||
"""write eval results."""
|
||||
if data_type == 'mot':
|
||||
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
|
||||
elif data_type == 'kitti':
|
||||
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
|
||||
else:
|
||||
raise ValueError(data_type)
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
for frame_id, tlwhs, track_ids in results:
|
||||
if data_type == 'kitti':
|
||||
frame_id -= 1
|
||||
for tlwh, track_id in zip(tlwhs, track_ids):
|
||||
if track_id < 0:
|
||||
continue
|
||||
x1, y1, w, h = tlwh
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
|
||||
f.write(line)
|
||||
logger.info('save results to %s', filename)
|
||||
|
||||
|
||||
def eval_seq(opt, net, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
||||
"""evaluation sequence."""
|
||||
if save_dir:
|
||||
mkdir_if_missing(save_dir)
|
||||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||
timer = Timer()
|
||||
results = []
|
||||
frame_id = 0
|
||||
# for path, img, img0 in dataloader:
|
||||
for _, img, img0 in dataloader:
|
||||
if frame_id % 20 == 0:
|
||||
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
|
||||
# run tracking
|
||||
timer.tic()
|
||||
blob = np.expand_dims(img, 0)
|
||||
blob = Tensor(blob, mstype.float32)
|
||||
img0 = Tensor(img0, mstype.float32)
|
||||
height, width = img0.shape[0], img0.shape[1]
|
||||
inp_height, inp_width = [blob.shape[2], blob.shape[3]]
|
||||
c = np.array([width / 2., height / 2.], dtype=np.float32)
|
||||
s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
|
||||
meta = {'c': c, 's': s, 'out_height': inp_height // opt.down_ratio,
|
||||
'out_width': inp_width // opt.down_ratio}
|
||||
id_feature, dets = net(blob)
|
||||
online_targets = tracker.update(id_feature.asnumpy(), dets, meta)
|
||||
online_tlwhs = []
|
||||
online_ids = []
|
||||
for t in online_targets:
|
||||
tlwh = t.tlwh
|
||||
tid = t.track_id
|
||||
vertical = tlwh[2] / tlwh[3] > 1.6
|
||||
if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
|
||||
online_tlwhs.append(tlwh)
|
||||
online_ids.append(tid)
|
||||
timer.toc()
|
||||
results.append((frame_id + 1, online_tlwhs, online_ids))
|
||||
if show_image or save_dir is not None:
|
||||
online_im = vis.plot_tracking(img0, online_tlwhs, online_ids, frame_id=frame_id,
|
||||
fps=1. / timer.average_time)
|
||||
if show_image:
|
||||
cv2.imshow('online_im', online_im)
|
||||
if save_dir is not None:
|
||||
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
|
||||
frame_id += 1
|
||||
write_results(result_filename, results, data_type)
|
||||
return frame_id, timer.average_time, timer.calls
|
||||
|
||||
|
||||
def main(opt, data_root, seqs=None, exp_name='MOT17_test_public_dla34',
|
||||
save_images=True, save_videos=False, show_image=False):
|
||||
"""evaluation sequence."""
|
||||
logger.setLevel(logging.INFO)
|
||||
result_root = os.path.join(data_root, '..', 'results', exp_name)
|
||||
mkdir_if_missing(result_root)
|
||||
data_type = 'mot'
|
||||
# run tracking
|
||||
accs = []
|
||||
n_frame = 0
|
||||
timer_avgs, timer_calls = [], []
|
||||
# eval=eval_seq(opt, data_type, show_image=True)
|
||||
backbone_net = DLASegConv(opt.heads,
|
||||
down_ratio=4,
|
||||
final_kernel=1,
|
||||
last_level=5,
|
||||
head_conv=256)
|
||||
load_checkpoint(opt.load_model, net=backbone_net)
|
||||
infer_net = InferNet()
|
||||
net = WithNetCell(backbone_net, infer_net)
|
||||
net.set_train(False)
|
||||
for sequence in seqs:
|
||||
output_dir = os.path.join(data_root, '..', 'outputs', exp_name, sequence) \
|
||||
if save_images or save_videos else None
|
||||
logger.info('start seq: %s', sequence)
|
||||
dataloader = datasets.LoadImages(osp.join(data_root, sequence, 'img1'), (1088, 608))
|
||||
result_filename = os.path.join(result_root, '{}.txt'.format(sequence))
|
||||
meta_info = open(os.path.join(data_root, sequence, 'seqinfo.ini')).read()
|
||||
frame_rate = int(meta_info[meta_info.find('frameRate') + 10:meta_info.find('\nseqLength')])
|
||||
nf, ta, tc = eval_seq(opt, net, dataloader, data_type, result_filename,
|
||||
save_dir=output_dir, show_image=show_image, frame_rate=frame_rate)
|
||||
n_frame += nf
|
||||
timer_avgs.append(ta)
|
||||
timer_calls.append(tc)
|
||||
logger.info('Evaluate seq: %s', sequence)
|
||||
evaluator = Evaluator(data_root, sequence, data_type)
|
||||
accs.append(evaluator.eval_file(result_filename))
|
||||
if save_videos:
|
||||
print(output_dir)
|
||||
output_video_path = osp.join(output_dir, '{}.mp4'.format(sequence))
|
||||
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -c:v copy {}'.format(output_dir, output_video_path)
|
||||
os.system(cmd_str)
|
||||
timer_avgs = np.asarray(timer_avgs)
|
||||
timer_calls = np.asarray(timer_calls)
|
||||
all_time = np.dot(timer_avgs, timer_calls)
|
||||
avg_time = all_time / np.sum(timer_calls)
|
||||
logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(all_time, 1.0 / avg_time))
|
||||
|
||||
# get summary
|
||||
metrics = mm.metrics.motchallenge_metrics
|
||||
mh = mm.metrics.create()
|
||||
summary = Evaluator.get_summary(accs, seqs, metrics)
|
||||
strsummary = mm.io.render_summary(
|
||||
summary,
|
||||
formatters=mh.formatters,
|
||||
namemap=mm.io.motchallenge_metric_names
|
||||
)
|
||||
print(strsummary)
|
||||
Evaluator.save_summary(summary, os.path.join(result_root, 'summary_{}.xlsx'.format(exp_name)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opts = Opts().init()
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
# mode=context.PYNATIVE_MODE,
|
||||
device_target="Ascend",
|
||||
device_id=opts.id,
|
||||
save_graphs=False)
|
||||
seqs_str = '''MOT20-01
|
||||
MOT20-02
|
||||
MOT20-03
|
||||
MOT20-05 '''
|
||||
data_roots = os.path.join(opts.data_dir, 'MOT20/train')
|
||||
seq = [seq.strip() for seq in seqs_str.split()]
|
||||
main(opts,
|
||||
data_root=data_roots,
|
||||
seqs=seq,
|
||||
exp_name='MOT20_distribute_dla34_conv',
|
||||
show_image=False,
|
||||
save_images=False,
|
||||
save_videos=False)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export fairmot."""
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor, export
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from src.opts import Opts
|
||||
from src.backbone_dla_conv import DLASegConv
|
||||
from src.infer_net import InferNet
|
||||
from src.fairmot_pose import WithNetCell
|
||||
|
||||
|
||||
def fairmot_export(opt):
|
||||
"""export fairmot to mindir or air."""
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=opt.id)
|
||||
backbone_net = DLASegConv(opt.heads,
|
||||
down_ratio=4,
|
||||
final_kernel=1,
|
||||
last_level=5,
|
||||
head_conv=256,
|
||||
is_training=True)
|
||||
load_checkpoint(opt.load_model, net=backbone_net)
|
||||
infer_net = InferNet()
|
||||
net = WithNetCell(backbone_net, infer_net)
|
||||
net.set_train(False)
|
||||
input_data = Tensor(np.zeros([1, 3, 608, 1088]), mstype.float32)
|
||||
export(net, input_data, file_name='fairmot', file_format="MINDIR")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt_ = Opts().init()
|
||||
fairmot_export(opt_)
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""run fairmot."""
|
||||
import os
|
||||
import os.path as osp
|
||||
from src.backbone_dla_conv import DLASegConv
|
||||
from src.opts import Opts
|
||||
from src.infer_net import InferNet
|
||||
from src.fairmot_pose import WithNetCell
|
||||
from src.tracking_utils.utils import mkdir_if_missing
|
||||
from src.tracking_utils.log import logger
|
||||
import src.utils.jde as datasets
|
||||
import fairmot_eval
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
|
||||
|
||||
def export(opt):
|
||||
"""run fairmot."""
|
||||
result_root = opt.output_root if opt.output_root != '' else '.'
|
||||
mkdir_if_missing(result_root)
|
||||
|
||||
logger.info('Starting tracking...')
|
||||
dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)
|
||||
result_filename = os.path.join(result_root, 'results.txt')
|
||||
frame_rate = dataloader.frame_rate
|
||||
frame_dir = None if opt.output_format == 'text' else osp.join(result_root, 'frame')
|
||||
backbone_net = DLASegConv(opt.heads,
|
||||
down_ratio=4,
|
||||
final_kernel=1,
|
||||
last_level=5,
|
||||
head_conv=256,
|
||||
is_training=True)
|
||||
load_checkpoint(opt.load_model, net=backbone_net)
|
||||
infer_net = InferNet()
|
||||
net = WithNetCell(backbone_net, infer_net)
|
||||
net.set_train(False)
|
||||
fairmot_eval.eval_seq(opt, net, dataloader, 'mot', result_filename,
|
||||
save_dir=frame_dir, show_image=False, frame_rate=frame_rate)
|
||||
if opt.output_format == 'video':
|
||||
output_video_path = osp.join(result_root, 'MOT16-03-results.mp4')
|
||||
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -b 5000k -c:v mpeg4 {}' \
|
||||
.format(osp.join(result_root, 'frame'),
|
||||
output_video_path)
|
||||
os.system(cmd_str)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opts = Opts().init()
|
||||
export(opts)
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train fairmot."""
|
||||
import json
|
||||
import os
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Model
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import TimeMonitor, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.opts import Opts
|
||||
from src.losses import CenterNetMultiPoseLossCell
|
||||
from src.backbone_dla_conv import DLASegConv
|
||||
from src.fairmot_pose import WithLossCell
|
||||
from src.utils.lr_schedule import dynamic_lr
|
||||
from src.utils.jde import JointDataset
|
||||
from src.utils.callback import LossCallback
|
||||
|
||||
|
||||
def train(opt):
|
||||
"""train fairmot."""
|
||||
local_data_path = '/cache/data'
|
||||
if opt.is_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, max_call_depth=10000)
|
||||
context.set_context(device_id=device_id)
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
local_data_path = os.path.join(local_data_path, str(device_id))
|
||||
opt.data_cfg = os.path.join(local_data_path, 'data_half.json')
|
||||
output_path = opt.train_url
|
||||
if opt.arch == 'dla_34':
|
||||
load_path = os.path.join(local_data_path, 'crowdhuman_ms.ckpt')
|
||||
elif opt.arch == 'hrnet_18':
|
||||
load_path = os.path.join(local_data_path, 'hrnet_ms.ckpt')
|
||||
else:
|
||||
load_path = os.path.join(local_data_path, 'dla34-ba72cf86_ms.ckpt')
|
||||
print('local_data_path:', local_data_path)
|
||||
print('mixdata_path:', opt.data_cfg)
|
||||
print('output_path:', output_path)
|
||||
print('load_path', load_path)
|
||||
# data download
|
||||
print('Download data.')
|
||||
mox.file.copy_parallel(src_url=opt.data_url, dst_url=local_data_path)
|
||||
elif opt.run_distribute:
|
||||
load_path = opt.load_pre_model
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
device_num=device_num,
|
||||
parameter_broadcast=True
|
||||
)
|
||||
init()
|
||||
else:
|
||||
load_path = opt.load_pre_model
|
||||
device_id = opt.id
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
# mode=context.PYNATIVE_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
f = open(opt.data_cfg)
|
||||
data_config = json.load(f)
|
||||
train_set_paths = data_config['train']
|
||||
dataset_root = data_config['root']
|
||||
f.close()
|
||||
if opt.is_modelarts:
|
||||
dataset_root = local_data_path
|
||||
dataset = JointDataset(opt, dataset_root, train_set_paths, (1088, 608), augment=True)
|
||||
opt = Opts().update_dataset_info_and_set_heads(opt, dataset)
|
||||
if opt.is_modelarts or opt.run_distribute:
|
||||
Ms_dataset = ds.GeneratorDataset(dataset, ['input', 'hm', 'reg_mask', 'ind', 'wh', 'reg', 'ids'],
|
||||
shuffle=True, num_parallel_workers=8,
|
||||
num_shards=device_num, shard_id=device_id)
|
||||
else:
|
||||
Ms_dataset = ds.GeneratorDataset(dataset, ['input', 'hm', 'reg_mask', 'ind', 'wh', 'reg', 'ids'],
|
||||
shuffle=True)
|
||||
Ms_dataset = Ms_dataset.batch(batch_size=opt.batch_size, drop_remainder=True)
|
||||
batch_dataset_size = Ms_dataset.get_dataset_size()
|
||||
net = DLASegConv(opt.heads,
|
||||
down_ratio=4,
|
||||
final_kernel=1,
|
||||
last_level=5,
|
||||
head_conv=256)
|
||||
net = net.set_train()
|
||||
param_dict = load_checkpoint(load_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
loss = CenterNetMultiPoseLossCell(opt)
|
||||
lr = Tensor(dynamic_lr(20, opt.num_epochs, batch_dataset_size),
|
||||
mstype.float32)
|
||||
optimizer = nn.Adam(net.trainable_params(), learning_rate=lr)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
fairmot_net = nn.TrainOneStepCell(net_with_loss, optimizer)
|
||||
|
||||
# define callback
|
||||
loss_cb = LossCallback(opt.batch_size)
|
||||
time_cb = TimeMonitor()
|
||||
config_ckpt = CheckpointConfig(saved_network=net)
|
||||
if opt.is_modelarts:
|
||||
ckpoint_cb = ModelCheckpoint(prefix='Fairmot_{}'.format(device_id), directory=local_data_path + '/output/ckpt',
|
||||
config=config_ckpt)
|
||||
else:
|
||||
ckpoint_cb = ModelCheckpoint(prefix='Fairmot_{}'.format(device_id), directory='./ckpt/', config=config_ckpt)
|
||||
callbacks = [loss_cb, ckpoint_cb, time_cb]
|
||||
|
||||
# train
|
||||
model = Model(fairmot_net)
|
||||
model.train(opt.num_epochs, Ms_dataset, callbacks=callbacks)
|
||||
if opt.is_modelarts:
|
||||
mox.file.copy_parallel(local_data_path + "/output", output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt_ = Opts().parse()
|
||||
train(opt_)
|
|
@ -0,0 +1,12 @@
|
|||
yacs
|
||||
opencv-python
|
||||
cython
|
||||
scipy
|
||||
numba
|
||||
progress
|
||||
motmetrics
|
||||
matplotlib
|
||||
lap
|
||||
openpyxl
|
||||
Pillow
|
||||
tensorboardX
|
|
@ -0,0 +1,87 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "========================================================================"
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distribute.sh RANK_SIZE DATA_CFG(options) LOAD_PRE_MODEL(options)"
|
||||
echo "For example: bash run_distribute.sh 8 ./data.json ./dla34.ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "========================================================================"
|
||||
set -e
|
||||
|
||||
RANK_SIZE=$1
|
||||
DATA_CFG=$2
|
||||
LOAD_PRE_MODEL=$3
|
||||
export RANK_SIZE
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf distribute_train
|
||||
mkdir distribute_train
|
||||
cd distribute_train
|
||||
for((i=0;i<${RANK_SIZE};i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cd ./device$i
|
||||
mkdir src
|
||||
cd src
|
||||
mkdir utils
|
||||
cd ../../../
|
||||
cp ./fairmot_train.py ./distribute_train/device$i
|
||||
cp ./src/*.py ./distribute_train/device$i/src
|
||||
cp ./src/utils/*.py ./distribute_train/device$i/src/utils
|
||||
cd ./distribute_train/device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
if [ -f ${DATA_CFG} ]
|
||||
then
|
||||
python fairmot_train.py --run_distribute True --data_cfg ${DATA_CFG} --load_pre_model ${LOAD_PRE_MODEL} --is_modelarts False > train$i.log 2>&1 &
|
||||
else
|
||||
python fairmot_train.py --run_distribute True --is_modelarts False > train$i.log 2>&1 &
|
||||
fi
|
||||
echo "$i finish"
|
||||
cd ../
|
||||
done
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "training success"
|
||||
else
|
||||
echo "training failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "========================================================================"
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_eval.sh DEVICE_ID LOAD_MODEL(options)"
|
||||
echo "For example: bash run_eval.sh 0 ./dla34.ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "========================================================================"
|
||||
export DEVICE_ID=$1
|
||||
export LOAD_MODEL=$2
|
||||
echo "start training for device $DEVICE_ID"
|
||||
python -u ../fairmot_eval.py --id=$DEVICE_ID --load_model=$LOAD_MODEL > eval_$DEVICE_ID.txt 2>&1 &
|
||||
echo "finish"
|
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "========================================================================"
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_standalone_train.sh DEVICE_ID DATA_CFG(options) LOAD_PRE_MODEL(options)"
|
||||
echo "For example: bash run_standalone_train.sh 0 ./dataset/ ./dla34.ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "========================================================================"
|
||||
export DEVICE_ID=$1
|
||||
DATA_CFG=$2
|
||||
LOAD_PRE_MODEL=$3
|
||||
echo "start training for device $DEVICE_ID"
|
||||
python -u ../fairmot_train.py --id ${DEVICE_ID} --data_cfg ${DATA_CFG} --load_pre_model ${LOAD_PRE_MODEL} --is_modelarts False > train${DEVICE_ID}.log 2>&1 &
|
||||
echo "finish"
|
|
@ -0,0 +1,401 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
deep layer aggregation backbone
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common.initializer import Constant
|
||||
|
||||
BN_MOMENTUM = 0.1
|
||||
|
||||
|
||||
class BasicBlock(nn.Cell):
|
||||
"""
|
||||
Basic residual block for dla.
|
||||
|
||||
Args:
|
||||
cin(int): Input channel.
|
||||
cout(int): Output channel.
|
||||
stride(int): Covolution stride. Default: 1.
|
||||
dilation(int): The dilation rate to be used for dilated convolution. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, the feature after covolution.
|
||||
"""
|
||||
|
||||
def __init__(self, cin, cout, stride=1, dilation=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv_bn_act = nn.Conv2dBnAct(cin, cout, kernel_size=3, stride=stride, pad_mode='pad',
|
||||
padding=dilation, has_bias=False, dilation=dilation,
|
||||
has_bn=True, momentum=BN_MOMENTUM,
|
||||
activation='relu', after_fake=False)
|
||||
self.conv_bn = nn.Conv2dBnAct(cout, cout, kernel_size=3, stride=1, pad_mode='same',
|
||||
has_bias=False, dilation=dilation, has_bn=True,
|
||||
momentum=BN_MOMENTUM, activation=None)
|
||||
self.relu = ops.ReLU()
|
||||
|
||||
def construct(self, x, residual=None):
|
||||
"""
|
||||
Basic residual block for dla.
|
||||
"""
|
||||
if residual is None:
|
||||
residual = x
|
||||
|
||||
out = self.conv_bn_act(x)
|
||||
out = self.conv_bn(out)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Root(nn.Cell):
|
||||
"""
|
||||
Get HDA node which play as the root of tree in each stage
|
||||
|
||||
Args:
|
||||
cin(int): Input channel.
|
||||
cout(int):Output channel.
|
||||
kernel_size(int): Covolution kernel size.
|
||||
residual(bool): Add residual or not.
|
||||
|
||||
Returns:
|
||||
Tensor, HDA node after aggregation.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, residual):
|
||||
super(Root, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=1, has_bias=False,
|
||||
pad_mode='pad', padding=(kernel_size - 1) // 2)
|
||||
self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
|
||||
self.relu = ops.ReLU()
|
||||
self.residual = residual
|
||||
self.cat = ops.Concat(axis=1)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
Get HDA node which play as the root of tree in each stage
|
||||
"""
|
||||
children = x
|
||||
x = self.conv(self.cat(x))
|
||||
x = self.bn(x)
|
||||
if self.residual:
|
||||
x += children[0]
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Tree(nn.Cell):
|
||||
"""
|
||||
Construct the deep aggregation network through recurrent. Each stage can be seen as a tree with multiple children.
|
||||
|
||||
Args:
|
||||
levels(list int): Tree height of each stage.
|
||||
block(Cell): Basic block of the tree.
|
||||
in_channels(list int): Input channel of each stage.
|
||||
out_channels(list int): Output channel of each stage.
|
||||
stride(int): Covolution stride. Default: 1.
|
||||
level_root(bool): Whether is the root of tree or not. Default: False.
|
||||
root_dim(int): Input channel of the root node. Default: 0.
|
||||
root_kernel_size(int): Covolution kernel size at the root. Default: 1.
|
||||
dilation(int): The dilation rate to be used for dilated convolution. Default: 1.
|
||||
root_residual(bool): Add residual or not. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the root ida node.
|
||||
"""
|
||||
|
||||
def __init__(self, levels, block, in_channels, out_channels, stride=1, level_root=False,
|
||||
root_dim=0, root_kernel_size=1, dilation=1, root_residual=False):
|
||||
super(Tree, self).__init__()
|
||||
self.levels = levels
|
||||
if root_dim == 0:
|
||||
root_dim = 2 * out_channels
|
||||
if level_root:
|
||||
root_dim += in_channels
|
||||
if self.levels == 1:
|
||||
self.tree1 = block(in_channels, out_channels, stride, dilation=dilation)
|
||||
self.tree2 = block(out_channels, out_channels, 1, dilation=dilation)
|
||||
else:
|
||||
self.tree1 = Tree(levels - 1, block, in_channels, out_channels, stride, root_dim=0,
|
||||
root_kernel_size=root_kernel_size, dilation=dilation, root_residual=root_residual)
|
||||
self.tree2 = Tree(levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels,
|
||||
root_kernel_size=root_kernel_size, dilation=dilation, root_residual=root_residual)
|
||||
if self.levels == 1:
|
||||
self.root = Root(root_dim, out_channels, root_kernel_size, root_residual)
|
||||
self.level_root = level_root
|
||||
self.root_dim = root_dim
|
||||
self.downsample = None
|
||||
self.project = None
|
||||
if stride > 1:
|
||||
self.downsample = nn.MaxPool2d(stride, stride=stride)
|
||||
if in_channels != out_channels:
|
||||
self.project = nn.Conv2dBnAct(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='same',
|
||||
has_bias=False, has_bn=True, momentum=BN_MOMENTUM,
|
||||
activation=None, after_fake=False)
|
||||
|
||||
def construct(self, x, residual=None, children=None):
|
||||
"""construct each stage tree recurrently"""
|
||||
children = () if children is None else children
|
||||
bottom = self.downsample(x) if self.downsample else x
|
||||
residual = self.project(bottom) if self.project else bottom
|
||||
if self.level_root:
|
||||
children += (bottom,)
|
||||
x1 = self.tree1(x, residual)
|
||||
if self.levels == 1:
|
||||
x2 = self.tree2(x1)
|
||||
ida_node = (x2, x1) + children
|
||||
x = self.root(ida_node)
|
||||
else:
|
||||
children += (x1,)
|
||||
x = self.tree2(x1, children=children)
|
||||
return x
|
||||
|
||||
|
||||
class DLA34(nn.Cell):
|
||||
"""
|
||||
Construct the downsampling deep aggregation network.
|
||||
|
||||
Args:
|
||||
levels(list int): Tree height of each stage.
|
||||
channels(list int): Input channel of each stage
|
||||
block(Cell): Initial basic block. Default: BasicBlock.
|
||||
residual_root(bool): Add residual or not. Default: False
|
||||
|
||||
Returns:
|
||||
tuple of Tensor, the root node of each stage.
|
||||
"""
|
||||
|
||||
def __init__(self, levels, channels, block=None, residual_root=False):
|
||||
super(DLA34, self).__init__()
|
||||
self.channels = channels
|
||||
self.base_layer = nn.Conv2dBnAct(3, channels[0], kernel_size=7, stride=1, pad_mode='same',
|
||||
has_bias=False, has_bn=True, momentum=BN_MOMENTUM,
|
||||
activation='relu', after_fake=False)
|
||||
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
|
||||
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
|
||||
self.level2 = Tree(levels[2], block, channels[1], channels[2], 2,
|
||||
level_root=False, root_residual=residual_root)
|
||||
self.level3 = Tree(levels[3], block, channels[2], channels[3], 2,
|
||||
level_root=True, root_residual=residual_root)
|
||||
self.level4 = Tree(levels[4], block, channels[3], channels[4], 2,
|
||||
level_root=True, root_residual=residual_root)
|
||||
self.level5 = Tree(levels[5], block, channels[4], channels[5], 2,
|
||||
level_root=True, root_residual=residual_root)
|
||||
self.dla_fn = [self.level0, self.level1, self.level2, self.level3, self.level4, self.level5]
|
||||
|
||||
def _make_conv_level(self, cin, cout, convs, stride=1, dilation=1):
|
||||
modules = []
|
||||
for i in range(convs):
|
||||
modules.append(nn.Conv2dBnAct(cin, cout, kernel_size=3, stride=stride if i == 0 else 1,
|
||||
pad_mode='pad', padding=dilation, has_bias=False, dilation=dilation,
|
||||
has_bn=True, momentum=BN_MOMENTUM, activation='relu', after_fake=False))
|
||||
cin = cout
|
||||
return nn.SequentialCell(modules)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
Construct the downsampling deep aggregation network.
|
||||
"""
|
||||
y = []
|
||||
x = self.base_layer(x)
|
||||
for i in range(len(self.channels)):
|
||||
x = self.dla_fn[i](x)
|
||||
y.append(x)
|
||||
return y
|
||||
|
||||
|
||||
class DeformConv(nn.Cell):
|
||||
"""
|
||||
Deformable convolution v2.
|
||||
|
||||
Args:
|
||||
cin(int): Input channel
|
||||
cout(int): Output_channel
|
||||
|
||||
Returns:
|
||||
Tensor, results after deformable convolution and activation
|
||||
"""
|
||||
|
||||
def __init__(self, cin, cout):
|
||||
super(DeformConv, self).__init__()
|
||||
self.actf = nn.SequentialCell([
|
||||
nn.BatchNorm2d(cout, momentum=BN_MOMENTUM),
|
||||
nn.ReLU()
|
||||
])
|
||||
self.conv = nn.Conv2d(cin, cout, kernel_size=3, stride=1, has_bias=False)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
Deformable convolution v2.
|
||||
"""
|
||||
x = self.conv(x)
|
||||
x = self.actf(x)
|
||||
return x
|
||||
|
||||
|
||||
class IDAUp(nn.Cell):
|
||||
"""IDAUp sample."""
|
||||
|
||||
def __init__(self, o, channels, up_f):
|
||||
super(IDAUp, self).__init__()
|
||||
proj_list = []
|
||||
up_list = []
|
||||
node_list = []
|
||||
for i in range(1, len(channels)):
|
||||
c = channels[i]
|
||||
f = int(up_f[i])
|
||||
proj = DeformConv(c, o)
|
||||
node = DeformConv(o, o)
|
||||
up = nn.Conv2dTranspose(o, o, f * 2, stride=f, pad_mode='pad', padding=f // 2, group=o)
|
||||
proj_list.append(proj)
|
||||
up_list.append(up)
|
||||
node_list.append(node)
|
||||
self.proj = nn.CellList(proj_list)
|
||||
self.up = nn.CellList(up_list)
|
||||
self.node = nn.CellList(node_list)
|
||||
|
||||
def construct(self, layers, startp, endp):
|
||||
"""IDAUp sample."""
|
||||
for i in range(startp + 1, endp):
|
||||
upsample = self.up[i - startp - 1]
|
||||
project = self.proj[i - startp - 1]
|
||||
layers[i] = upsample(project(layers[i]))
|
||||
node = self.node[i - startp - 1]
|
||||
layers[i] = node(layers[i] + layers[i - 1])
|
||||
return layers
|
||||
|
||||
|
||||
class DLAUp(nn.Cell):
|
||||
"""DLAUp sample."""
|
||||
|
||||
def __init__(self, startp, channels, scales, in_channels=None):
|
||||
super(DLAUp, self).__init__()
|
||||
self.startp = startp
|
||||
if in_channels is None:
|
||||
in_channels = channels
|
||||
self.channels = channels
|
||||
channels = list(channels)
|
||||
scales = np.array(scales, dtype=int)
|
||||
self.ida = []
|
||||
for i in range(len(channels) - 1):
|
||||
j = -i - 2
|
||||
self.ida.append(IDAUp(channels[j], in_channels[j:],
|
||||
scales[j:] // scales[j]))
|
||||
# setattr(self, 'ida_{}'.format(i),
|
||||
# IDAUp(channels[j], in_channels[j:],
|
||||
# scales[j:] // scales[j]))
|
||||
scales[j + 1:] = scales[j]
|
||||
in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]
|
||||
self.ida_nfs = nn.CellList(self.ida)
|
||||
|
||||
def construct(self, layers):
|
||||
"""DLAUp sample."""
|
||||
out = [layers[-1]] # start with 32
|
||||
for i in range(len(layers) - self.startp - 1):
|
||||
ida = self.ida_nfs[i]
|
||||
layers = ida(layers, len(layers) - i - 2, len(layers))
|
||||
out.append(layers[-1])
|
||||
a = []
|
||||
i = len(out)
|
||||
while i > 0:
|
||||
a.append(out[i - 1])
|
||||
i -= 1
|
||||
return a
|
||||
|
||||
|
||||
class DLASegConv(nn.Cell):
|
||||
"""
|
||||
The DLA backbone network.
|
||||
|
||||
Args:
|
||||
down_ratio(int): The ratio of input and output resolution
|
||||
last_level(int): The ending stage of the final upsampling
|
||||
stage_levels(list int): The tree height of each stage block
|
||||
stage_channels(list int): The feature channel of each stage
|
||||
|
||||
Returns:
|
||||
Tensor, the feature map extracted by dla network
|
||||
"""
|
||||
|
||||
def __init__(self, heads, down_ratio, final_kernel,
|
||||
last_level, head_conv, out_channel=0, is_training=True):
|
||||
super(DLASegConv, self).__init__()
|
||||
assert down_ratio in [2, 4, 8, 16]
|
||||
self.first_level = int(np.log2(down_ratio))
|
||||
self.last_level = last_level
|
||||
self.is_training = is_training
|
||||
self.base = DLA34([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock)
|
||||
channels = [16, 32, 64, 128, 256, 512]
|
||||
scales = [2 ** i for i in range(len(channels[self.first_level:]))]
|
||||
# self.dla_up = DLAUp(self.first_level, stage_channels[self.first_level:], last_level)
|
||||
self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)
|
||||
if out_channel == 0:
|
||||
out_channel = channels[self.first_level]
|
||||
self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level],
|
||||
[2 ** i for i in range(self.last_level - self.first_level)])
|
||||
self.heads = heads
|
||||
for head in self.heads:
|
||||
classes = self.heads[head]
|
||||
if head_conv > 0:
|
||||
if 'hm' in head:
|
||||
conv2d = nn.Conv2d(head_conv, classes, kernel_size=final_kernel, has_bias=True,
|
||||
bias_init=Constant(-2.19))
|
||||
self.hm_fc = nn.SequentialCell(
|
||||
[nn.Conv2d(channels[self.first_level], head_conv, kernel_size=3, has_bias=True), nn.ReLU(),
|
||||
conv2d])
|
||||
elif 'wh' in head:
|
||||
conv2d = nn.Conv2d(head_conv, classes, kernel_size=final_kernel, has_bias=True)
|
||||
self.wh_fc = nn.SequentialCell(
|
||||
[nn.Conv2d(channels[self.first_level], head_conv, kernel_size=3, has_bias=True), nn.ReLU(),
|
||||
conv2d])
|
||||
elif 'id' in head:
|
||||
conv2d = nn.Conv2d(head_conv, classes, kernel_size=final_kernel, has_bias=True)
|
||||
self.id_fc = nn.SequentialCell(
|
||||
[nn.Conv2d(channels[self.first_level], head_conv, kernel_size=3, has_bias=True), nn.ReLU(),
|
||||
conv2d])
|
||||
else:
|
||||
conv2d = nn.Conv2d(head_conv, classes, kernel_size=final_kernel, has_bias=True)
|
||||
self.reg_fc = nn.SequentialCell(
|
||||
[nn.Conv2d(channels[self.first_level], head_conv, kernel_size=3, has_bias=True), nn.ReLU(),
|
||||
conv2d])
|
||||
else:
|
||||
if 'hm' in head:
|
||||
self.hm_fc = nn.Conv2d(channels[self.first_level], classes, kernel_size=final_kernel, has_bias=True,
|
||||
bias_init=Constant(-2.19))
|
||||
elif 'wh' in head:
|
||||
self.wh_fc = nn.Conv2d(channels[self.first_level], classes, kernel_size=final_kernel, has_bias=True)
|
||||
elif 'id' in head:
|
||||
self.id_fc = nn.Conv2d(channels[self.first_level], classes, kernel_size=final_kernel, has_bias=True)
|
||||
else:
|
||||
self.reg_fc = nn.Conv2d(channels[self.first_level], classes, kernel_size=final_kernel,
|
||||
has_bias=True)
|
||||
|
||||
def construct(self, image):
|
||||
"""The DLA backbone network."""
|
||||
x = self.base(image)
|
||||
x = self.dla_up(x)
|
||||
y = []
|
||||
for i in range(self.last_level - self.first_level):
|
||||
y.append(x[i])
|
||||
y = self.ida_up(y, 0, len(y))
|
||||
hm = self.hm_fc(y[-1])
|
||||
wh = self.wh_fc(y[-1])
|
||||
feature_id = self.id_fc(y[-1])
|
||||
reg = self.reg_fc(y[-1])
|
||||
feature = {"hm": hm, "feature_id": feature_id, "wh": wh, "reg": reg}
|
||||
return feature
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Fairmot for training and evaluation
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""Cell with loss function.."""
|
||||
|
||||
def __init__(self, net, loss):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._net = net
|
||||
self._loss = loss
|
||||
|
||||
def construct(self, image, hm, reg_mask, ind, wh, reg, ids):
|
||||
"""Cell with loss function."""
|
||||
feature = self._net(image)
|
||||
return self._loss(feature, hm, reg_mask, ind, wh, reg, ids)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""Return net."""
|
||||
return self._net
|
||||
|
||||
|
||||
class WithNetCell(nn.Cell):
|
||||
"""Cell with infer_net function.."""
|
||||
|
||||
def __init__(self, net, infer_net):
|
||||
super(WithNetCell, self).__init__(auto_prefix=False)
|
||||
self._net = net
|
||||
self._infer_net = infer_net
|
||||
|
||||
def construct(self, image):
|
||||
"""Cell with loss function."""
|
||||
feature = self._net(image)
|
||||
return self._infer_net(feature)
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""infer net."""
|
||||
from mindspore import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from src.util import Sigmoid
|
||||
|
||||
|
||||
class GatherFeat(nn.Cell):
|
||||
"""gather feature."""
|
||||
|
||||
def __init__(self):
|
||||
super(GatherFeat, self).__init__()
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
self.gather = ops.GatherD()
|
||||
|
||||
def construct(self, feat, ind):
|
||||
"""gather feature."""
|
||||
dim = feat.shape[2]
|
||||
ind = self.expand_dims(ind, 2)
|
||||
shape = (ind.shape[0], ind.shape[1], dim)
|
||||
broadcast_to = ops.BroadcastTo(shape)
|
||||
ind = broadcast_to(ind)
|
||||
feat = self.gather(feat, 1, ind)
|
||||
# if mask is not None:
|
||||
# mask = self.expand_dims(mask, 2)
|
||||
# broadcast = ops.BroadcastTo(feat.shape)
|
||||
# mask = broadcast_to(mask)
|
||||
# # feat = feat[mask]
|
||||
# # feat = feat.view(-1, dim)
|
||||
return feat
|
||||
|
||||
|
||||
class TranposeAndGatherFeat(nn.Cell):
|
||||
"""transpose and gather feature."""
|
||||
|
||||
def __init__(self):
|
||||
super(TranposeAndGatherFeat, self).__init__()
|
||||
self.transpose = ops.Transpose()
|
||||
self.GatherFeat = GatherFeat()
|
||||
|
||||
def construct(self, feat, ind):
|
||||
"""transpose and gather feature."""
|
||||
feat = self.transpose(feat, (0, 2, 3, 1))
|
||||
feat = feat.view(feat.shape[0], -1, feat.shape[3])
|
||||
feat = self.GatherFeat(feat, ind)
|
||||
return feat
|
||||
|
||||
|
||||
class MotDecode(nn.Cell):
|
||||
"""
|
||||
Network tracking results of the decoder
|
||||
"""
|
||||
|
||||
def __init__(self, ltrb=False):
|
||||
super(MotDecode, self).__init__()
|
||||
self.cast = ops.Cast()
|
||||
self.concat = ops.Concat(axis=2)
|
||||
self.ltrb = ltrb
|
||||
self.topk = ops.TopK(sorted=True)
|
||||
self.div = ops.Div()
|
||||
self.GatherFeat = GatherFeat()
|
||||
self.TranposeAndGatherFeat = TranposeAndGatherFeat()
|
||||
self.select = ops.Select()
|
||||
self.zeroslike = ops.ZerosLike()
|
||||
self.equal = ops.Equal()
|
||||
self.pool = nn.MaxPool2d((3, 3), stride=1, pad_mode='same')
|
||||
|
||||
def construct(self, heat, wh, K, reg):
|
||||
"""
|
||||
Network tracking results of the decoder
|
||||
"""
|
||||
batch, cat, height, width = heat.shape
|
||||
heat = self.cast(heat, mstype.float16)
|
||||
hmax = self.pool(heat)
|
||||
keep = self.equal(hmax, heat)
|
||||
input_x = self.zeroslike(heat)
|
||||
M = self.select(keep, heat, input_x)
|
||||
heat = self.cast(M, mstype.float32)
|
||||
topk_scores, topk_inds = self.topk(heat.view(batch, cat, -1), K)
|
||||
topk_inds = topk_inds % (height * width)
|
||||
topk_ys = self.cast(self.div(topk_inds, width), mstype.float32)
|
||||
topk_xs = self.cast((topk_inds % width), mstype.float32)
|
||||
scores, topk_ind = self.topk(topk_scores.view(batch, -1), K)
|
||||
clses = self.cast(self.div(topk_ind, K), mstype.int32)
|
||||
inds = self.GatherFeat(
|
||||
topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
|
||||
ys = self.GatherFeat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
|
||||
xs = self.GatherFeat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
|
||||
if reg is not None:
|
||||
reg = self.TranposeAndGatherFeat(reg, inds)
|
||||
reg = reg.view(batch, K, 2)
|
||||
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
|
||||
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
|
||||
else:
|
||||
xs = xs.view(batch, K, 1) + 0.5
|
||||
ys = ys.view(batch, K, 1) + 0.5
|
||||
wh = self.TranposeAndGatherFeat(wh, inds)
|
||||
if self.ltrb:
|
||||
wh = wh.view(batch, K, 4)
|
||||
else:
|
||||
wh = wh.view(batch, K, 2)
|
||||
clses = clses.view(batch, K, 1)
|
||||
clses = self.cast(clses, mstype.float32)
|
||||
scores = scores.view(batch, K, 1)
|
||||
if self.ltrb:
|
||||
bboxes = self.concat((xs - wh[..., 0:1],
|
||||
ys - wh[..., 1:2],
|
||||
xs + wh[..., 2:3],
|
||||
ys + wh[..., 3:4]))
|
||||
else:
|
||||
bboxes = self.concat((xs - wh[..., 0:1] / 2,
|
||||
ys - wh[..., 1:2] / 2,
|
||||
xs + wh[..., 0:1] / 2,
|
||||
ys + wh[..., 1:2] / 2))
|
||||
detections = self.concat((bboxes, scores, clses))
|
||||
return detections, inds
|
||||
|
||||
|
||||
class InferNet(nn.Cell):
|
||||
"""
|
||||
Network tracking results of the decoder
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(InferNet, self).__init__()
|
||||
self.sigmoid = Sigmoid()
|
||||
self.l2_normalize = ops.L2Normalize(axis=1, epsilon=1e-12)
|
||||
self.mot_decode = MotDecode(ltrb=True)
|
||||
self.TranposeAndGatherFeat = TranposeAndGatherFeat()
|
||||
self.squeeze = ops.Squeeze(0)
|
||||
|
||||
def construct(self, feature):
|
||||
hm = self.sigmoid(feature['hm'])
|
||||
id_feature = self.l2_normalize(feature['feature_id'])
|
||||
dets, inds = self.mot_decode(hm, feature['wh'], 500, feature['reg'])
|
||||
id_feature = self.TranposeAndGatherFeat(id_feature, inds)
|
||||
id_feature = self.squeeze(id_feature)
|
||||
return id_feature, dets
|
|
@ -0,0 +1,213 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
loss
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore as ms
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from src.util import Sigmoid, TransposeGatherFeature
|
||||
|
||||
|
||||
class FocalLoss(nn.Cell):
|
||||
"""
|
||||
Warpper for focal loss.
|
||||
|
||||
Args:
|
||||
alpha(int): Super parameter in focal loss to mimic loss weight. Default: 2.
|
||||
beta(int): Super parameter in focal loss to mimic imbalance between positive and negative samples. Default: 4.
|
||||
|
||||
Returns:
|
||||
Tensor, focal loss.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=2, beta=4):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.pow = ops.Pow()
|
||||
self.log = ops.Log()
|
||||
self.select = ops.Select()
|
||||
self.equal = ops.Equal()
|
||||
self.less = ops.Less()
|
||||
self.cast = ops.Cast()
|
||||
self.fill = ops.Fill()
|
||||
self.dtype = ops.DType()
|
||||
self.shape = ops.Shape()
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
|
||||
def construct(self, out, target):
|
||||
"""focal loss"""
|
||||
pos_inds = self.cast(self.equal(target, 1.0), mstype.float32)
|
||||
neg_inds = self.cast(self.less(target, 1.0), mstype.float32)
|
||||
neg_weights = self.pow(1 - target, self.beta)
|
||||
|
||||
pos_loss = self.log(out) * self.pow(1 - out, self.alpha) * pos_inds
|
||||
neg_loss = self.log(1 - out) * self.pow(out, self.alpha) * neg_weights * neg_inds
|
||||
|
||||
num_pos = self.reduce_sum(pos_inds, ())
|
||||
num_pos = self.select(self.equal(num_pos, 0.0),
|
||||
self.fill(self.dtype(num_pos), self.shape(num_pos), 1.0), num_pos)
|
||||
pos_loss = self.reduce_sum(pos_loss, ())
|
||||
neg_loss = self.reduce_sum(neg_loss, ())
|
||||
loss = - (pos_loss + neg_loss) / num_pos
|
||||
return loss
|
||||
|
||||
|
||||
class RegLoss(nn.Cell):
|
||||
"""
|
||||
Warpper for regression loss.
|
||||
|
||||
Args:
|
||||
mode(str): L1 or Smoothed L1 loss. Default: "l1"
|
||||
|
||||
Returns:
|
||||
Tensor, regression loss.
|
||||
"""
|
||||
|
||||
def __init__(self, mode='l1'):
|
||||
super(RegLoss, self).__init__()
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.cast = ops.Cast()
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
self.reshape = ops.Reshape()
|
||||
self.gather_feature = TransposeGatherFeature()
|
||||
if mode == 'l1':
|
||||
self.loss = nn.L1Loss(reduction='sum')
|
||||
elif mode == 'sl1':
|
||||
self.loss = nn.SmoothL1Loss()
|
||||
else:
|
||||
self.loss = None
|
||||
|
||||
def construct(self, output, mask, ind, target):
|
||||
"""Warpper for regression loss."""
|
||||
pred = self.gather_feature(output, ind)
|
||||
mask = self.cast(mask, mstype.float32)
|
||||
num = self.reduce_sum(mask, ())
|
||||
mask = self.expand_dims(mask, 2)
|
||||
target = target * mask
|
||||
pred = pred * mask
|
||||
regr_loss = self.loss(pred, target)
|
||||
regr_loss = regr_loss / (num + 1e-4)
|
||||
return regr_loss
|
||||
|
||||
|
||||
class CenterNetMultiPoseLossCell(nn.Cell):
|
||||
"""
|
||||
Provide pose estimation network losses.
|
||||
|
||||
Args:
|
||||
net_config: The config info of CenterNet network.
|
||||
|
||||
Returns:
|
||||
Tensor, total loss.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(CenterNetMultiPoseLossCell, self).__init__()
|
||||
self.crit = FocalLoss()
|
||||
# self.crit_wh = RegWeightedL1Loss() if not config.net.dense_hp else nn.L1Loss(reduction='sum')
|
||||
self.crit_wh = RegLoss(opt.reg_loss)
|
||||
# wh
|
||||
self.crit_reg = RegLoss(opt.reg_loss) # reg_loss = 'l1'
|
||||
self.hm_weight = opt.hm_weight # hm_weight = 1 :loss weight for keypoint heatmaps
|
||||
self.wh_weight = opt.wh_weight # wh_weight = 0.1 : loss weight for bounding box size
|
||||
self.off_weight = opt.off_weight # off_weight = 1 : loss weight for keypoint local offsets
|
||||
self.reg_offset = opt.reg_offset # reg_offset = True : regress local offset
|
||||
|
||||
# self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind
|
||||
self.reg_ind = "reg" if self.reg_offset else "wh"
|
||||
|
||||
# define id
|
||||
self.emb_dim = opt.reid_dim # dataset.reid_dim = 128
|
||||
self.nID = opt.nID # nId = 14455
|
||||
self.classifier = nn.Dense(self.emb_dim, self.nID).to_float(mstype.float16)
|
||||
self.IDLoss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
self.emb_scale = math.sqrt(2) * math.log(self.nID - 1) # fix np
|
||||
self.s_det = Parameter(Tensor(-1.85 * np.ones(1), mstype.float32))
|
||||
self.s_id = Parameter(Tensor(-1.05 * np.ones(1), mstype.float32))
|
||||
# self.s_id = Tensor(-1.05 * self.ones(1, mindspore.float32))
|
||||
|
||||
self.normalize = ops.L2Normalize(axis=1)
|
||||
self.greater = ops.Greater()
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
self.tile = ops.Tile()
|
||||
self.multiples_1 = (1, 1, 128)
|
||||
# self.multiples_2 = (1, 1, 14455)
|
||||
self.select = ops.Select()
|
||||
self.zeros = ops.Zeros()
|
||||
self.exp = ops.Exp()
|
||||
self.squeeze = ops.Squeeze(0)
|
||||
self.TransposeGatherFeature = TransposeGatherFeature()
|
||||
self.reshape = ops.Reshape()
|
||||
self.reshape_mul = opt.batch_size * 500
|
||||
self.cast = ops.Cast()
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def construct(self, feature, hm, reg_mask, ind, wh, reg, ids):
|
||||
"""Defines the computation performed."""
|
||||
output_hm = feature["hm"] # FocalLoss()
|
||||
output_hm = self.sigmoid(output_hm)
|
||||
|
||||
hm_loss = self.crit(output_hm, hm)
|
||||
|
||||
output_id = feature["feature_id"] # SoftmaxCrossEntropyWithLogits()
|
||||
id_head = self.TransposeGatherFeature(output_id, ind) # id_head=[1,500,128]
|
||||
# print(id_head.shape)
|
||||
|
||||
# id_head = id_head[reg_mask > 0]
|
||||
cond = self.greater(reg_mask, 0) # cond=[1,500]
|
||||
cond_cast = self.cast(cond, ms.int32)
|
||||
expand_output = self.expand_dims(cond_cast, 2)
|
||||
tile_out = self.tile(expand_output, self.multiples_1)
|
||||
tile_cast = self.cast(tile_out, ms.bool_)
|
||||
fill_zero = self.zeros(id_head.shape, mstype.float32) # fill_zero=[1,500,128]
|
||||
id_head = self.select(tile_cast, id_head, fill_zero) # id_head=[1,500,128]
|
||||
|
||||
id_head = self.emb_scale * self.normalize(id_head) # id_head=[1,500,128]
|
||||
# id_head = self.emb_scale * ops.L2Normalize(id_head)
|
||||
|
||||
zero_input = self.zeros(ids.shape, mstype.int32)
|
||||
id_target = self.select(cond, ids, zero_input) # id_target=[1,500]
|
||||
id_target_out = self.reshape(id_target, (self.reshape_mul,))
|
||||
# expand_output = self.expand_dims(id_target, 2)
|
||||
# tile_out = self.tile(expand_output, self.multiples_2)
|
||||
|
||||
c_out = self.reshape(id_head, (self.reshape_mul, 128))
|
||||
c_out = self.cast(c_out, mstype.float16)
|
||||
id_output = self.classifier(c_out) # id_output=[1,500,14455]
|
||||
id_output = self.cast(id_output, ms.float32)
|
||||
# id_output = self.squeeze(id_output) # id_output=[500,14455]
|
||||
# id_target = self.squeeze(tile_out) # id_target=[500,14455]
|
||||
id_loss = self.IDLoss(id_output, id_target_out)
|
||||
|
||||
output_wh = feature["wh"] # Regl1Loss
|
||||
wh_loss = self.crit_reg(output_wh, reg_mask, ind, wh)
|
||||
|
||||
off_loss = 0
|
||||
if self.reg_offset and self.off_weight > 0: # Regl1Loss
|
||||
output_reg = feature[self.reg_ind]
|
||||
off_loss = self.crit_reg(output_reg, reg_mask, ind, reg)
|
||||
|
||||
det_loss = self.hm_weight * hm_loss + self.wh_weight * wh_loss + self.off_weight * off_loss
|
||||
loss = self.exp(-self.s_det) * det_loss + self.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)
|
||||
loss *= 0.5
|
||||
|
||||
return loss
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
config
|
||||
"""
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
|
||||
class Opts:
|
||||
"""
|
||||
parameter configuration
|
||||
"""
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
# basic experiment setting
|
||||
self.parser.add_argument('--load_model', default="/Fairmot/ckpt/Fairmot_7-30_1595.ckpt",
|
||||
help='path to pretrained model')
|
||||
self.parser.add_argument('--load_pre_model', default="/Fairmot/ckpt/dla34-ba72cf86_ms.ckpt",
|
||||
help='path to pretrained model')
|
||||
self.parser.add_argument('--data_cfg', type=str,
|
||||
default='/fairmot/data/data.json', help='load data from cfg')
|
||||
self.parser.add_argument('--arch', default='dla_34',
|
||||
help='model architecture. Currently tested'
|
||||
'resdcn_34 | resdcn_50 | resfpndcn_34 |' 'dla_34 | hrnet_18')
|
||||
self.parser.add_argument('--num_epochs', type=int, default=30, help='total training epochs.')
|
||||
self.parser.add_argument('--lr', type=float, default=1e-4,
|
||||
help='learning rate for batch size 12.')
|
||||
self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
|
||||
self.parser.add_argument('--input-video', type=str,
|
||||
default='/videos/MOT16-03.mp4', help='path to the input video')
|
||||
self.parser.add_argument('--output-root', type=str, default='./exports', help='expected output root path')
|
||||
self.parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False,
|
||||
help="Run distribute in modelarts, default is false.")
|
||||
self.parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is false.")
|
||||
self.parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
self.parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
self.parser.add_argument('--id', type=int, default=0)
|
||||
# model
|
||||
self.parser.add_argument('--head_conv', type=int, default=-1,
|
||||
help='conv layer channels for output head')
|
||||
self.parser.add_argument('--down_ratio', type=int, default=4,
|
||||
help='output stride. Currently only supports 4.')
|
||||
# input
|
||||
self.parser.add_argument('--input_res', type=int, default=-1,
|
||||
help='input height and width. -1 for default from '
|
||||
'dataset. Will be overridden by input_h | input_w')
|
||||
self.parser.add_argument('--input_h', type=int, default=-1,
|
||||
help='input height. -1 for default from dataset.')
|
||||
self.parser.add_argument('--input_w', type=int, default=-1,
|
||||
help='input width. -1 for default from dataset.')
|
||||
# test
|
||||
self.parser.add_argument('--K', type=int, default=500, help='max number of output objects.')
|
||||
self.parser.add_argument('--not_prefetch_test', action='store_true',
|
||||
help='not use parallal data pre-processing.')
|
||||
self.parser.add_argument('--fix_res', action='store_true',
|
||||
help='fix testing resolution or keep the original resolution')
|
||||
self.parser.add_argument('--keep_res', action='store_true',
|
||||
help='keep the original resolution during validation.')
|
||||
# tracking
|
||||
self.parser.add_argument('--conf_thres', type=float, default=0.3, help='confidence thresh for tracking')
|
||||
self.parser.add_argument('--det_thres', type=float, default=0.3, help='confidence thresh for detection')
|
||||
self.parser.add_argument('--nms_thres', type=float, default=0.4, help='iou thresh for nms')
|
||||
self.parser.add_argument('--track_buffer', type=int, default=30, help='tracking buffer')
|
||||
self.parser.add_argument('--min-box-area', type=float, default=100, help='filter out tiny boxes')
|
||||
self.parser.add_argument('--output-format', type=str, default='video', help='video or text')
|
||||
# mot
|
||||
self.parser.add_argument('--data_dir', default='/opt_data/xidian_wks/kasor/Fairmot/dataset')
|
||||
# loss
|
||||
self.parser.add_argument('--mse_loss', action='store_true',
|
||||
help='use mse loss or focal loss to train keypoint heatmaps.')
|
||||
self.parser.add_argument('--reg_loss', default='l1',
|
||||
help='regression loss: sl1 | l1 | l2')
|
||||
self.parser.add_argument('--hm_weight', type=float, default=1,
|
||||
help='loss weight for keypoint heatmaps.')
|
||||
self.parser.add_argument('--off_weight', type=float, default=1,
|
||||
help='loss weight for keypoint local offsets.')
|
||||
self.parser.add_argument('--wh_weight', type=float, default=0.1,
|
||||
help='loss weight for bounding box size.')
|
||||
self.parser.add_argument('--id_loss', default='ce',
|
||||
help='reid loss: ce | triplet')
|
||||
self.parser.add_argument('--id_weight', type=float, default=1,
|
||||
help='loss weight for id')
|
||||
self.parser.add_argument('--reid_dim', type=int, default=128,
|
||||
help='feature dim for reid')
|
||||
self.parser.add_argument('--ltrb', default=True,
|
||||
help='regress left, top, right, bottom of bbox')
|
||||
self.parser.add_argument('--norm_wh', action='store_true',
|
||||
help='L1(\\hat(y) / y, 1) or L1(\\hat(y), y)')
|
||||
self.parser.add_argument('--dense_wh', action='store_true',
|
||||
help='apply weighted regression near center or '
|
||||
'just apply regression on center point.')
|
||||
self.parser.add_argument('--cat_spec_wh', action='store_true',
|
||||
help='category specific bounding box size.')
|
||||
self.parser.add_argument('--not_reg_offset', action='store_true',
|
||||
help='not regress local offset.')
|
||||
|
||||
def parse(self, args=''):
|
||||
"""parameter parse"""
|
||||
if args == '':
|
||||
opt = self.parser.parse_args()
|
||||
else:
|
||||
opt = self.parser.parse_args(args)
|
||||
|
||||
opt.fix_res = not opt.keep_res
|
||||
print('Fix size testing.' if opt.fix_res else 'Keep resolution testing.')
|
||||
opt.reg_offset = not opt.not_reg_offset
|
||||
|
||||
if opt.head_conv == -1: # init default head_conv
|
||||
opt.head_conv = 256 if 'dla' in opt.arch else 256
|
||||
opt.pad = 31
|
||||
opt.num_stacks = 1
|
||||
|
||||
|
||||
return opt
|
||||
|
||||
def update_dataset_info_and_set_heads(self, opt, dataset):
|
||||
"""update dataset info and set heads"""
|
||||
input_h, input_w = dataset.default_resolution
|
||||
opt.mean, opt.std = dataset.mean, dataset.std
|
||||
opt.num_classes = dataset.num_classes
|
||||
|
||||
# input_h(w): opt.input_h overrides opt.input_res overrides dataset default
|
||||
input_h = opt.input_res if opt.input_res > 0 else input_h
|
||||
input_w = opt.input_res if opt.input_res > 0 else input_w
|
||||
opt.input_h = opt.input_h if opt.input_h > 0 else input_h
|
||||
opt.input_w = opt.input_w if opt.input_w > 0 else input_w
|
||||
opt.output_h = opt.input_h // opt.down_ratio
|
||||
opt.output_w = opt.input_w // opt.down_ratio
|
||||
opt.input_res = max(opt.input_h, opt.input_w)
|
||||
opt.output_res = max(opt.output_h, opt.output_w)
|
||||
|
||||
opt.heads = {'hm': opt.num_classes,
|
||||
'wh': 2 if not opt.ltrb else 4,
|
||||
'id': opt.reid_dim}
|
||||
if opt.reg_offset:
|
||||
opt.heads.update({'reg': 2})
|
||||
opt.nID = dataset.nID
|
||||
opt.img_size = (1088, 608)
|
||||
# opt.img_size = (864, 480)
|
||||
# opt.img_size = (576, 320)
|
||||
print('heads', opt.heads)
|
||||
return opt
|
||||
|
||||
def init(self, args=''):
|
||||
"""opt init"""
|
||||
default_dataset_info = {
|
||||
'mot': {'default_resolution': [608, 1088], 'num_classes': 1,
|
||||
'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278],
|
||||
'dataset': 'jde', 'nID': 14455},
|
||||
}
|
||||
|
||||
class Struct:
|
||||
"""opt struct"""
|
||||
def __init__(self, entries):
|
||||
for k, v in entries.items():
|
||||
self.__setattr__(k, v)
|
||||
|
||||
opt = self.parse(args)
|
||||
dataset = Struct(default_dataset_info['mot'])
|
||||
opt.dataset = dataset.dataset
|
||||
opt = self.update_dataset_info_and_set_heads(opt, dataset)
|
||||
return opt
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Fairmot for track
|
||||
"""
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrackState:
|
||||
"""TrackState"""
|
||||
New = 0
|
||||
Tracked = 1
|
||||
Lost = 2
|
||||
Removed = 3
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class BaseTrack:
|
||||
"""
|
||||
Fairmot for BaseTrack
|
||||
"""
|
||||
_count = 0
|
||||
|
||||
track_id = 0
|
||||
is_activated = False
|
||||
state = TrackState.New
|
||||
|
||||
history = collections.OrderedDict()
|
||||
features = []
|
||||
curr_feature = None
|
||||
score = 0
|
||||
start_frame = 0
|
||||
frame_id = 0
|
||||
time_since_update = 0
|
||||
|
||||
# multi-camera
|
||||
location = (np.inf, np.inf)
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
"""end frame"""
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
"""next id"""
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
"""activate"""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
"""predict"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
"""update"""
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
"""mark lost"""
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_removed(self):
|
||||
"""mark removed"""
|
||||
self.state = TrackState.Removed
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Fairmot for track
|
||||
"""
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import cdist
|
||||
import lap
|
||||
from cython_bbox import bbox_overlaps as bbox_ious
|
||||
from src.tracking_utils import kalman_filter
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh):
|
||||
"""linear assignment"""
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
matches, unmatched_a, unmatched_b = [], [], []
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
for ix, mx in enumerate(x):
|
||||
if mx >= 0:
|
||||
matches.append([ix, mx])
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
matches = np.asarray(matches)
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def ious(atlbrs, btlbrs):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
|
||||
:rtype ious np.ndarray
|
||||
"""
|
||||
iou = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
|
||||
if iou.size == 0:
|
||||
return iou
|
||||
|
||||
iou = bbox_ious(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float)
|
||||
)
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[Track]
|
||||
:type btracks: list[Track]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (atracks and isinstance(atracks[0], np.ndarray)) or (
|
||||
btracks and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
cost_matrix = 1 - _ious
|
||||
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks, detections, metric='cosine'):
|
||||
"""
|
||||
:param tracks: list[Track]
|
||||
:param detections: list[BaseTrack]
|
||||
:param metric:
|
||||
:return: cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float)
|
||||
# for i, track in enumerate(tracks):
|
||||
# cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
|
||||
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float)
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
|
||||
"""gate cost matrix"""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
|
||||
"""fuse motion"""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position, metric='maha')
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
|
||||
return cost_matrix
|
|
@ -0,0 +1,385 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Fairmot for multitracker
|
||||
"""
|
||||
import collections
|
||||
import numpy as np
|
||||
from src.tracker import matching
|
||||
from src.utils.tools import ctdet_post_process
|
||||
from src.tracker.basetrack import BaseTrack, TrackState
|
||||
from src.tracking_utils.kalman_filter import KalmanFilter
|
||||
|
||||
|
||||
class Track(BaseTrack):
|
||||
"""
|
||||
Fairmot for Track
|
||||
"""
|
||||
shared_kalman = KalmanFilter()
|
||||
|
||||
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
|
||||
|
||||
# wait activate
|
||||
self._tlwh = np.asarray(tlwh, dtype=np.float)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
self.track_id = None
|
||||
self.start_frame = None
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
self.frame_id = None
|
||||
self.smooth_feat = None
|
||||
self.update_features(temp_feat)
|
||||
self.features = collections.deque([], maxlen=buffer_size)
|
||||
self.alpha = 0.9
|
||||
self.state = None
|
||||
|
||||
def update_features(self, feat):
|
||||
"""update features"""
|
||||
feat /= np.linalg.norm(feat)
|
||||
self.curr_feat = feat
|
||||
if self.smooth_feat is None:
|
||||
self.smooth_feat = feat
|
||||
else:
|
||||
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
|
||||
self.features.append(feat)
|
||||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||
|
||||
def predict(self):
|
||||
"""predict"""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""multi predict"""
|
||||
if stracks:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = Track.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet"""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
|
||||
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
if frame_id == 1:
|
||||
self.is_activated = True
|
||||
# self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""reactivate a matched track"""
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
|
||||
)
|
||||
|
||||
self.update_features(new_track.curr_feat)
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
if new_id:
|
||||
self.track_id = self.next_id()
|
||||
|
||||
def update(self, new_track, frame_id, update_feature=True):
|
||||
"""
|
||||
Update a matched track
|
||||
:type new_track: Track
|
||||
:type frame_id: int
|
||||
:type update_feature: bool
|
||||
:return:
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
||||
new_tlwh = new_track.tlwh
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
|
||||
self.score = new_track.score
|
||||
if update_feature:
|
||||
self.update_features(new_track.curr_feat)
|
||||
|
||||
@property
|
||||
# @jit(nopython=True)
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
# @jit(nopython=True)
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
def to_xyah(self):
|
||||
"""to xyah"""
|
||||
return self.tlwh_to_xyah(self.tlwh)
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
"""tlbr to tlwh"""
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
"""tlwh to tlbr"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||
|
||||
|
||||
class JDETracker:
|
||||
"""
|
||||
Fairmot for JDETracker
|
||||
"""
|
||||
|
||||
def __init__(self, opt, frame_rate=30):
|
||||
self.opt = opt
|
||||
self.tracked_stracks = [] # type: list[Track]
|
||||
self.lost_stracks = [] # type: list[Track]
|
||||
self.removed_stracks = [] # type: list[Track]
|
||||
|
||||
self.frame_id = 0
|
||||
self.det_thresh = opt.conf_thres
|
||||
self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
|
||||
self.max_time_lost = self.buffer_size
|
||||
self.max_per_image = opt.K
|
||||
self.mean = np.array([[[0.408, 0.447, 0.47]]], dtype=np.float32)
|
||||
self.std = np.array([[[0.289, 0.274, 0.278]]], dtype=np.float32)
|
||||
self.kalman_filter = KalmanFilter()
|
||||
|
||||
def update(self, id_feature, dets, meta):
|
||||
"""update track frame"""
|
||||
self.frame_id += 1
|
||||
activated_starcks, refind_stracks, lost_stracks, removed_stracks = [], [], [], []
|
||||
dets = self.post_process(dets, meta)
|
||||
dets = self.merge_outputs([dets])[1]
|
||||
remain_inds = dets[:, 4] > self.opt.conf_thres
|
||||
dets = dets[remain_inds]
|
||||
id_feature = id_feature[remain_inds]
|
||||
detections = self.create_detections(dets, id_feature)
|
||||
# Add newly detected tracklets to tracked_stracks
|
||||
unconfirmed, tracked_stracks = [], []
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
unconfirmed.append(track)
|
||||
else:
|
||||
tracked_stracks.append(track)
|
||||
# Step 2: First association, with embedding
|
||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
Track.multi_predict(strack_pool)
|
||||
dists = matching.embedding_distance(strack_pool, detections)
|
||||
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.4)
|
||||
for itracked, idet in matches:
|
||||
track, det = strack_pool[itracked], detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
# Step 3: Second association, with IOU
|
||||
detections = [detections[i] for i in u_detection]
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if not track.state == TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
# Deal with unconfirmed tracks, usually tracks with only one beginning frame
|
||||
detections = [detections[i] for i in u_detection]
|
||||
dists = matching.iou_distance(unconfirmed, detections)
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(unconfirmed[itracked])
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
# Step 4: Init new stracks
|
||||
activated_starcks = self.init_new_stracks(u_detection, detections, activated_starcks)
|
||||
# Step 5: Update state
|
||||
removed_stracks = self.update_state(removed_stracks)
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
||||
# get scores of lost tracks
|
||||
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
||||
return output_stracks
|
||||
|
||||
def post_process(self, dets, meta):
|
||||
"""post process"""
|
||||
dets = dets.asnumpy()
|
||||
dets = dets.reshape(1, -1, dets.shape[2])
|
||||
dets = ctdet_post_process(
|
||||
dets.copy(), [meta['c']], [meta['s']],
|
||||
meta['out_height'], meta['out_width'], self.opt.num_classes)
|
||||
for j in range(1, self.opt.num_classes + 1):
|
||||
dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)
|
||||
return dets[0]
|
||||
|
||||
def create_detections(self, dets, id_feature):
|
||||
"""create detections"""
|
||||
if np.shape(dets)[0]:
|
||||
detections = []
|
||||
for tlbrs, f in zip(dets[:, :5], id_feature):
|
||||
detections.append(Track(Track.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30))
|
||||
else:
|
||||
detections = []
|
||||
return detections
|
||||
|
||||
def merge_outputs(self, detections):
|
||||
"""merge outputs"""
|
||||
results = {}
|
||||
for j in range(1, self.opt.num_classes + 1):
|
||||
results[j] = np.concatenate(
|
||||
[detection[j] for detection in detections], axis=0).astype(np.float32)
|
||||
|
||||
scores = np.hstack(
|
||||
[results[j][:, 4] for j in range(1, self.opt.num_classes + 1)])
|
||||
if len(scores) > self.max_per_image:
|
||||
kth = len(scores) - self.max_per_image
|
||||
thresh = np.partition(scores, kth)[kth]
|
||||
for j in range(1, self.opt.num_classes + 1):
|
||||
keep_inds = (results[j][:, 4] >= thresh)
|
||||
results[j] = results[j][keep_inds]
|
||||
return results
|
||||
|
||||
def init_new_stracks(self, u_detection, detections, activated_starcks):
|
||||
"""init new stracks"""
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
if track.score < self.det_thresh:
|
||||
continue
|
||||
track.activate(self.kalman_filter, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
return activated_starcks
|
||||
|
||||
def update_state(self, removed_stracks):
|
||||
"""update state"""
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
return removed_stracks
|
||||
|
||||
|
||||
def joint_stracks(tlista, tlistb):
|
||||
"""joint stracks"""
|
||||
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
exists[t.track_id] = 1
|
||||
res.append(t)
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if not exists.get(tid, 0):
|
||||
exists[tid] = 1
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
|
||||
def sub_stracks(tlista, tlistb):
|
||||
"""sub stracks"""
|
||||
stracks = {}
|
||||
for t in tlista:
|
||||
stracks[t.track_id] = t
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
|
||||
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
"""remove duplicate stracks"""
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = list(), list()
|
||||
for p, q in zip(*pairs):
|
||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
||||
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
||||
if timep > timeq:
|
||||
dupb.append(q)
|
||||
else:
|
||||
dupa.append(p)
|
||||
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
|
||||
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
|
||||
return resa, resb
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Fairmot for evaluation
|
||||
"""
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
import motmetrics as mm
|
||||
from src.tracking_utils.io import read_results, unzip_objs
|
||||
|
||||
mm.lap.default_solver = 'lap'
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""
|
||||
Evaluate Tracking Results
|
||||
"""
|
||||
def __init__(self, data_root, seq_name, data_type):
|
||||
self.data_root = data_root
|
||||
self.seq_name = seq_name
|
||||
self.data_type = data_type
|
||||
|
||||
self.load_annotations()
|
||||
self.reset_accumulator()
|
||||
|
||||
def load_annotations(self):
|
||||
"""load annotations"""
|
||||
assert self.data_type == 'mot'
|
||||
|
||||
gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
|
||||
self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
|
||||
self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
|
||||
|
||||
def reset_accumulator(self):
|
||||
"""reset accumulator"""
|
||||
self.acc = mm.MOTAccumulator(auto_id=True)
|
||||
|
||||
def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
|
||||
"""eval frame"""
|
||||
# results
|
||||
trk_tlwhs = np.copy(trk_tlwhs)
|
||||
trk_ids = np.copy(trk_ids)
|
||||
|
||||
# gts
|
||||
gt_objs = self.gt_frame_dict.get(frame_id, [])
|
||||
gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
|
||||
|
||||
# ignore boxes
|
||||
ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
|
||||
ignore_tlwhs = unzip_objs(ignore_objs)[0]
|
||||
|
||||
# remove ignored results
|
||||
keep = np.ones(len(trk_tlwhs), dtype=bool)
|
||||
iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
|
||||
if iou_distance:
|
||||
match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
|
||||
match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
|
||||
match_ious = iou_distance[match_is, match_js]
|
||||
|
||||
match_js = np.asarray(match_js, dtype=int)
|
||||
match_js = match_js[np.logical_not(np.isnan(match_ious))]
|
||||
keep[match_js] = False
|
||||
trk_tlwhs = trk_tlwhs[keep]
|
||||
trk_ids = trk_ids[keep]
|
||||
# match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
|
||||
# match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
|
||||
# match_ious = iou_distance[match_is, match_js]
|
||||
|
||||
# match_js = np.asarray(match_js, dtype=int)
|
||||
# match_js = match_js[np.logical_not(np.isnan(match_ious))]
|
||||
# keep[match_js] = False
|
||||
# trk_tlwhs = trk_tlwhs[keep]
|
||||
# trk_ids = trk_ids[keep]
|
||||
|
||||
# get distance matrix
|
||||
iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
|
||||
|
||||
# acc
|
||||
self.acc.update(gt_ids, trk_ids, iou_distance)
|
||||
|
||||
if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
|
||||
events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
|
||||
else:
|
||||
events = None
|
||||
return events
|
||||
|
||||
def eval_file(self, filename):
|
||||
"""eval f ile"""
|
||||
self.reset_accumulator()
|
||||
|
||||
result_frame_dict = read_results(filename, self.data_type, is_gt=False)
|
||||
# frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
|
||||
frames = sorted(list(set(result_frame_dict.keys())))
|
||||
for frame_id in frames:
|
||||
trk_objs = result_frame_dict.get(frame_id, [])
|
||||
trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
|
||||
self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
|
||||
|
||||
return self.acc
|
||||
|
||||
@staticmethod
|
||||
def get_summary(accs, names, metrics=None):
|
||||
"""summary results"""
|
||||
names = copy.deepcopy(names)
|
||||
if metrics is None:
|
||||
metrics = mm.metrics.motchallenge_metrics
|
||||
metrics = copy.deepcopy(metrics)
|
||||
|
||||
mh = mm.metrics.create()
|
||||
summary = mh.compute_many(
|
||||
accs,
|
||||
metrics=metrics,
|
||||
names=names,
|
||||
generate_overall=True
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def save_summary(summary, filename):
|
||||
"""save summary"""
|
||||
import pandas as pd
|
||||
writer = pd.ExcelWriter(filename)
|
||||
summary.to_excel(writer)
|
||||
writer.save()
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""write/read."""
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from src.tracking_utils.log import logger
|
||||
|
||||
|
||||
def write_results(filename, results_dict, data_type):
|
||||
"""write results"""
|
||||
if not filename:
|
||||
return
|
||||
path = os.path.dirname(filename)
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
if data_type in ('mot', 'mcmot', 'lab'):
|
||||
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
|
||||
elif data_type == 'kitti':
|
||||
save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
|
||||
else:
|
||||
raise ValueError(data_type)
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
for frame_id, frame_data in results_dict.items():
|
||||
if data_type == 'kitti':
|
||||
frame_id -= 1
|
||||
for tlwh, track_id in frame_data:
|
||||
if track_id < 0:
|
||||
continue
|
||||
x1, y1, w, h = tlwh
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
|
||||
f.write(line)
|
||||
logger.info('Save results to {}'.format(filename))
|
||||
|
||||
|
||||
def read_results(filename, data_type, is_gt=False, is_ignore=False):
|
||||
"""read results"""
|
||||
if data_type in ('mot', 'lab'):
|
||||
read_fun = read_mot_results
|
||||
else:
|
||||
raise ValueError('Unknown data type: {}'.format(data_type))
|
||||
|
||||
return read_fun(filename, is_gt, is_ignore)
|
||||
|
||||
|
||||
def read_mot_results(filename, is_gt, is_ignore):
|
||||
"""read_mot_results"""
|
||||
valid_labels = {1}
|
||||
ignore_labels = {2, 7, 8, 12}
|
||||
results_dict = dict()
|
||||
if os.path.isfile(filename):
|
||||
with open(filename, 'r') as f:
|
||||
for line in f.readlines():
|
||||
linelist = line.split(',')
|
||||
if len(linelist) < 7:
|
||||
continue
|
||||
fid = int(linelist[0])
|
||||
if fid < 1:
|
||||
continue
|
||||
results_dict.setdefault(fid, list())
|
||||
|
||||
if is_gt:
|
||||
if 'MOT16-' in filename or 'MOT17-' in filename:
|
||||
label = int(float(linelist[7]))
|
||||
mark = int(float(linelist[6]))
|
||||
if mark == 0 or label not in valid_labels:
|
||||
continue
|
||||
score = 1
|
||||
elif is_ignore:
|
||||
if 'MOT16-' in filename or 'MOT17-' in filename:
|
||||
label = int(float(linelist[7]))
|
||||
vis_ratio = float(linelist[8])
|
||||
if label not in ignore_labels and vis_ratio >= 0:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
score = 1
|
||||
else:
|
||||
score = float(linelist[6])
|
||||
|
||||
# if box_size > 7000:
|
||||
# if box_size <= 7000 or box_size >= 15000:
|
||||
# if box_size < 15000:
|
||||
# continue
|
||||
|
||||
tlwh = tuple(map(float, linelist[2:6]))
|
||||
target_id = int(linelist[1])
|
||||
|
||||
results_dict[fid].append((tlwh, target_id, score))
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def unzip_objs(objs):
|
||||
"""unzip objs"""
|
||||
if objs:
|
||||
tlwhs, ids, scores = zip(*objs)
|
||||
else:
|
||||
tlwhs, ids, scores = [], [], []
|
||||
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
|
||||
|
||||
return tlwhs, ids, scores
|
|
@ -0,0 +1,275 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
kalman_filter
|
||||
"""
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
# """
|
||||
# Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
||||
# freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
||||
# function and used as Mahalanobis gating threshold.
|
||||
# """
|
||||
chi2inv95 = {
|
||||
1: 3.8415,
|
||||
2: 5.9915,
|
||||
3: 7.8147,
|
||||
4: 9.4877,
|
||||
5: 11.070,
|
||||
6: 12.592,
|
||||
7: 14.067,
|
||||
8: 15.507,
|
||||
9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilter:
|
||||
"""
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""
|
||||
Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
# mean = np.dot(self._motion_mat, mean)
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrics of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = []
|
||||
for i in range(len(mean)):
|
||||
motion_cov.append(np.diag(sqr[i]))
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(
|
||||
projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((
|
||||
kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements,
|
||||
only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
result = np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(
|
||||
cholesky_factor, d.T, lower=True, check_finite=False,
|
||||
overwrite_b=True)
|
||||
result = np.sum(z * z, axis=0)
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
||||
return result
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
log utils
|
||||
"""
|
||||
import logging
|
||||
|
||||
|
||||
def get_logger(name='root'):
|
||||
"""
|
||||
get logger
|
||||
"""
|
||||
formatter = logging.Formatter(
|
||||
# fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s')
|
||||
fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger_ = logging.getLogger(name)
|
||||
logger_.setLevel(logging.DEBUG)
|
||||
logger_.addHandler(handler)
|
||||
return logger_
|
||||
|
||||
|
||||
logger = get_logger('root')
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# --------------------------------------------------------
|
||||
# Fast R-CNN
|
||||
# Copyright (c) 2015 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Ross Girshick
|
||||
# --------------------------------------------------------
|
||||
"""A simple timer."""
|
||||
import time
|
||||
|
||||
|
||||
class Timer:
|
||||
"""A simple timer."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_time = 0.
|
||||
self.calls = 0
|
||||
self.start_time = 0.
|
||||
self.diff = 0.
|
||||
self.average_time = 0.
|
||||
|
||||
self.duration = 0.
|
||||
|
||||
def tic(self):
|
||||
""" using time.time instead of time.clock because time time.clock
|
||||
does not normalize for multithreading"""
|
||||
self.start_time = time.time()
|
||||
|
||||
def toc(self, average=True):
|
||||
"""toc"""
|
||||
self.diff = time.time() - self.start_time
|
||||
self.total_time += self.diff
|
||||
self.calls += 1
|
||||
self.average_time = self.total_time / self.calls
|
||||
if average:
|
||||
self.duration = self.average_time
|
||||
else:
|
||||
self.duration = self.diff
|
||||
return self.duration
|
||||
|
||||
def clear(self):
|
||||
"""clear"""
|
||||
self.total_time = 0.
|
||||
self.calls = 0
|
||||
self.start_time = 0.
|
||||
self.diff = 0.
|
||||
self.average_time = 0.
|
||||
self.duration = 0.
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
track utils
|
||||
"""
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
|
||||
def mkdir_if_missing(d):
|
||||
"""mkdir if missing"""
|
||||
if not osp.exists(d):
|
||||
os.makedirs(d)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
img utils
|
||||
"""
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def get_color(idx):
|
||||
"""get color"""
|
||||
idx = idx * 3
|
||||
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
|
||||
|
||||
return color
|
||||
|
||||
|
||||
def plot_tracking(image, tlwhs, obj_ids, frame_id=0, fps=0., ids2=None):
|
||||
"""plot tracking"""
|
||||
im = np.ascontiguousarray(np.copy(image))
|
||||
|
||||
text_scale = max(1, image.shape[1] / 1600.)
|
||||
text_thickness = 2
|
||||
line_thickness = max(1, int(image.shape[1] / 500.))
|
||||
|
||||
cv2.putText(im, 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
|
||||
(0, int(15 * text_scale)), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), thickness=2)
|
||||
|
||||
for i, tlwh in enumerate(tlwhs):
|
||||
x1, y1, w, h = tlwh
|
||||
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
|
||||
obj_id = int(obj_ids[i])
|
||||
id_text = '{}'.format(int(obj_id))
|
||||
if ids2 is not None:
|
||||
id_text = id_text + ', {}'.format(int(ids2[i]))
|
||||
# _line_thickness = 1 if obj_id <= 0 else line_thickness
|
||||
color = get_color(abs(obj_id))
|
||||
cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
|
||||
cv2.putText(im, id_text, (intbox[0], intbox[1] + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255),
|
||||
thickness=text_thickness)
|
||||
return im
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Functional Cells to be used.
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class GatherFeature(nn.Cell):
|
||||
"""
|
||||
Gather feature at specified position
|
||||
|
||||
Args: None
|
||||
|
||||
Returns:
|
||||
Tensor, feature at spectified position
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GatherFeature, self).__init__()
|
||||
self.tile = ops.Tile()
|
||||
self.shape = ops.Shape()
|
||||
self.concat = ops.Concat(axis=1)
|
||||
self.reshape = ops.Reshape()
|
||||
self.gather_nd = ops.GatherNd()
|
||||
|
||||
def construct(self, feat, ind):
|
||||
"""gather by specified index"""
|
||||
# (b, N)->(b*N, 1)
|
||||
b, N = self.shape(ind)
|
||||
ind = self.reshape(ind, (-1, 1))
|
||||
ind_b = nn.Range(0, b, 1)()
|
||||
ind_b = self.reshape(ind_b, (-1, 1))
|
||||
ind_b = self.tile(ind_b, (1, N))
|
||||
ind_b = self.reshape(ind_b, (-1, 1))
|
||||
index = self.concat((ind_b, ind))
|
||||
# (b, N, 2)
|
||||
index = self.reshape(index, (b, N, -1))
|
||||
# (b, N, c)
|
||||
feat = self.gather_nd(feat, index)
|
||||
return feat
|
||||
|
||||
|
||||
class TransposeGatherFeature(nn.Cell):
|
||||
"""
|
||||
Transpose and gather feature at specified position
|
||||
|
||||
Args: None
|
||||
|
||||
Returns:
|
||||
Tensor, feature at spectified position
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(TransposeGatherFeature, self).__init__()
|
||||
self.shape = ops.Shape()
|
||||
self.reshape = ops.Reshape()
|
||||
self.transpose = ops.Transpose()
|
||||
self.perm_list = (0, 2, 3, 1)
|
||||
self.gather_feat = GatherFeature()
|
||||
|
||||
def construct(self, feat, ind):
|
||||
"""(b, c, h, w)->(b, h*w, c)"""
|
||||
feat = self.transpose(feat, self.perm_list)
|
||||
b, _, _, c = self.shape(feat)
|
||||
feat = self.reshape(feat, (b, -1, c))
|
||||
# (b, N, c)
|
||||
feat = self.gather_feat(feat, ind)
|
||||
return feat
|
||||
|
||||
|
||||
class Sigmoid(nn.Cell):
|
||||
"""
|
||||
Sigmoid and then Clip by value
|
||||
|
||||
Args: None
|
||||
|
||||
Returns:
|
||||
Tensor, feature after sigmoid and clip.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Sigmoid, self).__init__()
|
||||
self.cast = ops.Cast()
|
||||
self.dtype = ops.DType()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.clip_by_value = ops.clip_by_value
|
||||
|
||||
def construct(self, x, min_value=1e-4, max_value=1 - 1e-4):
|
||||
"""Sigmoid and then Clip by value"""
|
||||
x = self.sigmoid(x)
|
||||
dt = self.dtype(x)
|
||||
x = self.clip_by_value(x, self.cast(ops.tuple_to_array((min_value,)), dt),
|
||||
self.cast(ops.tuple_to_array((max_value,)), dt))
|
||||
return x
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
custom callback
|
||||
"""
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
class LossCallback(Callback):
|
||||
"""StopAtTime"""
|
||||
|
||||
def __init__(self, bach_size):
|
||||
"""init"""
|
||||
super(LossCallback, self).__init__()
|
||||
self.bach_size = bach_size
|
||||
self.time_start = time.time()
|
||||
|
||||
def begin(self, run_context):
|
||||
"""train begin"""
|
||||
cb_params = run_context.original_args()
|
||||
batch_num = cb_params.batch_num
|
||||
epoch_num = cb_params.epoch_num
|
||||
device_number = cb_params.device_number
|
||||
|
||||
print("Starting Training : device_number={},per_step_size={},batch_size={}, epoch={}".format(device_number,
|
||||
batch_num,
|
||||
self.bach_size,
|
||||
epoch_num))
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end"""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch_num = cb_params.cur_epoch_num
|
||||
cur_step_num = cb_params.cur_step_num
|
||||
loss = cb_params.net_outputs[0].asnumpy()
|
||||
batch_num = cb_params.batch_num
|
||||
if cur_step_num > batch_num:
|
||||
cur_step_num = cur_step_num % batch_num + 1
|
||||
epoch_num = cb_params.epoch_num
|
||||
print("epoch: {}/{}, step: {}/{}, loss is {}".format(cur_epoch_num, epoch_num, cur_step_num, batch_num, loss))
|
||||
|
||||
def end(self, run_context):
|
||||
"""train end"""
|
||||
cb_params = run_context.original_args()
|
||||
device_number = cb_params.device_number
|
||||
time_end = time.time()
|
||||
seconds = time_end - self.time_start
|
||||
seconds = seconds % (24 * 3600)
|
||||
hour = seconds // 3600
|
||||
seconds %= 3600
|
||||
minutes = seconds // 60
|
||||
seconds %= 60
|
||||
print(device_number, "device totally cost:%02d:%02d:%02d" % (hour, minutes, seconds))
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
img utils
|
||||
"""
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def flip(img):
|
||||
"""flip"""
|
||||
return img[:, :, ::-1].copy()
|
||||
|
||||
|
||||
def transform_preds(coords, center, scale, output_size):
|
||||
"""transform preds"""
|
||||
target_coords = np.zeros(coords.shape)
|
||||
trans = get_affine_transform(center, scale, 0, output_size, shift=np.array([0, 0], dtype=np.float32), inv=1)
|
||||
for p in range(coords.shape[0]):
|
||||
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
||||
return target_coords
|
||||
|
||||
|
||||
def get_affine_transform(center,
|
||||
scale,
|
||||
rot,
|
||||
output_size,
|
||||
shift=None,
|
||||
inv=0):
|
||||
"""get affine transform"""
|
||||
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
||||
scale = np.array([scale, scale], dtype=np.float32)
|
||||
|
||||
scale_tmp = scale
|
||||
src_w = scale_tmp[0]
|
||||
dst_w = output_size[0]
|
||||
dst_h = output_size[1]
|
||||
|
||||
rot_rad = np.pi * rot / 180
|
||||
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
||||
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale_tmp * shift
|
||||
src[1, :] = center + src_dir + scale_tmp * shift
|
||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
|
||||
|
||||
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
||||
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return trans
|
||||
|
||||
|
||||
def affine_transform(pt, t):
|
||||
"""affine transform"""
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
|
||||
new_pt = np.dot(t, new_pt)
|
||||
return new_pt[:2]
|
||||
|
||||
|
||||
def get_3rd_point(a, b):
|
||||
"""get 3rd point"""
|
||||
direct = a - b
|
||||
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
||||
|
||||
|
||||
def get_dir(src_point, rot_rad):
|
||||
"""get dir"""
|
||||
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||
|
||||
src_result = [0, 0]
|
||||
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
||||
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
||||
|
||||
return src_result
|
||||
|
||||
|
||||
def crop(img, center, scale, output_size, rot=0):
|
||||
"""crop"""
|
||||
trans = get_affine_transform(center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32))
|
||||
|
||||
dst_img = cv2.warpAffine(img,
|
||||
trans,
|
||||
(int(output_size[0]), int(output_size[1])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
|
||||
return dst_img
|
||||
|
||||
|
||||
def gaussian_radius(det_size, min_overlap=0.7):
|
||||
"""gaussian radius"""
|
||||
height, width = det_size
|
||||
|
||||
a1 = 1
|
||||
b1 = (height + width)
|
||||
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
|
||||
sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
|
||||
r1 = (b1 + sq1) / 2
|
||||
|
||||
a2 = 4
|
||||
b2 = 2 * (height + width)
|
||||
c2 = (1 - min_overlap) * width * height
|
||||
sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
|
||||
r2 = (b2 + sq2) / 2
|
||||
|
||||
a3 = 4 * min_overlap
|
||||
b3 = -2 * min_overlap * (height + width)
|
||||
c3 = (min_overlap - 1) * width * height
|
||||
sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
|
||||
r3 = (b3 + sq3) / 2
|
||||
return min(r1, r2, r3)
|
||||
|
||||
|
||||
def gaussian2D(shape, sigma=1):
|
||||
"""gaussian2D"""
|
||||
m, n = [(ss - 1.) / 2. for ss in shape]
|
||||
y, x = np.ogrid[-m:m + 1, -n:n + 1]
|
||||
|
||||
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
|
||||
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
||||
return h
|
||||
|
||||
|
||||
def draw_umich_gaussian(heatmap, center, radius, k=1):
|
||||
"""draw umich gaussian"""
|
||||
diameter = 2 * radius + 1
|
||||
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
|
||||
|
||||
x, y = int(center[0]), int(center[1])
|
||||
|
||||
height, width = heatmap.shape[0:2]
|
||||
|
||||
left, right = min(x, radius), min(width - x, radius + 1)
|
||||
top, bottom = min(y, radius), min(height - y, radius + 1)
|
||||
|
||||
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
|
||||
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
|
||||
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
|
||||
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
|
||||
return heatmap
|
||||
|
||||
|
||||
def draw_msra_gaussian(heatmap, center, sigma):
|
||||
"""draw msra gaussian"""
|
||||
tmp_size = sigma * 3
|
||||
mu_x = int(center[0] + 0.5)
|
||||
mu_y = int(center[1] + 0.5)
|
||||
w, h = heatmap.shape[0], heatmap.shape[1]
|
||||
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
||||
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
||||
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
|
||||
return heatmap
|
||||
size = 2 * tmp_size + 1
|
||||
x = np.arange(0, size, 1, np.float32)
|
||||
y = x[:, np.newaxis]
|
||||
x0 = y0 = size // 2
|
||||
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
||||
img_x = max(0, ul[0]), min(br[0], h)
|
||||
img_y = max(0, ul[1]), min(br[1], w)
|
||||
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum(
|
||||
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]],
|
||||
g[g_y[0]:g_y[1], g_x[0]:g_x[1]])
|
||||
return heatmap
|
||||
|
||||
|
||||
def grayscale(image):
|
||||
"""grayscale"""
|
||||
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
|
||||
def lighting_(data_rng, image, alphastd, eigval, eigvec):
|
||||
"""lighting"""
|
||||
alpha = data_rng.normal(scale=alphastd, size=(3,))
|
||||
image += np.dot(eigvec, eigval * alpha)
|
||||
|
||||
|
||||
def blend_(alpha, image1, image2):
|
||||
"""blend"""
|
||||
image1 *= alpha
|
||||
image2 *= (1 - alpha)
|
||||
image1 += image2
|
||||
|
||||
|
||||
def saturation_(data_rng, image, gs, var):
|
||||
"""saturation"""
|
||||
alpha = 1. + data_rng.uniform(low=-var, high=var)
|
||||
blend_(alpha, image, gs[:, :, None])
|
||||
|
||||
|
||||
def brightness_(data_rng, image, var):
|
||||
"""brightness"""
|
||||
alpha = 1. + data_rng.uniform(low=-var, high=var)
|
||||
image *= alpha
|
||||
|
||||
|
||||
def contrast_(data_rng, image, gs_mean, var):
|
||||
"""contrast"""
|
||||
alpha = 1. + data_rng.uniform(low=-var, high=var)
|
||||
blend_(alpha, image, gs_mean)
|
||||
|
||||
|
||||
def color_aug(data_rng, image, eig_val, eig_vec):
|
||||
"""color aug"""
|
||||
functions = [brightness_, contrast_, saturation_]
|
||||
random.shuffle(functions)
|
||||
|
||||
gs = grayscale(image)
|
||||
gs_mean = gs.mean()
|
||||
functions[0](data_rng, image, 0.4)
|
||||
functions[1](data_rng, image, gs_mean, 0.4)
|
||||
functions[2](data_rng, image, gs, 0.4)
|
||||
lighting_(data_rng, image, 0.1, eig_val, eig_vec)
|
|
@ -0,0 +1,454 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
load img
|
||||
"""
|
||||
import collections
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
import math
|
||||
import copy
|
||||
import glob
|
||||
import numpy as np
|
||||
import cv2
|
||||
from src.utils.tools import xyxy2xywh
|
||||
from src.utils.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian
|
||||
|
||||
|
||||
class LoadImages:
|
||||
""" for inference"""
|
||||
|
||||
def __init__(self, path, img_size=(1088, 608)):
|
||||
if os.path.isdir(path):
|
||||
image_format = ['.jpg', '.jpeg', '.png', '.tif']
|
||||
self.files = sorted(glob.glob('%s/*.*' % path))
|
||||
self.files = list(filter(lambda x: os.path.splitext(x)[1].lower() in image_format, self.files))
|
||||
elif os.path.isfile(path):
|
||||
self.files = [path]
|
||||
|
||||
self.nF = len(self.files) # number of image files
|
||||
self.width = img_size[0]
|
||||
self.height = img_size[1]
|
||||
self.count = 0
|
||||
|
||||
assert self.nF > 0, 'No images found in ' + path
|
||||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.count += 1
|
||||
if self.count == self.nF:
|
||||
raise StopIteration
|
||||
img_path = self.files[self.count]
|
||||
|
||||
# Read image
|
||||
img0 = cv2.imread(img_path) # BGR
|
||||
assert img0 is not None, 'Failed to load ' + img_path
|
||||
|
||||
# Padded resize
|
||||
img, _, _, _ = letterbox(img0, height=self.height, width=self.width)
|
||||
|
||||
# Normalize RGB
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img, dtype=np.float32)
|
||||
img /= 255.0
|
||||
|
||||
# cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
|
||||
return img_path, img, img0
|
||||
|
||||
def __getitem__(self, idx):
|
||||
idx = idx % self.nF
|
||||
img_path = self.files[idx]
|
||||
|
||||
# Read image
|
||||
img0 = cv2.imread(img_path) # BGR
|
||||
assert img0 is not None, 'Failed to load ' + img_path
|
||||
|
||||
# Padded resize
|
||||
img, _, _, _ = letterbox(img0, height=self.height, width=self.width)
|
||||
|
||||
# Normalize RGB
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img, dtype=np.float32)
|
||||
img /= 255.0
|
||||
|
||||
return img_path, img, img0
|
||||
|
||||
def __len__(self):
|
||||
return self.nF # number of files
|
||||
|
||||
|
||||
class LoadVideo:
|
||||
""" for inference"""
|
||||
|
||||
def __init__(self, path, img_size=(1088, 608)):
|
||||
self.cap = cv2.VideoCapture(path)
|
||||
self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
|
||||
self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
self.vh = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
self.width = img_size[0]
|
||||
self.height = img_size[1]
|
||||
self.count = 0
|
||||
|
||||
self.w, self.h = 1920, 1080
|
||||
print('Lenth of the video: {:d} frames'.format(self.vn))
|
||||
|
||||
def get_size(self, vw, vh, dw, dh):
|
||||
"""get size"""
|
||||
|
||||
wa, ha = float(dw) / vw, float(dh) / vh
|
||||
a = min(wa, ha)
|
||||
return int(vw * a), int(vh * a)
|
||||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.count += 1
|
||||
if self.count == len(self):
|
||||
raise StopIteration
|
||||
# Read image
|
||||
_, img0 = self.cap.read() # BGR
|
||||
assert img0 is not None, 'Failed to load frame {:d}'.format(self.count)
|
||||
img0 = cv2.resize(img0, (self.w, self.h))
|
||||
|
||||
# Padded resize
|
||||
img, _, _, _ = letterbox(img0, height=self.height, width=self.width)
|
||||
|
||||
# Normalize RGB
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img, dtype=np.float32)
|
||||
img /= 255.0
|
||||
|
||||
# cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
|
||||
return self.count, img, img0
|
||||
|
||||
def __len__(self):
|
||||
return self.vn # number of files
|
||||
|
||||
|
||||
class JointDataset:
|
||||
""" for training"""
|
||||
|
||||
default_resolution = [1088, 608]
|
||||
mean = None
|
||||
std = None
|
||||
num_classes = 1
|
||||
|
||||
def __init__(self, opt, root, paths, img_size=(1088, 608), augment=False):
|
||||
self.opt = opt
|
||||
self.img_files = collections.OrderedDict()
|
||||
self.label_files = collections.OrderedDict()
|
||||
self.tid_num = collections.OrderedDict()
|
||||
self.tid_start_index = collections.OrderedDict()
|
||||
self.num_classes = 1
|
||||
for ds, path in paths.items():
|
||||
path = root + path
|
||||
with open(path, 'r') as file:
|
||||
self.img_files[ds] = file.readlines()
|
||||
self.img_files[ds] = [osp.join(root, x.strip()) for x in self.img_files[ds]]
|
||||
self.img_files[ds] = list(filter(lambda x: len(x) > 0, self.img_files[ds]))
|
||||
self.label_files[ds] = [
|
||||
x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt')
|
||||
for x in self.img_files[ds]]
|
||||
for ds, label_paths in self.label_files.items():
|
||||
max_index = -1
|
||||
for lp in label_paths:
|
||||
lb = np.loadtxt(lp)
|
||||
if np.shape(lb)[0] < 1:
|
||||
continue
|
||||
if len(lb.shape) < 2:
|
||||
img_max = lb[1]
|
||||
else:
|
||||
img_max = np.max(lb[:, 1])
|
||||
if img_max > max_index:
|
||||
max_index = img_max
|
||||
self.tid_num[ds] = max_index + 1
|
||||
last_index = 0
|
||||
for k, v in self.tid_num.items():
|
||||
self.tid_start_index[k] = last_index
|
||||
last_index += v
|
||||
self.nID = int(last_index + 1) # 多个数据集中总的identity数目
|
||||
print('nID', self.nID)
|
||||
self.nds = [len(x) for x in self.img_files.values()] # 图片数量
|
||||
print('nds', self.nds)
|
||||
self.cds = [sum(self.nds[:i]) for i in range(len(self.nds))]
|
||||
self.nF = sum(self.nds)
|
||||
self.width = img_size[0]
|
||||
self.height = img_size[1]
|
||||
self.max_objs = opt.K
|
||||
self.augment = augment
|
||||
|
||||
print('=' * 80)
|
||||
print('dataset summary')
|
||||
print(self.tid_num)
|
||||
print('total # identities:', self.nID)
|
||||
print('start index')
|
||||
print(self.tid_start_index)
|
||||
print('=' * 80)
|
||||
|
||||
def __getitem__(self, files_index):
|
||||
|
||||
for i, c in enumerate(self.cds):
|
||||
if files_index >= c:
|
||||
ds = list(self.label_files.keys())[i]
|
||||
start_index = c
|
||||
|
||||
img_path = self.img_files[ds][files_index - start_index]
|
||||
label_path = self.label_files[ds][files_index - start_index]
|
||||
imgs, labels, img_path = self.get_data(img_path, label_path)
|
||||
for i, _ in enumerate(labels):
|
||||
if labels[i, 1] > -1:
|
||||
labels[i, 1] += self.tid_start_index[ds]
|
||||
|
||||
output_h = imgs.shape[1] // self.opt.down_ratio
|
||||
output_w = imgs.shape[2] // self.opt.down_ratio
|
||||
num_classes = self.num_classes
|
||||
num_objs = labels.shape[0]
|
||||
hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
|
||||
if self.opt.ltrb:
|
||||
wh = np.zeros((self.max_objs, 4), dtype=np.float32)
|
||||
else:
|
||||
wh = np.zeros((self.max_objs, 2), dtype=np.float32)
|
||||
reg = np.zeros((self.max_objs, 2), dtype=np.float32)
|
||||
ind = np.zeros((self.max_objs,), dtype=np.int32)
|
||||
reg_mask = np.zeros((self.max_objs,), dtype=np.uint8)
|
||||
ids = np.zeros((self.max_objs,), dtype=np.int32)
|
||||
bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
|
||||
|
||||
draw_gaussian = draw_msra_gaussian if self.opt.mse_loss else draw_umich_gaussian
|
||||
for k in range(num_objs):
|
||||
label = labels[k]
|
||||
bbox = label[2:]
|
||||
cls_id = int(label[0])
|
||||
bbox[[0, 2]] = bbox[[0, 2]] * output_w
|
||||
bbox[[1, 3]] = bbox[[1, 3]] * output_h
|
||||
bbox_amodal = copy.deepcopy(bbox)
|
||||
bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2.
|
||||
bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2.
|
||||
bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2]
|
||||
bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3]
|
||||
bbox[0] = np.clip(bbox[0], 0, output_w - 1)
|
||||
bbox[1] = np.clip(bbox[1], 0, output_h - 1)
|
||||
h = bbox[3]
|
||||
w = bbox[2]
|
||||
|
||||
bbox_xy = copy.deepcopy(bbox)
|
||||
bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2
|
||||
bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2
|
||||
bbox_xy[2] = bbox_xy[0] + bbox_xy[2]
|
||||
bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
|
||||
|
||||
if h > 0 and w > 0:
|
||||
radius = gaussian_radius((math.ceil(h), math.ceil(w)))
|
||||
radius = max(0, int(radius))
|
||||
radius = 6 if self.opt.mse_loss else radius
|
||||
# radius = max(1, int(radius)) if self.opt.mse_loss else radius
|
||||
ct = np.array(
|
||||
[bbox[0], bbox[1]], dtype=np.float32)
|
||||
ct_int = ct.astype(np.int32)
|
||||
draw_gaussian(hm[cls_id], ct_int, radius)
|
||||
if self.opt.ltrb:
|
||||
wh[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \
|
||||
bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1]
|
||||
else:
|
||||
wh[k] = 1. * w, 1. * h
|
||||
ind[k] = ct_int[1] * output_w + ct_int[0]
|
||||
reg[k] = ct - ct_int
|
||||
reg_mask[k] = 1
|
||||
ids[k] = label[1]
|
||||
bbox_xys[k] = bbox_xy
|
||||
# ret = {'input': imgs, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh, 'reg': reg, 'ids': ids,
|
||||
# 'bbox': bbox_xys}
|
||||
return imgs, hm, reg_mask, ind, wh, reg, ids
|
||||
|
||||
def __len__(self):
|
||||
return self.nF
|
||||
|
||||
def get_data(self, img_path, label_path):
|
||||
"""get data"""
|
||||
height = self.height
|
||||
width = self.width
|
||||
img = cv2.imread(img_path) # BGR
|
||||
if img is None:
|
||||
raise ValueError('File corrupt {}'.format(img_path))
|
||||
augment_hsv = True
|
||||
if self.augment and augment_hsv:
|
||||
# SV augmentation by 50%
|
||||
fraction = 0.50
|
||||
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
S = img_hsv[:, :, 1].astype(np.float32)
|
||||
V = img_hsv[:, :, 2].astype(np.float32)
|
||||
|
||||
a = (random.random() * 2 - 1) * fraction + 1
|
||||
S *= a
|
||||
if a > 1:
|
||||
np.clip(S, a_min=0, a_max=255, out=S)
|
||||
|
||||
a = (random.random() * 2 - 1) * fraction + 1
|
||||
V *= a
|
||||
if a > 1:
|
||||
np.clip(V, a_min=0, a_max=255, out=V)
|
||||
|
||||
img_hsv[:, :, 1] = S.astype(np.uint8)
|
||||
img_hsv[:, :, 2] = V.astype(np.uint8)
|
||||
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
|
||||
h, w, _ = img.shape
|
||||
img, ratio, padw, padh = letterbox(img, height=height, width=width)
|
||||
# Load labels
|
||||
if os.path.isfile(label_path):
|
||||
labels0 = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6)
|
||||
|
||||
# Normalized xywh to pixel xyxy format
|
||||
labels = labels0.copy()
|
||||
labels[:, 2] = ratio * w * (labels0[:, 2] - labels0[:, 4] / 2) + padw
|
||||
labels[:, 3] = ratio * h * (labels0[:, 3] - labels0[:, 5] / 2) + padh
|
||||
labels[:, 4] = ratio * w * (labels0[:, 2] + labels0[:, 4] / 2) + padw
|
||||
labels[:, 5] = ratio * h * (labels0[:, 3] + labels0[:, 5] / 2) + padh
|
||||
else:
|
||||
labels = np.array([])
|
||||
|
||||
# Augment image and labels
|
||||
if self.augment:
|
||||
img, labels, _ = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.50, 1.20))
|
||||
|
||||
plotFlag = False
|
||||
if plotFlag:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(50, 50))
|
||||
plt.imshow(img[:, :, ::-1])
|
||||
plt.plot(labels[:, [1, 3, 3, 1, 1]].T, labels[:, [2, 2, 4, 4, 2]].T, '.-')
|
||||
plt.axis('off')
|
||||
plt.savefig('test.jpg')
|
||||
time.sleep(10)
|
||||
|
||||
nL = len(labels)
|
||||
if nL > 0:
|
||||
# convert xyxy to xywh
|
||||
labels[:, 2:6] = xyxy2xywh(labels[:, 2:6].copy()) # / height
|
||||
labels[:, 2] /= width
|
||||
labels[:, 3] /= height
|
||||
labels[:, 4] /= width
|
||||
labels[:, 5] /= height
|
||||
if self.augment:
|
||||
# random left-right flip
|
||||
lr_flip = True
|
||||
if lr_flip & (random.random() > 0.5):
|
||||
img = np.fliplr(img)
|
||||
if nL > 0:
|
||||
labels[:, 2] = 1 - labels[:, 2]
|
||||
|
||||
img = np.ascontiguousarray(img[:, :, ::-1]) # BGR to RGB
|
||||
img = np.array(img, dtype=np.float32) / 255
|
||||
img = img.transpose((2, 0, 1))
|
||||
return img, labels, img_path
|
||||
|
||||
|
||||
def letterbox(img, height=608, width=1088,
|
||||
color=(127.5, 127.5, 127.5)):
|
||||
"""resize a rectangular image to a padded rectangular"""
|
||||
shape = img.shape[:2] # shape = [height, width]
|
||||
ratio = min(float(height) / shape[0], float(width) / shape[1])
|
||||
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # new_shape = [width, height]
|
||||
dw = (width - new_shape[0]) / 2 # width padding
|
||||
dh = (height - new_shape[1]) / 2 # height padding
|
||||
top, bottom = round(dh - 0.1), round(dh + 0.1)
|
||||
left, right = round(dw - 0.1), round(dw + 0.1)
|
||||
img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded rectangular
|
||||
return img, ratio, dw, dh
|
||||
|
||||
|
||||
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
|
||||
borderValue=(127.5, 127.5, 127.5)):
|
||||
"""
|
||||
https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
|
||||
"""
|
||||
|
||||
border = 0 # width of added border (optional)
|
||||
height = img.shape[0]
|
||||
width = img.shape[1]
|
||||
|
||||
# Rotation and Scale
|
||||
R = np.eye(3)
|
||||
a = random.random() * (degrees[1] - degrees[0]) + degrees[0]
|
||||
# a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations
|
||||
s = random.random() * (scale[1] - scale[0]) + scale[0]
|
||||
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
|
||||
|
||||
# Translation
|
||||
T = np.eye(3)
|
||||
T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] + border # x translation (pixels)
|
||||
T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] + border # y translation (pixels)
|
||||
|
||||
# Shear
|
||||
S = np.eye(3)
|
||||
S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # x shear (deg)
|
||||
S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # y shear (deg)
|
||||
|
||||
M = np.matmul(S, np.matmul(T, R)) # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
|
||||
imw = cv2.warpPerspective(img, M, dsize=(width, height), flags=cv2.INTER_LINEAR,
|
||||
borderValue=borderValue) # BGR order borderValue
|
||||
|
||||
# Return warped points also
|
||||
if targets is not None:
|
||||
if np.shape(targets)[0] > 0:
|
||||
n = targets.shape[0]
|
||||
points = targets[:, 2:6].copy()
|
||||
area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1])
|
||||
|
||||
# warp points
|
||||
xy = np.ones((n * 4, 3))
|
||||
xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
||||
xy = np.matmul(xy, M.T)[:, :2].reshape(n, 8)
|
||||
# create new boxes
|
||||
x = xy[:, [0, 2, 4, 6]]
|
||||
y = xy[:, [1, 3, 5, 7]]
|
||||
xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
||||
|
||||
# apply angle-based reduction
|
||||
radians = a * math.pi / 180
|
||||
reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
|
||||
x = (xy[:, 2] + xy[:, 0]) / 2
|
||||
y = (xy[:, 3] + xy[:, 1]) / 2
|
||||
w = (xy[:, 2] - xy[:, 0]) * reduction
|
||||
h = (xy[:, 3] - xy[:, 1]) * reduction
|
||||
xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
|
||||
|
||||
# reject warped points outside of image
|
||||
# np.clip(xy[:, 0], 0, width, out=xy[:, 0])
|
||||
# np.clip(xy[:, 2], 0, width, out=xy[:, 2])
|
||||
# np.clip(xy[:, 1], 0, height, out=xy[:, 1])
|
||||
# np.clip(xy[:, 3], 0, height, out=xy[:, 3])
|
||||
w = xy[:, 2] - xy[:, 0]
|
||||
h = xy[:, 3] - xy[:, 1]
|
||||
area = w * h
|
||||
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
|
||||
i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
|
||||
|
||||
targets = targets[i]
|
||||
targets[:, 2:6] = xy[i]
|
||||
|
||||
return imw, targets, M
|
||||
return imw
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
log utils
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
|
||||
USE_TENSORBOARD = False
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
|
||||
def __init__(self, opt):
|
||||
if not os.path.exists(opt.save_dir):
|
||||
os.makedirs(opt.save_dir)
|
||||
|
||||
time_str = time.strftime('%Y-%m-%d-%H-%M')
|
||||
args = {}
|
||||
for name in dir(opt):
|
||||
if not name.startswith('_'):
|
||||
args[name] = getattr(opt, name)
|
||||
file_name = os.path.join(opt.save_dir, 'opt.txt')
|
||||
with open(file_name, 'wt') as opt_file:
|
||||
opt_file.write('==> Cmd:\n')
|
||||
opt_file.write(str(sys.argv))
|
||||
opt_file.write('\n==> Opt:\n')
|
||||
for k, v in sorted(args.items()):
|
||||
opt_file.write(' %s: %s\n' % (str(k), str(v)))
|
||||
|
||||
log_dir = opt.save_dir + '/logs_{}'.format(time_str)
|
||||
if not os.path.exists(os.path.dirname(log_dir)):
|
||||
os.mkdir(os.path.dirname(log_dir))
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
self.log = open(log_dir + '/log.txt', 'w')
|
||||
try:
|
||||
os.system('cp {}/opt.txt {}/'.format(opt.save_dir, log_dir))
|
||||
except IOError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
self.start_line = True
|
||||
|
||||
def write(self, txt):
|
||||
"""write"""
|
||||
if self.start_line:
|
||||
time_str = time.strftime('%Y-%m-%d-%H-%M')
|
||||
self.log.write('{}: {}'.format(time_str, txt))
|
||||
else:
|
||||
self.log.write(txt)
|
||||
self.start_line = False
|
||||
if '\n' in txt:
|
||||
self.start_line = True
|
||||
self.log.flush()
|
||||
|
||||
def close(self):
|
||||
"""close"""
|
||||
self.log.close()
|
||||
|
||||
def scalar_summary(self, tag, value, step):
|
||||
"""Log a scalar variable."""
|
||||
if USE_TENSORBOARD:
|
||||
self.writer.add_scalar(tag, value, step)
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
custom lr schedule
|
||||
"""
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
def dynamic_lr(num_epoch_per_decay, total_epochs, steps_per_epoch):
|
||||
"""dynamic learning rate generator"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
decay_steps = steps_per_epoch * num_epoch_per_decay
|
||||
lr = nn.PolynomialDecayLR(1e-4, 1e-5, decay_steps, 0.5)
|
||||
for i in range(total_steps):
|
||||
if i < decay_steps:
|
||||
i = Tensor(i, mstype.int32)
|
||||
lr_each_step.append(lr(i).asnumpy())
|
||||
else:
|
||||
lr_each_step.append(1e-5)
|
||||
return lr_each_step
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
pth to ckpt
|
||||
"""
|
||||
import torch
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
|
||||
def pth2ckpt(path='fairmot_dla34.pth'):
|
||||
"""pth to ckpt """
|
||||
par_dict = torch.load(path, map_location='cpu')
|
||||
new_params_list = []
|
||||
for name in par_dict:
|
||||
param_dict = {}
|
||||
parameter = par_dict[name]
|
||||
name = 'base.' + name
|
||||
name = name.replace('level0', 'level0.0', 1)
|
||||
name = name.replace('level1', 'level1.0', 1)
|
||||
name = name.replace('hm', 'hm_fc', 1)
|
||||
name = name.replace('id.0', 'id_fc.0', 1)
|
||||
name = name.replace('id.2', 'id_fc.2', 1)
|
||||
name = name.replace('reg', 'reg_fc', 1)
|
||||
name = name.replace('wh', 'wh_fc', 1)
|
||||
name = name.replace('bn.running_mean', 'bn.moving_mean', 1)
|
||||
name = name.replace('bn.running_var', 'bn.moving_variance', 1)
|
||||
if name.endswith('.0.weight'):
|
||||
name = name[:name.rfind('0.weight')]
|
||||
name = name + 'conv.weight'
|
||||
if name.endswith('.1.weight'):
|
||||
name = name[:name.rfind('1.weight')]
|
||||
name = name + 'batchnorm.gamma'
|
||||
if name.endswith('.1.bias'):
|
||||
name = name[:name.rfind('1.bias')]
|
||||
name = name + 'batchnorm.beta'
|
||||
if name.endswith('.1.running_mean'):
|
||||
name = name[:name.rfind('1.running_mean')]
|
||||
name = name + 'batchnorm.moving_mean'
|
||||
if name.endswith('.1.running_var'):
|
||||
name = name[:name.rfind('1.running_var')]
|
||||
name = name + 'batchnorm.moving_variance'
|
||||
if name.endswith('conv1.weight'):
|
||||
name = name[:name.rfind('conv1.weight')]
|
||||
name = name + 'conv_bn_act.conv.weight'
|
||||
if name.endswith('bn1.weight'):
|
||||
name = name[:name.rfind('bn1.weight')]
|
||||
name = name + 'conv_bn_act.batchnorm.gamma'
|
||||
if name.endswith('bn1.bias'):
|
||||
name = name[:name.rfind('bn1.bias')]
|
||||
name = name + 'conv_bn_act.batchnorm.beta'
|
||||
if name.endswith('bn1.running_mean'):
|
||||
name = name[:name.rfind('bn1.running_mean')]
|
||||
name = name + 'conv_bn_act.batchnorm.moving_mean'
|
||||
if name.endswith('bn1.running_var'):
|
||||
name = name[:name.rfind('bn1.running_var')]
|
||||
name = name + 'conv_bn_act.batchnorm.moving_variance'
|
||||
if name.endswith('conv2.weight'):
|
||||
name = name[:name.rfind('conv2.weight')]
|
||||
name = name + 'conv_bn.conv.weight'
|
||||
if name.endswith('bn2.weight'):
|
||||
name = name[:name.rfind('bn2.weight')]
|
||||
name = name + 'conv_bn.batchnorm.gamma'
|
||||
if name.endswith('bn2.bias'):
|
||||
name = name[:name.rfind('bn2.bias')]
|
||||
name = name + 'conv_bn.batchnorm.beta'
|
||||
if name.endswith('bn2.running_mean'):
|
||||
name = name[:name.rfind('bn2.running_mean')]
|
||||
name = name + 'conv_bn.batchnorm.moving_mean'
|
||||
if name.endswith('bn2.running_var'):
|
||||
name = name[:name.rfind('bn2.running_var')]
|
||||
name = name + 'conv_bn.batchnorm.moving_variance'
|
||||
if name.endswith('bn.weight'):
|
||||
name = name[:name.rfind('bn.weight')]
|
||||
name = name + 'bn.gamma'
|
||||
if name.endswith('bn.bias'):
|
||||
name = name[:name.rfind('bn.bias')]
|
||||
name = name + 'bn.beta'
|
||||
param_dict['name'] = name
|
||||
param_dict['data'] = Tensor(parameter.numpy())
|
||||
new_params_list.append(param_dict)
|
||||
save_checkpoint(new_params_list, '{}_ms.ckpt'.format(path[:path.rfind('.pth')]))
|
||||
|
||||
|
||||
pth2ckpt('dla34-ba72cf86.pth')
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
img tools
|
||||
"""
|
||||
import numpy as np
|
||||
from src.utils.image import transform_preds
|
||||
|
||||
|
||||
def xyxy2xywh(x):
|
||||
# Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h]
|
||||
y = np.zeros(x.shape)
|
||||
y[:, 0] = (x[:, 0] + x[:, 2]) / 2
|
||||
y[:, 1] = (x[:, 1] + x[:, 3]) / 2
|
||||
y[:, 2] = x[:, 2] - x[:, 0]
|
||||
y[:, 3] = x[:, 3] - x[:, 1]
|
||||
return y
|
||||
|
||||
|
||||
def ctdet_post_process(dets, c, s, h, w, num_classes):
|
||||
"""
|
||||
dets: batch x max_dets x dim
|
||||
return 1-based class det dict
|
||||
"""
|
||||
ret = []
|
||||
for i in range(dets.shape[0]):
|
||||
top_preds = {}
|
||||
dets[i, :, :2] = transform_preds(
|
||||
dets[i, :, 0:2], c[i], s[i], (w, h))
|
||||
dets[i, :, 2:4] = transform_preds(
|
||||
dets[i, :, 2:4], c[i], s[i], (w, h))
|
||||
classes = dets[i, :, -1]
|
||||
for j in range(num_classes):
|
||||
inds = (classes == j)
|
||||
top_preds[j + 1] = np.concatenate([
|
||||
dets[i, inds, :4].astype(np.float32),
|
||||
dets[i, inds, 4:5].astype(np.float32)], axis=1).tolist()
|
||||
ret.append(top_preds)
|
||||
return ret
|
Loading…
Reference in New Issue