This commit is contained in:
Ren Wenhao 2021-08-05 16:24:18 +08:00
parent 6b46fa9046
commit b0187ba8ff
16 changed files with 1500 additions and 0 deletions

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""start eval """
from __future__ import absolute_import
import argparse
import os
import sys
from got10k.experiments import ExperimentOTB
from mindspore import context
from src import SiamFCTracker
sys.path.append(os.getcwd())
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='siamfc tracking')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend')
parser.add_argument('--model_path', default='/root/SiamFC/models/siamfc_{}.ckpt/SiamFC-6650.ckpt'
, type=str, help='eval one special video')
parser.add_argument('--dataset_path', default='/root/datasets/OTB2013', type=str)
args = parser.parse_args()
context.set_context(
mode=context.GRAPH_MODE,
device_id=args.device_id,
save_graphs=False,
device_target='Ascend')
tracker = SiamFCTracker(model_path=args.model_path)
root_dir = os.path.abspath(args.dataset_path)
e = ExperimentOTB(root_dir, version=2013)
e.run(tracker, visualize=False)
prec_score, succ_score, succ_rate = e.report([tracker.name])
ss = '-prec_score:%.3f -succ_score:%.3f -succ_rate:%.3f' % (float(prec_score),
float(succ_score),
float(succ_rate))
print(args.model_path.split('/')[-1], ss)

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into models"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, export, load_param_into_net
from src.alexnet import SiameseAlexNet
parser = argparse.ArgumentParser(description='siamfc export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument('--model_path', default='/root/HRBEU-MedAI/SiamFC/models/siamfc_{}.ckpt/',
type=str, help='eval one special video')
parser.add_argument('--file_name', type=str, default='/root/HRBEU-MedAI/SiamFC/models',
help='SiamFc output file name.')
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR',
help='file format')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == "__main__":
net = SiameseAlexNet(train=False)
load_param_into_net(net, load_checkpoint(args.model_path), strict_load=True)
net.set_train(False)
input_data_exemplar = Tensor(np.zeros([3, 256, 6, 6]), ms.float32)
input_data_instance = Tensor(np.zeros([3, 3, 255, 255]), ms.float32)
export(net, input_data_exemplar, input_data_instance, file_name=args.file_name,
file_format=args.file_format)

View File

@ -0,0 +1,195 @@
# Contents
- [SiamFC Description](#SiamFC-Description)
- [Model Architecture](#SiamFC-Architecture)
- [Dataset](#SiamFC-dataset)
- [Environmental requirements](#Environmental)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
# [SiamFC Description](#Contents)
Siamfc proposes a new full convolution twin network as the basic tracking algorithm, which is trained end-to-end on ilsvrc15 target tracking video data set. Our tracker exceeds the real-time requirement in frame rate. Although it is very simple, it achieves the best performance on multiple benchmarks.
[paper](https://arxiv.org/pdf/1606.09549.pdf) Luca Bertinetto Jack Valmadre Jo˜ao F. Henriques Andrea Vedaldi Philip H. S. Torr
Department of Engineering Science, University of Oxford
# [Model Architecture](#Contents)
Siamfc first uses full convolution alexnet for feature extraction online and offline, and uses twin network to train the template and background respectively. On line, after getting the box of the first frame, it carries out centrrop, and then loads checkpoint to track the subsequent frames. In order to find the box, it needs to carry out a series of penalties on the score graph, Finally, the final prediction point is obtained by twice trilinear interpolation.
# [Dataset](#Contents)
used Dataset :[ILSVRC2015-VID](http://bvisionweb1.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID.tar.gz)
- Dataset size : 85GB ,total 30 type
- Training set: a total of 3862 videos and their corresponding frame pictures and box positions
- Verification set: 555 videos and corresponding pictures and box locations
- Test set: a total of 973 videos and corresponding pictures and box locations
- Data format: the image is in h*w*C format, the box position includes the coordinates of the lower left corner and the upper right corner, the format is XML, and the XML needs to be parsed
# [Environmental requirements](#Contents)
- Hardware :(Ascend)
- Prepare ascend processor to build hardware environment
- frame:
- [Mindspore](https://www.mindspore.cn/install)
- For details, please refer to the following resources:
- [MindSpore course](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
- more API
- got10k toolkit
- opencv
- lmdb
# [quick start](#Contents)
After installing mindspree through the official website, you can follow the following steps to train and evaluate:
- Run the python script to preprocess the data set
python src/create_dataset_ILSVRC.py -d data_dir -o output_dir
- Run Python script to create LMDB
python src/create_lmdb.py -d data_dir -o output_dir
for example
data_dir = '/data/VID/ILSVRC_VID_CURATION_train'
output_dir = '/data/VID/ILSVRC_VID_CURATION_train.lmdb'
__Remarks:The encrypted pathname is used as the index.Therefore,you cannot change the location of the dataset
after creating it, because you need to find the corresponding image according to the index.__
- Run the script for training
bash run_standalone_train_ascend.sh [Device_ID] [Dataset_path]
Remarks:For the training set position after preprocessing
- more
This example is single card training.
- Run the script for evaluation
python eval.py,need got10k toolkit,the dataset is OTB2013(50) or OTB2015(100)
# [Script description](#Contents)
## Script and sample code
```python
├── SiamFC
├── README.md // Notes on siamfc
├── scripts
│ ├──ma-pre-start.sh // Create environment before modelarts training
│ ├──run_standalone_train_ascend.sh // Single card training in ascend
│ ├──run_distribution_ascend.sh // Multi card distributed training in ascend
├── src
│ ├──alexnet.py // Create dataset
│ ├──config.py // Alexnet architecture
│ ├──custom_transforms.py //Data set processing
│ ├──dataset.py //GeneratorDataset
│ ├──Groupconv.py //Mindpore does not support group convolution at present. This is an alternative
│ ├──lr_generator.py //Dynamic learning rate
│ ├──tracker.py //Trace script
│ ├──utils.py // utils
│ ├──create_dataset_ILSVRC.py // Create dataset
│ ├──create_lmdb.py //Create LMDB
├── train.py // Training script
├── eval.py // Evaluation script
```
## Script parameters
python train.py and config.py The main parameters are as follows:
- data_pathAn absolutely complete path to training and evaluation data sets.
- epoch_sizeTotal training rounds
- batch_sizeTraining batch size.
- image_heightThe image height is used as the model input.
- image_widthThe image width is used as the model input.
- exemplar_sizeTemplate size
- instance_sizeSample size.
- lrLearning rate.
- frame_rangeSelect the frame interval of the template and sample.
- response_scaleScaling factor of score chart.
## Training process
### Training
- Running in ascend processor environment
```python
python train.py --device_id=${DEVICE_ID} --data_path=${DATASET_PATH}
```
- After training, the loss value is as follows:
```bash
grep "loss is " log
epoch: 1 step: 1, loss is 1.14123213
...
epoch: 1 step: 1536, loss is 0.5234123
epoch: 1 step: 1537, loss is 0.4523326
epoch: 1 step: 1538, loss is 0.6235748
...
```
- Model checkpoints are saved in the current directory.
- After training, the loss value is as follows:
```bash
grep "loss is " log:
epoch: 30 step: 1, loss is 0.12534634
...
epoch: 30 step: 1560, loss is 0.2364573
epoch: 30 step: 1561, loss is 0.156347
epoch: 30 step: 1561, loss is 0.173423
```
## Evaluation process
Check the checkpoint path used for evaluation before running the following command.
- Running in ascend processor environment
```bash
python eval.py --device_id=${DEVICE_ID} --model_path=${MODEL_PATH}
```
The results were as follows:
```bash
SiamFC_159_50_6650.ckpt -prec_score:0.777 -succ_score:0.589 _succ_rate:0.754
```
# [Model description](#Contents)
## performance
### Evaluate performance
|parameter | Ascend |
| -------------------------- | ---------------------------------------------- |
|resources | Ascend 910CPU 2.60GHz, 192corememory755G |
|Upload date |2021.5.20 |
|mindspore version |mindspore1.2.0 |
|training parameter | epoch=50,step=6650,batch_size=8,lr_init=1e-2,lr_endl=1e-5 |
|optimizer |SGD optimizermomentum=0.0,weight_decay=0.0 |
|loss function |BCEWithLogits |
|training speed | epoch time285693.557 ms per step time :42.961 ms |
|total time |about 5 hours |
|Script URL |https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/SiamFC |
|Random number seed |set_seed = 1234 |

View File

@ -0,0 +1,5 @@
lmdb==1.1.1
tqdm @ file:///tmp/build/80754af9/tqdm_1615925068909/work
opencv-python==3.4.5
got10k==0.1.3
numpy==1.19.5

View File

@ -0,0 +1,20 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
export MODEL_PATH=$2
export DATASET_PATH=$3
python eval.py --device_id=${DEVICE_ID} --model_path=${MODEL_PATH} --dataset_path=${DATASET_PATH} > log.txt 2>&1 &

View File

@ -0,0 +1,19 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
export DATASET_PATH=$2
python train.py --device_id=${DEVICE_ID} --data_path=${DATASET_PATH} > log.txt 2>&1 &

View File

@ -0,0 +1,21 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
init.
"""
from .tracker import SiamFCTracker
from .config import config
from .utils import get_instance_image
from .dataset import ImagnetVIDDataset

