Add deeptext.

This commit is contained in:
疯子奏乐 2020-12-26 18:49:05 +08:00
parent 4a940e4be6
commit e735b85802
21 changed files with 3651 additions and 0 deletions

View File

@ -0,0 +1,223 @@
# DeepText for Ascend
- [DeepText Description](#DeepText-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [DeepText Description](#contents)
DeepText is a convolutional neural network architecture for text detection in non-specific scenarios. The DeepText system is based on the elegant framwork of Faster R-CNN. This idea was proposed in the paper "DeepText: A new approach for text proposal generation and text detection in natural images.", published in 2017.
[Paper](https://arxiv.org/pdf/1605.07314v1.pdf) Zhuoyao Zhong, Lianwen Jin, Shuangping Huang, South China University of Technology (SCUT), Published in ICASSP 2017.
# [Model architecture](#contents)
The overall network architecture of InceptionV4 is show below:
[Link](https://arxiv.org/pdf/1605.07314v1.pdf)
# [Dataset](#contents)
Here we used 4 datasets for training, and 1 datasets for Evaluation.
- Dataset1: ICDAR 2013: Focused Scene Text
- Train: 142MB, 229 images
- Test: 110MB, 233 images
- Dataset2: ICDAR 2013: Born-Digital Images
- Train: 27.7MB, 410 images
- Dataset3: SCUT-FORU: Flickr OCR Universal Database
- Train: 388MB, 1715 images
- Dataset4: CocoText v2(Subset of MSCOCO2014):
- Train: 13GB, 63686 images
# [Features](#contents)
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```shell
.
└─deeptext
├─README.md
├─scripts
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
└─run_eval_ascend.sh # launch evaluating with ascend platform
├─src
├─DeepText
├─__init__.py # package init file
├─anchor_genrator.py # anchor generator
├─bbox_assign_sample.py # proposal layer for stage 1
├─bbox_assign_sample_stage2.py # proposal layer for stage 2
├─deeptext_vgg16.py # main network defination
├─proposal_generator.py # proposal generator
├─rcnn.py # rcnn
├─roi_align.py # roi_align cell wrapper
├─rpn.py # region-proposal network
└─vgg16.py # backbone
├─config.py # training configuration
├─dataset.py # data proprocessing
├─lr_schedule.py # learning rate scheduler
├─network_define.py # network defination
└─utils.py # some functions which is commonly used
├─eval.py # eval net
├─export.py # export checkpoint, surpport .onnx, .air, .mindir convert
└─train.py # train net
```
## [Training process](#contents)
### Usage
- Ascend:
```bash
# distribute training example(8p)
sh run_distribute_train_ascend.sh [IMGS_PATH] [ANNOS_PATH] [RANK_TABLE_FILE] [PRETRAINED_PATH] [COCO_TEXT_PARSER_PATH]
# standalone training
sh run_standalone_train_ascend.sh [IMGS_PATH] [ANNOS_PATH] [PRETRAINED_PATH] [COCO_TEXT_PARSER_PATH]
# evaluation:
sh run_eval_ascend.sh [IMGS_PATH] [ANNOS_PATH] [CHECKPOINT_PATH] [COCO_TEXT_PARSER_PATH]
```
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
>
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
>
> The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.
> COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).
>
### Launch
```bash
# training example
shell:
Ascend:
# distribute training example(8p)
sh run_distribute_train_ascend.sh [IMGS_PATH] [ANNOS_PATH] [RANK_TABLE_FILE] [PRETRAINED_PATH] [COCO_TEXT_PARSER_PATH]
# standalone training
sh run_standalone_train_ascend.sh [IMGS_PATH] [ANNOS_PATH] [PRETRAINED_PATH] [COCO_TEXT_PARSER_PATH]
```
### Result
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log`, also the loss will be redirected to `./loss_0.log` like followings.
```python
469 epoch: 1 step: 982 ,rpn_loss: 0.03940, rcnn_loss: 0.48169, rpn_cls_loss: 0.02910, rpn_reg_loss: 0.00344, rcnn_cls_loss: 0.41943, rcnn_reg_loss: 0.06223, total_loss: 0.52109
659 epoch: 2 step: 982 ,rpn_loss: 0.03607, rcnn_loss: 0.32129, rpn_cls_loss: 0.02916, rpn_reg_loss: 0.00230, rcnn_cls_loss: 0.25732, rcnn_reg_loss: 0.06390, total_loss: 0.35736
847 epoch: 3 step: 982 ,rpn_loss: 0.07074, rcnn_loss: 0.40527, rpn_cls_loss: 0.03494, rpn_reg_loss: 0.01193, rcnn_cls_loss: 0.30591, rcnn_reg_loss: 0.09937, total_loss: 0.47601
```
## [Eval process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend:
```bash
sh run_eval_ascend.sh [IMGS_PATH] [ANNOS_PATH] [CHECKPOINT_PATH] [COCO_TEXT_PARSER_PATH]
```
### Launch
```bash
# eval example
shell:
Ascend:
sh run_eval_ascend.sh [IMGS_PATH] [ANNOS_PATH] [CHECKPOINT_PATH] [COCO_TEXT_PARSER_PATH]
```
> checkpoint can be produced in training process.
### Result
Evaluation result will be stored in the example path, you can find result like the followings in `log`.
```python
========================================
class 1 precision is 88.01%, recall is 82.77%
```
# [Model description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | Deeptext |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
| uploaded Date | 12/26/2020 |
| MindSpore Version | 1.1.0 |
| Dataset | 66040 images |
| Batch_size | 2 |
| Training Parameters | src/config.py |
| Optimizer | Momentum |
| Loss Function | SoftmaxCrossEntropyWithLogits for classification, SmoothL2Loss for bbox regression|
| Loss | ~0.008 |
| Accuracy (8p) | precision=0.8854, recall=0.8024 |
| Total time (8p) | 4h |
| Scripts | [deeptext script](https://gitee.com/mindspore/mindspore/tree/r1.1/mindspore/official/cv/deeptext) |
#### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | Deeptext |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
| Uploaded Date | 12/26/2020 |
| MindSpore Version | 1.1.0 |
| Dataset | 229 images |
| Batch_size | 2 |
| Accuracy | precision=0.8854, recall=0.8024 |
| Total time | 1 min |
| Model for inference | 3492M (.ckpt file) |
#### Training performance results
| **Ascend** | train performance |
| :--------: | :---------------: |
| 1p | 42 img/s |
| **Ascend** | train performance |
| :--------: | :---------------: |
| 8p | 330 img/s |
# [Description of Random Situation](#contents)
We set seed to 1 in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,138 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for Deeptext"""
import argparse
import os
import time
import numpy as np
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_deeptext_dataset
from src.utils import metrics
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
set_seed(1)
parser = argparse.ArgumentParser(description="Deeptext evaluation")
parser.add_argument("--checkpoint_path", type=str, default='test', help="Checkpoint file path.")
parser.add_argument("--imgs_path", type=str, required=True,
help="Test images files paths, multiple paths can be separated by ','.")
parser.add_argument("--annos_path", type=str, required=True,
help="Annotations files paths of test images, multiple paths can be separated by ','.")
parser.add_argument("--device_id", type=int, default=7, help="Device id, default is 7.")
parser.add_argument("--mindrecord_prefix", type=str, default='Deeptext-TEST', help="Prefix of mindrecord.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
def Deeptext_eval_test(dataset_path='', ckpt_path=''):
"""Deeptext evaluation."""
ds = create_deeptext_dataset(dataset_path, batch_size=config.test_batch_size,
repeat_num=1, is_training=False)
total = ds.get_dataset_size()
net = Deeptext_VGG16(config)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
eval_iter = 0
print("\n========================================\n")
print("Processing, please wait a moment.")
max_num = 32
pred_data = []
for data in ds.create_dict_iterator():
eval_iter = eval_iter + 1
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
start = time.time()
# run net
output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
gt_bboxes = gt_bboxes.asnumpy()
gt_bboxes = gt_bboxes[gt_num.asnumpy().astype(bool), :]
print(gt_bboxes)
gt_labels = gt_labels.asnumpy()
gt_labels = gt_labels[gt_num.asnumpy().astype(bool)]
print(gt_labels)
end = time.time()
print("Iter {} cost time {}".format(eval_iter, end - start))
# output
all_bbox = output[0]
all_label = output[1] + 1
all_mask = output[2]
for j in range(config.test_batch_size):
all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :])
all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :])
all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :])
all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :]
all_labels_tmp_mask = all_label_squee[all_mask_squee]
if all_bboxes_tmp_mask.shape[0] > max_num:
inds = np.argsort(-all_bboxes_tmp_mask[:, -1])
inds = inds[:max_num]
all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds]
all_labels_tmp_mask = all_labels_tmp_mask[inds]
pred_data.append({"boxes": all_bboxes_tmp_mask,
"labels": all_labels_tmp_mask,
"gt_bboxes": gt_bboxes,
"gt_labels": gt_labels})
percent = round(eval_iter / total * 100, 2)
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
precisions, recalls = metrics(pred_data)
print("\n========================================\n")
for i in range(config.num_classes - 1):
j = i + 1
print("class {} precision is {:.2f}%, recall is {:.2f}%".format(j, precisions[j] * 100, recalls[j] * 100))
if config.use_ambigous_sample:
break
if __name__ == '__main__':
prefix = args_opt.mindrecord_prefix
config.test_images = args_opt.imgs_path
config.test_txts = args_opt.annos_path
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
print("CHECKING MINDRECORD FILES ...")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image(False, prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
print("CHECKING MINDRECORD FILES DONE!")
print("Start Eval!")
Deeptext_eval_test(mindrecord_file, args_opt.checkpoint_path)

View File

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 5 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [IMGS_PATH] [ANNOS_PATH] [RANK_TABLE_FILE] [PRETRAINED_PATH] [COCO_TEXT_PARSER_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
PATH2=$(get_real_path $2)
echo $PATH2
PATH3=$(get_real_path $3)
echo $PATH3
PATH4=$(get_real_path $4)
echo $PATH4
PATH5=$(get_real_path $5)
echo $PATH5
if [ ! -f $PATH3 ]
then
echo "error: RANK_TABLE_FILE=$PATH3 is not a file"
exit 1
fi
if [ ! -f $PATH4 ]
then
echo "error: PRETRAINED_PATH=$PATH4 is not a file"
exit 1
fi
if [ ! -f $PATH5 ]
then
echo "error: COCO_TEXT_PARSER_PATH=$PATH5 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH3
cp $PATH5 ../src/
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --device_id=$i --rank_id=$i --imgs_path=$PATH1 --annos_path=$PATH2 --run_distribute=True --device_num=$DEVICE_NUM --pre_trained=$PATH4 &> log &
cd ..
done

View File

@ -0,0 +1,70 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]
then
echo "Usage: sh run_eval_ascend.sh [IMGS_PATH] [ANNOS_PATH] [CHECKPOINT_PATH] [COCO_TEXT_PARSER_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PATH3=$(get_real_path $3)
PATH4=$(get_real_path $4)
echo $PATH1
echo $PATH2
echo $PATH3
echo $PATH4
if [ ! -f $PATH3 ]
then
echo "error: CHECKPOINT_PATH=$PATH3 is not a file"
exit 1
fi
if [ ! -f $PATH4 ]
then
echo "error: COCO_TEXT_PARSER_PATH=$PATH4 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
cp $PATH4 ../src/
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start eval for device $DEVICE_ID"
python eval.py --device_id=$DEVICE_ID --imgs_path=$PATH1 --annos_path=$PATH2 --checkpoint_path=$PATH3 &> log &
cd ..

View File

@ -0,0 +1,70 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 4 ]
then
echo "Usage: sh run_standalone_train_ascend.sh [IMGS_PATH] [ANNOS_PATH] [PRETRAINED_PATH] [COCO_TEXT_PARSER_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
PATH2=$(get_real_path $2)
echo $PATH2
PATH3=$(get_real_path $3)
echo $PATH3
PATH4=$(get_real_path $4)
echo $PATH4
if [ ! -f $PATH3 ]
then
echo "error: PRETRAINED_PATH=$PATH3 is not a file"
exit 1
fi
if [ ! -f $PATH4 ]
then
echo "error: COCO_TEXT_PARSER_PATH=$PATH4 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
cp $PATH4 ../src/
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --imgs_path=$PATH1 --annos_path=$PATH2 --pre_trained=$PATH3 &> log &
cd ..

View File

@ -0,0 +1,29 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext Init."""
from .bbox_assign_sample import BboxAssignSample
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .proposal_generator import Proposal
from .rcnn import Rcnn
from .rpn import RPN
from .roi_align import SingleRoIExtractor
from .anchor_generator import AnchorGenerator
__all__ = [
"BboxAssignSample",
"Proposal", "Rcnn",
"RPN", "SingleRoIExtractor", "AnchorGenerator"
]

View File

@ -0,0 +1,86 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext anchor generator."""
import numpy as np
class AnchorGenerator():
"""Anchor generator for Deeptext."""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
"""Anchor generator init method."""
self.base_size = base_size
self.scales = np.array(scales)
self.ratios = np.array(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
def gen_base_anchors(self):
"""Generate a single anchor."""
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = np.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1)
hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1)
else:
ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1)
base_anchors = np.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
axis=-1).round()
return base_anchors
def _meshgrid(self, x, y, row_major=True):
"""Generate grid."""
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
yy = np.repeat(y, len(x))
if row_major:
return xx, yy
return yy, xx
def grid_anchors(self, featmap_size, stride=16):
"""Generate anchor list."""
base_anchors = self.base_anchors
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w) * stride
shift_y = np.arange(0, feat_h) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
shifts = shifts.astype(base_anchors.dtype)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.reshape(-1, 4)
return all_anchors

View File

@ -0,0 +1,165 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext positive and negative sample screening for RPN."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
class BboxAssignSample(nn.Cell):
"""
Bbox assigner and sampler defination.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
Examples:
BboxAssignSample(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSample, self).__init__()
cfg = config
self.batch_size = batch_size
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
self.zero_thr = Tensor(0.0, mstype.float16)
self.num_bboxes = num_bboxes
self.num_gts = cfg.num_gts
self.num_expected_pos = cfg.num_expected_pos
self.num_expected_neg = cfg.num_expected_neg
self.add_gt_as_proposals = add_gt_as_proposals
if self.add_gt_as_proposals:
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
self.scatterNdUpdate = P.ScatterNdUpdate()
self.scatterNd = P.ScatterNd()
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_),
gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes,
self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
self.less(max_overlaps_w_gt, self.neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
assigned_gt_inds4 = assigned_gt_inds3
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j + 1:1]
overlaps_w_gt_j = self.squeeze(overlaps[j:j + 1:1, ::])
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
num_pos = self.sum_inds(num_pos, -1)
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
pos_bboxes_ = self.gatherND(bboxes, pos_index)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
total_index = self.concat((pos_index, neg_index))
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
labels_total, self.cast(label_weights_total, mstype.bool_)

View File

@ -0,0 +1,257 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext tpositive and negative sample screening for Rcnn."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
class BboxAssignSampleForRcnn(nn.Cell):
"""
Bbox assigner and sampler defination.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
Examples:
BboxAssignSampleForRcnn(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSampleForRcnn, self).__init__()
cfg = config
self.use_ambigous_sample = cfg.use_ambigous_sample
self.batch_size = batch_size
self.neg_iou_thr = cfg.neg_iou_thr_stage2
self.pos_iou_thr = cfg.pos_iou_thr_stage2
self.min_pos_iou = cfg.min_pos_iou_stage2
self.num_gts = cfg.num_gts
self.num_bboxes = num_bboxes
self.num_expected_pos = cfg.num_expected_pos_stage2
self.num_expected_amb = cfg.num_expected_amb_stage2
self.num_expected_neg = cfg.num_expected_neg_stage2
self.num_expected_total = cfg.num_expected_total_stage2
self.add_gt_as_proposals = add_gt_as_proposals
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32))
self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts),
dtype=np.int32))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.gatherV2 = P.GatherV2()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_amb = P.RandomChoiceWithMask(self.num_expected_amb)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2))
self.concat_axis1 = P.Concat(axis=1)
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
# Check
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
# Init tensor
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_amb = Tensor(np.array(-3 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.range_amb_size = Tensor(np.arange(self.num_expected_amb).astype(np.float16))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
if self.use_ambigous_sample:
self.check_neg_mask = Tensor(
np.array(np.ones(self.num_expected_neg - self.num_expected_pos - self.num_expected_amb), dtype=np.bool))
check_neg_mask_ignore_end = np.array(np.ones(self.num_expected_neg), dtype=np.bool)
check_neg_mask_ignore_end[-1] = False
self.check_neg_mask_ignore_end = Tensor(check_neg_mask_ignore_end)
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16))
self.bboxs_amb_mask = Tensor(np.zeros((self.num_expected_amb, 4), dtype=np.float16))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
self.labels_amb_mask = Tensor(np.array(np.zeros(self.num_expected_amb) + 2, dtype=np.uint8))
self.reshape_shape_pos = (self.num_expected_pos, 1)
self.reshape_shape_amb = (self.num_expected_amb, 1)
self.reshape_shape_neg = (self.num_expected_neg, 1)
self.scalar_zero = Tensor(0.0, dtype=mstype.float16)
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16)
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16)
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16)
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), \
gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), \
bboxes, self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt,
self.scalar_zero),
self.less(max_overlaps_w_gt,
self.scalar_neg_iou_thr))
assigned_gt_inds = self.assigned_gt_inds
if self.use_ambigous_sample:
amb_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt,
self.scalar_neg_iou_thr),
self.less(max_overlaps_w_gt,
self.scalar_pos_iou_thr))
assigned_gt_inds = self.select(amb_sample_iou_mask, self.assigned_amb, self.assigned_gt_inds)
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j + 1:1]
overlaps_w_ac_j = overlaps[j:j + 1:1, ::]
temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou)
temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j))
pos_mask_j = self.logicaland(temp1, temp2)
assigned_gt_inds3 = self.select(pos_mask_j, (j + 1) * self.assigned_gt_ones, assigned_gt_inds3)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores)
bboxes = self.concat((gt_bboxes_i, bboxes))
label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores)
label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid
assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5))
# Get pos index
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
pos_index = self.reshape(pos_index, self.reshape_shape_pos)
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
pos_index = pos_index * valid_pos_index
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
# Get ambiguous index
num_amb = None
amb_index = None
valid_amb_index = None
if self.use_ambigous_sample:
amb_index, valid_amb_index = self.random_choice_with_mask_amb(self.equal(assigned_gt_inds5, -3))
amb_check_valid = self.cast(self.equal(assigned_gt_inds5, -3), mstype.float16)
amb_check_valid = self.sum_inds(amb_check_valid, -1)
valid_amb_index = self.less(self.range_amb_size, amb_check_valid)
amb_index = amb_index * self.reshape(self.cast(valid_amb_index, mstype.int32), (self.num_expected_amb, 1))
num_amb = self.sum_inds(self.cast(self.logicalnot(valid_amb_index), mstype.float16), -1)
valid_amb_index = self.cast(valid_amb_index, mstype.int32)
amb_index = self.reshape(amb_index, self.reshape_shape_amb)
valid_amb_index = self.reshape(valid_amb_index, self.reshape_shape_amb)
amb_index = amb_index * valid_amb_index
# Get neg index
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
if self.use_ambigous_sample:
unvalid_amb_index = self.less(self.range_amb_size, num_amb)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_amb_index, unvalid_pos_index)),
valid_neg_index)
else:
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
valid_neg_index = self.logicaland(valid_neg_index, self.check_neg_mask_ignore_end)
# import pdb
# pdb.set_trace()
neg_index = self.reshape(neg_index, self.reshape_shape_neg)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
neg_index = neg_index * valid_neg_index
pos_bboxes_ = self.gatherND(bboxes, pos_index)
amb_bboxes_ = None
if self.use_ambigous_sample:
amb_bboxes_ = self.gatherND(bboxes, amb_index)
neg_bboxes_ = self.gatherV2(bboxes, self.squeeze(neg_index), 0)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
total_bboxes = self.concat((pos_bboxes_, neg_bboxes_))
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask))
total_labels = self.concat((pos_gt_labels, self.labels_neg_mask))
if self.use_ambigous_sample:
total_bboxes = self.concat((pos_bboxes_, amb_bboxes_, neg_bboxes_))
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_amb_mask, self.bboxs_neg_mask))
total_labels = self.concat((pos_gt_labels, self.labels_amb_mask, self.labels_neg_mask))
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
total_mask = self.concat((valid_pos_index, valid_neg_index))
if self.use_ambigous_sample:
valid_amb_index = self.reshape(valid_amb_index, self.reshape_shape_amb)
total_mask = self.concat((valid_pos_index, valid_amb_index, valid_neg_index))
return total_bboxes, total_deltas, total_labels, total_mask

View File

@ -0,0 +1,432 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext based on VGG16."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from .anchor_generator import AnchorGenerator
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .proposal_generator import Proposal
from .rcnn import Rcnn
from .rpn import RPN
from .vgg16 import VGG16FeatureExtraction
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shp_weight_conv = (out_channels, in_channels, kernel_size, kernel_size)
shp_bias_conv = (out_channels,)
weights = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor()
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor()
layers = []
layers += [nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias_conv)]
# layers += [nn.BatchNorm2d(out_channels)]
return nn.SequentialCell(layers)
class Deeptext_VGG16(nn.Cell):
"""
Deeptext_VGG16 Network.
Note:
backbone = vgg16
Returns:
Tuple, tuple of output tensor.
rpn_loss: Scalar, Total loss of RPN subnet.
rcnn_loss: Scalar, Total loss of RCNN subnet.
rpn_cls_loss: Scalar, Classification loss of RPN subnet.
rpn_reg_loss: Scalar, Regression loss of RPN subnet.
rcnn_cls_loss: Scalar, Classification loss of RCNN subnet.
rcnn_reg_loss: Scalar, Regression loss of RCNN subnet.
Examples:
net = Deeptext_VGG16()
"""
def __init__(self, config):
super(Deeptext_VGG16, self).__init__()
self.train_batch_size = config.batch_size
self.num_classes = config.num_classes
self.anchor_scales = config.anchor_scales
self.anchor_ratios = config.anchor_ratios
self.anchor_strides = config.anchor_strides
self.target_means = tuple(config.rcnn_target_means)
self.target_stds = tuple(config.rcnn_target_stds)
# Anchor generator
anchor_base_sizes = None
self.anchor_base_sizes = list(
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
featmap_sizes = config.feature_shapes
assert len(featmap_sizes) == len(self.anchor_generators)
self.anchor_list = self.get_anchors(featmap_sizes)
# Rpn and rpn loss
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
self.rpn_with_loss = RPN(config,
self.train_batch_size,
config.rpn_in_channels,
config.rpn_feat_channels,
config.num_anchors,
config.rpn_cls_out_channels)
# Proposal
self.proposal_generator = Proposal(config,
self.train_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator.set_train_local(config, True)
self.proposal_generator_test = Proposal(config,
config.test_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False)
# Assign and sampler stage two
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size,
config.num_bboxes_stage2, True)
self.decode = P.BoundingBoxDecode(max_shape=(576, 960), means=self.target_means, \
stds=self.target_stds)
# Rcnn
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer['out_size'] * config.roi_layer['out_size'],
self.train_batch_size, self.num_classes)
# Op declare
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.concat = P.Concat(axis=0)
self.concat_1 = P.Concat(axis=1)
self.concat_2 = P.Concat(axis=2)
self.reshape = P.Reshape()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
# Test mode
self.test_batch_size = config.test_batch_size
self.split = P.Split(axis=0, output_num=self.test_batch_size)
self.split_shape = P.Split(axis=0, output_num=4)
self.split_scores = P.Split(axis=1, output_num=self.num_classes)
self.split_cls = P.Split(axis=0, output_num=self.num_classes - 1)
self.tile = P.Tile()
self.gather = P.GatherNd()
self.rpn_max_num = config.rpn_max_num
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16))
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
self.ones_mask, self.zeros_mask), axis=1))
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr)
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0)
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1)
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr)
self.test_max_per_img = config.test_max_per_img
self.nms_test = P.NMSWithMask(config.test_iou_thr)
self.softmax = P.Softmax(axis=1)
self.logicand = P.LogicalAnd()
self.oneslike = P.OnesLike()
self.test_topk = P.TopK(sorted=True)
self.test_num_proposal = self.test_batch_size * self.rpn_max_num
# Improve speed
self.concat_start = (self.num_classes - 2)
self.concat_end = (self.num_classes - 1)
# Init tensor
self.use_ambigous_sample = config.use_ambigous_sample
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
dtype=np.float16) for i in range(self.train_batch_size)]
if self.use_ambigous_sample:
roi_align_index = [np.array(np.ones((
config.num_expected_pos_stage2 + config.num_expected_amb_stage2 + config.num_expected_neg_stage2,
1)) * i,
dtype=np.float16) for i in range(self.train_batch_size)]
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \
for i in range(self.test_batch_size)]
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))
self.roi_align4 = P.ROIAlign(pooled_width=7, pooled_height=7, spatial_scale=0.125)
self.roi_align5 = P.ROIAlign(pooled_width=7, pooled_height=7, spatial_scale=0.0625)
self.concat1 = P.Concat(axis=1)
self.roi_align_fuse = _conv(in_channels=1024, out_channels=512, kernel_size=1, padding=0, stride=1)
self.vgg16_feature_extractor = VGG16FeatureExtraction()
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
# f1, f2, f3, f4, f5 = self.vgg16_feature_extractor(img_data)
_, _, _, f4, f5 = self.vgg16_feature_extractor(img_data)
f4 = self.cast(f4, mstype.float16)
f5 = self.cast(f5, mstype.float16)
x = (f4, f5)
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x,
img_metas,
self.anchor_list,
gt_bboxes,
self.gt_labels_stage1,
gt_valids)
if self.training:
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list)
else:
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
gt_labels = self.cast(gt_labels, mstype.int32)
gt_valids = self.cast(gt_valids, mstype.int32)
bboxes_tuple = ()
deltas_tuple = ()
labels_tuple = ()
mask_tuple = ()
if self.training:
for i in range(self.train_batch_size):
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_labels_i = self.cast(gt_labels_i, mstype.uint8)
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
gt_valids_i = self.cast(gt_valids_i, mstype.bool_)
bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i,
gt_labels_i,
proposal_mask[i],
proposal[i][::, 0:4:1],
gt_valids_i)
bboxes_tuple += (bboxes,)
deltas_tuple += (deltas,)
labels_tuple += (labels,)
mask_tuple += (mask,)
bbox_targets = self.concat(deltas_tuple)
rcnn_labels = self.concat(labels_tuple)
bbox_targets = F.stop_gradient(bbox_targets)
rcnn_labels = F.stop_gradient(rcnn_labels)
rcnn_labels = self.cast(rcnn_labels, mstype.int32)
else:
mask_tuple += proposal_mask
bbox_targets = proposal_mask
rcnn_labels = proposal_mask
for p_i in proposal:
bboxes_tuple += (p_i[::, 0:4:1],)
if self.training:
if self.train_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
else:
if self.test_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
rois = self.cast(rois, mstype.float32)
rois = F.stop_gradient(rois)
roi_feats = self.roi_align5(x[1], rois)
roi_align4_out = self.roi_align4(x[0], rois)
roi_align4_out = self.cast(roi_align4_out, mstype.float32)
roi_feats = self.cast(roi_feats, mstype.float32)
roi_feats = self.concat1((roi_feats, roi_align4_out))
roi_feats = self.cast(roi_feats, mstype.float16)
roi_feats = self.roi_align_fuse(roi_feats)
roi_feats = self.cast(roi_feats, mstype.float16)
rcnn_masks = self.concat(mask_tuple)
rcnn_masks = F.stop_gradient(rcnn_masks)
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats,
bbox_targets,
rcnn_labels,
rcnn_mask_squeeze)
output = ()
if self.training:
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss)
else:
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas)
return output
def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas):
"""Get the actual detection box."""
scores = self.softmax(cls_logits)
boxes_all = ()
for i in range(self.num_classes):
k = i * 4
reg_logits_i = self.squeeze(reg_logits[::, k:k + 4:1])
out_boxes_i = self.decode(rois, reg_logits_i)
boxes_all += (out_boxes_i,)
# img_metas_all = self.split(img_metas)
scores_all = self.split(scores)
mask_all = self.split(self.cast(mask_logits, mstype.int32))
boxes_all_with_batchsize = ()
for i in range(self.test_batch_size):
# scale = self.split_shape(self.squeeze(img_metas_all[i]))
# scale_h = scale[2]
# scale_w = scale[3]
boxes_tuple = ()
for j in range(self.num_classes):
boxes_tmp = self.split(boxes_all[j])
out_boxes_h = boxes_tmp[i] / 1
out_boxes_w = boxes_tmp[i] / 1
boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),)
boxes_all_with_batchsize += (boxes_tuple,)
output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all)
return output
def multiclass_nms(self, boxes_all, scores_all, mask_all):
"""Multiscale postprocessing."""
all_bboxes = ()
all_labels = ()
all_masks = ()
for i in range(self.test_batch_size):
bboxes = boxes_all[i]
scores = scores_all[i]
masks = self.cast(mask_all[i], mstype.bool_)
res_boxes_tuple = ()
res_labels_tuple = ()
res_masks_tuple = ()
for j in range(self.num_classes - 1):
k = j + 1
_cls_scores = scores[::, k:k + 1:1]
_bboxes = self.squeeze(bboxes[k])
_mask_o = self.reshape(masks, (self.rpn_max_num, 1))
cls_mask = self.greater(_cls_scores, self.test_score_thresh)
_mask = self.logicand(_mask_o, cls_mask)
_reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_)
_bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros)
_cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros)
__cls_scores = self.squeeze(_cls_scores)
scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num)
topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1))
scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1))
_bboxes_sorted = self.gather(_bboxes, topk_inds)
_mask_sorted = self.gather(_mask, topk_inds)
scores_sorted = self.tile(scores_sorted, (1, 4))
cls_dets = self.concat_1((_bboxes_sorted, scores_sorted))
cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5))
cls_dets, _index, _mask_nms = self.nms_test(cls_dets)
_index = self.reshape(_index, (self.rpn_max_num, 1))
_mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1))
_mask_n = self.gather(_mask_sorted, _index)
_mask_n = self.logicand(_mask_n, _mask_nms)
cls_labels = self.oneslike(_index) * j
res_boxes_tuple += (cls_dets,)
res_labels_tuple += (cls_labels,)
res_masks_tuple += (_mask_n,)
res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start])
res_labels_start = self.concat(res_labels_tuple[:self.concat_start])
res_masks_start = self.concat(res_masks_tuple[:self.concat_start])
res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end])
res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end])
res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end])
res_boxes = self.concat((res_boxes_start, res_boxes_end))
res_labels = self.concat((res_labels_start, res_labels_end))
res_masks = self.concat((res_masks_start, res_masks_end))
reshape_size = (self.num_classes - 1) * self.rpn_max_num
if self.use_ambigous_sample:
res_boxes = res_boxes_tuple[0]
res_labels = res_labels_tuple[0]
res_masks = res_masks_tuple[0]
reshape_size = self.rpn_max_num
res_boxes = self.reshape(res_boxes, (1, reshape_size, 5))
res_labels = self.reshape(res_labels, (1, reshape_size, 1))
res_masks = self.reshape(res_masks, (1, reshape_size, 1))
all_bboxes += (res_boxes,)
all_labels += (res_labels,)
all_masks += (res_masks,)
all_bboxes = self.concat(all_bboxes)
all_labels = self.concat(all_labels)
all_masks = self.concat(all_masks)
return all_bboxes, all_labels, all_masks
def get_anchors(self, featmap_sizes):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = ()
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors += (Tensor(anchors.astype(np.float16)),)
return multi_level_anchors

View File

@ -0,0 +1,199 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext proposal generator."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Proposal(nn.Cell):
"""
Proposal subnet.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_classes (int) - Class number.
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
Returns:
Tuple, tuple of output tensor,(proposal, mask).
Examples:
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
"""
def __init__(self,
config,
batch_size,
num_classes,
use_sigmoid_cls,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)
):
super(Proposal, self).__init__()
cfg = config
self.batch_size = batch_size
self.num_classes = num_classes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
self.activation = P.Sigmoid()
self.reshape_shape = (-1, 1)
else:
self.cls_out_channels = num_classes
self.activation = P.Softmax(axis=1)
self.reshape_shape = (-1, 2)
if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes))
self.num_pre = cfg.rpn_proposal_nms_pre
self.min_box_size = cfg.rpn_proposal_min_bbox_size
self.nms_thr = cfg.rpn_proposal_nms_thr
self.nms_post = cfg.rpn_proposal_nms_post
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
self.max_num = cfg.rpn_proposal_max_num
self.num_levels = len(cfg.anchor_strides)
# Op Define
self.squeeze = P.Squeeze()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.feature_shapes = cfg.feature_shapes
self.transpose_shape = (1, 2, 0)
self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \
means=self.target_means, \
stds=self.target_stds)
self.nms = P.NMSWithMask(self.nms_thr)
self.concat_axis0 = P.Concat(axis=0)
self.concat_axis1 = P.Concat(axis=1)
self.split = P.Split(axis=1, output_num=5)
self.min = P.Minimum()
self.gatherND = P.GatherNd()
self.slice = P.Slice()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
self.tile = P.Tile()
self.set_train_local(config, training=True)
self.multi_10 = Tensor(10.0, mstype.float16)
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
cfg = config
self.topK_stage1 = ()
self.topK_shape = ()
total_max_topk_input = 0
if not self.training_local:
self.num_pre = cfg.rpn_nms_pre
self.min_box_size = cfg.rpn_min_bbox_min_size
self.nms_thr = cfg.rpn_nms_thr
self.nms_post = cfg.rpn_nms_post
self.nms_across_levels = cfg.rpn_nms_across_levels
self.max_num = cfg.rpn_max_num
for shp in self.feature_shapes:
k_num = min(self.num_pre, (shp[0] * shp[1] * 3))
total_max_topk_input += k_num
self.topK_stage1 += (k_num,)
self.topK_shape += ((k_num, 1),)
self.topKv2 = P.TopK(sorted=True)
self.topK_shape_stage2 = (self.max_num, 1)
self.min_float_num = -65536.0
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
proposals_tuple = ()
masks_tuple = ()
for img_id in range(self.batch_size):
cls_score_list = ()
bbox_pred_list = ()
for i in range(self.num_levels):
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id + 1:1, ::, ::, ::])
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id + 1:1, ::, ::, ::])
cls_score_list = cls_score_list + (rpn_cls_score_i,)
bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,)
proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list)
proposals_tuple += (proposals,)
masks_tuple += (masks,)
return proposals_tuple, masks_tuple
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
"""Get proposal boundingbox."""
mlvl_proposals = ()
mlvl_mask = ()
for idx in range(self.num_levels):
rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape)
rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape)
anchors = mlvl_anchors[idx]
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
rpn_cls_score = self.activation(rpn_cls_score)
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16)
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx])
topk_inds = self.reshape(topk_inds, self.topK_shape[idx])
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx])))
proposals, _, mask_valid = self.nms(proposals_decode)
mlvl_proposals = mlvl_proposals + (proposals,)
mlvl_mask = mlvl_mask + (mask_valid,)
proposals = self.concat_axis0(mlvl_proposals)
masks = self.concat_axis0(mlvl_mask)
_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, mstype.float16)
scores_using = self.select(masks, scores, topk_mask)
_, topk_inds = self.topKv2(scores_using, self.max_num)
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
proposals = self.gatherND(proposals, topk_inds)
masks = self.gatherND(masks, topk_inds)
return proposals, masks

View File

@ -0,0 +1,181 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext Rcnn network."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
class DenseNoTranpose(nn.Cell):
"""Dense method"""
def __init__(self, input_channels, output_channels, weight_init):
super(DenseNoTranpose, self).__init__()
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16),
name="weight")
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias")
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
def construct(self, x):
output = self.bias_add(self.matmul(x, self.weight), self.bias)
return output
class Rcnn(nn.Cell):
"""
Rcnn subnet.
Args:
config (dict) - Config.
representation_size (int) - Channels of shared dense.
batch_size (int) - Batchsize.
num_classes (int) - Class number.
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).
Returns:
Tuple, tuple of output tensor.
Examples:
Rcnn(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
"""
def __init__(self,
config,
representation_size,
batch_size,
num_classes,
target_means=(0., 0., 0., 0.),
target_stds=(0.1, 0.1, 0.2, 0.2)
):
super(Rcnn, self).__init__()
cfg = config
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16))
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16))
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
self.target_means = target_means
self.target_stds = target_stds
self.num_classes = num_classes
self.in_channels = cfg.rcnn_in_channels
self.train_batch_size = batch_size
self.test_batch_size = cfg.test_batch_size
self.use_ambigous_sample = cfg.use_ambigous_sample
shape_0 = (self.rcnn_fc_out_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor()
shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor()
self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0)
self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1)
cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1],
dtype=mstype.float16).to_tensor()
reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1],
dtype=mstype.float16).to_tensor()
self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight)
self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight)
self.flatten = P.Flatten()
self.relu = P.ReLU()
self.logicaland = P.LogicalAnd()
self.loss_cls = P.SoftmaxCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(beta=1.0)
self.reshape = P.Reshape()
self.onehot = P.OneHot()
self.greater = P.Greater()
self.equal = P.Equal()
self.cast = P.Cast()
self.sum_loss = P.ReduceSum()
self.tile = P.Tile()
self.expandims = P.ExpandDims()
self.gather = P.GatherNd()
self.argmax = P.ArgMaxWithValue(axis=1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.value = Tensor(1.0, mstype.float16)
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size
if self.use_ambigous_sample:
self.num_bboxes = (
cfg.num_expected_pos_stage2 + cfg.num_expected_amb_stage2 + cfg.num_expected_neg_stage2) * batch_size
rmv_first = np.ones((self.num_bboxes, self.num_classes))
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))
self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size
range_max = np.arange(self.num_bboxes_test).astype(np.int32)
self.range_max = Tensor(range_max)
def construct(self, featuremap, bbox_targets, labels, mask):
x = self.flatten(featuremap)
x = self.relu(self.shared_fc_0(x))
x = self.relu(self.shared_fc_1(x))
x_cls = self.cls_scores(x)
x_reg = self.reg_scores(x)
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
if self.use_ambigous_sample:
bbox_weights = self.cast(self.logicaland(self.equal(labels, 1), mask), mstype.int32) * labels
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)
out = (loss, loss_cls, loss_reg, loss_print)
else:
out = (x_cls, (x_cls / self.value), x_reg, x_cls)
return out
def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights):
"""Loss method."""
loss_print = ()
loss_cls, _ = self.loss_cls(cls_score, labels)
weights = self.cast(weights, mstype.float16)
loss_cls = loss_cls * weights
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
if not self.use_ambigous_sample:
bbox_weights = bbox_weights * self.rmv_first_tensor
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
loss_reg = self.sum_loss(loss_reg, (2,))
loss_reg = loss_reg * bbox_weights
loss_reg = loss_reg / self.sum_loss(weights, (0,))
loss_reg = self.sum_loss(loss_reg, (0, 1))
loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg
loss_print += (loss_cls, loss_reg)
return loss, loss_cls, loss_reg, loss_print

View File

@ -0,0 +1,181 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext ROIAlign module."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.nn import layer as L
from mindspore.common.tensor import Tensor
class ROIAlign(nn.Cell):
"""
Extract RoI features from mulitple feature map.
Args:
out_size_h (int) - RoI height.
out_size_w (int) - RoI width.
spatial_scale (int) - RoI spatial scale.
sample_num (int) - RoI sample number.
"""
def __init__(self,
out_size_h,
out_size_w,
spatial_scale,
sample_num=0):
super(ROIAlign, self).__init__()
self.out_size = (out_size_h, out_size_w)
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1],
self.spatial_scale, self.sample_num)
def construct(self, features, rois):
return self.align_op(features, rois)
def __repr__(self):
format_str = self.__class__.__name__
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format(
self.out_size, self.spatial_scale, self.sample_num)
return format_str
class SingleRoIExtractor(nn.Cell):
"""
Extract RoI features from a single level feature map.
If there are mulitple input feature levels, each RoI is mapped to a level
according to its scale.
Args:
config (dict): Config
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
batch_size (int) Batchsize.
finest_scale (int): Scale threshold of mapping to level 0.
"""
def __init__(self,
config,
roi_layer,
out_channels,
featmap_strides,
batch_size=1,
finest_scale=56):
super(SingleRoIExtractor, self).__init__()
cfg = config
self.train_batch_size = batch_size
self.out_channels = out_channels
self.featmap_strides = featmap_strides
self.num_levels = len(self.featmap_strides)
self.out_size = roi_layer['out_size']
self.sample_num = roi_layer['sample_num']
self.roi_layers = self.build_roi_layers(self.featmap_strides)
self.roi_layers = L.CellList(self.roi_layers)
self.sqrt = P.Sqrt()
self.log = P.Log()
self.finest_scale_ = finest_scale
self.clamp = C.clip_by_value
self.cast = P.Cast()
self.equal = P.Equal()
self.select = P.Select()
_mode_16 = False
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32
self.set_train_local(cfg, training=True)
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
cfg = config
# Init tensor
self.batch_size = cfg.roi_sample_num if self.training_local else cfg.rpn_max_num
self.batch_size = self.train_batch_size * self.batch_size \
if self.training_local else cfg.test_batch_size * self.batch_size
self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype))
finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_
self.finest_scale = Tensor(finest_scale)
self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.dtype(1e-6))
self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32))
self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32) * (self.num_levels - 1))
self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2)
self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels,
self.out_size, self.out_size)), dtype=self.dtype))
def num_inputs(self):
return len(self.featmap_strides)
def init_weights(self):
pass
def log2(self, value):
return self.log(value) / self.log(self.twos)
def build_roi_layers(self, featmap_strides):
roi_layers = []
for s in featmap_strides:
layer_cls = ROIAlign(self.out_size, self.out_size,
spatial_scale=1 / s,
sample_num=self.sample_num)
roi_layers.append(layer_cls)
return roi_layers
def _c_map_roi_levels(self, rois):
"""Map rois to corresponding feature levels by scales.
- scale < finest_scale * 2: level 0
- finest_scale * 2 <= scale < finest_scale * 4: level 1
- finest_scale * 4 <= scale < finest_scale * 8: level 2
- scale >= finest_scale * 8: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
"""
scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \
self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones)
target_lvls = self.log2(scale / self.finest_scale + self.epslion)
target_lvls = P.Floor()(target_lvls)
target_lvls = self.cast(target_lvls, mstype.int32)
target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels)
return target_lvls
def construct(self, rois, feat1, feat2):
feats = (feat1, feat2)
res = self.res_
target_lvls = self._c_map_roi_levels(rois)
for i in range(self.num_levels):
mask = self.equal(target_lvls, P.ScalarToArray()(i))
mask = P.Reshape()(mask, (-1, 1, 1, 1))
roi_feats_t = self.roi_layers[i](feats[i], rois)
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, 7, 7)), mstype.bool_)
res = self.select(mask, roi_feats_t, res)
return res

View File

@ -0,0 +1,332 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RPN for deeptext"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from .bbox_assign_sample import BboxAssignSample
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shp_weight_conv = (out_channels, in_channels, kernel_size, kernel_size)
shp_bias_conv = (out_channels,)
weights = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor()
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor()
layers = []
layers += [nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias_conv)]
return nn.SequentialCell(layers)
class RpnRegClsBlock(nn.Cell):
"""
Rpn reg cls block for rpn layer
Args:
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
weight_conv (Tensor) - weight init for rpn conv.
bias_conv (Tensor) - bias init for rpn conv.
weight_cls (Tensor) - weight init for rpn cls conv.
bias_cls (Tensor) - bias init for rpn cls conv.
weight_reg (Tensor) - weight init for rpn reg conv.
bias_reg (Tensor) - bias init for rpn reg conv.
Returns:
Tensor, output tensor.
"""
def __init__(self,
in_channels,
feat_channels,
num_anchors,
cls_out_channels,
weight_conv,
bias_conv,
weight_cls,
bias_cls,
weight_reg,
bias_reg):
super(RpnRegClsBlock, self).__init__()
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same',
has_bias=True, weight_init=weight_conv, bias_init=bias_conv)
self.relu = nn.ReLU()
self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid',
has_bias=True, weight_init=weight_cls, bias_init=bias_cls)
self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid',
has_bias=True, weight_init=weight_reg, bias_init=bias_reg)
self.rpn_conv1x1 = _conv(in_channels=in_channels, out_channels=128, kernel_size=1, stride=1, padding=0)
self.rpn_conv3x3 = _conv(in_channels=in_channels, out_channels=384, kernel_size=3, stride=1, padding=1)
self.rpn_conv5x5 = _conv(in_channels=in_channels, out_channels=128, kernel_size=5, stride=1, padding=2)
self.rpn_output = P.Concat(axis=1)
def construct(self, x):
x1 = self.rpn_conv1x1(x)
x2 = self.rpn_conv3x3(x)
x3 = self.rpn_conv5x5(x)
x = self.relu(self.rpn_output((x1, x2, x3)))
x1 = self.rpn_cls(x)
x2 = self.rpn_reg(x)
return x1, x2
class RPN(nn.Cell):
"""
ROI proposal network..
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
Tuple, tuple of output tensor.
Examples:
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
num_anchors=3, cls_out_channels=512)
"""
def __init__(self,
config,
batch_size,
in_channels,
feat_channels,
num_anchors,
cls_out_channels):
super(RPN, self).__init__()
cfg_rpn = config
self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = ()
self.feature_anchor_shape = ()
self.slice_index += (0,)
index = 0
for shape in cfg_rpn.feature_shapes:
self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,)
self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,)
index += 1
self.num_anchors = num_anchors
self.batch_size = batch_size
self.test_batch_size = cfg_rpn.test_batch_size
self.num_layers = 1
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels,
num_anchors, cls_out_channels))
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=0)
self.fill = P.Fill()
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
self.trans_shape = (0, 2, 3, 1)
self.reshape_shape_reg = (-1, 4)
self.reshape_shape_cls = (-1,)
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
self.num_bboxes = cfg_rpn.num_bboxes
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
self.CheckValid = P.CheckValid()
self.sum_loss = P.ReduceSum()
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(beta=1.0 / 9.0)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
"""
make rpn layer for rpn proposal network
Args:
num_layers (int) - layer num.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
List, list of RpnRegClsBlock cells.
"""
rpn_layer = []
shp_weight_conv = (feat_channels, in_channels, 3, 3)
shp_bias_conv = (feat_channels,)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor()
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor()
shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
shp_bias_cls = (num_anchors * cls_out_channels,)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor()
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor()
shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
shp_bias_reg = (num_anchors * 4,)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor()
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor()
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg))
rpn_layer[0].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
rpn_layer[0].rpn_cls.weight = rpn_layer[0].rpn_cls.weight
rpn_layer[0].rpn_reg.weight = rpn_layer[0].rpn_reg.weight
rpn_layer[0].rpn_conv.bias = rpn_layer[0].rpn_conv.bias
rpn_layer[0].rpn_cls.bias = rpn_layer[0].rpn_cls.bias
rpn_layer[0].rpn_reg.bias = rpn_layer[0].rpn_reg.bias
return rpn_layer
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
loss_print = ()
rpn_cls_score = ()
rpn_bbox_pred = ()
rpn_cls_score_total = ()
rpn_bbox_pred_total = ()
x1, x2 = self.rpn_convs_list[0](inputs[1])
rpn_cls_score_total = rpn_cls_score_total + (x1,)
rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,)
x1 = self.transpose(x1, self.trans_shape)
x1 = self.reshape(x1, self.reshape_shape_cls)
x2 = self.transpose(x2, self.trans_shape)
x2 = self.reshape(x2, self.reshape_shape_reg)
rpn_cls_score = rpn_cls_score + (x1,)
rpn_bbox_pred = rpn_bbox_pred + (x2,)
loss = self.loss
clsloss = self.clsloss
regloss = self.regloss
bbox_targets = ()
bbox_weights = ()
labels = ()
label_weights = ()
output = ()
if self.training:
for i in range(self.batch_size):
multi_level_flags = ()
anchor_list_tuple = ()
res = self.cast(self.CheckValid(anchor_list[0], self.squeeze(img_metas[i:i + 1:1, ::])),
mstype.int32)
multi_level_flags = multi_level_flags + (res,)
anchor_list_tuple = anchor_list_tuple + (anchor_list[0],)
valid_flag_list = self.concat(multi_level_flags)
anchor_using_list = self.concat(anchor_list_tuple)
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
gt_labels_i,
self.cast(valid_flag_list,
mstype.bool_),
anchor_using_list, gt_valids_i)
bbox_weight = self.cast(bbox_weight, mstype.float16)
label = self.cast(label, mstype.float16)
label_weight = self.cast(label_weight, mstype.float16)
begin = self.slice_index[0]
end = self.slice_index[0 + 1]
stride = 1
bbox_targets += (bbox_target[begin:end:stride, ::],)
bbox_weights += (bbox_weight[begin:end:stride],)
labels += (label[begin:end:stride],)
label_weights += (label_weight[begin:end:stride],)
bbox_target_using = ()
bbox_weight_using = ()
label_using = ()
label_weight_using = ()
for j in range(self.batch_size):
bbox_target_using += (bbox_targets[0 + (self.num_layers * j)],)
bbox_weight_using += (bbox_weights[0 + (self.num_layers * j)],)
label_using += (labels[0 + (self.num_layers * j)],)
label_weight_using += (label_weights[0 + (self.num_layers * j)],)
bbox_target_with_batchsize = self.concat(bbox_target_using)
bbox_weight_with_batchsize = self.concat(bbox_weight_using)
label_with_batchsize = self.concat(label_using)
label_weight_with_batchsize = self.concat(label_weight_using)
# stop
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
cls_score_i = rpn_cls_score[0]
reg_score_i = rpn_bbox_pred[0]
loss_cls = self.loss_cls(cls_score_i, label_)
loss_cls_item = loss_cls * label_weight_
loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total
loss_reg = self.loss_bbox(reg_score_i, bbox_target_)
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[0], 1)), (1, 4))
loss_reg = loss_reg * bbox_weight_
loss_reg_item = self.sum_loss(loss_reg, (1,))
loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total
loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item
loss += loss_total
loss_print += (loss_total, loss_cls_item, loss_reg_item)
clsloss += loss_cls_item
regloss += loss_reg_item
output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print)
else:
output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1)
return output

View File

@ -0,0 +1,104 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
# """VGG16 for deeptext"""
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
# shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = 'ones'
layers = []
layers += [nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)]
layers += [nn.BatchNorm2d(out_channels)]
return nn.SequentialCell(layers)
class VGG16FeatureExtraction(nn.Cell):
"""VGG16FeatureExtraction for deeptext"""
def __init__(self):
super(VGG16FeatureExtraction, self).__init__()
self.relu = nn.ReLU()
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv1_1 = _conv(in_channels=3, out_channels=64, kernel_size=3, padding=1)
self.conv1_2 = _conv(in_channels=64, out_channels=64, kernel_size=3, padding=1)
self.conv2_1 = _conv(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.conv2_2 = _conv(in_channels=128, out_channels=128, kernel_size=3, padding=1)
self.conv3_1 = _conv(in_channels=128, out_channels=256, kernel_size=3, padding=1)
self.conv3_2 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv3_3 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv4_1 = _conv(in_channels=256, out_channels=512, kernel_size=3, padding=1)
self.conv4_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv4_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_1 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.cast = P.Cast()
def construct(self, x):
x = self.cast(x, mstype.float32)
x = self.conv1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
x = self.relu(x)
x = self.max_pool(x)
f1 = x
x = self.conv2_1(x)
x = self.relu(x)
x = self.conv2_2(x)
x = self.relu(x)
x = self.max_pool(x)
f2 = x
x = self.conv3_1(x)
x = self.relu(x)
x = self.conv3_2(x)
x = self.relu(x)
x = self.conv3_3(x)
x = self.relu(x)
x = self.max_pool(x)
f3 = x
x = self.conv4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.relu(x)
x = self.conv4_3(x)
x = self.relu(x)
f4 = x
x = self.max_pool(x)
x = self.conv5_1(x)
x = self.relu(x)
x = self.conv5_2(x)
x = self.relu(x)
x = self.conv5_3(x)
x = self.relu(x)
f5 = x
return f1, f2, f3, f4, f5

View File

@ -0,0 +1,130 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# " :===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"img_width": 960,
"img_height": 576,
"keep_ratio": False,
"flip_ratio": 0.0,
"photo_ratio": 0.0,
"expand_ratio": 0.3,
# anchor
"feature_shapes": [(36, 60)],
"anchor_scales": [2, 4, 6, 8, 12],
"anchor_ratios": [0.2, 0.5, 0.8, 1.0, 1.2, 1.5],
"anchor_strides": [16],
"num_anchors": 5 * 6,
# rpn
"rpn_in_channels": 512,
"rpn_feat_channels": 640,
"rpn_loss_cls_weight": 1.0,
"rpn_loss_reg_weight": 3.0,
"rpn_cls_out_channels": 1,
"rpn_target_means": [0., 0., 0., 0.],
"rpn_target_stds": [1.0, 1.0, 1.0, 1.0],
# bbox_assign_sampler
"neg_iou_thr": 0.3,
"pos_iou_thr": 0.5,
"min_pos_iou": 0.3,
"num_bboxes": 5 * 6 * 36 * 60,
"num_gts": 128,
"num_expected_neg": 256,
"num_expected_pos": 128,
# proposal
"activate_num_classes": 2,
"use_sigmoid_cls": True,
# roi_align
"roi_layer": dict(type='RoIAlign', out_size=7, sample_num=2),
# bbox_assign_sampler_stage2
"neg_iou_thr_stage2": 0.2,
"pos_iou_thr_stage2": 0.5,
"min_pos_iou_stage2": 0.5,
"num_bboxes_stage2": 2000,
"use_ambigous_sample": True,
"num_expected_pos_stage2": 128,
"num_expected_amb_stage2": 128,
"num_expected_neg_stage2": 640,
"num_expected_total_stage2": 640,
# rcnn
"rcnn_in_channels": 512,
"rcnn_fc_out_channels": 4096,
"rcnn_loss_cls_weight": 1,
"rcnn_loss_reg_weight": 1,
"rcnn_target_means": [0., 0., 0., 0.],
"rcnn_target_stds": [0.1, 0.1, 0.2, 0.2],
# train proposal
"rpn_proposal_nms_across_levels": False,
"rpn_proposal_nms_pre": 2000,
"rpn_proposal_nms_post": 2000,
"rpn_proposal_max_num": 2000,
"rpn_proposal_nms_thr": 0.7,
"rpn_proposal_min_bbox_size": 0,
# test proposal
"rpn_nms_across_levels": False,
"rpn_nms_pre": 300,
"rpn_nms_post": 300,
"rpn_max_num": 300,
"rpn_nms_thr": 0.7,
"rpn_min_bbox_min_size": 0,
"test_score_thr": 0.95,
"test_iou_thr": 0.5,
"test_max_per_img": 100,
"test_batch_size": 2,
"rpn_head_loss_type": "CrossEntropyLoss",
"rpn_head_use_sigmoid": True,
"rpn_head_weight": 1.0,
# LR
"base_lr": 0.02,
"base_step": 982 * 8,
"total_epoch": 70,
"warmup_step": 50,
"warmup_mode": "linear",
"warmup_ratio": 1 / 3.0,
"sgd_step": [8, 11],
"sgd_momentum": 0.9,
# train
"batch_size": 2,
"loss_scale": 1,
"momentum": 0.91,
"weight_decay": 1e-4,
"epoch_size": 70,
"save_checkpoint": True,
"save_checkpoint_epochs": 10,
"keep_checkpoint_max": 5,
"save_checkpoint_path": "./",
"mindrecord_dir": "/home/deeptext_sustech/data/mindrecord/full_ori",
"use_coco": True,
"coco_root": "/d0/dataset/coco2017",
"cocotext_json": "/home/deeptext_sustech/data/cocotext.v2.json",
"coco_train_data_type": "train2017",
"num_classes": 3
})

View File

@ -0,0 +1,504 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext dataset"""
from __future__ import division
import os
import numpy as np
from numpy import random
import mmcv
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as CC
import mindspore.common.dtype as mstype
from mindspore.mindrecord import FileWriter
from src.config import config
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
"""Calculate the ious between each bbox of bboxes1 and bboxes2.
Args:
bboxes1(ndarray): shape (n, 4)
bboxes2(ndarray): shape (k, 4)
mode(str): iou (intersection over union) or iof (intersection
over foreground)
Returns:
ious(ndarray): shape (n, k)
"""
assert mode in ['iou', 'iof']
bboxes1 = bboxes1.astype(np.float32)
bboxes2 = bboxes2.astype(np.float32)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
ious = np.zeros((rows, cols), dtype=np.float32)
if rows * cols == 0:
return ious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
ious = np.zeros((cols, rows), dtype=np.float32)
exchange = True
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1)
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1)
for i in range(bboxes1.shape[0]):
x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
y_end - y_start + 1, 0)
if mode == 'iou':
union = area1[i] + area2 - overlap
else:
union = area1[i] if not exchange else area2
ious[i, :] = overlap / union
if exchange:
ious = ious.T
return ious
class PhotoMetricDistortion:
"""Photo Metric Distortion"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, img, boxes, labels):
# random brightness
img = img.astype('float32')
if random.randint(2):
delta = random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
return img, boxes, labels
class Expand:
"""expand image"""
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
if to_rgb:
self.mean = mean[::-1]
else:
self.mean = mean
self.min_ratio, self.max_ratio = ratio_range
def __call__(self, img, boxes, labels):
if random.randint(2):
return img, boxes, labels
h, w, c = img.shape
ratio = random.uniform(self.min_ratio, self.max_ratio)
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean).astype(img.dtype)
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
expand_img[top:top + h, left:left + w] = img
img = expand_img
boxes += np.tile((left, top), 2)
return img, boxes, labels
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""resize operation for image"""
img_data = img
img_data, w_scale, h_scale = mmcv.imresize(
img_data, (config.img_width, config.img_height), return_scale=True)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
img_shape = (config.img_height, config.img_width, 1.0)
img_shape = np.asarray(img_shape, dtype=np.float32)
gt_bboxes = gt_bboxes * scale_factor
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
"""resize operation for image of eval"""
img_data = img
img_data, w_scale, h_scale = mmcv.imresize(
img_data, (config.img_width, config.img_height), return_scale=True)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
img_shape = (config.img_height, config.img_width)
img_shape = np.append(img_shape, (h_scale, w_scale))
img_shape = np.asarray(img_shape, dtype=np.float32)
gt_bboxes = gt_bboxes * scale_factor
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""impad operation for image"""
img_data = mmcv.impad(img, (config.img_height, config.img_width))
img_data = img_data.astype(np.float32)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""imnormalize operation for image"""
img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
img_data = img_data.astype(np.float32)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""flip operation for image"""
img_data = img
img_data = mmcv.imflip(img_data)
flipped = gt_bboxes.copy()
_, w, _ = img_data.shape
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
return (img_data, img_shape, flipped, gt_label, gt_num)
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num):
"""flipped generation"""
img_data = img
flipped = gt_bboxes.copy()
_, w, _ = img_data.shape
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
return (img_data, img_shape, flipped, gt_label, gt_num)
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num):
img_data = img[:, :, ::-1]
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""transpose operation for image"""
img_data = img.transpose(2, 0, 1).copy()
img_data = img_data.astype(np.float16)
img_shape = img_shape.astype(np.float16)
gt_bboxes = gt_bboxes.astype(np.float16)
gt_label = gt_label.astype(np.int32)
gt_num = gt_num.astype(np.bool)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""photo crop operation for image"""
random_photo = PhotoMetricDistortion()
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""expand operation for image"""
expand = Expand()
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
return (img, img_shape, gt_bboxes, gt_label, gt_num)
def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
image_shape = image_shape[:2]
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
input_data = resize_column_test(*input_data)
input_data = image_bgr_rgb(*input_data)
output_data = input_data
return output_data
def _data_aug(image, box, is_training):
"""Data augmentation function."""
image_bgr = image.copy()
image_bgr[:, :, 0] = image[:, :, 2]
image_bgr[:, :, 1] = image[:, :, 1]
image_bgr[:, :, 2] = image[:, :, 0]
image_shape = image_bgr.shape[:2]
gt_box = box[:, :4]
gt_label = box[:, 4]
gt_iscrowd = box[:, 5]
pad_max_number = 128
if box.shape[0] < 128:
gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant",
constant_values=0)
gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1)
gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant",
constant_values=1)
else:
gt_box_new = gt_box[0:pad_max_number]
gt_label_new = gt_label[0:pad_max_number]
gt_iscrowd_new = gt_iscrowd[0:pad_max_number]
gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32)
if not is_training:
return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
expand = (np.random.rand() < config.expand_ratio)
if expand:
input_data = expand_column(*input_data)
input_data = photo_crop_column(*input_data)
input_data = resize_column(*input_data)
input_data = image_bgr_rgb(*input_data)
output_data = input_data
return output_data
return _data_aug(image, box, is_training)
def create_label(is_training):
"""Create image label."""
image_files = []
image_anno_dict = {}
if is_training:
img_dirs = config.train_images.split(',')
txt_dirs = config.train_txts.split(',')
else:
img_dirs = config.test_images.split(',')
txt_dirs = config.test_txts.split(',')
for img_dir, txt_dir in zip(img_dirs, txt_dirs):
img_basenames = []
for file in os.listdir(img_dir):
# Filter git file.
if 'gif' not in file:
img_basenames.append(os.path.basename(file))
img_names = []
for item in img_basenames:
temp1, _ = os.path.splitext(item)
img_names.append((temp1, item))
for img, img_basename in img_names:
image_path = img_dir + '/' + img_basename
annos = []
# Parse annotation of dataset in paper.
if len(img) == 6 and '_' not in img_basename:
gt = open(txt_dir + '/' + img + '.txt').read().splitlines()
if img.isdigit() and int(img) > 1200:
continue
for img_each_label in gt:
spt = img_each_label.replace(',', '').split(' ')
if ' ' not in img_each_label:
spt = img_each_label.split(',')
annos.append(
[spt[0], spt[1], str(int(spt[0]) + int(spt[2])), str(int(spt[1]) + int(spt[3]))] + [1] + [
int(0)])
else:
anno_file = txt_dir + '/gt_img_' + img.split('_')[-1] + '.txt'
if not os.path.exists(anno_file):
anno_file = txt_dir + '/gt_' + img.split('_')[-1] + '.txt'
if not os.path.exists(anno_file):
anno_file = txt_dir + '/img_' + img.split('_')[-1] + '.txt'
gt = open(anno_file).read().splitlines()
for img_each_label in gt:
spt = img_each_label.replace(',', '').split(' ')
if ' ' not in img_each_label:
spt = img_each_label.split(',')
annos.append([spt[0], spt[1], spt[2], spt[3]] + [1] + [int(0)])
image_files.append(image_path)
if annos:
image_anno_dict[image_path] = np.array(annos)
else:
image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
if is_training and config.use_coco:
coco_root = config.coco_root
data_type = config.coco_train_data_type
from src.coco_text import COCO_Text
anno_json = config.cocotext_json
ct = COCO_Text(anno_json)
image_ids = ct.getImgIds(imgIds=ct.train,
catIds=[('legibility', 'legible')])
for img_id in image_ids:
image_info = ct.loadImgs(img_id)[0]
file_name = image_info['file_name'][15:]
anno_ids = ct.getAnnIds(imgIds=img_id)
anno = ct.loadAnns(anno_ids)
image_path = os.path.join(coco_root, data_type, file_name)
annos = []
for label in anno:
# if label["utf8_string"] != '':
bbox = label["bbox"]
x1, x2 = bbox[0], bbox[0] + bbox[2]
y1, y2 = bbox[1], bbox[1] + bbox[3]
annos.append([x1, y1, x2, y2] + [1] + [int(0)])
image_files.append(image_path)
if annos:
image_anno_dict[image_path] = np.array(annos)
else:
image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
return image_files, image_anno_dict
def anno_parser(annos_str):
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos
def data_to_mindrecord_byte_image(is_training=True, prefix="deeptext.mindrecord", file_num=8):
"""Create MindRecord file."""
mindrecord_dir = config.mindrecord_dir
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
image_files, image_anno_dict = create_label(is_training)
deeptext_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 6]},
}
writer.add_schema(deeptext_json, "deeptext_json")
for image_name in image_files:
with open(image_name, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def create_deeptext_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=4):
"""Creatr deeptext dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
num_parallel_workers=1, shuffle=is_training)
decode = C.Decode()
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
horizontally_op = C.RandomHorizontalFlip(1)
type_cast0 = CC.TypeCast(mstype.float32)
type_cast1 = CC.TypeCast(mstype.float16)
type_cast2 = CC.TypeCast(mstype.int32)
type_cast3 = CC.TypeCast(mstype.bool_)
if is_training:
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)
flip = (np.random.rand() < config.flip_ratio)
if flip:
ds = ds.map(operations=[normalize_op, type_cast0, horizontally_op], input_columns=["image"],
num_parallel_workers=12)
ds = ds.map(operations=flipped_generation,
input_columns=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)
else:
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
num_parallel_workers=12)
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=12)
else:
ds = ds.map(operations=compose_map_func,
input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=24)
# transpose_column from python to c
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
ds = ds.map(operations=[type_cast1], input_columns=["box"])
ds = ds.map(operations=[type_cast2], input_columns=["label"])
ds = ds.map(operations=[type_cast3], input_columns=["valid_num"])
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for deeptext"""
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, rank_size=1):
"""dynamic learning rate generator"""
base_lr = config.base_lr
base_step = (config.base_step // rank_size) + rank_size
total_steps = int(base_step * config.total_epoch)
warmup_steps = int(config.warmup_step)
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

View File

@ -0,0 +1,188 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Deeptext training network wrapper."""
import time
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = time.time()
time_stamp_init = True
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rcnn_loss = cb_params.net_outputs[1].asnumpy()
rpn_cls_loss = cb_params.net_outputs[2].asnumpy()
rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rcnn_loss_sum += float(rcnn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
self.rcnn_reg_loss_sum += float(rcnn_reg_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum / self.count
rcnn_loss = self.rcnn_loss_sum / self.count
rpn_cls_loss = self.rpn_cls_loss_sum / self.count
rpn_reg_loss = self.rpn_reg_loss_sum / self.count
rcnn_cls_loss = self.rcnn_cls_loss_sum / self.count
rcnn_reg_loss = self.rcnn_reg_loss_sum / self.count
total_loss = rpn_loss + rcnn_loss
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
rcnn_cls_loss, rcnn_reg_loss, total_loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
class LossNet(nn.Cell):
"""Deeptext loss method"""
def construct(self, x1, x2, x3, x4, x5, x6):
return x1 + x2
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone
class TrainOneStepCell(nn.Cell):
"""
Network training package class.
Append an optimizer to the training network after that the construct function
can be called to create the backward graph.
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
self.reduce_flag = reduce_flag
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
weights = self.weights
loss1, loss2, loss3, loss4, loss5, loss6 = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6

View File

@ -0,0 +1,99 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics utils"""
import numpy as np
from src.config import config
def calc_iou(bbox_pred, bbox_ground):
"""Calculate iou of predicted bbox and ground truth."""
x1 = float(bbox_pred[0])
y1 = float(bbox_pred[1])
width1 = float(bbox_pred[2] - bbox_pred[0])
height1 = float(bbox_pred[3] - bbox_pred[1])
x2 = float(bbox_ground[0])
y2 = float(bbox_ground[1])
width2 = float(bbox_ground[2] - bbox_ground[0])
height2 = float(bbox_ground[3] - bbox_ground[1])
endx = max(x1 + width1, x2 + width2)
startx = min(x1, x2)
width = width1 + width2 - (endx - startx)
endy = max(y1 + height1, y2 + height2)
starty = min(y1, y2)
height = height1 + height2 - (endy - starty)
if width <= 0 or height <= 0:
iou = 0
else:
area = width * height
area1 = width1 * height1
area2 = width2 * height2
iou = area * 1. / (area1 + area2 - area)
return iou
def metrics(pred_data):
"""Calculate precision and recall of predicted bboxes."""
num_classes = config.num_classes
count_corrects = [1e-6 for _ in range(num_classes)]
count_grounds = [1e-6 for _ in range(num_classes)]
count_preds = [1e-6 for _ in range(num_classes)]
ious = []
for i, sample in enumerate(pred_data):
gt_bboxes = sample['gt_bboxes']
gt_labels = sample['gt_labels']
print('gt_bboxes', gt_bboxes)
print('gt_labels', gt_labels)
boxes = sample['boxes']
classes = sample['labels']
print('boxes', boxes)
print('labels', classes)
# metric
count_correct = [1e-6 for _ in range(num_classes)]
count_ground = [1e-6 for _ in range(num_classes)]
count_pred = [1e-6 for _ in range(num_classes)]
for gt_label in gt_labels:
count_ground[gt_label] += 1
for box_index, box in enumerate(boxes):
bbox_pred = [box[0], box[1], box[2], box[3]]
count_pred[classes[box_index]] += 1
for gt_index, gt_label in enumerate(gt_labels):
class_ground = gt_label
if classes[box_index] == class_ground:
iou = calc_iou(bbox_pred, gt_bboxes[gt_index])
ious.append(iou)
if iou >= 0.5:
count_correct[class_ground] += 1
break
count_corrects = [count_corrects[i] + count_correct[i] for i in range(num_classes)]
count_preds = [count_preds[i] + count_pred[i] for i in range(num_classes)]
count_grounds = [count_grounds[i] + count_ground[i] for i in range(num_classes)]
precision = np.array([count_corrects[ix] / count_preds[ix] for ix in range(num_classes)])
recall = np.array([count_corrects[ix] / count_grounds[ix] for ix in range(num_classes)])
return precision, recall * config.test_batch_size

View File

@ -0,0 +1,139 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train Deeptext and get checkpoint files."""
import argparse
import ast
import os
import time
import numpy as np
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_deeptext_dataset
from src.lr_schedule import dynamic_lr
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.common import set_seed
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.nn import Momentum
from mindspore.train import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
np.set_printoptions(threshold=np.inf)
set_seed(1)
parser = argparse.ArgumentParser(description="Deeptext training")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: False.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.")
parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.")
parser.add_argument("--device_id", type=int, default=5, help="Device id, default: 5.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
parser.add_argument("--imgs_path", type=str, required=True,
help="Train images files paths, multiple paths can be separated by ','.")
parser.add_argument("--annos_path", type=str, required=True,
help="Annotations files paths of train images, multiple paths can be separated by ','.")
parser.add_argument("--mindrecord_prefix", type=str, default='Deeptext-TRAIN', help="Prefix of mindrecord.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
if __name__ == '__main__':
if args_opt.run_distribute:
rank = args_opt.rank_id
device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1
print("Start create dataset!")
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is DeepText.mindrecord0, 1, ... file_num.
prefix = args_opt.mindrecord_prefix
config.train_images = args_opt.imgs_path
config.train_txts = args_opt.annos_path
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
print("CHECKING MINDRECORD FILES ...")
if rank == 0 and not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if os.path.isdir(config.coco_root):
if not os.path.exists(config.coco_root):
print("Please make sure config:coco_root is valid.")
raise ValueError(config.coco_root)
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image(True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
print("CHECKING MINDRECORD FILES DONE!")
loss_scale = float(config.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
dataset = create_deeptext_dataset(mindrecord_file, repeat_num=1,
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done! dataset_size = ", dataset_size)
net = Deeptext_VGG16(config=config)
net = net.set_train()
load_path = args_opt.pre_trained
if load_path != "":
param_dict = load_checkpoint(load_path)
load_param_into_net(net, param_dict)
loss = LossNet()
lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32)
opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb]
if config.save_checkpoint:
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
keep_checkpoint_max=config.keep_checkpoint_max)
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
ckpoint_cb = ModelCheckpoint(prefix='deeptext', directory=save_checkpoint_path, config=ckptconfig)
cb += [ckpoint_cb]
model = Model(net)
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)