!9262 Add FaceDetection net to /model_zoo/research/cv/

From: @zhanghuiyao
Reviewed-by: @oacjiewen,@linqingke,@oacjiewen
Signed-off-by: @linqingke
This commit is contained in:
mindspore-ci-bot 2020-12-09 12:52:57 +08:00 committed by Gitee
commit 2fa9b51ae8
23 changed files with 3985 additions and 0 deletions

View File

@ -0,0 +1,245 @@
# Contents
- [Face Detection Description](#face-detection-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Running Example](#running-example)
- [Model Description](#model-description)
- [Performance](#performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Face Detection Description](#contents)
This is a Face Detection network based on Yolov3, with support for training and evaluation on Ascend910.
You only look once (YOLO) is a state-of-the-art, real-time object detection system. YOLOv3 is extremely fast and accurate.
Prior detection systems repurpose classifiers or localizers to perform detection. They apply the model to an image at multiple locations and scales. High scoring regions of the image are considered detections.
YOLOv3 use a totally different approach. It apply a single neural network to the full image. This network divides the image into regions and predicts bounding boxes and probabilities for each region. These bounding boxes are weighted by the predicted probabilities.
[Paper](https://pjreddie.com/media/files/papers/YOLOv3.pdf): YOLOv3: An Incremental Improvement. Joseph Redmon, Ali Farhadi,
University of Washington
# [Model Architecture](#contents)
Face Detection uses a modified-DarkNet53 network for performing feature extraction. It has 45 convolutional layers.
# [Dataset](#contents)
We use about 13K images as training dataset and 3K as evaluating dataset in this example, and you can also use your own datasets or open source datasets (e.g. WiderFace)
- step 1: The dataset should follow the Pascal VOC data format for object detection. The directory structure is as follows:(Because of the small input shape of network, we remove the face lower than 50*50 at 1080P in evaluating dataset )
```python
.
└─ dataset
├─ Annotations
├─ img1.xml
├─ img2.xml
├─ ...
├─ JPEGImages
├─ img1.jpg
├─ img2.jpg
├─ ...
└─ ImageSets
└─ Main
└─ train.txt or test.txt
```
- step 2: Convert the dataset to mindrecord:
```bash
python data_to_mindrecord_train.py
```
or
```bash
python data_to_mindrecord_eval.py
```
If your dataset is too big to convert at a time, you can add data to an existed mindrecord in turn:
```
python data_to_mindrecord_train_append.py
```
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
The entire code structure is as following:
```python
.
└─ Face Detection
├─ README.md
├─ scripts
├─ run_standalone_train.sh # launch standalone training(1p) in ascend
├─ run_distribute_train.sh # launch distributed training(8p) in ascend
├─ run_eval.sh # launch evaluating in ascend
└─ run_export.sh # launch exporting air model
├─ src
├─ FaceDetection
├─ voc_wrapper.py # get detection results
├─ yolo_loss.py # loss function
├─ yolo_postprocess.py # post process
└─ yolov3.py # network
├─ config.py # parameter configuration
├─ data_preprocess.py # preprocess
├─ logging.py # log function
├─ lrsche_factory.py # generate learning rate
├─ network_define.py # network define
├─ transforms.py # data transforms
├─ data_to_mindrecord_train.py # convert dataset to mindrecord for training
├─ data_to_mindrecord_train_append.py # add dataset to an existed mindrecord for training
└─ data_to_mindrecord_eval.py # convert dataset to mindrecord for evaluating
├─ train.py # training scripts
├─ eval.py # evaluation scripts
└─ export.py # export air model
```
## [Running Example](#contents)
### Train
- Stand alone mode
```bash
cd ./scripts
sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]
```
or (fine-tune)
```bash
cd ./scripts
sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_standalone_train.sh /home/train.mindrecord 0 /home/a.ckpt
```
- Distribute mode (recommended)
```bash
cd ./scripts
sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]
```
or (fine-tune)
```bash
cd ./scripts
sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_distribute_train.sh /home/train.mindrecord ./rank_table_8p.json /home/a.ckpt
```
You will get the loss value of each step as following in "./output/[TIME]/[TIME].log" or "./scripts/device0/train.log":
```python
rank[0], iter[0], loss[318555.8], overflow:False, loss_scale:1024.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6)
rank[0], iter[1], loss[95394.28], overflow:True, loss_scale:1024.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6)
rank[0], iter[2], loss[81332.92], overflow:True, loss_scale:512.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6)
rank[0], iter[3], loss[27250.805], overflow:True, loss_scale:256.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6)
...
rank[0], iter[62496], loss[2218.6282], overflow:False, loss_scale:256.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6)
rank[0], iter[62497], loss[3788.5146], overflow:False, loss_scale:256.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6)
rank[0], iter[62498], loss[3427.5479], overflow:False, loss_scale:256.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6)
rank[0], iter[62499], loss[4294.194], overflow:False, loss_scale:256.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6)
```
### Evaluation
```bash
cd ./scripts
sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_eval.sh /home/eval.mindrecord 0 /home/a.ckpt
```
You will get the result as following in "./scripts/device0/eval.log":
```python
calculate [recall | persicion | ap]...
Saving ../../results/0-2441_61000/.._.._results_0-2441_61000_face_AP_0.760.png
```
And the detect result and P-R graph will also be saved in "./results/[MODEL_NAME]/"
### Convert model
If you want to infer the network on Ascend 310, you should convert the model to AIR:
```bash
cd ./scripts
sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Face Detection |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 13K images |
| Training Parameters | epoch=2500, batch_size=64, momentum=0.5 |
| Optimizer | Momentum |
| Loss Function | Softmax Cross Entropy, Sigmoid Cross Entropy, SmoothL1Loss |
| outputs | boxes and label |
| Speed | 1pc: 800~850 ms/step; 8pcs: 1000~1150 ms/step |
| Total time | 1pc: 120 hours; 8pcs: 18 hours |
| Checkpoint for Fine tuning | 37M (.ckpt file) |
### Evaluation Performance
| Parameters | Face Detection |
| ------------------- | --------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 |
| Uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 3K images |
| batch_size | 1 |
| outputs | mAP |
| Accuracy | 8pcs: 76.0% |
| Model for inference | 37M (.ckpt file) |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,215 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection eval."""
import os
import argparse
import matplotlib.pyplot as plt
from mindspore import context
from mindspore import Tensor
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype
import mindspore.dataset as de
from src.data_preprocess import SingleScaleTrans
from src.config import config
from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
from src.FaceDetection import voc_wrapper
from src.network_define import BuildTestNetwork, get_bounding_boxes, tensor_to_brambox, \
parse_gt_from_anno, parse_rets, calc_recall_presicion_ap
plt.switch_backend('agg')
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
def parse_args():
'''parse_args'''
parser = argparse.ArgumentParser('Yolov3 Face Detection')
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=1, help='current process number to support distributed')
args, _ = parser.parse_known_args()
return args
def val(args):
'''eval'''
print('=============yolov3 start evaluating==================')
# logger
args.batch_size = config.batch_size
args.input_shape = config.input_shape
args.result_path = config.result_path
args.conf_thresh = config.conf_thresh
args.nms_thresh = config.nms_thresh
context.set_auto_parallel_context(parallel_mode=ParallelMode.STAND_ALONE, device_num=args.world_size,
gradients_mean=True)
mindrecord_path = args.mindrecord_path
print('Loading data from {}'.format(mindrecord_path))
num_classes = config.num_classes
if num_classes > 1:
raise NotImplementedError('num_classes > 1: Yolov3 postprocess not implemented!')
anchors = config.anchors
anchors_mask = config.anchors_mask
num_anchors_list = [len(x) for x in anchors_mask]
reduction_0 = 64.0
reduction_1 = 32.0
reduction_2 = 16.0
labels = ['face']
classes = {0: 'face'}
# dataloader
ds = de.MindDataset(mindrecord_path + "0", columns_list=["image", "annotation", "image_name", "image_size"])
single_scale_trans = SingleScaleTrans(resize=args.input_shape)
ds = ds.batch(args.batch_size, per_batch_map=single_scale_trans,
input_columns=["image", "annotation", "image_name", "image_size"], num_parallel_workers=8)
args.steps_per_epoch = ds.get_dataset_size()
# backbone
network = backbone_HwYolov3(num_classes, num_anchors_list, args)
# load pretrain model
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('load model {} success'.format(args.pretrained))
else:
print('load model {} failed, please check the path of model, evaluating end'.format(args.pretrained))
exit(0)
ds = ds.repeat(1)
det = {}
img_size = {}
img_anno = {}
model_name = args.pretrained.split('/')[-1].replace('.ckpt', '')
result_path = os.path.join(args.result_path, model_name)
if os.path.exists(result_path):
pass
if not os.path.isdir(result_path):
os.makedirs(result_path, exist_ok=True)
# result file
ret_files_set = {
'face': os.path.join(result_path, 'comp4_det_test_face_rm5050.txt'),
}
test_net = BuildTestNetwork(network, reduction_0, reduction_1, reduction_2, anchors, anchors_mask, num_classes,
args)
print('conf_thresh:', args.conf_thresh)
eval_times = 0
for data in ds.create_tuple_iterator(output_numpy=True):
batch_images = data[0]
batch_labels = data[1]
batch_image_name = data[2]
batch_image_size = data[3]
eval_times += 1
img_tensor = Tensor(batch_images, mstype.float32)
dets = []
tdets = []
coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2, cls_scores_2 = test_net(img_tensor)
boxes_0, boxes_1, boxes_2 = get_bounding_boxes(coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2,
cls_scores_2, args.conf_thresh, args.input_shape,
num_classes)
converted_boxes_0, converted_boxes_1, converted_boxes_2 = tensor_to_brambox(boxes_0, boxes_1, boxes_2,
args.input_shape, labels)
tdets.append(converted_boxes_0)
tdets.append(converted_boxes_1)
tdets.append(converted_boxes_2)
batch = len(tdets[0])
for b in range(batch):
single_dets = []
for op in range(3):
single_dets.extend(tdets[op][b])
dets.append(single_dets)
det.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(dets)})
img_size.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(batch_image_size)})
img_anno.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(batch_labels)})
print('eval times:', eval_times)
print('batch size: ', args.batch_size)
netw, neth = args.input_shape
reorg_dets = voc_wrapper.reorg_detection(det, netw, neth, img_size)
voc_wrapper.gen_results(reorg_dets, result_path, img_size, args.nms_thresh)
# compute mAP
ground_truth = parse_gt_from_anno(img_anno, classes)
ret_list = parse_rets(ret_files_set)
iou_thr = 0.5
evaluate = calc_recall_presicion_ap(ground_truth, ret_list, iou_thr)
aps_str = ''
for cls in evaluate:
per_line, = plt.plot(evaluate[cls]['recall'], evaluate[cls]['presicion'], 'b-')
per_line.set_label('%s:AP=%.3f' % (cls, evaluate[cls]['ap']))
aps_str += '_%s_AP_%.3f' % (cls, evaluate[cls]['ap'])
plt.plot([i / 1000.0 for i in range(1, 1001)], [i / 1000.0 for i in range(1, 1001)], 'y--')
plt.axis([0, 1.2, 0, 1.2])
plt.xlabel('recall')
plt.ylabel('precision')
plt.grid()
plt.legend()
plt.title('PR')
# save mAP
ap_save_path = os.path.join(result_path, result_path.replace('/', '_') + aps_str + '.png')
print('Saving {}'.format(ap_save_path))
plt.savefig(ap_save_path)
print('=============yolov3 evaluating finished==================')
if __name__ == "__main__":
arg = parse_args()
val(arg)

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert ckpt to air."""
import os
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
def save_air(args):
'''save air'''
print('============= yolov3 start save air ==================')
num_classes = config.num_classes
anchors_mask = config.anchors_mask
num_anchors_list = [len(x) for x in anchors_mask]
network = backbone_HwYolov3(num_classes, num_anchors_list, args)
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('load model {} success'.format(args.pretrained))
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 448, 768)).astype(np.float32)
tensor_input_data = Tensor(input_data)
export(network, tensor_input_data,
file_name=args.pretrained.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR')
print("export model success.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert ckpt to air')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
arg = parser.parse_args()
save_air(arg)

View File

@ -0,0 +1,81 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 -a $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]"
echo " or: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname $(pwd))
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
SCRIPT_NAME='train.py'
rm -rf ${current_exec_path}/device*
ulimit -c unlimited
MINDRECORD_FILE=$(get_real_path $1)
RANK_TABLE=$(get_real_path $2)
PRETRAINED_BACKBONE=''
if [ $# == 3 ]
then
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
fi
echo $MINDRECORD_FILE
echo $RANK_TABLE
echo $PRETRAINED_BACKBONE
export RANK_TABLE_FILE=$RANK_TABLE
export RANK_SIZE=8
echo 'start training'
for((i=0;i<=$RANK_SIZE-1;i++));
do
echo 'start rank '$i
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i
export RANK_ID=$i
dev=`expr $i + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--mindrecord_path=$MINDRECORD_FILE \
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
done
echo 'running'

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname $(pwd))
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='eval.py'
ulimit -c unlimited
MINDRECORD_FILE=$(get_real_path $1)
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
echo $MINDRECORD_FILE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start evaluating'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--mindrecord_path=$MINDRECORD_FILE \
--pretrained=$PRETRAINED_BACKBONE > eval.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname $(pwd))
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='export.py'
ulimit -c unlimited
BATCH_SIZE=$1
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
echo $BATCH_SIZE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start converting'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--batch_size=$BATCH_SIZE \
--pretrained=$PRETRAINED_BACKBONE > convert.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,77 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 -a $# != 3 ]
then
echo "Usage: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
echo " or: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname $(pwd))
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='train.py'
ulimit -c unlimited
MINDRECORD_FILE=$(get_real_path $1)
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=''
if [ $# == 3 ]
then
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
fi
echo $MINDRECORD_FILE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start training'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--world_size=1 \
--mindrecord_path=$MINDRECORD_FILE \
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,126 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection compute final result."""
import os
import numpy as np
def remove_5050_face(dst_txt, img_size):
'''remove_5050_face'''
dst_txt_rm5050 = dst_txt.replace('.txt', '') + '_rm5050.txt'
if os.path.exists(dst_txt_rm5050):
os.remove(dst_txt_rm5050)
write_lines = []
with open(dst_txt, 'r') as file:
lines = file.readlines()
for line in lines:
info = line.replace('\n', '').split(' ')
img_name = info[0]
size = img_size[img_name][0]
w = float(info[4]) - float(info[2])
h = float(info[5]) - float(info[3])
radio = max(float(size[0]) / 1920., float(size[1]) / 1080.)
new_w = float(w) / radio
new_h = float(h) / radio
if min(new_w, new_h) >= 50.:
write_lines.append(line)
file.close()
with open(dst_txt_rm5050, 'a') as fw:
for line in write_lines:
fw.write(line)
def nms(boxes, threshold=0.5):
'''NMS.'''
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
scores = boxes[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
reserved_boxes = []
while order.size > 0:
i = order[0]
reserved_boxes.append(i)
max_x1 = np.maximum(x1[i], x1[order[1:]])
max_y1 = np.maximum(y1[i], y1[order[1:]])
min_x2 = np.minimum(x2[i], x2[order[1:]])
min_y2 = np.minimum(y2[i], y2[order[1:]])
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
intersect_area = intersect_w * intersect_h
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
indexs = np.where(ovr <= threshold)[0]
order = order[indexs + 1]
return reserved_boxes
def gen_results(reorg_dets, results_folder, img_size, nms_thresh=0.45):
'''gen_results'''
for label, pieces in reorg_dets.items():
ret = []
dst_fp = '%s/comp4_det_test_%s.txt' % (results_folder, label)
for name in pieces.keys():
pred = np.array(pieces[name], dtype=np.float32)
keep = nms(pred, nms_thresh)
for ik in keep:
line = '%s %f %s' % (name, pred[ik][-1], ' '.join([str(num) for num in pred[ik][:4]]))
ret.append(line)
with open(dst_fp, 'w') as fd:
fd.write('\n'.join(ret))
remove_5050_face(dst_fp, img_size)
def reorg_detection(dets, netw, neth, img_sizes):
'''reorg_detection'''
reorg_dets = {}
for k, v in dets.items():
name = k
orig_width, orig_height = img_sizes[k][0]
scale = min(float(netw)/orig_width, float(neth)/orig_height)
new_width = orig_width * scale
new_height = orig_height * scale
pad_w = (netw - new_width) / 2.0
pad_h = (neth - new_height) / 2.0
for iv in v:
xmin = iv.x_top_left
ymin = iv.y_top_left
xmax = xmin + iv.width
ymax = ymin + iv.height
conf = iv.confidence
class_label = iv.class_label
xmin = max(0, float(xmin - pad_w)/scale)
xmax = min(orig_width - 1, float(xmax - pad_w)/scale)
ymin = max(0, float(ymin - pad_h)/scale)
ymax = min(orig_height - 1, float(ymax - pad_h)/scale)
reorg_dets.setdefault(class_label, {})
reorg_dets[class_label].setdefault(name, [])
piece = (xmin, ymin, xmax, ymax, conf)
reorg_dets[class_label][name].append(piece)
return reorg_dets

View File

@ -0,0 +1,278 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection loss."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss
from mindspore.nn import Cell
from mindspore import Tensor
from mindspore.common import dtype as mstype
class PtLinspace(Cell):
'''PtLinspace'''
def __init__(self):
super(PtLinspace, self).__init__()
self.tuple_to_array = P.TupleToArray()
def construct(self, start, end, steps):
lin_x = ()
step = (end - start + 1) / steps
for i in range(start, end + 1, step):
lin_x += (i,)
lin_x = self.tuple_to_array(lin_x)
return lin_x
class MSELoss(_Loss):
'''MSELoss'''
def __init__(self):
super(MSELoss, self).__init__()
self.sum = P.Sum()
self.mean = P.ReduceMean(keepdims=False)
self.pow = P.Pow()
self.sqrt = P.Sqrt()
def construct(self, nembeddings1, nembeddings2):
dist = nembeddings1 - nembeddings2
dist_pow = self.pow(dist, 2.0)
dist_sum = self.sum(dist_pow, 1)
dist_sqrt = self.sqrt(dist_sum)
loss = self.mean(dist_sqrt, 0)
return loss
class YoloLoss(Cell):
""" Computes yolo loss from darknet network output and target annotation.
Args:
num_classes (int): number of categories
anchors (list): 2D list representing anchor boxes
coord_scale (float): weight of bounding box coordinates
no_object_scale (float): weight of regions without target boxes
object_scale (float): weight of regions with target boxes
class_scale (float): weight of categorical predictions
thresh (float): minimum iou between a predicted box and ground truth for them to be considered matching
seen (int): How many images the network has already been trained on.
"""
def __init__(self, num_classes, anchors, anchors_mask, reduction=32, seen=0, coord_scale=1.0, no_object_scale=1.0,
object_scale=1.0, class_scale=1.0, thresh=0.5, head_idx=0.0):
super(YoloLoss, self).__init__()
self.num_classes = num_classes
self.num_anchors = len(anchors_mask)
self.anchor_step = len(anchors[0]) # each scale has step anchors
self.anchors = np.array(anchors, dtype=np.float32) / reduction # scale every anchor for every scale
self.tensor_anchors = Tensor(self.anchors, mstype.float32)
self.anchors_mask = anchors_mask
anchors_w = []
anchors_h = []
for i in range(len(anchors_mask)):
anchors_w.append(self.anchors[self.anchors_mask[i]][0])
anchors_h.append(self.anchors[self.anchors_mask[i]][1])
self.anchors_w = Tensor(np.array(anchors_w).reshape(len(self.anchors_mask), 1))
self.anchors_h = Tensor(np.array(anchors_h).reshape(len(self.anchors_mask), 1))
self.reduction = reduction
self.seen = seen
self.head_idx = head_idx
self.zero = Tensor(0)
self.coord_scale = coord_scale
self.no_object_scale = no_object_scale
self.object_scale = object_scale
self.class_scale = class_scale
self.thresh = thresh
self.info = {'avg_iou': 0, 'class': 0, 'obj': 0, 'no_obj': 0,
'recall50': 0, 'recall75': 0, 'obj_cur': 0, 'obj_all': 0,
'coord_xy': 0, 'coord_wh': 0}
self.shape = P.Shape()
self.reshape = P.Reshape()
self.sigmoid = P.Sigmoid()
self.zeros_like = P.ZerosLike()
self.concat0 = P.Concat(0)
self.concat0_2 = P.Concat(0)
self.concat0_3 = P.Concat(0)
self.concat0_4 = P.Concat(0)
self.concat1 = P.Concat(1)
self.concat1_2 = P.Concat(1)
self.concat1_3 = P.Concat(1)
self.concat1_4 = P.Concat(1)
self.concat2 = P.Concat(2)
self.concat2_2 = P.Concat(2)
self.concat2_3 = P.Concat(2)
self.concat2_4 = P.Concat(2)
self.tile = P.Tile()
self.transpose = P.Transpose()
self.cast = P.Cast()
self.exp = P.Exp()
self.sum = P.ReduceSum()
self.smooth_l1_loss = P.SmoothL1Loss()
self.bce = P.SigmoidCrossEntropyWithLogits()
self.ce = P.SoftmaxCrossEntropyWithLogits()
self.pt_linspace = PtLinspace()
self.one_hot = nn.OneHot(-1, self.num_classes, 1.0, 0.0)
self.squeeze_2 = P.Squeeze(2)
self.reduce_sum = P.ReduceSum()
self.select = P.Select()
self.iou = P.IOU()
def construct(self, output, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list):
"""
Compute Yolo loss.
"""
output_d = self.shape(output)
num_batch = output_d[0]
num_anchors = self.num_anchors
num_classes = self.num_classes
num_channels = output_d[1] / num_anchors
height = output_d[2]
width = output_d[3]
output = self.reshape(output, (num_batch, num_anchors, num_channels, height * width))
coord_01 = output[:, :, :2] # tx,ty
coord_23 = output[:, :, 2:4] # tw,th
coord = self.concat2((coord_01, coord_23))
conf = self.squeeze_2(output[:, :, 4:5, :])
cls = output[:, :, 5:]
cls = self.reshape(cls, (num_batch*num_anchors, num_classes, height*width))
perm = (0, 2, 1)
cls = self.transpose(cls, perm)
cls_shp = self.shape(cls)
cls = self.reshape(cls, (cls_shp[0] * cls_shp[1] * cls_shp[2] / num_classes, num_classes))
lin_x = self.pt_linspace(0, width - 1, width)
lin_x = self.tile(lin_x, (height,))
lin_x = self.cast(lin_x, mstype.float32)
lin_y = self.pt_linspace(0, height - 1, height)
lin_y = self.reshape(lin_y, (height, 1))
lin_y = self.tile(lin_y, (1, width))
lin_y = self.reshape(lin_y, (self.shape(lin_y)[0] * self.shape(lin_y)[1],))
lin_y = self.cast(lin_y, mstype.float32)
anchor_w = self.anchors_w
anchor_h = self.anchors_h
anchor_w = self.cast(anchor_w, mstype.float32)
anchor_h = self.cast(anchor_h, mstype.float32)
coord_x = self.sigmoid(coord[:, :, 0:1, :])
pred_boxes_0 = self.squeeze_2(coord_x) + lin_x
shape_pb0 = self.shape(pred_boxes_0)
pred_boxes_0 = self.reshape(pred_boxes_0, (shape_pb0[0] * shape_pb0[1] * shape_pb0[2], 1))
coord_y = self.sigmoid(coord[:, :, 1:2, :])
pred_boxes_1 = self.squeeze_2(coord_y) + lin_y
shape_pb1 = self.shape(pred_boxes_1)
pred_boxes_1 = self.reshape(pred_boxes_1, (shape_pb1[0] * shape_pb1[1] * shape_pb1[2], 1))
pred_boxes_2 = self.exp(self.squeeze_2(coord[:, :, 2:3, :])) * anchor_w
shape_pb2 = self.shape(pred_boxes_2)
pred_boxes_2 = self.reshape(pred_boxes_2, (shape_pb2[0] * shape_pb2[1] * shape_pb2[2], 1))
pred_boxes_3 = self.exp(self.squeeze_2(coord[:, :, 3:4, :])) * anchor_h
shape_pb3 = self.shape(pred_boxes_3)
pred_boxes_3 = self.reshape(pred_boxes_3, (shape_pb3[0] * shape_pb3[1] * shape_pb3[2], 1))
pred_boxes_x1 = pred_boxes_0 - pred_boxes_2 / 2
pred_boxes_y1 = pred_boxes_1 - pred_boxes_3 / 2
pred_boxes_x2 = pred_boxes_0 + pred_boxes_2 / 2
pred_boxes_y2 = pred_boxes_1 + pred_boxes_3 / 2
pred_boxes_points = self.concat1_4((pred_boxes_x1, pred_boxes_y1, pred_boxes_x2, pred_boxes_y2))
total_anchors = num_anchors * height * width
mask_concat = None
conf_neg_mask_zero = self.zeros_like(conf_neg_mask)
pred_boxes_points = pred_boxes_points * 64
gt_list = gt_list * 64
for b in range(num_batch):
cur_pred_boxes = pred_boxes_points[b * total_anchors:(b + 1) * total_anchors]
iou_gt_pred = self.iou(self.cast(cur_pred_boxes, mstype.float16), self.cast(gt_list[b], mstype.float16))
mask = self.cast((iou_gt_pred > self.thresh), mstype.float16)
mask = self.reduce_sum(mask, 0)
mask = mask > 0
shape_neg = self.shape(conf_neg_mask[0])
mask = self.reshape(mask, (1, shape_neg[0], shape_neg[1]))
if b == 0:
mask_concat = mask
else:
mask_concat = self.concat0_2((mask_concat, mask))
conf_neg_mask = self.select(mask_concat, conf_neg_mask_zero, conf_neg_mask)
coord_mask = self.tile(coord_mask, (1, 1, 4, 1))
coord_mask = coord_mask[:, :, :2]
coord_center = coord[:, :, :2]
t_coord_center = t_coord[:, :, :2]
coord_wh = coord[:, :, 2:]
t_coord_wh = t_coord[:, :, 2:]
one_hot_label = None
shape_cls_mask = None
if num_classes > 1:
shape_t_cls = self.shape(t_cls)
t_cls = self.reshape(t_cls, (shape_t_cls[0] * shape_t_cls[1] * shape_t_cls[2],))
one_hot_label = self.one_hot(self.cast(t_cls, mstype.int32))
shape_cls_mask = self.shape(cls_mask)
cls_mask = self.reshape(cls_mask, (1, shape_cls_mask[0] * shape_cls_mask[1] * shape_cls_mask[2]))
added_scale = 1.0 + self.head_idx * 0.5
loss_coord_center = added_scale * 2.0 * 1.0 * self.coord_scale * self.sum(
coord_mask * self.bce(coord_center, t_coord_center), ())
loss_coord_wh = added_scale * 2.0 * 1.5 * self.coord_scale * self.sum(
coord_mask * self.smooth_l1_loss(coord_wh, t_coord_wh), ())
loss_coord = 1.0 * (loss_coord_center + loss_coord_wh)
loss_conf_pos = added_scale * 2.0 * self.object_scale * self.sum(conf_pos_mask * self.bce(conf, t_conf), ())
loss_conf_neg = 1.0 * self.no_object_scale * self.sum(conf_neg_mask * self.bce(conf, t_conf), ())
loss_conf = loss_conf_pos + loss_conf_neg
loss_cls = None
if num_classes > 1:
loss_cls = self.class_scale * 1.0 * self.sum(cls_mask * self.ce(cls, one_hot_label)[0], ())
else:
loss_cls = 0.0
cls = self.squeeze_2(output[:, :, 5:6, :])
loss_cls_pos = added_scale * 2.0 * self.object_scale * self.sum(conf_pos_mask * self.bce(cls, t_conf), ())
loss_cls_neg = 1.0 * self.no_object_scale * self.sum(conf_neg_mask * self.bce(cls, t_conf), ())
loss_cls = loss_cls_pos + loss_cls_neg
loss_tot = loss_coord + 0.5 * loss_conf + 0.5 * loss_cls
return loss_tot

View File

@ -0,0 +1,125 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection yolov3 post-process."""
import numpy as np
from mindspore.ops import operations as P
from mindspore.nn import Cell
from mindspore import Tensor
from mindspore.common import dtype as mstype
class PtLinspace(Cell):
'''PtLinspace'''
def __init__(self):
super(PtLinspace, self).__init__()
self.tuple_to_array = P.TupleToArray()
def construct(self, start, end, steps):
lin_x = ()
step = (end - start + 1) / steps
for i in range(start, end + 1, step):
lin_x += (i,)
lin_x = self.tuple_to_array(lin_x)
return lin_x
class YoloPostProcess(Cell):
"""
Yolov3 post-process of network output.
"""
def __init__(self, num_classes, cur_anchors, conf_thresh, network_size, reduction, anchors_mask):
super(YoloPostProcess, self).__init__()
self.print = P.Print()
self.num_classes = num_classes
self.anchors = cur_anchors
self.conf_thresh = conf_thresh
self.network_size = network_size
self.reduction = reduction
self.anchors_mask = anchors_mask
self.num_anchors = len(anchors_mask)
anchors_w = []
anchors_h = []
for i in range(len(self.anchors_mask)):
anchors_w.append(self.anchors[i][0])
anchors_h.append(self.anchors[i][1])
self.anchors_w = Tensor(np.array(anchors_w).reshape((1, len(self.anchors_mask), 1)))
self.anchors_h = Tensor(np.array(anchors_h).reshape((1, len(self.anchors_mask), 1)))
self.shape = P.Shape()
self.reshape = P.Reshape()
self.sigmoid = P.Sigmoid()
self.cast = P.Cast()
self.exp = P.Exp()
self.concat3 = P.Concat(3)
self.tile = P.Tile()
self.expand_dims = P.ExpandDims()
self.pt_linspace = PtLinspace()
def construct(self, output):
'''construct'''
output_d = self.shape(output)
num_batch = output_d[0]
num_anchors = self.num_anchors
num_channels = output_d[1] / num_anchors
height = output_d[2]
width = output_d[3]
lin_x = self.pt_linspace(0, width - 1, width)
lin_x = self.tile(lin_x, (height,))
lin_x = self.cast(lin_x, mstype.float32)
lin_y = self.pt_linspace(0, height - 1, height)
lin_y = self.reshape(lin_y, (height, 1))
lin_y = self.tile(lin_y, (1, width))
lin_y = self.reshape(lin_y, (self.shape(lin_y)[0] * self.shape(lin_y)[1],))
lin_y = self.cast(lin_y, mstype.float32)
anchor_w = self.anchors_w
anchor_h = self.anchors_h
anchor_w = self.cast(anchor_w, mstype.float32)
anchor_h = self.cast(anchor_h, mstype.float32)
output = self.reshape(output, (num_batch, num_anchors, num_channels, height * width))
coord_x = (self.sigmoid(output[:, :, 0, :]) + lin_x) / width
coord_y = (self.sigmoid(output[:, :, 1, :]) + lin_y) / height
coord_w = self.exp(output[:, :, 2, :]) * anchor_w / width
coord_h = self.exp(output[:, :, 3, :]) * anchor_h / height
obj_conf = self.sigmoid(output[:, :, 4, :])
cls_conf = 0.0
if self.num_classes > 1:
# num_classes > 1: not implemented!
pass
else:
cls_conf = self.sigmoid(output[:, :, 4, :])
cls_scores = obj_conf * cls_conf
coord_x_t = self.expand_dims(coord_x, 3)
coord_y_t = self.expand_dims(coord_y, 3)
coord_w_t = self.expand_dims(coord_w, 3)
coord_h_t = self.expand_dims(coord_h, 3)
coord_1 = self.concat3((coord_x_t, coord_y_t))
coord_2 = self.concat3((coord_w_t, coord_h_t))
coords = self.concat3((coord_1, coord_2))
return coords, cls_scores

View File

@ -0,0 +1,267 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection yolov3 backbone."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.nn import Cell
class Conv2dBatchReLU(Cell):
'''Conv2dBatchReLU'''
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(Conv2dBatchReLU, self).__init__()
# Parameters
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
if isinstance(kernel_size, (list, tuple)):
self.padding = [int(ii / 2) for ii in kernel_size]
else:
self.padding = int(kernel_size / 2)
self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, has_bias=False,
pad_mode='pad', padding=self.padding)
self.bn = nn.BatchNorm2d(self.out_channels, momentum=0.9, eps=1e-5)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv2dBatch(Cell):
'''Conv2dBatch'''
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(Conv2dBatch, self).__init__()
# Parameters
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, has_bias=False,
pad_mode='pad', padding=self.padding)
self.bn = nn.BatchNorm2d(self.out_channels, momentum=0.9, eps=1e-5)
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class MakeYoloLayer(Cell):
'''MakeYoloLayer'''
def __init__(self, layer):
super(MakeYoloLayer, self).__init__()
self.layers = []
for x in layer:
if len(x) == 4:
self.layers.append(Conv2dBatchReLU(x[0], x[1], x[2], x[3]))
else:
self.layers.append(Conv2dBatch(x[0], x[1], x[2], x[3], x[4]))
self.layers = nn.CellList(self.layers)
def construct(self, x):
for block in self.layers:
x = block(x)
return x
class UpsampleLayer(Cell):
def __init__(self, factor):
super(UpsampleLayer, self).__init__()
self.upsample = P.Upsample(factor)
def construct(self, x):
x = self.upsample(x)
return x
class HwYolov3(Cell):
'''HwYolov3'''
def __init__(self, num_classes, num_anchors_list, args):
layer_index = {
# backbone
'0_conv_batch_relu': [3, 16],
'1_conv_batch_relu': [16, 32],
'2_conv_batch_relu': [32, 64],
'3_conv_batch_relu': [64, 64],
'4_conv_batch_relu': [64, 64],
'5_conv_batch_relu': [64, 128],
'6_conv_batch': [128, 64],
'7_conv_batch': [64, 64],
'8_conv_batch_relu': [64, 128],
'9_conv_batch': [128, 64],
'10_conv_batch': [64, 64],
'11_conv_batch_relu': [64, 128],
'12_conv_batch': [128, 128],
'13_conv_batch_relu': [128, 256],
'14_conv_batch': [256, 144],
'15_conv_batch': [144, 128],
'16_conv_batch_relu': [128, 256],
'17_conv_batch': [256, 128],
'18_conv_batch': [128, 128],
'19_conv_batch_relu': [128, 256],
'20_conv_batch': [256, 144],
'21_conv_batch': [144, 256],
'22_conv_batch_relu': [256, 512],
'23_conv_batch': [512, 256],
'24_conv_batch': [256, 256],
'25_conv_batch_relu': [256, 512],
'26_conv_batch': [512, 256],
'27_conv_batch': [256, 256],
'28_conv_batch_relu': [256, 512],
'30_deconv_up': [512, 64],
'31_conv_batch': [320, 160],
'32_conv_batch_relu': [160, 96],
'33_conv_batch_relu': [96, 96],
'34_conv_batch_relu': [96, 96],
'35_conv_batch': [96, 80],
'36_conv_batch_relu': [80, 128],
'37_conv_batch': [128, 96],
'38_conv_batch': [96, 128],
'39_conv_batch_relu': [128, 256],
'41_deconv_up': [256, 64],
'42_conv_batch_relu': [192, 64],
'43_conv_batch_relu': [64, 64],
'44_conv_batch_relu': [64, 64],
'45_conv_batch_relu': [64, 64],
'46_conv_batch_relu': [64, 96],
'47_conv_batch': [96, 64],
'48_conv_batch_relu': [64, 128],
# head
'29_conv': [512],
'40_conv': [256],
'49_conv': [128]
}
super(HwYolov3, self).__init__()
layer0 = [
(layer_index['0_conv_batch_relu'][0], layer_index['0_conv_batch_relu'][1], 3, 2),
(layer_index['1_conv_batch_relu'][0], layer_index['1_conv_batch_relu'][1], 3, 2),
(layer_index['2_conv_batch_relu'][0], layer_index['2_conv_batch_relu'][1], 3, 2),
(layer_index['3_conv_batch_relu'][0], layer_index['3_conv_batch_relu'][1], 3, 1),
(layer_index['4_conv_batch_relu'][0], layer_index['4_conv_batch_relu'][1], 3, 1),
]
layer1 = [
(layer_index['5_conv_batch_relu'][0], layer_index['5_conv_batch_relu'][1], 3, 2),
(layer_index['6_conv_batch'][0], layer_index['6_conv_batch'][1], 1, 1, 0),
(layer_index['7_conv_batch'][0], layer_index['7_conv_batch'][1], 3, 1, 1),
(layer_index['8_conv_batch_relu'][0], layer_index['8_conv_batch_relu'][1], 1, 1),
(layer_index['9_conv_batch'][0], layer_index['9_conv_batch'][1], 1, 1, 0),
(layer_index['10_conv_batch'][0], layer_index['10_conv_batch'][1], 3, 1, 1),
(layer_index['11_conv_batch_relu'][0], layer_index['11_conv_batch_relu'][1], 1, 1),
]
layer2 = [
(layer_index['12_conv_batch'][0], layer_index['12_conv_batch'][1], 3, 2, 1),
(layer_index['13_conv_batch_relu'][0], layer_index['13_conv_batch_relu'][1], 1, 1),
(layer_index['14_conv_batch'][0], layer_index['14_conv_batch'][1], 1, 1, 0),
(layer_index['15_conv_batch'][0], layer_index['15_conv_batch'][1], 3, 1, 1),
(layer_index['16_conv_batch_relu'][0], layer_index['16_conv_batch_relu'][1], 1, 1),
(layer_index['17_conv_batch'][0], layer_index['17_conv_batch'][1], 1, 1, 0),
(layer_index['18_conv_batch'][0], layer_index['18_conv_batch'][1], 3, 1, 1),
(layer_index['19_conv_batch_relu'][0], layer_index['19_conv_batch_relu'][1], 1, 1),
]
layer3 = [
(layer_index['20_conv_batch'][0], layer_index['20_conv_batch'][1], 1, 1, 0),
(layer_index['21_conv_batch'][0], layer_index['21_conv_batch'][1], 3, 2, 1),
(layer_index['22_conv_batch_relu'][0], layer_index['22_conv_batch_relu'][1], 1, 1),
(layer_index['23_conv_batch'][0], layer_index['23_conv_batch'][1], 1, 1, 0),
(layer_index['24_conv_batch'][0], layer_index['24_conv_batch'][1], 3, 1, 1),
(layer_index['25_conv_batch_relu'][0], layer_index['25_conv_batch_relu'][1], 1, 1),
(layer_index['26_conv_batch'][0], layer_index['26_conv_batch'][1], 1, 1, 0),
(layer_index['27_conv_batch'][0], layer_index['27_conv_batch'][1], 3, 1, 1),
(layer_index['28_conv_batch_relu'][0], layer_index['28_conv_batch_relu'][1], 1, 1),
]
layer4 = [
(layer_index['30_deconv_up'][0], layer_index['30_deconv_up'][1], 4, 2, 1),
]
layer5 = [
(layer_index['31_conv_batch'][0], layer_index['31_conv_batch'][1], 1, 1, 0),
(layer_index['32_conv_batch_relu'][0], layer_index['32_conv_batch_relu'][1], 3, 1),
(layer_index['33_conv_batch_relu'][0], layer_index['33_conv_batch_relu'][1], 3, 1),
(layer_index['34_conv_batch_relu'][0], layer_index['34_conv_batch_relu'][1], 3, 1),
(layer_index['35_conv_batch'][0], layer_index['35_conv_batch'][1], 1, 1, 0),
(layer_index['36_conv_batch_relu'][0], layer_index['36_conv_batch_relu'][1], 3, 1),
(layer_index['37_conv_batch'][0], layer_index['37_conv_batch'][1], 1, 1, 0),
(layer_index['38_conv_batch'][0], layer_index['38_conv_batch'][1], 3, 1, 1),
(layer_index['39_conv_batch_relu'][0], layer_index['39_conv_batch_relu'][1], 1, 1),
]
layer6 = [
(layer_index['41_deconv_up'][0], layer_index['41_deconv_up'][1], 4, 2, 1),
]
layer7 = [
(layer_index['42_conv_batch_relu'][0], layer_index['42_conv_batch_relu'][1], 1, 1),
(layer_index['43_conv_batch_relu'][0], layer_index['43_conv_batch_relu'][1], 3, 1),
(layer_index['44_conv_batch_relu'][0], layer_index['44_conv_batch_relu'][1], 3, 1),
(layer_index['45_conv_batch_relu'][0], layer_index['45_conv_batch_relu'][1], 3, 1),
(layer_index['46_conv_batch_relu'][0], layer_index['46_conv_batch_relu'][1], 3, 1),
(layer_index['47_conv_batch'][0], layer_index['47_conv_batch'][1], 3, 1, 1),
(layer_index['48_conv_batch_relu'][0], layer_index['48_conv_batch_relu'][1], 1, 1),
]
self.layer0 = MakeYoloLayer(layer0)
self.layer1 = MakeYoloLayer(layer1)
self.layer2 = MakeYoloLayer(layer2)
self.layer3 = MakeYoloLayer(layer3)
self.layer4 = nn.Conv2dTranspose(layer4[0][0], layer4[0][1], layer4[0][2], layer4[0][3], pad_mode='pad',
padding=layer4[0][4], has_bias=True)
self.args = args
self.concat = P.Concat(1)
self.layer5 = MakeYoloLayer(layer5)
self.layer6 = nn.Conv2dTranspose(layer6[0][0], layer6[0][1], layer6[0][2], layer6[0][3], pad_mode='pad',
padding=layer6[0][4], has_bias=True)
self.layer7 = MakeYoloLayer(layer7)
self.head1_conv = nn.Conv2d(layer_index['29_conv'][0], num_anchors_list[0]*(4 + 1 + num_classes), 1, 1,
has_bias=True)
self.head2_conv = nn.Conv2d(layer_index['40_conv'][0], num_anchors_list[1]*(4 + 1 + num_classes), 1, 1,
has_bias=True)
self.head3_conv = nn.Conv2d(layer_index['49_conv'][0], num_anchors_list[2]*(4 + 1 + num_classes), 1, 1,
has_bias=True)
self.relu = nn.ReLU()
def construct(self, x):
'''construct'''
stem = self.layer0(x)
stage4 = self.layer1(stem)
stage5 = self.layer2(stage4)
stage6_det1 = self.layer3(stage5)
upsample1 = self.layer4(stage6_det1)
upsample1_relu = self.relu(upsample1)
concat_s5_s6 = self.concat((upsample1_relu, stage5))
stage7_det2 = self.layer5(concat_s5_s6)
upsample2 = self.layer6(stage7_det2)
upsample2_relu = self.relu(upsample2)
concat_s4_s7 = self.concat((upsample2_relu, stage4))
stage8_det3 = self.layer7(concat_s4_s7)
det1 = self.head1_conv(stage6_det1)
det2 = self.head2_conv(stage7_det2)
det3 = self.head3_conv(stage8_det3)
return det1, det2, det3

View File

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Network config setting, will be used in train.py and eval.py"""
from easydict import EasyDict as ed
config = ed({
'batch_size': 64,
'warmup_lr': 0.0004,
'lr_rates': [0.002, 0.004, 0.002, 0.0008, 0.0004, 0.0002, 0.00008, 0.00004, 0.000004],
'lr_steps': [1000, 10000, 40000, 60000, 80000, 100000, 130000, 160000, 190000],
'gamma': 0.5,
'weight_decay': 0.0005,
'momentum': 0.5,
'max_epoch': 2500,
'log_interval': 10,
'ckpt_path': '../../output',
'ckpt_interval': 1000,
'result_path': '../../results',
'input_shape': [768, 448],
'jitter': 0.3,
'flip': 0.5,
'hue': 0.1,
'sat': 1.5,
'val': 1.5,
'num_classes': 1,
'anchors': [
[3, 4],
[5, 6],
[7, 9],
[10, 13],
[15, 19],
[21, 26],
[28, 36],
[38, 49],
[54, 71],
[77, 102],
[122, 162],
[207, 268],
],
'anchors_mask': [(8, 9, 10, 11), (4, 5, 6, 7), (0, 1, 2, 3)],
'conf_thresh': 0.1,
'nms_thresh': 0.45,
})

View File

@ -0,0 +1,244 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection yolov3 data pre-process."""
import numpy as np
import mindspore.dataset.vision.py_transforms as P
from src.transforms import RandomCropLetterbox, RandomFlip, HSVShift, ResizeLetterbox
from src.config import config
class SingleScaleTrans:
'''SingleScaleTrans'''
def __init__(self, resize, max_anno_count=200):
self.resize = (resize[0], resize[1])
self.max_anno_count = max_anno_count
def __call__(self, imgs, ann, image_names, image_size, batch_info):
size = self.resize
decode = P.Decode()
resize_letter_box_op = ResizeLetterbox(input_dim=size)
to_tensor = P.ToTensor()
ret_imgs = []
ret_anno = []
for i, image in enumerate(imgs):
img_pil = decode(image)
input_data = img_pil, ann[i]
input_data = resize_letter_box_op(*input_data)
image_arr = to_tensor(input_data[0])
ret_imgs.append(image_arr)
ret_anno.append(input_data[1])
for i, anno in enumerate(ret_anno):
anno_count = anno.shape[0]
if anno_count < self.max_anno_count:
ret_anno[i] = np.concatenate(
(ret_anno[i], np.zeros((self.max_anno_count - anno_count, 6), dtype=float)), axis=0)
else:
ret_anno[i] = ret_anno[i][:self.max_anno_count]
return np.array(ret_imgs), np.array(ret_anno), image_names, image_size
def check_gt_negative_or_empty(gt):
new_gt = []
for anno in gt:
for data in anno:
if data not in (0, -1):
new_gt.append(anno)
break
if not new_gt:
return True, new_gt
return False, new_gt
def bbox_ious_numpy(boxes1, boxes2):
""" Compute IOU between all boxes from ``boxes1`` with all boxes from ``boxes2``.
Args:
boxes1 (np.array): List of bounding boxes
boxes2 (np.array): List of bounding boxes
Note:
List format: [[xc, yc, w, h],...]
"""
b1x1, b1y1 = np.split((boxes1[:, :2] - (boxes1[:, 2:4] / 2)), 2, axis=1)
b1x2, b1y2 = np.split((boxes1[:, :2] + (boxes1[:, 2:4] / 2)), 2, axis=1)
b2x1, b2y1 = np.split((boxes2[:, :2] - (boxes2[:, 2:4] / 2)), 2, axis=1)
b2x2, b2y2 = np.split((boxes2[:, :2] + (boxes2[:, 2:4] / 2)), 2, axis=1)
dx = np.minimum(b1x2, b2x2.transpose()) - np.maximum(b1x1, b2x1.transpose())
dx = np.maximum(dx, 0)
dy = np.minimum(b1y2, b2y2.transpose()) - np.maximum(b1y1, b2y1.transpose())
dy = np.maximum(dy, 0)
intersections = dx * dy
areas1 = (b1x2 - b1x1) * (b1y2 - b1y1)
areas2 = (b2x2 - b2x1) * (b2y2 - b2y1)
unions = (areas1 + areas2.transpose()) - intersections
return intersections / unions
def build_targets_brambox(img, anno, reduction, img_shape_para, anchors_mask, anchors):
"""
Compare prediction boxes and ground truths, convert ground truths to network output tensors
"""
ground_truth = anno
img_shape = img.shape
n_h = int(img_shape[1] / img_shape_para) # height
n_w = int(img_shape[2] / img_shape_para) # width
anchors_ori = np.array(anchors) / reduction
num_anchor = len(anchors_mask)
conf_pos_mask = np.zeros((num_anchor, n_h * n_w), dtype=np.float32) # pos mask
conf_neg_mask = np.ones((num_anchor, n_h * n_w), dtype=np.float32) # neg mask
# coordination mask and classification mask
coord_mask = np.zeros((num_anchor, 1, n_h * n_w), dtype=np.float32) # coord mask
cls_mask = np.zeros((num_anchor, n_h * n_w), dtype=np.int)
# for target coordination confidence classification
t_coord = np.zeros((num_anchor, 4, n_h * n_w), dtype=np.float32)
t_conf = np.zeros((num_anchor, n_h * n_w), dtype=np.float32)
t_cls = np.zeros((num_anchor, n_h * n_w), dtype=np.float32)
gt_list = None
is_empty_or_negative, filtered_ground_truth = check_gt_negative_or_empty(ground_truth)
if is_empty_or_negative:
gt_np = np.zeros((len(ground_truth), 4), dtype=np.float32)
gt_temp = gt_np[:]
gt_list = gt_temp
# continue
return coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list
# Build up tensors
anchors = np.concatenate([np.zeros_like(anchors_ori), anchors_ori], axis=1)
gt = np.zeros((len(filtered_ground_truth), 4), dtype=np.float32)
gt_np = np.zeros((len(ground_truth), 4), dtype=np.float32)
for i, annotation in enumerate(filtered_ground_truth):
# gt x y x h->x_c y_c w h
# reduction for remap the gt to the feature
gt[i, 0] = (annotation[1] + annotation[3] / 2) / reduction
gt[i, 1] = (annotation[2] + annotation[4] / 2) / reduction
gt[i, 2] = annotation[3] / reduction
gt[i, 3] = annotation[4] / reduction
gt_np[i, 0] = annotation[1] / reduction
gt_np[i, 1] = annotation[2] / reduction
gt_np[i, 2] = (annotation[1] + annotation[3]) / reduction
gt_np[i, 3] = (annotation[2] + annotation[4]) / reduction
gt_temp = gt_np[:]
gt_list = gt_temp
# Find best anchor for each gt
gt_wh = np.copy(gt)
gt_wh[:, :2] = 0
iou_gt_anchors = bbox_ious_numpy(gt_wh, anchors)
best_anchors = np.argmax(iou_gt_anchors, axis=1)
# Set masks and target values for each gt
for i, annotation in enumerate(filtered_ground_truth):
annotation_ignore = annotation[5]
annotation_width = annotation[3]
annotation_height = annotation[4]
annotation_class_id = annotation[0]
gi = min(n_w - 1, max(0, int(gt[i, 0])))
gj = min(n_h - 1, max(0, int(gt[i, 1])))
cur_n = best_anchors[i] # best anchors for current ground truth
if cur_n in anchors_mask:
best_n = np.where(np.array(anchors_mask) == cur_n)[0][0]
else:
continue
if annotation_ignore:
# current annotation is ignore for difficult
conf_pos_mask[best_n][gj * n_w + gi] = 0
conf_neg_mask[best_n][gj * n_w + gi] = 0
else:
coord_mask[best_n][0][gj * n_w + gi] = 2 - annotation_width * annotation_height / \
(n_w * n_h * reduction * reduction)
cls_mask[best_n][gj * n_w + gi] = 1
conf_pos_mask[best_n][gj * n_w + gi] = 1
conf_neg_mask[best_n][gj * n_w + gi] = 0
t_coord[best_n][0][gj * n_w + gi] = gt[i, 0] - gi
t_coord[best_n][1][gj * n_w + gi] = gt[i, 1] - gj
t_coord[best_n][2][gj * n_w + gi] = np.log(gt[i, 2] / anchors[cur_n, 2])
t_coord[best_n][3][gj * n_w + gi] = np.log(gt[i, 3] / anchors[cur_n, 3])
t_conf[best_n][gj * n_w + gi] = 1
t_cls[best_n][gj * n_w + gi] = annotation_class_id
return coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list
def preprocess_fn(image, annotation):
'''preprocess_fn'''
jitter = config.jitter
flip = config.flip
hue = config.hue
sat = config.sat
val = config.val
size = config.input_shape
max_anno_count = 200
reduction_0 = 64.0
reduction_1 = 32.0
reduction_2 = 16.0
anchors = config.anchors
anchors_mask = config.anchors_mask
decode = P.Decode()
random_crop_letter_box_op = RandomCropLetterbox(jitter=jitter, input_dim=size)
random_flip_op = RandomFlip(flip)
hsv_shift_op = HSVShift(hue, sat, val)
to_tensor = P.ToTensor()
img_pil = decode(image)
input_data = img_pil, annotation
input_data = random_crop_letter_box_op(*input_data)
input_data = random_flip_op(*input_data)
input_data = hsv_shift_op(*input_data)
image_arr = to_tensor(input_data[0])
ret_img = image_arr
ret_anno = input_data[1]
anno_count = ret_anno.shape[0]
if anno_count < max_anno_count:
ret_anno = np.concatenate((ret_anno, np.zeros((max_anno_count - anno_count, 6), dtype=float)), axis=0)
else:
ret_anno = ret_anno[:max_anno_count]
ret_img = np.array(ret_img)
ret_anno = np.array(ret_anno)
coord_mask_0, conf_pos_mask_0, conf_neg_mask_0, cls_mask_0, t_coord_0, t_conf_0, t_cls_0, gt_list_0 = \
build_targets_brambox(ret_img, ret_anno, reduction_0, int(reduction_0), anchors_mask[0], anchors)
coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1, t_cls_1, gt_list_1 = \
build_targets_brambox(ret_img, ret_anno, reduction_1, int(reduction_1), anchors_mask[1], anchors)
coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2, t_cls_2, gt_list_2 = \
build_targets_brambox(ret_img, ret_anno, reduction_2, int(reduction_2), anchors_mask[2], anchors)
return ret_img, ret_anno, coord_mask_0, conf_pos_mask_0, conf_neg_mask_0, cls_mask_0, t_coord_0, t_conf_0,\
t_cls_0, gt_list_0, coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1,\
t_cls_1, gt_list_1, coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2, \
t_cls_2, gt_list_2
compose_map_func = (preprocess_fn)

View File

@ -0,0 +1,174 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert VOC format dataset to mindrecord for evaluating Face detection."""
import os
import xml.etree.ElementTree as ET
import numpy as np
from PIL import Image
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
dataset_root_list = ["Your_VOC_dataset_path1",
"Your_VOC_dataset_path2",
"Your_VOC_dataset_pathN",
]
mindrecord_file_name = "Your_output_path/data.mindrecord"
mindrecord_num = 8
is_train = False
class_indexing_1 = {'face': 0}
def prepare_file_paths():
'''prepare_file_paths'''
image_files = []
anno_files = []
image_names = []
for dataset_root in dataset_root_list:
if not os.path.isdir(dataset_root):
raise ValueError("dataset root is unvalid!")
anno_dir = os.path.join(dataset_root, "Annotations")
image_dir = os.path.join(dataset_root, "JPEGImages")
if is_train:
valid_txt = os.path.join(dataset_root, "ImageSets/Main/train.txt")
else:
valid_txt = os.path.join(dataset_root, "ImageSets/Main/test.txt")
ret_image_files, ret_anno_files, ret_image_names = filter_valid_files_by_txt(image_dir, anno_dir, valid_txt)
image_files.extend(ret_image_files)
anno_files.extend(ret_anno_files)
image_names.extend(ret_image_names)
return image_files, anno_files, image_names
def filter_valid_files_by_txt(image_dir, anno_dir, valid_txt):
'''filter_valid_files_by_txt'''
with open(valid_txt, "r") as txt:
valid_names = txt.readlines()
image_files = []
anno_files = []
image_names = []
for name in valid_names:
strip_name = name.strip("\n")
anno_joint_path = os.path.join(anno_dir, strip_name + ".xml")
if os.path.isfile(anno_joint_path):
image_joint_path = os.path.join(image_dir, strip_name + ".jpg")
image_name = image_joint_path.split('/')[-1].replace('.jpg', '')
if os.path.isfile(image_joint_path):
image_files.append(image_joint_path)
anno_files.append(anno_joint_path)
image_names.append(image_name)
continue
image_joint_path = os.path.join(image_dir, strip_name + ".png")
image_name = image_joint_path.split('/')[-1].replace('.png', '')
if os.path.isfile(image_joint_path):
image_files.append(image_joint_path)
anno_files.append(anno_joint_path)
image_names.append(image_name)
return image_files, anno_files, image_names
def deserialize(member, class_indexing):
'''deserialize'''
class_name = member[0].text
if class_name in class_indexing:
class_num = class_indexing[class_name]
else:
return None
bnx = member.find('bndbox')
box_x_min = float(bnx.find('xmin').text)
box_y_min = float(bnx.find('ymin').text)
box_x_max = float(bnx.find('xmax').text)
box_y_max = float(bnx.find('ymax').text)
width = float(box_x_max - box_x_min + 1)
height = float(box_y_max - box_y_min + 1)
try:
ignore = float(member.find('ignore').text)
except ValueError:
ignore = 0.0
return [class_num, box_x_min, box_y_min, width, height, ignore]
def get_data(image_file, anno_file, image_name):
'''get_data'''
count = 0
annotation = []
tree = ET.parse(anno_file)
root = tree.getroot()
with Image.open(image_file) as fd:
orig_width, orig_height = fd.size
with open(image_file, 'rb') as f:
img = f.read()
for member in root.findall('object'):
anno = deserialize(member, class_indexing_1)
if anno is not None:
annotation.extend(anno)
count += 1
for member in root.findall('Object'):
anno = deserialize(member, class_indexing_1)
if anno is not None:
annotation.extend(anno)
count += 1
if count == 0:
annotation = np.array([[-1, -1, -1, -1, -1, -1]], dtype='float64')
count = 1
data = {
"image": img,
"annotation": np.array(annotation, dtype='float64'),
"image_name": image_name,
"image_size": np.array([orig_width, orig_height], dtype='int32')
}
return data
def convert_yolo_data_to_mindrecord():
'''convert_yolo_data_to_mindrecord'''
writer = FileWriter(mindrecord_file_name, mindrecord_num)
yolo_json = {
"image": {"type": "bytes"},
"annotation": {"type": "float64", "shape": [-1, 6]},
"image_name": {"type": "string"},
"image_size": {"type": "int32", "shape": [-1, 2]}
}
print('Loading eval data...')
image_files, anno_files, image_names = prepare_file_paths()
dataset_size = len(anno_files)
assert dataset_size == len(image_files)
assert dataset_size == len(image_names)
logger.info("#size of dataset: {}".format(dataset_size))
data = []
for i in range(dataset_size):
data.append(get_data(image_files[i], anno_files[i], image_names[i]))
print('Writing eval data to mindrecord...')
writer.add_schema(yolo_json, "yolo_json")
if data is None:
raise ValueError("None needs writing to mindrecord.")
writer.write_raw_data(data)
writer.commit()
convert_yolo_data_to_mindrecord()

View File

@ -0,0 +1,157 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert VOC format dataset to mindrecord for training Face detection."""
import os
import xml.etree.ElementTree as ET
import numpy as np
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
dataset_root_list = ["Your_VOC_dataset_path1",
"Your_VOC_dataset_path2",
"Your_VOC_dataset_pathN",
]
mindrecord_file_name = "Your_output_path/data.mindrecord"
mindrecord_num = 8
is_train = True
class_indexing_1 = {'face': 0}
def prepare_file_paths():
'''prepare_file_paths'''
image_files = []
anno_files = []
for dataset_root in dataset_root_list:
if not os.path.isdir(dataset_root):
raise ValueError("dataset root is unvalid!")
anno_dir = os.path.join(dataset_root, "Annotations")
image_dir = os.path.join(dataset_root, "JPEGImages")
if is_train:
valid_txt = os.path.join(dataset_root, "ImageSets/Main/train.txt")
else:
valid_txt = os.path.join(dataset_root, "ImageSets/Main/test.txt")
ret_image_files, ret_anno_files = filter_valid_files_by_txt(image_dir, anno_dir, valid_txt)
image_files.extend(ret_image_files)
anno_files.extend(ret_anno_files)
return image_files, anno_files
def filter_valid_files_by_txt(image_dir, anno_dir, valid_txt):
'''filter_valid_files_by_txt'''
with open(valid_txt, "r") as txt:
valid_names = txt.readlines()
image_files = []
anno_files = []
for name in valid_names:
strip_name = name.strip("\n")
anno_joint_path = os.path.join(anno_dir, strip_name + ".xml")
if os.path.isfile(anno_joint_path):
image_joint_path = os.path.join(image_dir, strip_name + ".jpg")
if os.path.isfile(image_joint_path):
image_files.append(image_joint_path)
anno_files.append(anno_joint_path)
continue
image_joint_path = os.path.join(image_dir, strip_name + ".png")
if os.path.isfile(image_joint_path):
image_files.append(image_joint_path)
anno_files.append(anno_joint_path)
return image_files, anno_files
def deserialize(member, class_indexing):
'''deserialize'''
class_name = member[0].text
if class_name in class_indexing:
class_num = class_indexing[class_name]
else:
return None
bnx = member.find('bndbox')
box_x_min = float(bnx.find('xmin').text)
box_y_min = float(bnx.find('ymin').text)
box_x_max = float(bnx.find('xmax').text)
box_y_max = float(bnx.find('ymax').text)
width = float(box_x_max - box_x_min + 1)
height = float(box_y_max - box_y_min + 1)
try:
ignore = float(member.find('ignore').text)
except ValueError:
ignore = 0.0
return [class_num, box_x_min, box_y_min, width, height, ignore]
def get_data(image_file, anno_file):
'''get_data'''
count = 0
annotation = []
tree = ET.parse(anno_file)
root = tree.getroot()
with open(image_file, 'rb') as f:
img = f.read()
for member in root.findall('object'):
anno = deserialize(member, class_indexing_1)
if anno is not None:
annotation.extend(anno)
count += 1
for member in root.findall('Object'):
anno = deserialize(member, class_indexing_1)
if anno is not None:
annotation.extend(anno)
count += 1
if count == 0:
annotation = np.array([[-1, -1, -1, -1, -1, -1]], dtype='float64')
count = 1
data = {
"image": img,
"annotation": np.array(annotation, dtype='float64')
}
return data
def convert_yolo_data_to_mindrecord():
'''convert_yolo_data_to_mindrecord'''
writer = FileWriter(mindrecord_file_name, mindrecord_num)
yolo_json = {
"image": {"type": "bytes"},
"annotation": {"type": "float64", "shape": [-1, 6]}
}
print('Loading train data...')
image_files, anno_files = prepare_file_paths()
dataset_size = len(anno_files)
assert dataset_size == len(image_files)
logger.info("#size of dataset: {}".format(dataset_size))
data = []
for i in range(dataset_size):
data.append(get_data(image_files[i], anno_files[i]))
print('Writing train data to mindrecord...')
writer.add_schema(yolo_json, "yolo_json")
if data is None:
raise ValueError("None needs writing to mindrecord.")
writer.write_raw_data(data)
writer.commit()
convert_yolo_data_to_mindrecord()

View File

@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Add VOC format dataset to an existed mindrecord for training Face detection."""
import os
import xml.etree.ElementTree as ET
import numpy as np
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
dataset_root_list = ["Your_VOC_dataset_path1",
"Your_VOC_dataset_path2",
"Your_VOC_dataset_pathN",
]
mindrecord_file_name = "Your_previous_output_path/data.mindrecord0"
mindrecord_num = 8
is_train = True
class_indexing_1 = {'face': 0}
def prepare_file_paths():
'''prepare file paths'''
image_files = []
anno_files = []
for dataset_root in dataset_root_list:
if not os.path.isdir(dataset_root):
raise ValueError("dataset root is unvalid!")
anno_dir = os.path.join(dataset_root, "Annotations")
image_dir = os.path.join(dataset_root, "JPEGImages")
if is_train:
valid_txt = os.path.join(dataset_root, "ImageSets/Main/train.txt")
else:
valid_txt = os.path.join(dataset_root, "ImageSets/Main/test.txt")
ret_image_files, ret_anno_files = filter_valid_files_by_txt(image_dir, anno_dir, valid_txt)
image_files.extend(ret_image_files)
anno_files.extend(ret_anno_files)
return image_files, anno_files
def filter_valid_files_by_txt(image_dir, anno_dir, valid_txt):
'''filter valid files by txt'''
with open(valid_txt, "r") as txt:
valid_names = txt.readlines()
image_files = []
anno_files = []
for name in valid_names:
strip_name = name.strip("\n")
anno_joint_path = os.path.join(anno_dir, strip_name + ".xml")
if os.path.isfile(anno_joint_path):
image_joint_path = os.path.join(image_dir, strip_name + ".jpg")
if os.path.isfile(image_joint_path):
image_files.append(image_joint_path)
anno_files.append(anno_joint_path)
continue
image_joint_path = os.path.join(image_dir, strip_name + ".png")
if os.path.isfile(image_joint_path):
image_files.append(image_joint_path)
anno_files.append(anno_joint_path)
return image_files, anno_files
def deserialize(member, class_indexing):
'''deserialize'''
class_name = member[0].text
if class_name in class_indexing:
class_num = class_indexing[class_name]
else:
return None
bnx = member.find('bndbox')
box_x_min = float(bnx.find('xmin').text)
box_y_min = float(bnx.find('ymin').text)
box_x_max = float(bnx.find('xmax').text)
box_y_max = float(bnx.find('ymax').text)
width = float(box_x_max - box_x_min + 1)
height = float(box_y_max - box_y_min + 1)
try:
ignore = float(member.find('ignore').text)
except ValueError:
ignore = 0.0
return [class_num, box_x_min, box_y_min, width, height, ignore]
def get_data(image_file, anno_file):
'''get_data'''
count = 0
annotation = []
tree = ET.parse(anno_file)
root = tree.getroot()
with open(image_file, 'rb') as f:
img = f.read()
for member in root.findall('object'):
anno = deserialize(member, class_indexing_1)
if anno is not None:
annotation.extend(anno)
count += 1
for member in root.findall('Object'):
anno = deserialize(member, class_indexing_1)
if anno is not None:
annotation.extend(anno)
count += 1
if count == 0:
annotation = np.array([[-1, -1, -1, -1, -1, -1]], dtype='float64')
count = 1
data = {
"image": img,
"annotation": np.array(annotation, dtype='float64')
}
return data
def convert_yolo_data_to_mindrecord():
'''convert_yolo_data_to_mindrecord'''
print('Loading mindrecord...')
writer = FileWriter.open_for_append(mindrecord_file_name,)
print('Loading train data...')
image_files, anno_files = prepare_file_paths()
dataset_size = len(anno_files)
assert dataset_size == len(image_files)
logger.info("#size of dataset: {}".format(dataset_size))
data = []
for i in range(dataset_size):
data.append(get_data(image_files[i], anno_files[i]))
print('Writing train data to mindrecord...')
if data is None:
raise ValueError("None needs writing to mindrecord.")
writer.write_raw_data(data)
writer.commit()
convert_yolo_data_to_mindrecord()

View File

@ -0,0 +1,105 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom logger."""
import logging
import os
import sys
from datetime import datetime
logger_name_1 = 'yolov3_face_detection'
class LOGGER(logging.Logger):
'''LOGGER'''
def __init__(self, logger_name):
super(LOGGER, self).__init__(logger_name)
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
self.local_rank = 0
def setup_logging_file(self, log_dir, local_rank=0):
'''setup_logging_file'''
self.local_rank = local_rank
if self.local_rank == 0:
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER(logger_name_1)
logger.setup_logging_file(path, rank)
return logger
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)

View File

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection learning rate scheduler."""
from collections import Counter
import math
import numpy as np
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def warmup_step(args, gamma=0.1, lr_scale=1.0):
'''warmup_step'''
base_lr = args.lr
warmup_init_lr = 0
total_steps = int(args.max_epoch * args.steps_per_epoch)
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
milestones = args.lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone*args.steps_per_epoch
milestones_steps.append(milestones_step)
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_learning_rate(
i, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr_scale * lr * gamma**milestones_steps_counter[i]
print('i:{} lr:{}'.format(i, lr))
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_step_new(args, lr_scale=1.0):
'''warmup_step_new'''
warmup_lr = args.warmup_lr / args.batch_size
lr_rates = [lr_rate / args.batch_size * lr_scale for lr_rate in args.lr_rates]
total_steps = int(args.max_epoch * args.steps_per_epoch)
lr_steps = args.lr_steps
warmup_steps = lr_steps[0]
lr_left = 0
print('real warmup_lr', warmup_lr)
print('real lr_rates', lr_rates)
if args.max_epoch * args.steps_per_epoch > lr_steps[-1]:
lr_steps.append(args.max_epoch * args.steps_per_epoch)
lr_rates.append(lr_rates[-1])
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = warmup_lr
elif i < lr_steps[lr_left]:
lr = lr_rates[lr_left]
else:
lr_left = lr_left + 1
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
'''warmup_cosine_annealing_lr'''
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_learning_rate(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

View File

@ -0,0 +1,635 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection network wrapper."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, \
LessEqual, ControlDepend
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore import Tensor
from mindspore.context import ParallelMode
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.common.parameter import ParameterTuple
from mindspore.common import dtype as mstype
from src.FaceDetection.yolo_postprocess import YoloPostProcess
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class BuildTrainNetwork(nn.Cell):
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data, label):
output = self.network(input_data)
loss = self.criterion(output, label)
return loss
class TrainOneStepWithLossScaleCell(nn.Cell):
'''TrainOneStepWithLossScaleCell'''
def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap()
self.alloc_status = NPUAllocFloatStatus()
self.get_status = NPUGetFloatStatus()
self.clear_status = NPUClearFloatStatus()
self.reduce_sum = ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.reducer_flag = False
self.less_equal = LessEqual()
self.depend_parameter_use = ControlDepend(depend_mode=1)
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
def construct(self, data, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list,
coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1, t_cls_1, gt_list_1,
coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2, t_cls_2, gt_list_2,
sens=None):
'''construct'''
weights = self.weights
loss = self.network(data, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list,
coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1, t_cls_1,
gt_list_1, coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2,
t_cls_2, gt_list_2)
# init overflow buffer
init = self.alloc_status()
# clear overflow buffer
self.clear_status(init)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
grads = self.grad(self.network, weights)(data, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord,
t_conf, t_cls, gt_list, coord_mask_1, conf_pos_mask_1, conf_neg_mask_1,
cls_mask_1, t_coord_1, t_conf_1, t_cls_1, gt_list_1, coord_mask_2,
conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2,
t_cls_2, gt_list_2, F.cast(scaling_sens, F.dtype(loss)))
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
if self.reducer_flag:
grads = self.grad_reducer(grads)
# get the overflow buffer
self.get_status(init)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
opt = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return F.depend(ret, opt)
class BuildTrainNetworkV2(nn.Cell):
'''BuildTrainNetworkV2'''
def __init__(self, network, criterion0, criterion1, criterion2, args):
super(BuildTrainNetworkV2, self).__init__()
self.network = network
self.criterion0 = criterion0
self.criterion1 = criterion1
self.criterion2 = criterion2
self.args = args
def construct(self, input_data, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list,
coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1, t_cls_1, gt_list_1,
coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2, t_cls_2, gt_list_2):
'''construct'''
output0, output1, output2 = self.network(input_data)
loss0 = self.criterion0(output0, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls,
gt_list)
loss1 = self.criterion1(output1, coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1,
t_conf_1, t_cls_1, gt_list_1)
loss2 = self.criterion2(output2, coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2,
t_conf_2, t_cls_2, gt_list_2)
total_loss = loss0 + loss1 + loss2
return total_loss
class BuildTestNetwork(nn.Cell):
'''BuildTestNetwork'''
def __init__(self, network, reduction_0, reduction_1, reduction_2, anchors, anchors_mask, num_classes, args):
super(BuildTestNetwork, self).__init__()
self.print = P.Print()
self.network = network
self.reduction_0 = reduction_0
self.reduction_1 = reduction_1
self.reduction_2 = reduction_2
self.anchors = anchors
self.anchors_mask = anchors_mask
self.args = args
self.conf_thresh = self.args.conf_thresh
self.nms_thresh = self.args.nms_thresh
self.num_classes = num_classes
self.network_size = args.input_shape
cur_anchors_0 = [self.anchors[ii] for ii in self.anchors_mask[0]]
cur_anchors_0 = [(ii[0] / self.reduction_0, ii[1] / self.reduction_0) for ii in cur_anchors_0]
cur_anchors_1 = [self.anchors[ii] for ii in self.anchors_mask[1]]
cur_anchors_1 = [(ii[0] / self.reduction_1, ii[1] / self.reduction_1) for ii in cur_anchors_1]
cur_anchors_2 = [self.anchors[ii] for ii in self.anchors_mask[2]]
cur_anchors_2 = [(ii[0] / self.reduction_2, ii[1] / self.reduction_2) for ii in cur_anchors_2]
self.postprocess_0 = YoloPostProcess(self.num_classes, cur_anchors_0, self.conf_thresh, self.network_size,
self.reduction_0, self.anchors_mask[0])
self.postprocess_1 = YoloPostProcess(self.num_classes, cur_anchors_1, self.conf_thresh, self.network_size,
self.reduction_1, self.anchors_mask[1])
self.postprocess_2 = YoloPostProcess(self.num_classes, cur_anchors_2, self.conf_thresh, self.network_size,
self.reduction_2, self.anchors_mask[2])
def construct(self, input_data):
output0, output1, output2 = self.network(input_data)
coords_0, cls_scores_0 = self.postprocess_0(output0)
coords_1, cls_scores_1 = self.postprocess_1(output1)
coords_2, cls_scores_2 = self.postprocess_2(output2)
return coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2, cls_scores_2
class Box:
""" This is a generic bounding box representation.
This class provides some base functionality to both annotations and detections.
Attributes:
class_label (string): class string label; Default **''**
object_id (int): Object identifier for reid purposes; Default **0**
x_top_left (Number): X pixel coordinate of the top left corner of the bounding box; Default **0.0**
y_top_left (Number): Y pixel coordinate of the top left corner of the bounding box; Default **0.0**
width (Number): Width of the bounding box in pixels; Default **0.0**
height (Number): Height of the bounding box in pixels; Default **0.0**
"""
def __init__(self):
self.class_label = '' # class string label
self.object_id = 0 # object identifier
self.x_top_left = 0.0 # x pixel coordinate top left of the box
self.y_top_left = 0.0 # y pixel coordinate top left of the box
self.width = 0.0 # width of the box in pixels
self.height = 0.0 # height of the box in pixels
@classmethod
def create(cls, obj=None):
""" Create a bounding box from a string or other detection object.
Args:
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
"""
instance = cls()
if obj is None:
return instance
if isinstance(obj, str):
instance.deserialize(obj)
elif isinstance(obj, Box):
instance.class_label = obj.class_label
instance.object_id = obj.object_id
instance.x_top_left = obj.x_top_left
instance.y_top_left = obj.y_top_left
instance.width = obj.width
instance.height = obj.height
else:
raise TypeError(f'Object is not of type Box or not a string [obj.__class__.__name__]')
return instance
def __eq__(self, other):
return self.__dict__ == other.__dict__
def serialize(self):
""" abstract serializer, implement in derived classes. """
raise NotImplementedError
def deserialize(self, string):
""" abstract parser, implement in derived classes. """
raise NotImplementedError
class Detection(Box):
""" This is a generic detection class that provides some base functionality all detections need.
It builds upon :class:`~brambox.boxes.box.Box`.
Attributes:
confidence (Number): confidence score between 0-1 for that detection; Default **0.0**
"""
def __init__(self):
""" x_top_left,y_top_left,width,height are in pixel coordinates """
super(Detection, self).__init__()
self.confidence = 0.0 # Confidence score between 0-1
@classmethod
def create(cls, obj=None):
""" Create a detection from a string or other box object.
Args:
obj (Box or string, optional): Bounding box object to copy attributes from or string to deserialize
Note:
The obj can be both an :class:`~brambox.boxes.annotations.Annotation` or
a :class:`~brambox.boxes.detections.Detection`.
For Detections the confidence score is copied over, for Annotations it is set to 1.
"""
instance = super(Detection, cls).create(obj)
if obj is None:
return instance
if isinstance(obj, Detection):
instance.confidence = obj.confidence
return instance
def __repr__(self):
""" Unambiguous representation """
string = f'{self.__class__.__name__} ' + '{'
string += f'class_label = {self.class_label}, '
string += f'object_id = {self.object_id}, '
string += f'x = {self.x_top_left}, '
string += f'y = {self.y_top_left}, '
string += f'w = {self.width}, '
string += f'h = {self.height}, '
string += f'confidence = {self.confidence}'
return string + '}'
def __str__(self):
""" Pretty print """
string = 'Detection {'
string += f'\'{self.class_label}\' {self.object_id}, '
string += f'[{int(self.x_top_left)}, {int(self.y_top_left)}, {int(self.width)}, {int(self.height)}]'
string += f', {round(self.confidence*100, 2)} %'
return string + '}'
def serialize(self):
""" abstract serializer, implement in derived classes. """
raise NotImplementedError
def deserialize(self, string):
""" abstract parser, implement in derived classes. """
raise NotImplementedError
def get_bounding_boxes(coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2, cls_scores_2, conf_thresh,
input_shape, num_classes):
'''get_bounding_boxes'''
coords_0 = coords_0.asnumpy()
coords_1 = coords_1.asnumpy()
coords_2 = coords_2.asnumpy()
cls_scores_0 = cls_scores_0.asnumpy()
cls_scores_1 = cls_scores_1.asnumpy()
cls_scores_2 = cls_scores_2.asnumpy()
batch = cls_scores_0.shape[0]
w_0 = int(input_shape[0] / 64)
h_0 = int(input_shape[1] / 64)
w_1 = int(input_shape[0] / 32)
h_1 = int(input_shape[1] / 32)
w_2 = int(input_shape[0] / 16)
h_2 = int(input_shape[1] / 16)
num_anchors_0 = cls_scores_0.shape[1]
num_anchors_1 = cls_scores_1.shape[1]
num_anchors_2 = cls_scores_2.shape[1]
score_thresh_0 = cls_scores_0 > conf_thresh
score_thresh_1 = cls_scores_1 > conf_thresh
score_thresh_2 = cls_scores_2 > conf_thresh
score_thresh_flat_0 = score_thresh_0.reshape(-1)
score_thresh_flat_1 = score_thresh_1.reshape(-1)
score_thresh_flat_2 = score_thresh_2.reshape(-1)
score_thresh_expand_0 = np.expand_dims(score_thresh_0, axis=3)
score_thresh_expand_1 = np.expand_dims(score_thresh_1, axis=3)
score_thresh_expand_2 = np.expand_dims(score_thresh_2, axis=3)
score_thresh_cat_0 = np.concatenate((score_thresh_expand_0, score_thresh_expand_0), axis=3)
score_thresh_cat_0 = np.concatenate((score_thresh_cat_0, score_thresh_cat_0), axis=3)
score_thresh_cat_1 = np.concatenate((score_thresh_expand_1, score_thresh_expand_1), axis=3)
score_thresh_cat_1 = np.concatenate((score_thresh_cat_1, score_thresh_cat_1), axis=3)
score_thresh_cat_2 = np.concatenate((score_thresh_expand_2, score_thresh_expand_2), axis=3)
score_thresh_cat_2 = np.concatenate((score_thresh_cat_2, score_thresh_cat_2), axis=3)
coords_0 = coords_0[score_thresh_cat_0].reshape(-1, 4)
coords_1 = coords_1[score_thresh_cat_1].reshape(-1, 4)
coords_2 = coords_2[score_thresh_cat_2].reshape(-1, 4)
scores_0 = cls_scores_0[score_thresh_0].reshape(-1, 1)
scores_1 = cls_scores_1[score_thresh_1].reshape(-1, 1)
scores_2 = cls_scores_2[score_thresh_2].reshape(-1, 1)
idx_0 = np.tile((np.arange(num_classes)), (batch, num_anchors_0, w_0 * h_0))
idx_0 = idx_0[score_thresh_0].reshape(-1, 1)
idx_1 = np.tile((np.arange(num_classes)), (batch, num_anchors_1, w_1 * h_1))
idx_1 = idx_1[score_thresh_1].reshape(-1, 1)
idx_2 = np.tile((np.arange(num_classes)), (batch, num_anchors_2, w_2 * h_2))
idx_2 = idx_2[score_thresh_2].reshape(-1, 1)
detections_0 = np.concatenate([coords_0, scores_0, idx_0.astype(np.float32)], axis=1)
detections_1 = np.concatenate([coords_1, scores_1, idx_1.astype(np.float32)], axis=1)
detections_2 = np.concatenate([coords_2, scores_2, idx_2.astype(np.float32)], axis=1)
max_det_per_batch_0 = num_anchors_0 * h_0 * w_0 * num_classes
slices_0 = [slice(max_det_per_batch_0 * i, max_det_per_batch_0 * (i + 1)) for i in range(batch)]
det_per_batch_0 = np.array([score_thresh_flat_0[s].astype(np.int32).sum() for s in slices_0], dtype=np.int32)
max_det_per_batch_1 = num_anchors_1 * h_1 * w_1 * num_classes
slices_1 = [slice(max_det_per_batch_1 * i, max_det_per_batch_1 * (i + 1)) for i in range(batch)]
det_per_batch_1 = np.array([score_thresh_flat_1[s].astype(np.int32).sum() for s in slices_1], dtype=np.int32)
max_det_per_batch_2 = num_anchors_2 * h_2 * w_2 * num_classes
slices_2 = [slice(max_det_per_batch_2 * i, max_det_per_batch_2 * (i + 1)) for i in range(batch)]
det_per_batch_2 = np.array([score_thresh_flat_2[s].astype(np.int32).sum() for s in slices_2], dtype=np.int32)
split_idx_0 = np.cumsum(det_per_batch_0, axis=0)
split_idx_1 = np.cumsum(det_per_batch_1, axis=0)
split_idx_2 = np.cumsum(det_per_batch_2, axis=0)
boxes_0 = []
boxes_1 = []
boxes_2 = []
start = 0
for end in split_idx_0:
boxes_0.append(detections_0[start: end])
start = end
start = 0
for end in split_idx_1:
boxes_1.append(detections_1[start: end])
start = end
start = 0
for end in split_idx_2:
boxes_2.append(detections_2[start: end])
start = end
return boxes_0, boxes_1, boxes_2
def convert_tensor_to_brambox(boxes, width, height, class_label_map):
'''convert_tensor_to_brambox'''
boxes[:, 0:3:2] = boxes[:, 0:3:2] * width
boxes[:, 0] -= boxes[:, 2] / 2
boxes[:, 1:4:2] = boxes[:, 1:4:2] * height
boxes[:, 1] -= boxes[:, 3] / 2
brambox = []
for box in boxes:
det = Detection()
det.x_top_left = box[0]
det.y_top_left = box[1]
det.width = box[2]
det.height = box[3]
det.confidence = box[4]
if class_label_map is not None:
det.class_label = class_label_map[int(box[5])]
else:
det.class_label = str(int(box[5]))
brambox.append(det)
return brambox
def tensor_to_brambox(boxes_0, boxes_1, boxes_2, input_shape, labels):
'''tensor_to_brambox'''
converted_boxes_0 = []
converted_boxes_1 = []
converted_boxes_2 = []
for box in boxes_0:
if box.size == 0:
converted_boxes_0.append([])
else:
converted_boxes_0.append(convert_tensor_to_brambox(box, input_shape[0], input_shape[1], labels))
for box in boxes_1:
if box.size == 0:
converted_boxes_1.append([])
else:
converted_boxes_1.append(convert_tensor_to_brambox(box, input_shape[0], input_shape[1], labels))
for box in boxes_2:
if box.size == 0:
converted_boxes_2.append([])
else:
converted_boxes_2.append(convert_tensor_to_brambox(box, input_shape[0], input_shape[1], labels))
return converted_boxes_0, converted_boxes_1, converted_boxes_2
def parse_gt_from_anno(img_anno, classes):
'''parse_gt_from_anno'''
print('parse ground truth files...')
ground_truth = {}
for img_name, annos in img_anno.items():
objs = []
for anno in annos:
if anno[1] == 0. and anno[2] == 0. and anno[3] == 0. and anno[4] == 0.:
continue
if int(anno[0]) == -1:
continue
xmin = anno[1]
ymin = anno[2]
xmax = xmin + anno[3] - 1
ymax = ymin + anno[4] - 1
xmin = int(xmin)
ymin = int(ymin)
xmax = int(xmax)
ymax = int(ymax)
cls = classes[int(anno[0])]
gt_box = {'class': cls, 'box': [xmin, ymin, xmax, ymax]}
objs.append(gt_box)
ground_truth[img_name] = objs
return ground_truth
def parse_rets(ret_files_set):
'''parse_rets'''
print('parse ret files...')
ret_list = {}
for cls in ret_files_set:
ret_file = open(ret_files_set[cls])
ret_list[cls] = []
for line in ret_file.readlines():
info = line.strip().split()
img_name = info[0]
scole = float(info[1])
xmin = float(info[2])
ymin = float(info[3])
xmax = float(info[4])
ymax = float(info[5])
ret_list[cls].append({'img_name': img_name, 'scole': scole, 'ret': [xmin, ymin, xmax, ymax]})
return ret_list
def calc_gt_count(gt_set, cls):
count = 0
for img in gt_set:
for obj in gt_set[img]:
if obj['class'] == cls:
count += 1
return count
def calc_rect_area(rect):
return (rect[2] - rect[0] + 0.001) * (rect[3] - rect[1] + 0.001)
def calc_iou(rect1, rect2):
bd_i = (max(rect1[0], rect2[0]), max(rect1[1], rect2[1]),
min(rect1[2], rect2[2]), min(rect1[3], rect2[3]))
iw = bd_i[2] - bd_i[0] + 0.001
ih = bd_i[3] - bd_i[1] + 0.001
iou = 0
if iw > 0 and ih > 0:
ua = calc_rect_area(rect1) + calc_rect_area(rect2) - iw * ih
iou = iw * ih / ua
return iou
def cal_ap_voc2012(recall, precision):
'''cal_ap_voc2012'''
ap_val = 0.0
eps = 1e-6
assert len(recall) == len(precision)
lenght = len(recall)
cur_prec = precision[lenght - 1]
cur_rec = recall[lenght - 1]
for i in range(0, lenght - 1)[::-1]:
cur_prec = max(precision[i], cur_prec)
if abs(recall[i] - cur_rec) > eps:
ap_val += cur_prec * abs(recall[i] - cur_rec)
cur_rec = recall[i]
return ap_val
def cal_ap_11point(recall, precision):
'''cal_ap_11point'''
ap_val = 0.0
assert len(recall) == len(precision)
num = len(recall)
max_precs = np.zeros(10 + 1)
start_idx = num - 1
for j in range(0, 11)[::-1]:
for i in range(0, start_idx + 1)[::-1]:
if recall[i] < (j / 10.0):
start_idx = i
if j > 0:
max_precs[j - 1] = max_precs[j]
break
else:
if max_precs[j] < precision[i]:
max_precs[j] = precision[i]
for j in range(0, 11):
ap_val += max_precs[j] / 11.0
return ap_val
def calc_recall_presicion_ap(ground_truth, ret_list, iou_thr=0.5):
'''calc_recall_presicion_ap'''
print('calculate [recall | persicion | ap]...')
evaluate = {}
for cls in ret_list:
ret = ret_list[cls]
n_gt_obj = calc_gt_count(ground_truth, cls)
print('class [%s] ground truth:%d' % (cls, n_gt_obj))
ret = sorted(ret, key=lambda ret: ret['scole'], reverse=True)
tp = np.zeros(len(ret))
fp = np.zeros(len(ret))
for ret_idx, info in enumerate(ret):
img_name = info['img_name']
if img_name not in ground_truth:
print('%s not in ground truth' % img_name)
continue
else:
img_gts = ground_truth[img_name]
max_iou = 0
max_idx = -1
for idx, gt in enumerate(img_gts):
if (not gt['class'] == cls) or 'used' in gt:
continue
iou = calc_iou(info['ret'], gt['box'])
if iou > max_iou:
max_iou = iou
max_idx = idx
if max_iou > iou_thr:
tp[ret_idx] = 1
img_gts[max_idx]['used'] = 1
else:
fp[ret_idx] = 1
tp = tp.cumsum()
fp = fp.cumsum()
recall = tp / n_gt_obj
presicion = tp / (tp + fp)
ap = cal_ap_voc2012(recall, presicion)
evaluate[cls] = {'recall': recall, 'presicion': presicion, 'ap': ap}
return evaluate

View File

@ -0,0 +1,423 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data transform."""
import random
import numpy as np
from PIL import Image, ImageOps
try:
import cv2
except ImportError:
print('OpenCV is not installed and cannot be used')
cv2 = None
__all__ = ['RandomCropLetterbox', 'RandomFlip', 'HSVShift', 'ResizeLetterbox']
class RandomCropLetterbox():
""" Take random crop from the image.
Args:
jitter (Number [0-1]): Indicates how much of the image we can crop
crop_anno(Boolean, optional): Whether we crop the annotations inside the image crop; Default **False**
intersection_threshold(number or list, optional): Argument passed on to :class:
`brambox.boxes.util.modifiers.CropModifier`
Note:
Create 1 RandomCrop object and use it for both image and annotation transforms.
This object will save data from the image transform and use that on the annotation transform.
"""
def __init__(self, jitter, fill_color=127, input_dim=(1408, 768)):
self.fill_color = fill_color
self.jitter = jitter
self.crop_info = None
self.output_w = None
self.output_h = None
self.input_dim = input_dim
def __call__(self, img, annos):
if img is None:
return None, None
if isinstance(img, Image.Image):
img, _ = self._tf_pil(img)
elif isinstance(img, np.ndarray):
img, _ = self._tf_cv(img)
annos = self._tf_anno(annos)
annos = np.asarray(annos)
return (img, annos)
def _tf_cv(self, img, save_info=None):
""" Take random crop from image """
self.output_w, self.output_h = self.input_dim
orig_h, orig_w = img.shape[:2]
channels = img.shape[2] if len(img.shape) > 2 else 1
dw = int(self.jitter * orig_w)
dh = int(self.jitter * orig_h)
if save_info is None:
new_ar = float(orig_w + random.randint(-dw, dw)) / (orig_h + random.randint(-dh, dh))
else:
new_ar = save_info[0]
if save_info is None:
scale = random.random() * (2 - 0.25) + 0.25
else:
scale = save_info[1]
if new_ar < 1:
nh = int(scale * orig_h)
nw = int(nh * new_ar)
else:
nw = int(scale * orig_w)
nh = int(nw / new_ar)
if save_info is None:
if self.output_w > nw:
dx = random.randint(0, self.output_w - nw)
else:
dx = random.randint(self.output_w - nw, 0)
else:
dx = save_info[2]
if save_info is None:
if self.output_h > nh:
dy = random.randint(0, self.output_h - nh)
else:
dy = random.randint(self.output_h - nh, 0)
else:
dy = save_info[3]
nxmin = max(0, -dx)
nymin = max(0, -dy)
nxmax = min(nw, -dx + self.output_w - 1)
nymax = min(nh, -dy + self.output_h - 1)
sx, sy = float(orig_w) / nw, float(orig_h) / nh
orig_xmin = int(nxmin * sx)
orig_ymin = int(nymin * sy)
orig_xmax = int(nxmax * sx)
orig_ymax = int(nymax * sy)
orig_crop = img[orig_ymin:orig_ymax, orig_xmin:orig_xmax, :]
orig_crop_resize = cv2.resize(orig_crop, (nxmax - nxmin, nymax - nymin), interpolation=cv2.INTER_CUBIC)
output_img = np.ones((self.output_h, self.output_w, channels), dtype=np.uint8) * self.fill_color
y_lim = int(min(output_img.shape[0], orig_crop_resize.shape[0]))
x_lim = int(min(output_img.shape[1], orig_crop_resize.shape[1]))
output_img[:y_lim, :x_lim, :] = orig_crop_resize[:y_lim, :x_lim, :]
self.crop_info = [sx, sy, nxmin, nymin, nxmax, nymax]
if save_info is None:
return output_img, [new_ar, scale, dx, dy]
return output_img, save_info
def _tf_pil(self, img, save_info=None):
""" Take random crop from image """
self.output_w, self.output_h = self.input_dim
orig_w, orig_h = img.size
img_np = np.array(img)
channels = img_np.shape[2] if len(img_np.shape) > 2 else 1
dw = int(self.jitter * orig_w)
dh = int(self.jitter * orig_h)
if save_info is None:
new_ar = float(orig_w + random.randint(-dw, dw)) / (orig_h + random.randint(-dh, dh))
else:
new_ar = save_info[0]
if save_info is None:
scale = random.random() * (2 - 0.25) + 0.25
else:
scale = save_info[1]
if new_ar < 1:
nh = int(scale * orig_h)
nw = int(nh * new_ar)
else:
nw = int(scale * orig_w)
nh = int(nw / new_ar)
if save_info is None:
if self.output_w > nw:
dx = random.randint(0, self.output_w - nw)
else:
dx = random.randint(self.output_w - nw, 0)
else:
dx = save_info[2]
if save_info is None:
if self.output_h > nh:
dy = random.randint(0, self.output_h - nh)
else:
dy = random.randint(self.output_h - nh, 0)
else:
dy = save_info[3]
nxmin = max(0, -dx)
nymin = max(0, -dy)
nxmax = min(nw, -dx + self.output_w - 1)
nymax = min(nh, -dy + self.output_h - 1)
sx, sy = float(orig_w) / nw, float(orig_h) / nh
orig_xmin = int(nxmin * sx)
orig_ymin = int(nymin * sy)
orig_xmax = int(nxmax * sx)
orig_ymax = int(nymax * sy)
orig_crop = img.crop((orig_xmin, orig_ymin, orig_xmax, orig_ymax))
orig_crop_resize = orig_crop.resize((nxmax - nxmin, nymax - nymin))
output_img = Image.new(img.mode, (self.output_w, self.output_h), color=(self.fill_color,) * channels)
output_img.paste(orig_crop_resize, (0, 0))
self.crop_info = [sx, sy, nxmin, nymin, nxmax, nymax]
if save_info is None:
return output_img, [new_ar, scale, dx, dy]
return output_img, save_info
def _tf_anno(self, annos):
""" Change coordinates of an annotation, according to the previous crop """
def is_negative(anno):
for value in anno:
if value != -1:
return False
return True
sx, sy, crop_xmin, crop_ymin, crop_xmax, crop_ymax = self.crop_info
for i in range(len(annos) - 1, -1, -1):
anno = annos[i]
if is_negative(anno):
continue
else:
x1 = max(crop_xmin, int(anno[1] / sx))
x2 = min(crop_xmax, int((anno[1] + anno[3]) / sx))
y1 = max(crop_ymin, int(anno[2] / sy))
y2 = min(crop_ymax, int((anno[2] + anno[4]) / sy))
w = x2 - x1
h = y2 - y1
if w <= 2 or h <= 2:
annos[i] = np.zeros(6)
continue
annos[i][1] = x1 - crop_xmin
annos[i][2] = y1 - crop_ymin
annos[i][3] = w
annos[i][4] = h
return annos
class RandomFlip():
""" Randomly flip image.
Args:
threshold (Number [0-1]): Chance of flipping the image
Note:
Create 1 RandomFlip object and use it for both image and annotation transforms.
This object will save data from the image transform and use that on the annotation transform.
"""
def __init__(self, threshold):
self.threshold = threshold
self.flip = False
self.im_w = None
def __call__(self, img, annos):
if img is None and annos is None:
return None, None
if isinstance(img, Image.Image):
img = self._tf_pil(img)
elif isinstance(img, np.ndarray):
img = self._tf_cv(img)
annos = [self._tf_anno(anno) for anno in annos]
annos = np.asarray(annos)
return (img, annos)
def _tf_pil(self, img):
""" Randomly flip image """
self.flip = self._get_flip()
self.im_w = img.size[0]
if self.flip:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def _tf_cv(self, img):
""" Randomly flip image """
self.flip = self._get_flip()
self.im_w = img.shape[1]
if self.flip:
img = cv2.flip(img, 1)
return img
def _get_flip(self):
flip = random.random() < self.threshold
return flip
def _tf_anno(self, anno):
""" Change coordinates of an annotation, according to the previous flip """
def is_negative(anno):
for value in anno:
if value not in (-1, 0):
return False
return True
if is_negative(anno):
return anno
if self.flip and self.im_w is not None:
anno[1] = self.im_w - anno[1] - anno[3]
return anno
class HSVShift():
""" Perform random HSV shift on the RGB data.
Args:
hue (Number): Random number between -hue,hue is used to shift the hue
saturation (Number): Random number between 1,saturation is used to shift the saturation; 50% chance to
get 1/dSaturation in stead of dSaturation
value (Number): Random number between 1,value is used to shift the value; 50% chance to get 1/dValue in
stead of dValue
Warning:
If you use OpenCV as your image processing library, make sure the image is RGB before using this transform.
By default OpenCV uses BGR, so you must use `cvtColor`_ function to transform it to RGB.
.. _cvtColor: https://docs.opencv.org/master/d7/d1b/group__imgproc__misc.html#ga397ae87e1288a81d2363b61574eb8cab
"""
def __init__(self, hue, saturation, value):
self.hue = hue
self.saturation = saturation
self.value = value
def __call__(self, img, annos):
dh = random.uniform(-self.hue, self.hue)
ds = random.uniform(1, self.saturation)
if random.random() < 0.5:
ds = 1 / ds
dv = random.uniform(1, self.value)
if random.random() < 0.5:
dv = 1 / dv
if img is None:
return None
if isinstance(img, Image.Image):
img = self._tf_pil(img, dh, ds, dv)
return (img, annos)
if isinstance(img, np.ndarray):
return (self._tf_cv(img, dh, ds, dv), annos)
print(f'HSVShift only works with <PIL images> or <OpenCV images> [{type(img)}]')
return (img, annos)
@staticmethod
def _tf_pil(img, dh, ds, dv):
""" Random hsv shift """
img = img.convert('HSV')
channels = list(img.split())
def change_hue(x):
x += int(dh * 255)
if x > 255:
x -= 255
elif x < 0:
x += 0
return x
channels[0] = channels[0].point(change_hue)
channels[1] = channels[1].point(lambda i: min(255, max(0, int(i * ds))))
channels[2] = channels[2].point(lambda i: min(255, max(0, int(i * dv))))
img = Image.merge(img.mode, tuple(channels))
img = img.convert('RGB')
return img
@staticmethod
def _tf_cv(img, dh, ds, dv):
""" Random hsv shift """
img = img.astype(np.float32) / 255.0
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
def wrap_hue(x):
x[x >= 360.0] -= 360.0
x[x < 0.0] += 360.0
return x
img[:, :, 0] = wrap_hue(img[:, :, 0] + (360.0 * dh))
img[:, :, 1] = np.clip(ds * img[:, :, 1], 0.0, 1.0)
img[:, :, 2] = np.clip(dv * img[:, :, 2], 0.0, 1.0)
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
img = (img * 255).astype(np.uint8)
return img
class ResizeLetterbox:
""" Resize the image to input_dim.
Args:
input_dim: Input size of network.
"""
def __init__(self, fill_color=127, input_dim=(1408, 768)):
self.fill_color = fill_color
self.crop_info = None
self.output_w = None
self.output_h = None
self.input_dim = input_dim
self.pad = None
self.scale = None
def __call__(self, img, annos):
if img is None:
return None, None
if isinstance(img, Image.Image):
img = self._tf_pil(img)
annos = np.asarray(annos)
return img, annos
def _tf_pil(self, img):
""" Letterbox an image to fit in the network """
net_w, net_h = self.input_dim
im_w, im_h = img.size
if im_w == net_w and im_h == net_h:
self.scale = None
self.pad = None
return img
# Rescaling
if im_w / net_w >= im_h / net_h:
self.scale = net_w / im_w
else:
self.scale = net_h / im_h
if self.scale != 1:
resample_mode = Image.NEAREST
img = img.resize((int(self.scale * im_w), int(self.scale * im_h)), resample_mode)
im_w, im_h = img.size
if im_w == net_w and im_h == net_h:
self.pad = None
return img
# Padding
img_np = np.array(img)
channels = img_np.shape[2] if len(img_np.shape) > 2 else 1
pad_w = (net_w - im_w) / 2
pad_h = (net_h - im_h) / 2
self.pad = (int(pad_w), int(pad_h), int(pad_w + .5), int(pad_h + .5))
img = ImageOps.expand(img, border=self.pad, fill=(self.fill_color,) * channels)
return img

View File

@ -0,0 +1,317 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face detection train."""
import os
import time
import datetime
import argparse
import numpy as np
from mindspore import context
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore import Tensor
from mindspore.nn import Momentum
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype
import mindspore.dataset as de
from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
from src.FaceDetection.yolo_loss import YoloLoss
from src.network_define import BuildTrainNetworkV2, TrainOneStepWithLossScaleCell
from src.lrsche_factory import warmup_step_new
from src.logging import get_logger
from src.data_preprocess import compose_map_func
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
def parse_args():
'''parse_args'''
parser = argparse.ArgumentParser('Yolov3 Face Detection')
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
args, _ = parser.parse_known_args()
return args
def train(args):
'''train'''
print('=============yolov3 start trainging==================')
# init distributed
if args.world_size != 1:
init()
args.local_rank = get_rank()
args.world_size = get_group_size()
args.batch_size = config.batch_size
args.warmup_lr = config.warmup_lr
args.lr_rates = config.lr_rates
args.lr_steps = config.lr_steps
args.gamma = config.gamma
args.weight_decay = config.weight_decay
args.momentum = config.momentum
args.max_epoch = config.max_epoch
args.log_interval = config.log_interval
args.ckpt_path = config.ckpt_path
args.ckpt_interval = config.ckpt_interval
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
print('args.outputs_dir', args.outputs_dir)
args.logger = get_logger(args.outputs_dir, args.local_rank)
if args.world_size != 8:
args.lr_steps = [i * 8 // args.world_size for i in args.lr_steps]
if args.world_size == 1:
args.weight_decay = 0.
if args.world_size != 1:
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.world_size, gradients_mean=True)
mindrecord_path = args.mindrecord_path
num_classes = config.num_classes
anchors = config.anchors
anchors_mask = config.anchors_mask
num_anchors_list = [len(x) for x in anchors_mask]
momentum = args.momentum
args.logger.info('train opt momentum:{}'.format(momentum))
weight_decay = args.weight_decay * float(args.batch_size)
args.logger.info('real weight_decay:{}'.format(weight_decay))
lr_scale = args.world_size / 8
args.logger.info('lr_scale:{}'.format(lr_scale))
# dataloader
args.logger.info('start create dataloader')
epoch = args.max_epoch
ds = de.MindDataset(mindrecord_path + "0", columns_list=["image", "annotation"], num_shards=args.world_size,
shard_id=args.local_rank)
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation", 'coord_mask_0', 'conf_pos_mask_0', 'conf_neg_mask_0',
'cls_mask_0', 't_coord_0', 't_conf_0', 't_cls_0', 'gt_list_0', 'coord_mask_1',
'conf_pos_mask_1', 'conf_neg_mask_1', 'cls_mask_1', 't_coord_1', 't_conf_1',
't_cls_1', 'gt_list_1', 'coord_mask_2', 'conf_pos_mask_2', 'conf_neg_mask_2',
'cls_mask_2', 't_coord_2', 't_conf_2', 't_cls_2', 'gt_list_2'],
column_order=["image", "annotation", 'coord_mask_0', 'conf_pos_mask_0', 'conf_neg_mask_0',
'cls_mask_0', 't_coord_0', 't_conf_0', 't_cls_0', 'gt_list_0', 'coord_mask_1',
'conf_pos_mask_1', 'conf_neg_mask_1', 'cls_mask_1', 't_coord_1', 't_conf_1',
't_cls_1', 'gt_list_1', 'coord_mask_2', 'conf_pos_mask_2', 'conf_neg_mask_2',
'cls_mask_2', 't_coord_2', 't_conf_2', 't_cls_2', 'gt_list_2'],
operations=compose_map_func, num_parallel_workers=16, python_multiprocessing=True)
ds = ds.batch(args.batch_size, drop_remainder=True, num_parallel_workers=8)
args.steps_per_epoch = ds.get_dataset_size()
lr = warmup_step_new(args, lr_scale=lr_scale)
ds = ds.repeat(epoch)
args.logger.info('args.steps_per_epoch:{}'.format(args.steps_per_epoch))
args.logger.info('args.world_size:{}'.format(args.world_size))
args.logger.info('args.local_rank:{}'.format(args.local_rank))
args.logger.info('end create dataloader')
args.logger.save_args(args)
args.logger.important_info('start create network')
create_network_start = time.time()
# backbone and loss
network = backbone_HwYolov3(num_classes, num_anchors_list, args)
criterion0 = YoloLoss(num_classes, anchors, anchors_mask[0], 64, 0, head_idx=0.0)
criterion1 = YoloLoss(num_classes, anchors, anchors_mask[1], 32, 0, head_idx=1.0)
criterion2 = YoloLoss(num_classes, anchors, anchors_mask[2], 16, 0, head_idx=2.0)
# load pretrain model
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(args.pretrained))
train_net = BuildTrainNetworkV2(network, criterion0, criterion1, criterion2, args)
# optimizer
opt = Momentum(params=train_net.trainable_params(), learning_rate=Tensor(lr), momentum=momentum,
weight_decay=weight_decay)
# package training process
train_net = TrainOneStepWithLossScaleCell(train_net, opt)
train_net.set_broadcast_flag()
# checkpoint
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
train_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
cb_params = _InternalCallbackParam()
cb_params.train_network = train_net
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
train_net.set_train()
t_end = time.time()
t_epoch = time.time()
old_progress = -1
i = 0
scale_manager = DynamicLossScaleManager(init_loss_scale=2 ** 10, scale_factor=2, scale_window=2000)
for data in ds.create_tuple_iterator(output_numpy=True):
batch_images = data[0]
batch_labels = data[1]
coord_mask_0 = data[2]
conf_pos_mask_0 = data[3]
conf_neg_mask_0 = data[4]
cls_mask_0 = data[5]
t_coord_0 = data[6]
t_conf_0 = data[7]
t_cls_0 = data[8]
gt_list_0 = data[9]
coord_mask_1 = data[10]
conf_pos_mask_1 = data[11]
conf_neg_mask_1 = data[12]
cls_mask_1 = data[13]
t_coord_1 = data[14]
t_conf_1 = data[15]
t_cls_1 = data[16]
gt_list_1 = data[17]
coord_mask_2 = data[18]
conf_pos_mask_2 = data[19]
conf_neg_mask_2 = data[20]
cls_mask_2 = data[21]
t_coord_2 = data[22]
t_conf_2 = data[23]
t_cls_2 = data[24]
gt_list_2 = data[25]
img_tensor = Tensor(batch_images, mstype.float32)
coord_mask_tensor_0 = Tensor(coord_mask_0.astype(np.float32))
conf_pos_mask_tensor_0 = Tensor(conf_pos_mask_0.astype(np.float32))
conf_neg_mask_tensor_0 = Tensor(conf_neg_mask_0.astype(np.float32))
cls_mask_tensor_0 = Tensor(cls_mask_0.astype(np.float32))
t_coord_tensor_0 = Tensor(t_coord_0.astype(np.float32))
t_conf_tensor_0 = Tensor(t_conf_0.astype(np.float32))
t_cls_tensor_0 = Tensor(t_cls_0.astype(np.float32))
gt_list_tensor_0 = Tensor(gt_list_0.astype(np.float32))
coord_mask_tensor_1 = Tensor(coord_mask_1.astype(np.float32))
conf_pos_mask_tensor_1 = Tensor(conf_pos_mask_1.astype(np.float32))
conf_neg_mask_tensor_1 = Tensor(conf_neg_mask_1.astype(np.float32))
cls_mask_tensor_1 = Tensor(cls_mask_1.astype(np.float32))
t_coord_tensor_1 = Tensor(t_coord_1.astype(np.float32))
t_conf_tensor_1 = Tensor(t_conf_1.astype(np.float32))
t_cls_tensor_1 = Tensor(t_cls_1.astype(np.float32))
gt_list_tensor_1 = Tensor(gt_list_1.astype(np.float32))
coord_mask_tensor_2 = Tensor(coord_mask_2.astype(np.float32))
conf_pos_mask_tensor_2 = Tensor(conf_pos_mask_2.astype(np.float32))
conf_neg_mask_tensor_2 = Tensor(conf_neg_mask_2.astype(np.float32))
cls_mask_tensor_2 = Tensor(cls_mask_2.astype(np.float32))
t_coord_tensor_2 = Tensor(t_coord_2.astype(np.float32))
t_conf_tensor_2 = Tensor(t_conf_2.astype(np.float32))
t_cls_tensor_2 = Tensor(t_cls_2.astype(np.float32))
gt_list_tensor_2 = Tensor(gt_list_2.astype(np.float32))
scaling_sens = Tensor(scale_manager.get_loss_scale(), dtype=mstype.float32)
loss0, overflow, _ = train_net(img_tensor, coord_mask_tensor_0, conf_pos_mask_tensor_0,
conf_neg_mask_tensor_0, cls_mask_tensor_0, t_coord_tensor_0,
t_conf_tensor_0, t_cls_tensor_0, gt_list_tensor_0,
coord_mask_tensor_1, conf_pos_mask_tensor_1, conf_neg_mask_tensor_1,
cls_mask_tensor_1, t_coord_tensor_1, t_conf_tensor_1,
t_cls_tensor_1, gt_list_tensor_1, coord_mask_tensor_2,
conf_pos_mask_tensor_2, conf_neg_mask_tensor_2,
cls_mask_tensor_2, t_coord_tensor_2, t_conf_tensor_2,
t_cls_tensor_2, gt_list_tensor_2, scaling_sens)
overflow = np.all(overflow.asnumpy())
if overflow:
scale_manager.update_loss_scale(overflow)
else:
scale_manager.update_loss_scale(False)
args.logger.info('rank[{}], iter[{}], loss[{}], overflow:{}, loss_scale:{}, lr:{}, batch_images:{}, '
'batch_labels:{}'.format(args.local_rank, i, loss0, overflow, scaling_sens, lr[i],
batch_images.shape, batch_labels.shape))
# save ckpt
cb_params.cur_step_num = i + 1 # current step number
cb_params.batch_num = i + 2
if args.local_rank == 0:
ckpt_cb.step_end(run_context)
# save Log
if i == 0:
time_for_graph_compile = time.time() - create_network_start
args.logger.important_info('Yolov3, graph compile time={:.2f}s'.format(time_for_graph_compile))
if i % args.steps_per_epoch == 0:
cb_params.cur_epoch_num += 1
if i % args.log_interval == 0 and args.local_rank == 0:
time_used = time.time() - t_end
epoch = int(i / args.steps_per_epoch)
fps = args.batch_size * (i - old_progress) * args.world_size / time_used
args.logger.info('epoch[{}], iter[{}], loss:[{}], {:.2f} imgs/sec'.format(epoch, i, loss0, fps))
t_end = time.time()
old_progress = i
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
epoch_time_used = time.time() - t_epoch
epoch = int(i / args.steps_per_epoch)
fps = args.batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
args.logger.info('=================================================')
args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
args.logger.info('=================================================')
t_epoch = time.time()
i = i + 1
args.logger.info('=============yolov3 training finished==================')
if __name__ == "__main__":
arg = parse_args()
train(arg)