View File

@ -0,0 +1,156 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""alexnet backbone"""
import numpy as np
import mindspore.nn as nn
import mindspore.numpy as n_p
from mindspore.common.initializer import HeNormal
from mindspore import Parameter, Tensor, ops
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .config import config
class SiameseAlexNet(nn.Cell):
"""
define alexnet both used in train and eval
if train = True
Returns: loss
else
if the first image pair used in evel:
Return: exemplar
if the other image pair used in evel:
Return: score map
"""
def __init__(self, train=True):
super(SiameseAlexNet, self).__init__()
self.conv1 = nn.Conv2d(3, 96, 11, has_bias=True, stride=2, pad_mode='valid',
weight_init=HeNormal(mode='fan_out', nonlinearity='relu'),
bias_init=0)
self.conv2 = nn.Conv2d(96, 256, 5, has_bias=True, stride=1, pad_mode='valid',
weight_init=HeNormal(mode='fan_out', nonlinearity='relu'),
bias_init=0)
self.conv3 = nn.Conv2d(256, 384, 3, has_bias=True, stride=1, pad_mode='valid',
weight_init=HeNormal(mode='fan_out', nonlinearity='relu'),
bias_init=0)
self.conv4 = nn.Conv2d(384, 384, 3, has_bias=True, stride=1, pad_mode='valid',
weight_init=HeNormal(mode='fan_out', nonlinearity='relu'),
bias_init=0)
self.conv5 = nn.Conv2d(384, 256, 3, has_bias=True, stride=1, pad_mode='valid',
weight_init=HeNormal(mode='fan_out', nonlinearity='relu'),
bias_init=0)
self.bn1 = nn.BatchNorm2d(96)
self.bn2 = nn.BatchNorm2d(256)
self.bn3 = nn.BatchNorm2d(384)
self.bn4 = nn.BatchNorm2d(384)
self.relu = nn.ReLU()
self.max_pool2d_1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
self.max_pool2d_2 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
self.training = train
self.conv2d_train = ops.Conv2D(out_channel=8, kernel_size=6, stride=1, pad_mode='valid')
self.corr_bias = Parameter(n_p.zeros(1))
self.cast = P.Cast()
if train:
gt_train, weight_train = self._create_gt_mask((config.train_response_sz,
config.train_response_sz))
self.train_gt = Tensor(gt_train).astype(mstype.float32)
self.train_weight = Tensor(weight_train).astype(mstype.float32)
self.loss = nn.BCEWithLogitsLoss(reduction='sum', weight=self.train_weight)
self.adjust_batchnormal = nn.BatchNorm2d(1)
self.groups = config.train_batch_size
self.feature_spilt = P.Split(axis=1, output_num=self.groups)
self.kernel_split = P.Split(axis=0, output_num=self.groups)
self.feature_spilt_evel = P.Split(axis=1, output_num=3)
self.kernel_split_evel = P.Split(axis=0, output_num=3)
self.Conv2D_1 = ops.Conv2D(out_channel=1, kernel_size=6, stride=1, pad_mode='valid')
self.op_concat = P.Concat(axis=1)
self.op_concat_exemplar = P.Concat(axis=0)
self.seq = nn.SequentialCell([self.conv1,
self.bn1,
self.relu,
self.max_pool2d_1,
self.conv2,
self.bn2,
self.relu,
self.max_pool2d_2,
self.conv3,
self.bn3,
self.relu,
self.conv4,
self.bn4,
self.relu,
self.conv5,
])
def construct(self, x, y):
"""network construct"""
if self.training:
x = n_p.squeeze(x)
y = n_p.squeeze(y)
exemplar = x
instance = y
exemplar = self.seq(exemplar)
instance = self.seq(instance)
nx, cx, h, w = instance.shape
instance = instance.view(1, nx * cx, h, w)
features = self.feature_spilt(instance)
kernel = self.kernel_split(exemplar)
outputs = ()
for i in range(self.groups):
outputs = outputs + ((self.Conv2D_1(self.cast(features[i], mstype.float32),
self.cast(kernel[i], mstype.float32))),)
score_map = self.op_concat(outputs)
score_map = n_p.transpose(score_map, (1, 0, 2, 3))
score_map = score_map*1e-3+self.corr_bias
score = self.loss(score_map, self.train_gt)/8
else:
exemplar = x
instance = y
if exemplar.size is not None and instance.size == 1:
exemplar = self.seq(exemplar)
return exemplar
instance = self.seq(instance)
nx, cx, h, w = instance.shape
instance = n_p.reshape(instance, [1, nx*cx, h, w])
features = self.feature_spilt_evel(instance)
kernel = self.kernel_split_evel(exemplar)
outputs = ()
outputs = outputs + (self.Conv2D_1(features[0], kernel[0]),)
outputs = outputs + (self.Conv2D_1(features[1], kernel[1]),)
outputs = outputs + (self.Conv2D_1(features[2], kernel[2]),)
score_map = self.op_concat(outputs)
score = n_p.transpose(score_map, (1, 0, 2, 3))
return score
def _create_gt_mask(self, shape):
"""crete label """
# same for all pairs
h, w = shape
y = np.arange(h, dtype=np.float32) - (h - 1) / 2.
x = np.arange(w, dtype=np.float32) - (w - 1) / 2.
y, x = np.meshgrid(y, x)
dist = np.sqrt(x ** 2 + y ** 2)
mask = np.zeros((h, w))
mask[dist <= config.radius / config.total_stride] = 1
mask = mask[np.newaxis, :, :]
weights = np.ones_like(mask)
weights[mask == 1] = 0.5 / np.sum(mask == 1)
weights[mask == 0] = 0.5 / np.sum(mask == 0)
mask = np.repeat(mask, config.train_batch_size, axis=0)[:, np.newaxis, :, :]
return mask.astype(np.float32), weights.astype(np.float32)

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config used in train, evel and deal with image """
class Config:
"""set config"""
# dataset related
exemplar_size = 127 # exemplar size
instance_size = 255 # instance size
context_amount = 0.5 # context amount
# training related
num_per_epoch = 53200 # num of samples per epoch
train_ratio = 0.9 # training ratio of VID dataset
frame_range = 120 # frame range of choosing the instance
train_batch_size = 8 # training batch size
valid_batch_size = 8 # validation batch size
train_num_workers = 8 # number of workers of train dataloader
valid_num_workers = 8 # number of workers of validation dataloader
lr = 1e-2
end_lr = 1e-5 # learning rate of SGD
momentum = 0.0 #0.9 # momentum of SGD
weight_decay = 0.0#5e-4 # weight decay of optimizator
step_size = 1 # step size of LR_Schedular
gamma = 0.8685 # decay rate of LR_Schedular
epoch = 50 # total epoch
seed = 1234 # seed to sample training videos
log_dir = './models/logs' # log dirs
radius = 16 # radius of positive label
response_scale = 1e-3 # normalize of response
max_translate = 3 # max translation of random shift
# tracking related
scale_step = 1.0375 # scale step of instance image
num_scale = 3 # number of scales
scale_lr = 0.59 # scale learning rate
response_up_stride = 16 # response upsample stride
response_sz = 17 # response size
train_response_sz = 15 # train response size
window_influence = 0.176 # window influence
scale_penalty = 0.9745 # scale penalty
total_stride = 8 # total stride of backbone
sample_type = 'uniform'
gray_ratio = 0.25
blur_ratio = 0.15
dataset_OTB2013 = 'OTB2013'
dataset_OTB2015 = 'OTB2015'
dataset_VOT2015 = 'VOT2015'
dataset_VOT2018 = 'VOT2018'
config = Config()

View File

@ -0,0 +1,98 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""start create_dataset_ILSVRC"""
import pickle
import os
import functools
import xml.etree.ElementTree as ET
import sys
import argparse
import multiprocessing
from multiprocessing import Pool
from glob import glob
import cv2
from tqdm import tqdm
from src import config, get_instance_image
sys.path.append(os.getcwd())
multiprocessing.set_start_method('spawn', True)
def worker(output_dir, video_dir):
"""worker used read image and box position"""
image_names = glob(os.path.join(video_dir, '*.JPEG'))
video_name = video_dir.split('/')[-1]
save_folder = os.path.join(output_dir, video_name)
if not os.path.exists(save_folder):
os.mkdir(save_folder)
trajs = {}
for image_name in image_names:
img = cv2.imread(image_name)
img_mean = tuple(map(int, img.mean(axis=(0, 1))))
anno_name = image_name.replace('Data', 'Annotations')
anno_name = anno_name.replace('JPEG', 'xml')
tree = ET.parse(anno_name)
root = tree.getroot()
filename = root.find('filename').text
for obj in root.iter('object'):
bbox = obj.find(
'bndbox')
bbox = list(map(int, [bbox.find('xmin').text,
bbox.find('ymin').text,
bbox.find('xmax').text,
bbox.find('ymax').text]))
trkid = int(obj.find('trackid').text)
if trkid in trajs:
trajs[trkid].append(filename)
else:
trajs[trkid] = [filename]
instance_img, _, _ = get_instance_image(img, bbox,
config.exemplar_size,
config.instance_size,
config.context_amount,
img_mean)
instance_img_name = os.path.join(save_folder, filename + ".{:02d}.x.jpg".format(trkid))
cv2.imwrite(instance_img_name, instance_img)
return video_name, trajs
def processing(data_dir, output_dir, num_threads=32):
"""
the mian process to pretreatment picture and use multi-threads
"""
video_dir = os.path.join(data_dir, 'Data/VID')
all_videos = glob(os.path.join(video_dir, 'train/ILSVRC2015_VID_train_0000/*')) + \
glob(os.path.join(video_dir, 'train/ILSVRC2015_VID_train_0001/*')) + \
glob(os.path.join(video_dir, 'train/ILSVRC2015_VID_train_0002/*')) + \
glob(os.path.join(video_dir, 'train/ILSVRC2015_VID_train_0003/*')) + \
glob(os.path.join(video_dir, 'val/*'))
meta_data = []
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with Pool(processes=num_threads) as pool:
for ret in tqdm(pool.imap_unordered(functools.partial(worker, output_dir), all_videos),
total=len(all_videos)):
meta_data.append(ret)
pickle.dump(meta_data, open(os.path.join(output_dir, "meta_data.pkl"), 'wb'))
Data_dir = '/data/VID/ILSVRC2015'
Output_dir = '/data/VID/ILSVRC_VID_CURATION_train'
Num_threads = 32
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Demo SiamFC")
parser.add_argument('--d', default=Data_dir, type=str, help="data_dir")
parser.add_argument('--o', default=Output_dir, type=str, help="out put")
parser.add_argument('--t', default=Num_threads, type=int, help="thread_num")
args = parser.parse_args()
processing(args.d, args.o, args.t)

View File

@ -0,0 +1,68 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""start create lmdb"""
import os
import hashlib
import functools
import argparse
from glob import glob
import multiprocessing
from multiprocessing import Pool
from tqdm import tqdm
import lmdb
import cv2
multiprocessing.set_start_method('spawn', True)
def worker(video_name):
"""
workers used create key and value
"""
image_names = glob(video_name + '/*')
kv = {}
for image_name in image_names:
img = cv2.imread(image_name)
_, img_encode = cv2.imencode('.jpg', img)
img_encode = img_encode.tobytes()
kv[hashlib.md5(image_name.encode()).digest()] = img_encode
return kv
def create_lmdb(data_dir, output_dir, num_threads):
"""
create lmdb use multi-threads
"""
video_names = glob(data_dir + '/*')
video_names = [x for x in video_names if os.path.isdir(x)]
db = lmdb.open(output_dir, map_size=int(50e12))
with Pool(processes=num_threads) as pool:
for ret in tqdm(pool.imap_unordered(functools.partial(worker),
video_names),
total=len(video_names)):
with db.begin(write=True) as txn:
for k, v in ret.items():
txn.put(k, v)
Data_dir = '/data/VID/ILSVRC_VID_CURATION_train'
Output_dir = '/data/VID/ILSVRC_VID_CURATION_train.lmdb'
Num_threads = 32
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Demo SiamFC")
parser.add_argument('--d', default=Data_dir, type=str, help="data_dir")
parser.add_argument('--o', default=Output_dir, type=str, help="out put")
parser.add_argument('--n', default=Num_threads, type=int, help="thread_num")
args = parser.parse_args()
create_lmdb(args.d, args.o, args.n)

View File

@ -0,0 +1,186 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""used in data enhance"""
import numpy as np
import cv2
class RandomStretch():
"""
Random resize image according to the stretch
Args:
max_stretch(float): 0 to 1 value
"""
def __init__(self, max_stretch=0.05):
self.max_stretch = max_stretch #
def __call__(self, sample):
"""
Args:
sample(numpy array): 3 or 1 dim image
"""
scale_h = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch)
scale_w = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch)
h, w = sample.shape[:2]
shape = (int(h * scale_h), int(w * scale_w))
return cv2.resize(sample, shape, cv2.INTER_LINEAR)
class CenterCrop():
"""
Crop the image in the center according the given size
if size greater than image size, zero padding will adpot
Args:
size (tuple): desired size
"""
def __init__(self, size):
self.size = size # z-> 127x127 x->255x255
def __call__(self, sample):
"""
Args:
sample(numpy array): 3 or 1 dim image
"""
shape = sample.shape[:2]
cy, cx = (shape[0]-1) // 2, (shape[1]-1) // 2
ymin, xmin = cy - self.size[0]//2, cx - self.size[1] // 2
ymax, xmax = cy + self.size[0]//2 + self.size[0] % 2,\
cx + self.size[1]//2 + self.size[1] % 2
left = right = top = bottom = 0
im_h, im_w = shape
if xmin < 0:
left = int(abs(xmin))
if xmax > im_w:
right = int(xmax - im_w)
if ymin < 0:
top = int(abs(ymin))
if ymax > im_h:
bottom = int(ymax - im_h)
xmin = int(max(0, xmin))
xmax = int(min(im_w, xmax))
ymin = int(max(0, ymin))
ymax = int(min(im_h, ymax))
im_patch = sample[ymin:ymax, xmin:xmax]
if left != 0 or right != 0 or top != 0 or bottom != 0:
im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=0)
return im_patch
class RandomCrop():
"""
Crop the image in the center according the given size
if size greater than image size, zero padding will adpot
Args:
size (tuple): desired size
max_translate: max translate of random shift
"""
def __init__(self, size, max_translate):
self.size = size # 255 - 2*stride stride=8
self.max_translate = max_translate # 255 - 2*stride
def __call__(self, sample):
"""
Args:
sample(numpy array): 3 or 1 dim image
"""
shape = sample.shape[:2]
cy_o = (shape[0] - 1) // 2
cx_o = (shape[1] - 1) // 2
cy = np.random.randint(cy_o - self.max_translate,
cy_o + self.max_translate+1)
cx = np.random.randint(cx_o - self.max_translate,
cx_o + self.max_translate+1)
assert abs(cy-cy_o) <= self.max_translate and \
abs(cx-cx_o) <= self.max_translate
ymin = cy - self.size[0] // 2
xmin = cx - self.size[1] // 2
ymax = cy + self.size[0] // 2 + self.size[0] % 2
xmax = cx + self.size[1] // 2 + self.size[1] % 2
left = right = top = bottom = 0
im_h, im_w = shape
if xmin < 0:
left = int(abs(xmin))
if xmax > im_w:
right = int(xmax - im_w)
if ymin < 0:
top = int(abs(ymin))
if ymax > im_h:
bottom = int(ymax - im_h)
xmin = int(max(0, xmin))
xmax = int(min(im_w, xmax))
ymin = int(max(0, ymin))
ymax = int(min(im_h, ymax))
im_patch = sample[ymin:ymax, xmin:xmax]
if left != 0 or right != 0 or top != 0 or bottom != 0:
im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=0)
return im_patch
class ColorAug():
"""
colorAug
"""
def __init__(self, type_in='z'):
if type_in == 'z':
rgb_var = np.array([[3.2586416e+03, 2.8992207e+03, 2.6392236e+03],
[2.8992207e+03, 3.0958174e+03, 2.9321748e+03],
[2.6392236e+03, 2.9321748e+03, 3.4533721e+03]])
if type_in == 'x':
rgb_var = np.array([[2.4847285e+03, 2.1796064e+03, 1.9766885e+03],
[2.1796064e+03, 2.3441289e+03, 2.2357402e+03],
[1.9766885e+03, 2.2357402e+03, 2.7369697e+03]])
self.v, _ = np.linalg.eig(rgb_var)
self.v = np.sqrt(self.v)
def __call__(self, sample):
return sample + 0.1 * self.v * np.random.randn(3)
class RandomBlur():
"""Randomblur"""
def __init__(self, ratio):
self.ratio = ratio
def __call__(self, sample):
if np.random.rand(1) < self.ratio:
# random kernel size
kernel_size = np.random.choice([3, 5, 7])
# random gaussian sigma
sigma = np.random.rand() * 5
sample_gaussian = cv2.GaussianBlur(sample, (kernel_size, kernel_size), sigma)
else:
return sample
return sample_gaussian
class Normalize():
"""
image normalize to 0-1
"""
def __init__(self):
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def __call__(self, sample):
return (sample / 255. - self.mean) / self.std
class ToTensor():
"""transpose and totensor"""
def __call__(self, sample):
sample = np.transpose(sample, (2, 0, 1))
return np.array(sample, dtype=np.float32)

View File

@ -0,0 +1,119 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""VID dataset"""
import os
import pickle
import hashlib
import cv2
import numpy as np
from src.config import config
class ImagnetVIDDataset():
"""
used in GeneratorDataset to deal with image pair
Args:
db : lmdb file
video_names : all video name
data_dir : the location of image pair
z_transforms : the transforms list used in exemplar
x_transforms : the transforms list used in instance
training : status of training
"""
def __init__(self, db, video_names, data_dir, z_transforms, x_transforms, training=True):
self.video_names = video_names
self.data_dir = data_dir
self.z_transforms = z_transforms
self.x_transforms = x_transforms
meta_data_path = os.path.join(data_dir, 'meta_data.pkl')
self.meta_data = pickle.load(open(meta_data_path, 'rb'))
self.meta_data = {x[0]: x[1] for x in self.meta_data}
for key in self.meta_data.keys():
trajs = self.meta_data[key]
for trkid in list(trajs.keys()):
if len(trajs[trkid]) < 2:
del trajs[trkid]
self.txn = db.begin(write=False)
self.num = len(self.video_names) if config.num_per_epoch is None or not \
training else config.num_per_epoch
def imread(self, path):
"""
read iamges according to path
Args :
path : the image path
"""
key = hashlib.md5(path.encode()).digest()
img_buffer = self.txn.get(key)
img_buffer = np.frombuffer(img_buffer, np.uint8)
img = cv2.imdecode(img_buffer, cv2.IMREAD_COLOR)
return img
def _sample_weights(self, center, low_idx, high_idx, s_type='uniform'):
"""
According to the center image to pick another image,setting the weights
will be used in different type distribution
Args:
center : the position of center image
low_idx : the minimum of id
high_idx : the max of id
s_type : choose different distribution. "uniform", "sqrt", "linear"
can be chosen
"""
weights = list(range(low_idx, high_idx))
weights.remove(center)
weights = np.array(weights)
if s_type == 'linear':
weights = abs(weights - center)
elif s_type == 'sqrt':
weights = np.sqrt(abs(weights - center))
elif s_type == 'uniform':
weights = np.ones_like(weights)
return weights / sum(weights)
def __getitem__(self, idx):
idx = idx % len(self.video_names)
video = self.video_names[idx]
trajs = self.meta_data[video]
trkid = np.random.choice(list(trajs.keys()))
traj = trajs[trkid]
assert len(traj) > 1, "video_name: {}".format(video)
exemplar_idx = np.random.choice(list(range(len(traj))))
exemplar_name = os.path.join(self.data_dir, video,
traj[exemplar_idx] + ".{:02d}.x.jpg".format(trkid))
exemplar_img = self.imread(exemplar_name)
exemplar_img = cv2.cvtColor(exemplar_img, cv2.COLOR_BGR2RGB)
# sample instance
low_idx = max(0, exemplar_idx - config.frame_range)
up_idx = min(len(traj), exemplar_idx + config.frame_range)
weights = self._sample_weights(exemplar_idx, low_idx, up_idx, config.sample_type)
instance = np.random.choice(traj[low_idx:exemplar_idx] + traj[exemplar_idx + 1:up_idx],
p=weights)
instance_name = os.path.join(self.data_dir, video, instance + ".{:02d}.x.jpg".format(trkid))
instance_img = self.imread(instance_name)
instance_img = cv2.cvtColor(instance_img, cv2.COLOR_BGR2RGB)
if np.random.rand(1) < config.gray_ratio:
exemplar_img = cv2.cvtColor(exemplar_img, cv2.COLOR_RGB2GRAY)
exemplar_img = cv2.cvtColor(exemplar_img, cv2.COLOR_GRAY2RGB)
instance_img = cv2.cvtColor(instance_img, cv2.COLOR_RGB2GRAY)
instance_img = cv2.cvtColor(instance_img, cv2.COLOR_GRAY2RGB)
exemplar_img = self.z_transforms(exemplar_img)
instance_img = self.x_transforms(instance_img)
return exemplar_img, instance_img
def __len__(self):
return self.num

View File

@ -0,0 +1,165 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""set tracker"""
import time
import numpy as np
import cv2
from tqdm import tqdm
import mindspore.numpy as ms_np
from mindspore import load_checkpoint, load_param_into_net, ops, Tensor
from .alexnet import SiameseAlexNet
from .config import config
from .utils import get_exemplar_image, get_pyramid_instance_image, show_image
class SiamFCTracker:
"""
tracker used in evel
Args:
model_path : checkpoint path
"""
def __init__(self, model_path, is_deterministic=True):
self.network = SiameseAlexNet(train=False)
load_param_into_net(self.network, load_checkpoint(model_path), strict_load=True)
self.network.set_train(False)
self.name = 'SiamFC'
self.is_deterministic = is_deterministic
def _cosine_window(self, size):
"""
get the cosine window
"""
cos_window = np.hanning(int(size[0]))[:, np.newaxis].dot(np.hanning(int(size[1]))
[np.newaxis, :])
cos_window = cos_window.astype(np.float32)
cos_window /= np.sum(cos_window)
return cos_window
def init(self, frame, bbox):
""" initialize siamfc trackers
Args:
frame: an RGB image
bbox: one-based bounding box [x, y, width, height]
"""
self.bbox = (bbox[0]-1, bbox[1]-1, bbox[0]-1+bbox[2], bbox[1]-1 + bbox[3]) # zero based
self.pos = np.array(
[bbox[0]-1+(bbox[2]-1)/2, bbox[1]-1+(bbox[3]-1)/2]) # center x, center y, zero based
self.target_sz = np.array([bbox[2], bbox[3]]) # width, height
# get exemplar img
self.img_mean = tuple(map(int, frame.mean(axis=(0, 1))))
exemplar_img, scale_z, s_z = get_exemplar_image(frame, self.bbox,
config.exemplar_size, config.context_amount,
self.img_mean) # context_amount = 0.5
# exemplar_size=127
exemplar_img = np.transpose(exemplar_img, (2, 0, 1))[None, :, :, :]
exemplar_img = Tensor(exemplar_img, dtype=ms_np.float32)
none = ms_np.ones(1)
self.exemplar = self.network(exemplar_img, none)
self.exemplar = ops.repeat_elements(self.exemplar, rep=3, axis=0)
self.penalty = np.ones((config.num_scale)) * config.scale_penalty # 111*0.9745
self.penalty[config.num_scale // 2] = 1
# create cosine window
self.interp_response_sz = config.response_up_stride * config.response_sz
self.cosine_window = self._cosine_window((self.interp_response_sz,
self.interp_response_sz))
self.scales = config.scale_step**np.arange(np.ceil(3/2)-3,
np.floor(3/2)+1) # [0.96385542,1,1.0375]
# create s_x
self.s_x = s_z + (config.instance_size - config.exemplar_size) / scale_z
# arbitrary scale saturation
self.min_s_x = 0.2 * self.s_x
self.max_s_x = 5 * self.s_x
def update(self, frame):
"""
track object based on the previous frame
Args:
frame: an RGB image
Returns:
bbox: tuple of 1-based bounding box(xmin, ymin, xmax, ymax)
"""
size_x_scales = self.s_x * self.scales
pyramid = get_pyramid_instance_image(frame, self.pos, config.instance_size,
size_x_scales, self.img_mean)
x_crops_tensor = ()
for k in range(3):
tmp_x_crop = Tensor(pyramid[k], dtype=ms_np.float32)
tmp_x_crop = ops.transpose(tmp_x_crop, (2, 0, 1))[np.newaxis, :, :, :]
x_crops_tensor = x_crops_tensor+(tmp_x_crop,)
instance_imgs = ms_np.concatenate(x_crops_tensor, axis=0)
response_maps = self.network(self.exemplar, instance_imgs)
response_maps = ms_np.squeeze(response_maps)
response_maps = Tensor.asnumpy(response_maps)
response_maps_up = [cv2.resize(x, (self.interp_response_sz, self.interp_response_sz),
cv2.INTER_CUBIC)for x in response_maps]
# get max score
max_score = np.array([x.max() for x in response_maps_up]) * self.penalty
# penalty scale change
scale_idx = max_score.argmax()
response_map = response_maps_up[scale_idx]
response_map -= response_map.min()
response_map /= response_map.sum()
response_map = (1 - config.window_influence) * response_map + \
config.window_influence * self.cosine_window
max_r, max_c = np.unravel_index(response_map.argmax(), response_map.shape)
# displacement in interpolation response
disp_response_interp = np.array([max_c, max_r]) - (self.interp_response_sz - 1) / 2.
# displacement in input
disp_response_input = disp_response_interp*config.total_stride/config.response_up_stride
# displacement in frame
scale = self.scales[scale_idx]
disp_response_frame = disp_response_input*(self.s_x * scale)/config.instance_size
# position in frame coordinates
self.pos += disp_response_frame
# scale damping and saturation
self.s_x *= ((1 - config.scale_lr) + config.scale_lr * scale)
self.s_x = max(self.min_s_x, min(self.max_s_x, self.s_x))
self.target_sz = ((1 - config.scale_lr) + config.scale_lr * scale) * self.target_sz
box = np.array([
self.pos[0] + 1 - (self.target_sz[0]) / 2,
self.pos[1] + 1 - (self.target_sz[1]) / 2,
self.target_sz[0], self.target_sz[1]])
return box
def track(self, img_files, box, visualize=False):
"""
To get the update track box and calculate time
Args :
img_files : the location of img
box : the first image box, include x, y, width, high
"""
frame_num = len(img_files)
boxes = np.zeros((frame_num, 4))
boxes[0] = box # xy, w, h
times = np.zeros(frame_num)
for f, img_file in tqdm(enumerate(img_files), total=len(img_files)):
img = cv2.imread(img_file, cv2.IMREAD_COLOR)
begin = time.time()
if f == 0:
self.init(img, box)
else:
boxes[f, :] = self.update(img)
times[f] = time.time() - begin
if visualize:
show_image(img, boxes[f, :])
return boxes, times

View File

@ -0,0 +1,161 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""the utils used in dataset"""
import numpy as np
import cv2
def get_center(x):
"""get box center"""
return (x - 1.) / 2.
def xyxy2cxcywh(bbox):
"""change box format"""
return get_center(bbox[0]+bbox[2]), \
get_center(bbox[1]+bbox[3]), \
(bbox[2]-bbox[0]), \
(bbox[3]-bbox[1])
def crop_and_pad(img, cx, cy, model_sz, original_sz, img_mean=None):
"""crop and pad image, pad with image mean """
xmin = cx - original_sz // 2
xmax = cx + original_sz // 2
ymin = cy - original_sz // 2
ymax = cy + original_sz // 2
im_h, im_w, _ = img.shape
left = right = top = bottom = 0
if xmin < 0:
left = int(abs(xmin))
if xmax > im_w:
right = int(xmax - im_w)
if ymin < 0:
top = int(abs(ymin))
if ymax > im_h:
bottom = int(ymax - im_h)
xmin = int(max(0, xmin))
xmax = int(min(im_w, xmax))
ymin = int(max(0, ymin))
ymax = int(min(im_h, ymax))
im_patch = img[ymin:ymax, xmin:xmax]
if left != 0 or right != 0 or top != 0 or bottom != 0:
if img_mean is None:
img_mean = tuple(map(int, img.mean(axis=(0, 1))))
im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=img_mean)
if model_sz != original_sz:
im_patch = cv2.resize(im_patch, (model_sz, model_sz))
return im_patch
def get_exemplar_image(img, bbox, size_z, context_amount, img_mean=None):
"""get exemplar according to given size"""
cx, cy, w, h = xyxy2cxcywh(bbox)
wc_z = w + context_amount * (w+h)#w+0.5*(w+h)
hc_z = h + context_amount * (w+h)#h+0.5*(w+h)
#orginal_sz
s_z = np.sqrt(wc_z * hc_z)
#model_sz
scale_z = size_z / s_z
exemplar_img = crop_and_pad(img, cx, cy, size_z, s_z, img_mean)
return exemplar_img, scale_z, s_z
def get_instance_image(img, bbox, size_z, size_x, context_amount, img_mean=None):
"""get instance according to given size"""
cx, cy, w, h = xyxy2cxcywh(bbox)
wc_z = w + context_amount * (w+h)
hc_z = h + context_amount * (w+h)
s_z = np.sqrt(wc_z * hc_z)
scale_z = size_z / s_z
d_search = (size_x - size_z) / 2
pad = d_search / scale_z
s_x = s_z + 2 * pad
scale_x = size_x / s_x
instance_img = crop_and_pad(img, cx, cy, size_x, s_x, img_mean)
return instance_img, scale_x, s_x
def get_pyramid_instance_image(img, center, size_x, size_x_scales,
img_mean=None):
"""get pyramid instance"""
if img_mean is None:
img_mean = tuple(map(int, img.mean(axis=(0, 1))))
pyramid = [crop_and_pad(img, center[0], center[1], size_x,
size_x_scale, img_mean)
for size_x_scale in size_x_scales]
return pyramid
def show_image(img, boxes=None, box_fmt='ltwh', colors=None,
thickness=3, fig_n=1, delay=1, visualize=True,
cvt_code=cv2.COLOR_RGB2BGR):
"""show image define """
if cvt_code is not None:
img = cv2.cvtColor(img, cvt_code)
# resize img if necessary
max_size = 960
if max(img.shape[:2]) > max_size:
scale = max_size / max(img.shape[:2])
out_size = (
int(img.shape[1] * scale),
int(img.shape[0] * scale))
img = cv2.resize(img, out_size)
if boxes is not None:
boxes = np.array(boxes, dtype=np.float32) * scale
if boxes is not None:
assert box_fmt in ['ltwh', 'ltrb']
boxes = np.array(boxes, dtype=np.float32)
if boxes.ndim == 1:
boxes = np.expand_dims(boxes, axis=0)
if box_fmt == 'ltrb':
boxes[:, 2:] -= boxes[:, :2]
# clip bounding boxes
bound = np.array(img.shape[1::-1])[None, :]
boxes[:, :2] = np.clip(boxes[:, :2], 0, bound)
boxes[:, 2:] = np.clip(boxes[:, 2:], 0, bound - boxes[:, :2])
if colors is None:
colors = [
(0, 0, 255),
(0, 255, 0),
(255, 0, 0),
(0, 255, 255),
(255, 0, 255),
(255, 255, 0),
(0, 0, 128),
(0, 128, 0),
(128, 0, 0),
(0, 128, 128),
(128, 0, 128),
(128, 128, 0)]
colors = np.array(colors, dtype=np.int32)
if colors.ndim == 1:
colors = np.expand_dims(colors, axis=0)
for i, box in enumerate(boxes):
color = colors[i % len(colors)]
pt1 = (box[0], box[1])
pt2 = (box[0] + box[2], box[1] + box[3])
img = cv2.rectangle(img, pt1, pt2, color.tolist(), thickness)
if visualize:
winname = 'window_{}'.format(fig_n)
cv2.imshow(winname, img)
cv2.waitKey(delay)
return img

View File

@ -0,0 +1,132 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""start train """
import sys
import os
import pickle
import argparse
import lmdb
from mindspore.common import set_seed
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore import context
from mindspore.context import ParallelMode
import mindspore.dataset as ds
from mindspore import nn
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
import mindspore.dataset.transforms.py_transforms as py_transforms
from src.config import config
from src.alexnet import SiameseAlexNet
from src.dataset import ImagnetVIDDataset
from src.custom_transforms import ToTensor, RandomStretch, RandomCrop, CenterCrop
sys.path.append(os.getcwd())
def train(data_dir):
"""set train """
# loading meta data
meta_data_path = os.path.join(data_dir, "meta_data.pkl")
meta_data = pickle.load(open(meta_data_path, 'rb'))
all_videos = [x[0] for x in meta_data]
set_seed(1234)
random_crop_size = config.instance_size - 2 * config.total_stride
train_z_transforms = py_transforms.Compose([
RandomStretch(),
CenterCrop((config.exemplar_size, config.exemplar_size)),
ToTensor()
])
train_x_transforms = py_transforms.Compose([
RandomStretch(),
RandomCrop((random_crop_size, random_crop_size),
config.max_translate),
ToTensor()
])
db_open = lmdb.open(data_dir + '.lmdb', readonly=True, map_size=int(50e12))
# create dataset
train_dataset = ImagnetVIDDataset(db_open, all_videos, data_dir,
train_z_transforms, train_x_transforms)
dataset = ds.GeneratorDataset(train_dataset, ["exemplar_img", "instance_img"], shuffle=True,
num_parallel_workers=8)
dataset = dataset.batch(batch_size=8, drop_remainder=True)
#set network
network = SiameseAlexNet(train=True)
decay_lr = nn.polynomial_decay_lr(config.lr,
config.end_lr,
total_step=config.epoch * config.num_per_epoch,
step_per_epoch=config.num_per_epoch,
decay_epoch=config.epoch,
power=1.0)
optim = nn.SGD(params=network.trainable_params(),
learning_rate=decay_lr,
momentum=config.momentum,
weight_decay=config.weight_decay)
loss_scale_manager = DynamicLossScaleManager()
model = Model(network,
optimizer=optim,
loss_scale_manager=loss_scale_manager,
metrics=None,
amp_level='O3')
config_ck_train = CheckpointConfig(save_checkpoint_steps=6650, keep_checkpoint_max=20)
ckpoint_cb_train = ModelCheckpoint(prefix='SiamFC',
directory='./models/siamfc_{}.ckpt',
config=config_ck_train)
time_cb_train = TimeMonitor(data_size=config.num_per_epoch)
loss_cb_train = LossMonitor()
model.train(epoch=config.epoch,
train_dataset=dataset,
callbacks=[time_cb_train, ckpoint_cb_train, loss_cb_train],
dataset_sink_mode=True
)
if __name__ == '__main__':
ARGPARSER = argparse.ArgumentParser(description=" SiamFC Train")
ARGPARSER.add_argument('--device_target',
type=str,
default="Ascend",
choices=['GPU', 'CPU', 'Ascend'],
help='the target device to run, support "GPU", "CPU"')
ARGPARSER.add_argument('--data_path',
default="/data/VID/ILSVRC_VID_CURATION_train",
type=str,
help=" the path of data")
ARGPARSER.add_argument('--sink_size',
type=int, default=-1,
help='control the amount of data in each sink')
ARGPARSER.add_argument('--device_id',
type=int, default=7,
help='device id of GPU or Ascend')
ARGS = ARGPARSER.parse_args()
DEVICENUM = int(os.environ.get("DEVICE_NUM", 1))
DEVICETARGET = ARGS.device_target
if DEVICETARGET == "Ascend":
context.set_context(
mode=context.GRAPH_MODE,
device_id=ARGS.device_id,
save_graphs=False,
device_target=ARGS.device_target)
if DEVICENUM > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=DEVICENUM,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
# train
train(ARGS.data_path)