forked from mindspore-Ecosystem/mindspore
commit
f0a9cb7c20
|
@ -0,0 +1,293 @@
|
|||
![](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
# CTPN for Ascend
|
||||
|
||||
- [CTPN Description](#CTPN-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [CTPN Description](#contents)
|
||||
|
||||
CTPN is a text detection model based on object detection method. It improves Faster R-CNN and combines with bidirectional LSTM, so ctpn is very effective for horizontal text detection. Another highlight of ctpn is to transform the text detection task into a series of small-scale text box detection.This idea was proposed in the paper "Detecting Text in Natural Image with Connectionist Text Proposal Network".
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1609.03605.pdf) Zhi Tian, Weilin Huang, Tong He, Pan He, Yu Qiao, "Detecting Text in Natural Image with Connectionist Text Proposal Network", ArXiv, vol. abs/1609.03605, 2016.
|
||||
|
||||
# [Model architecture](#contents)
|
||||
|
||||
The overall network architecture contains a VGG16 as backbone, and use bidirection lstm to extract context feature of the small-scale text box, then it used the RPN(RegionProposal Network) to predict the boundding box and probability.
|
||||
|
||||
[Link](https://arxiv.org/pdf/1605.07314v1.pdf)
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Here we used 6 datasets for training, and 1 datasets for Evaluation.
|
||||
|
||||
- Dataset1: ICDAR 2013: Focused Scene Text
|
||||
- Train: 142MB, 229 images
|
||||
- Test: 110MB, 233 images
|
||||
- Dataset2: ICDAR 2011: Born-Digital Images
|
||||
- Train: 27.7MB, 410 images
|
||||
- Dataset3: ICDAR 2015:
|
||||
- Train:89MB, 1000 images
|
||||
- Dataset4: SCUT-FORU: Flickr OCR Universal Database
|
||||
- Train: 388MB, 1715 images
|
||||
- Dataset5: CocoText v2(Subset of MSCOCO2017):
|
||||
- Train: 13GB, 63686 images
|
||||
- Dataset6: SVT(The Street View Dataset)
|
||||
- Train: 115MB, 349 images
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─ctpn
|
||||
├── README.md # network readme
|
||||
├── eval.py # eval net
|
||||
├── scripts
|
||||
│ ├── eval_res.sh # calculate precision and recall
|
||||
│ ├── run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
|
||||
│ ├── run_eval_ascend.sh # launch evaluating with ascend platform
|
||||
│ └── run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
|
||||
├── src
|
||||
│ ├── CTPN
|
||||
│ │ ├── BoundingBoxDecode.py # bounding box decode
|
||||
│ │ ├── BoundingBoxEncode.py # bounding box encode
|
||||
│ │ ├── __init__.py # package init file
|
||||
│ │ ├── anchor_generator.py # anchor generator
|
||||
│ │ ├── bbox_assign_sample.py # proposal layer
|
||||
│ │ ├── proposal_generator.py # proposla generator
|
||||
│ │ ├── rpn.py # region-proposal network
|
||||
│ │ └── vgg16.py # backbone
|
||||
│ ├── config.py # training configuration
|
||||
│ ├── convert_icdar2015.py # convert icdar2015 dataset label
|
||||
│ ├── convert_svt.py # convert svt label
|
||||
│ ├── create_dataset.py # create mindrecord dataset
|
||||
│ ├── ctpn.py # ctpn network definition
|
||||
│ ├── dataset.py # data proprocessing
|
||||
│ ├── lr_schedule.py # learning rate scheduler
|
||||
│ ├── network_define.py # network definition
|
||||
│ └── text_connector
|
||||
│ ├── __init__.py # package init file
|
||||
│ ├── connect_text_lines.py # connect text lines
|
||||
│ ├── detector.py # detect box
|
||||
│ ├── get_successions.py # get succession proposal
|
||||
│ └── utils.py # some functions which is commonly used
|
||||
└── train.py # train net
|
||||
|
||||
```
|
||||
|
||||
## [Training process](#contents)
|
||||
|
||||
### Dataset
|
||||
|
||||
To create dataset, download the dataset first and deal with it.We provided src/convert_svt.py and src/convert_icdar2015.py to deal with svt and icdar2015 dataset label.For svt dataset, you can deal with it as below:
|
||||
|
||||
```shell
|
||||
python convert_svt.py --dataset_path=/path/img --xml_file=/path/train.xml --location_dir=/path/location
|
||||
```
|
||||
|
||||
For ICDAR2015 dataset, you can deal with it
|
||||
|
||||
```shell
|
||||
python convert_icdar2015.py --src_label_path=/path/train_label --target_label_path=/path/label
|
||||
```
|
||||
|
||||
Then modify the src/config.py and add the dataset path.For each path, add IMAGE_PATH and LABEL_PATH into a list in config.An example is show as blow:
|
||||
|
||||
```python
|
||||
# create dataset
|
||||
"coco_root": "/path/coco",
|
||||
"coco_train_data_type": "train2017",
|
||||
"cocotext_json": "/path/cocotext.v2.json",
|
||||
"icdar11_train_path": ["/path/image/", "/path/label"],
|
||||
"icdar13_train_path": ["/path/image/", "/path/label"],
|
||||
"icdar15_train_path": ["/path/image/", "/path/label"],
|
||||
"icdar13_test_path": ["/path/image/", "/path/label"],
|
||||
"flick_train_path": ["/path/image/", "/path/label"],
|
||||
"svt_train_path": ["/path/image/", "/path/label"],
|
||||
"pretrain_dataset_path": "",
|
||||
"finetune_dataset_path": "",
|
||||
"test_dataset_path": "",
|
||||
```
|
||||
|
||||
Then you can create dataset with src/create_dataset.py with the command as below:
|
||||
|
||||
```shell
|
||||
python src/create_dataset.py
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
- Ascend:
|
||||
|
||||
```bash
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]
|
||||
# standalone training
|
||||
sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]
|
||||
# evaluation:
|
||||
sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).To get the vgg16 backbone, you can use the network structure defined in src/CTPN/vgg16.py.To train the backbone, copy the src/CTPN/vgg16.py under modelzoo/official/cv/vgg16/src/, and modify the vgg16/train.py to suit the new construction.You can fix it as below:
|
||||
|
||||
```python
|
||||
...
|
||||
from src.vgg16 import VGG16
|
||||
...
|
||||
network = VGG16()
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
Then you can train it with ImageNet2012.
|
||||
> Notes:
|
||||
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
>
|
||||
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
|
||||
>
|
||||
> TASK_TYPE contains Pretraining and Finetune. For Pretraining, we use ICDAR2013, ICDAR2015, SVT, SCUT-FORU, CocoText v2. For Finetune, we use ICDAR2011,
|
||||
ICDAR2013, SCUT-FORU to improve precision and recall, and when doing Finetune, we use the checkpoint training in Pretrain as our PRETRAINED_PATH.
|
||||
> COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).
|
||||
>
|
||||
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
# training example
|
||||
shell:
|
||||
Ascend:
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]
|
||||
# standalone training
|
||||
sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log`, also the loss will be redirected to `./loss_0.log` like followings.
|
||||
|
||||
```python
|
||||
377 epoch: 1 step: 229 ,rpn_loss: 0.00355, rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00103,
|
||||
399 epoch: 2 step: 229 ,rpn_loss: 0.00327,rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00093,
|
||||
424 epoch: 3 step: 229 ,rpn_loss: 0.00910, rpn_cls_loss: 0.00385, rpn_reg_loss: 0.00175,
|
||||
```
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- Ascend:
|
||||
|
||||
```bash
|
||||
sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
After eval, you can get serval archive file named submit_ctpn-xx_xxxx.zip, which contains the name of your checkpoint file.To evalulate it, you can use the scripts provided by the ICDAR2013 network, you can download the Deteval scripts from the [link](https://rrc.cvc.uab.es/?com=downloads&action=download&ch=2&f=aHR0cHM6Ly9ycmMuY3ZjLnVhYi5lcy9zdGFuZGFsb25lcy9zY3JpcHRfdGVzdF9jaDJfdDFfZTItMTU3Nzk4MzA2Ny56aXA=)
|
||||
After download the scripts, unzip it and put it under ctpn/scripts and use eval_res.sh to get the result.You will get files as below:
|
||||
|
||||
```text
|
||||
gt.zip
|
||||
readme.txt
|
||||
rrc_evalulation_funcs_1_1.py
|
||||
script.py
|
||||
```
|
||||
|
||||
Then you can run the scripts/eval_res.sh to calculate the evalulation result.
|
||||
|
||||
```base
|
||||
bash eval_res.sh
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Evaluation result will be stored in the example path, you can find result like the followings in `log`.
|
||||
|
||||
```text
|
||||
{"precision": 0.90791, "recall": 0.86118, "hmean": 0.88393}
|
||||
```
|
||||
|
||||
# [Model description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------ |
|
||||
| Model Version | CTPN |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||
| uploaded Date | 02/06/2021 |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | 16930 images |
|
||||
| Batch_size | 2 |
|
||||
| Training Parameters | src/config.py |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits for classification, SmoothL2Loss for bbox regression|
|
||||
| Loss | ~0.04 |
|
||||
| Total time (8p) | 6h |
|
||||
| Scripts | [ctpn script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ctpn) |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | CTPN |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||
| Uploaded Date | 02/06/2020 |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | 229 images |
|
||||
| Batch_size | 1 |
|
||||
| Accuracy | precision=0.9079, recall=0.8611 F-measure:0.8839 |
|
||||
| Total time | 1 min |
|
||||
| Model for inference | 135M (.ckpt file) |
|
||||
|
||||
#### Training performance results
|
||||
|
||||
| **Ascend** | train performance |
|
||||
| :--------: | :---------------: |
|
||||
| 1p | 10 img/s |
|
||||
|
||||
| **Ascend** | train performance |
|
||||
| :--------: | :---------------: |
|
||||
| 8p | 84 img/s |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We set seed to 1 in train.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for CTPN"""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.ctpn import CTPN
|
||||
from src.config import config
|
||||
from src.dataset import create_ctpn_dataset
|
||||
from src.text_connector.detector import detect
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CTPN evaluation")
|
||||
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path.")
|
||||
parser.add_argument("--image_path", type=str, default="", help="Image path.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
|
||||
def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
|
||||
"""ctpn infer."""
|
||||
print("ckpt path is {}".format(ckpt_path))
|
||||
ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False)
|
||||
config.batch_size = config.test_batch_size
|
||||
total = ds.get_dataset_size()
|
||||
print("*************total dataset size is {}".format(total))
|
||||
net = CTPN(config, is_training=False)
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
eval_iter = 0
|
||||
|
||||
print("\n========================================\n")
|
||||
print("Processing, please wait a moment.")
|
||||
img_basenames = []
|
||||
output_dir = os.path.join(os.getcwd(), "submit")
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
for file in os.listdir(img_dir):
|
||||
img_basenames.append(os.path.basename(file))
|
||||
for data in ds.create_dict_iterator():
|
||||
img_data = data['image']
|
||||
img_metas = data['image_shape']
|
||||
gt_bboxes = data['box']
|
||||
gt_labels = data['label']
|
||||
gt_num = data['valid_num']
|
||||
|
||||
start = time.time()
|
||||
# run net
|
||||
output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
|
||||
gt_bboxes = gt_bboxes.asnumpy()
|
||||
gt_labels = gt_labels.asnumpy()
|
||||
gt_num = gt_num.asnumpy().astype(bool)
|
||||
end = time.time()
|
||||
proposal = output[0]
|
||||
proposal_mask = output[1]
|
||||
print("start to draw pic")
|
||||
for j in range(config.test_batch_size):
|
||||
img = img_basenames[config.test_batch_size * eval_iter + j]
|
||||
all_box_tmp = proposal[j].asnumpy()
|
||||
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
|
||||
using_boxes_mask = all_box_tmp * all_mask_tmp
|
||||
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
|
||||
scores = using_boxes_mask[:, 4].astype(np.float32)
|
||||
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
|
||||
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
|
||||
from PIL import Image, ImageDraw
|
||||
im = Image.open(img_dir + '/' + img)
|
||||
draw = ImageDraw.Draw(im)
|
||||
image_h = img_metas.asnumpy()[j][2]
|
||||
image_w = img_metas.asnumpy()[j][3]
|
||||
gt_boxs = gt_bboxes[j][gt_num[j], :]
|
||||
for gt_box in gt_boxs:
|
||||
gt_x1 = gt_box[0] / image_w
|
||||
gt_y1 = gt_box[1] / image_h
|
||||
gt_x2 = gt_box[2] / image_w
|
||||
gt_y2 = gt_box[3] / image_h
|
||||
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
|
||||
fill='green', width=2)
|
||||
file_name = "res_" + img.replace("jpg", "txt")
|
||||
output_file = os.path.join(output_dir, file_name)
|
||||
f = open(output_file, 'w')
|
||||
for bbox in bboxes:
|
||||
x1 = bbox[0] / image_w
|
||||
y1 = bbox[1] / image_h
|
||||
x2 = bbox[2] / image_w
|
||||
y2 = bbox[3] / image_h
|
||||
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
|
||||
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
|
||||
f.write(str_tmp)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
im.save(img)
|
||||
percent = round(eval_iter / total * 100, 2)
|
||||
eval_iter = eval_iter + 1
|
||||
print("Iter {} cost time {}".format(eval_iter, end - start))
|
||||
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
|
||||
|
||||
if __name__ == '__main__':
|
||||
ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path)
|
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
for submit_file in "submit"*.zip
|
||||
do
|
||||
echo "eval result for ${submit_file}"
|
||||
python script.py –g=gt.zip –s=${submit_file} –o=./
|
||||
echo -e ".\n"
|
||||
done
|
|
@ -0,0 +1,67 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
TASK_TYPE=$2
|
||||
PATH2=$(get_real_path $3)
|
||||
echo $PATH2
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --task_type=$TASK_TYPE --pre_trained=$PATH2 &> log &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,80 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
IMAGE_PATH=$(get_real_path $1)
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
CHECKPOINT_PATH=$(get_real_path $3)
|
||||
echo $IMAGE_PATH
|
||||
echo $DATASET_PATH
|
||||
echo $CHECKPOINT_PATH
|
||||
|
||||
if [ ! -d $IMAGE_PATH ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$DATASET_PATH is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $CHECKPOINT_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$CHECKPOINT_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
for file in "${CHECKPOINT_PATH}"/*.ckpt
|
||||
do
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval
|
||||
env > env.log
|
||||
CHECKPOINT_FILE_PATH=$file
|
||||
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
python eval.py --device_id=$DEVICE_ID --image_path=$IMAGE_PATH --dataset_path=$DATASET_PATH --checkpoint_path=$CHECKPOINT_FILE_PATH &> log
|
||||
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
cd ./submit
|
||||
file_base_name=$(basename $file)
|
||||
zip -r ../../submit_${file_base_name%.*}.zip *.txt
|
||||
cd ../../
|
||||
done
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
TASK_TYPE=$1
|
||||
PRETRAINED_PATH=$(get_real_path $2)
|
||||
echo $PRETRAINED_PATH
|
||||
if [ ! -f $PRETRAINED_PATH ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH &> log &
|
||||
cd ..
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class BoundingBoxDecode(nn.Cell):
|
||||
"""
|
||||
BoundintBox Decoder.
|
||||
|
||||
Returns:
|
||||
pred_box(Tensor): decoder bounding boxes.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BoundingBoxDecode, self).__init__()
|
||||
self.split = P.Split(axis=1, output_num=4)
|
||||
self.ones = 1.0
|
||||
self.half = 0.5
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, bboxes, deltas):
|
||||
"""
|
||||
boxes(Tensor): boundingbox.
|
||||
deltas(Tensor): delta between boundingboxs and anchors.
|
||||
"""
|
||||
x1, y1, x2, y2 = self.split(bboxes)
|
||||
width = x2 - x1 + self.ones
|
||||
height = y2 - y1 + self.ones
|
||||
ctr_x = x1 + self.half * width
|
||||
ctr_y = y1 + self.half * height
|
||||
_, dy, _, dh = self.split(deltas)
|
||||
pred_ctr_x = ctr_x
|
||||
pred_ctr_y = dy * height + ctr_y
|
||||
pred_w = width
|
||||
pred_h = self.exp(dh) * height
|
||||
|
||||
x1 = pred_ctr_x - self.half * pred_w
|
||||
y1 = pred_ctr_y - self.half * pred_h
|
||||
x2 = pred_ctr_x + self.half * pred_w
|
||||
y2 = pred_ctr_y + self.half * pred_h
|
||||
pred_box = self.concat((x1, y1, x2, y2))
|
||||
return pred_box
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class BoundingBoxEncode(nn.Cell):
|
||||
"""
|
||||
BoundintBox Decoder.
|
||||
|
||||
Returns:
|
||||
pred_box(Tensor): decoder bounding boxes.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BoundingBoxEncode, self).__init__()
|
||||
self.split = P.Split(axis=1, output_num=4)
|
||||
self.ones = 1.0
|
||||
self.half = 0.5
|
||||
self.log = P.Log()
|
||||
self.concat = P.Concat(axis=1)
|
||||
def construct(self, anchor_box, gt_box):
|
||||
"""
|
||||
boxes(Tensor): boundingbox.
|
||||
deltas(Tensor): delta between boundingboxs and anchors.
|
||||
"""
|
||||
x1, y1, x2, y2 = self.split(anchor_box)
|
||||
width = x2 - x1 + self.ones
|
||||
height = y2 - y1 + self.ones
|
||||
ctr_x = x1 + self.half * width
|
||||
ctr_y = y1 + self.half * height
|
||||
gt_x1, gt_y1, gt_x2, gt_y2 = self.split(gt_box)
|
||||
gt_width = gt_x2 - gt_x1 + self.ones
|
||||
gt_height = gt_y2 - gt_y1 + self.ones
|
||||
ctr_gt_x = gt_x1 + self.half * gt_width
|
||||
ctr_gt_y = gt_y1 + self.half * gt_height
|
||||
|
||||
target_dx = (ctr_gt_x - ctr_x) / width
|
||||
target_dy = (ctr_gt_y - ctr_y) / height
|
||||
dw = gt_width / width
|
||||
dh = gt_height / height
|
||||
target_dw = self.log(dw)
|
||||
target_dh = self.log(dh)
|
||||
deltas = self.concat((target_dx, target_dy, target_dw, target_dh))
|
||||
return deltas
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn anchor generator."""
|
||||
import numpy as np
|
||||
class AnchorGenerator():
|
||||
"""Anchor generator for FasterRcnn."""
|
||||
def __init__(self, config):
|
||||
"""Anchor generator init method."""
|
||||
self.base_size = config.anchor_base
|
||||
self.num_anchor = config.num_anchors
|
||||
self.anchor_height = config.anchor_height
|
||||
self.anchor_width = config.anchor_width
|
||||
self.size = self.gen_anchor_size()
|
||||
self.base_anchors = self.gen_base_anchors()
|
||||
|
||||
def gen_base_anchors(self):
|
||||
"""Generate a single anchor."""
|
||||
base_anchor = np.array([0, 0, self.base_size - 1, self.base_size - 1], np.int32)
|
||||
anchors = np.zeros((len(self.size), 4), np.int32)
|
||||
index = 0
|
||||
for h, w in self.size:
|
||||
anchors[index] = self.scale_anchor(base_anchor, h, w)
|
||||
index += 1
|
||||
return anchors
|
||||
|
||||
def gen_anchor_size(self):
|
||||
"""Generate a list of anchor size"""
|
||||
size = []
|
||||
for width in self.anchor_width:
|
||||
for height in self.anchor_height:
|
||||
size.append((height, width))
|
||||
return size
|
||||
|
||||
def scale_anchor(self, anchor, h, w):
|
||||
x_ctr = (anchor[0] + anchor[2]) * 0.5
|
||||
y_ctr = (anchor[1] + anchor[3]) * 0.5
|
||||
scaled_anchor = anchor.copy()
|
||||
scaled_anchor[0] = x_ctr - w / 2 # xmin
|
||||
scaled_anchor[2] = x_ctr + w / 2 # xmax
|
||||
scaled_anchor[1] = y_ctr - h / 2 # ymin
|
||||
scaled_anchor[3] = y_ctr + h / 2 # ymax
|
||||
return scaled_anchor
|
||||
|
||||
def _meshgrid(self, x, y):
|
||||
"""Generate grid."""
|
||||
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
|
||||
yy = np.repeat(y, len(x))
|
||||
return xx, yy
|
||||
|
||||
def grid_anchors(self, featmap_size, stride=16):
|
||||
"""Generate anchor list."""
|
||||
base_anchors = self.base_anchors
|
||||
feat_h, feat_w = featmap_size
|
||||
shift_x = np.arange(0, feat_w) * stride
|
||||
shift_y = np.arange(0, feat_h) * stride
|
||||
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
||||
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
|
||||
shifts = shifts.astype(base_anchors.dtype)
|
||||
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
|
||||
all_anchors = all_anchors.reshape(-1, 4)
|
||||
return all_anchors
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn positive and negative sample screening for RPN."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.CTPN.BoundingBoxEncode import BoundingBoxEncode
|
||||
|
||||
|
||||
class BboxAssignSample(nn.Cell):
|
||||
"""
|
||||
Bbox assigner and sampler definition.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_bboxes (int): The anchor nums.
|
||||
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
|
||||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
|
||||
labels: label for every bboxes, (batch_size, num_bboxes, 1)
|
||||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
|
||||
|
||||
Examples:
|
||||
BboxAssignSample(config, 2, 1024, True)
|
||||
"""
|
||||
|
||||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
|
||||
super(BboxAssignSample, self).__init__()
|
||||
cfg = config
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
|
||||
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
|
||||
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
|
||||
self.zero_thr = Tensor(0.0, mstype.float16)
|
||||
|
||||
self.num_bboxes = num_bboxes
|
||||
self.num_gts = cfg.num_gts
|
||||
self.num_expected_pos = cfg.num_expected_pos
|
||||
self.num_expected_neg = cfg.num_expected_neg
|
||||
self.add_gt_as_proposals = add_gt_as_proposals
|
||||
|
||||
if self.add_gt_as_proposals:
|
||||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
|
||||
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.max_gt = P.ArgMaxWithValue(axis=0)
|
||||
self.max_anchor = P.ArgMaxWithValue(axis=1)
|
||||
self.sum_inds = P.ReduceSum()
|
||||
self.iou = P.IOU()
|
||||
self.greaterequal = P.GreaterEqual()
|
||||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.less = P.Less()
|
||||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
|
||||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
|
||||
self.reshape = P.Reshape()
|
||||
self.equal = P.Equal()
|
||||
self.bounding_box_encode = BoundingBoxEncode()
|
||||
self.scatterNdUpdate = P.ScatterNdUpdate()
|
||||
self.scatterNd = P.ScatterNd()
|
||||
self.logicalnot = P.LogicalNot()
|
||||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
|
||||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
|
||||
|
||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
|
||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
|
||||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
|
||||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
|
||||
self.print = P.Print()
|
||||
|
||||
|
||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
|
||||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
|
||||
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
|
||||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
|
||||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
|
||||
overlaps = self.iou(bboxes, gt_bboxes_i)
|
||||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
|
||||
_, max_overlaps_w_ac = self.max_anchor(overlaps)
|
||||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
|
||||
self.less(max_overlaps_w_gt, self.neg_iou_thr))
|
||||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
|
||||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
|
||||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
|
||||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
|
||||
assigned_gt_inds4 = assigned_gt_inds3
|
||||
for j in range(self.num_gts):
|
||||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
|
||||
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
|
||||
|
||||
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
|
||||
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
|
||||
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
|
||||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
|
||||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
|
||||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
|
||||
pos_check_valid = self.sum_inds(pos_check_valid, -1)
|
||||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
|
||||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
|
||||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
|
||||
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
|
||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
|
||||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
|
||||
|
||||
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
|
||||
num_pos = self.sum_inds(num_pos, -1)
|
||||
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
|
||||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
|
||||
pos_bboxes_ = self.gatherND(bboxes, pos_index)
|
||||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
|
||||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
|
||||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
|
||||
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
|
||||
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
|
||||
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
|
||||
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
|
||||
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
|
||||
total_index = self.concat((pos_index, neg_index))
|
||||
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
|
||||
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
|
||||
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
|
||||
labels_total, self.cast(label_weights_total, mstype.bool_)
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn proposal generator."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from src.CTPN.BoundingBoxDecode import BoundingBoxDecode
|
||||
|
||||
class Proposal(nn.Cell):
|
||||
"""
|
||||
Proposal subnet.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_classes (int) - Class number.
|
||||
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
|
||||
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
|
||||
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor,(proposal, mask).
|
||||
|
||||
Examples:
|
||||
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
|
||||
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
batch_size,
|
||||
num_classes,
|
||||
use_sigmoid_cls,
|
||||
target_means=(.0, .0, .0, .0),
|
||||
target_stds=(1.0, 1.0, 1.0, 1.0)
|
||||
):
|
||||
super(Proposal, self).__init__()
|
||||
cfg = config
|
||||
self.batch_size = batch_size
|
||||
self.num_classes = num_classes
|
||||
self.target_means = target_means
|
||||
self.target_stds = target_stds
|
||||
self.use_sigmoid_cls = config.use_sigmoid_cls
|
||||
|
||||
if self.use_sigmoid_cls:
|
||||
self.cls_out_channels = 1
|
||||
self.activation = P.Sigmoid()
|
||||
self.reshape_shape = (-1, 1)
|
||||
else:
|
||||
self.cls_out_channels = num_classes
|
||||
self.activation = P.Softmax(axis=1)
|
||||
self.reshape_shape = (-1, 2)
|
||||
|
||||
if self.cls_out_channels <= 0:
|
||||
raise ValueError('num_classes={} is too small'.format(num_classes))
|
||||
|
||||
self.num_pre = cfg.rpn_proposal_nms_pre
|
||||
self.min_box_size = cfg.rpn_proposal_min_bbox_size
|
||||
self.nms_thr = cfg.rpn_proposal_nms_thr
|
||||
self.nms_post = cfg.rpn_proposal_nms_post
|
||||
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
|
||||
self.max_num = cfg.rpn_proposal_max_num
|
||||
|
||||
# Op Define
|
||||
self.squeeze = P.Squeeze()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.feature_shapes = cfg.feature_shapes
|
||||
|
||||
self.transpose_shape = (1, 2, 0)
|
||||
|
||||
self.decode = BoundingBoxDecode()
|
||||
|
||||
self.nms = P.NMSWithMask(self.nms_thr)
|
||||
self.concat_axis0 = P.Concat(axis=0)
|
||||
self.concat_axis1 = P.Concat(axis=1)
|
||||
self.split = P.Split(axis=1, output_num=5)
|
||||
self.min = P.Minimum()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.slice = P.Slice()
|
||||
self.select = P.Select()
|
||||
self.greater = P.Greater()
|
||||
self.transpose = P.Transpose()
|
||||
self.tile = P.Tile()
|
||||
self.set_train_local(config, training=True)
|
||||
|
||||
self.multi_10 = Tensor(10.0, mstype.float16)
|
||||
|
||||
def set_train_local(self, config, training=False):
|
||||
"""Set training flag."""
|
||||
self.training_local = training
|
||||
cfg = config
|
||||
self.topK_stage1 = ()
|
||||
self.topK_shape = ()
|
||||
total_max_topk_input = 0
|
||||
if not self.training_local:
|
||||
self.num_pre = cfg.rpn_nms_pre
|
||||
self.min_box_size = cfg.rpn_min_bbox_min_size
|
||||
self.nms_thr = cfg.rpn_nms_thr
|
||||
self.nms_post = cfg.rpn_nms_post
|
||||
self.max_num = cfg.rpn_max_num
|
||||
k_num = self.num_pre
|
||||
total_max_topk_input = k_num
|
||||
self.topK_stage1 = k_num
|
||||
self.topK_shape = (k_num, 1)
|
||||
|
||||
self.topKv2 = P.TopK(sorted=True)
|
||||
self.topK_shape_stage2 = (self.max_num, 1)
|
||||
self.min_float_num = -65536.0
|
||||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
|
||||
self.shape = P.Shape()
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
|
||||
proposals_tuple = ()
|
||||
masks_tuple = ()
|
||||
for img_id in range(self.batch_size):
|
||||
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[img_id:img_id+1:1, ::, ::, ::])
|
||||
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[img_id:img_id+1:1, ::, ::, ::])
|
||||
proposals, masks = self.get_bboxes_single(rpn_cls_score_i, rpn_bbox_pred_i, anchor_list)
|
||||
proposals_tuple += (proposals,)
|
||||
masks_tuple += (masks,)
|
||||
return proposals_tuple, masks_tuple
|
||||
|
||||
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
|
||||
"""Get proposal boundingbox."""
|
||||
mlvl_proposals = ()
|
||||
mlvl_mask = ()
|
||||
|
||||
rpn_cls_score = self.transpose(cls_scores, self.transpose_shape)
|
||||
rpn_bbox_pred = self.transpose(bbox_preds, self.transpose_shape)
|
||||
anchors = mlvl_anchors
|
||||
|
||||
# (H, W, A*2)
|
||||
rpn_cls_score_shape = self.shape(rpn_cls_score)
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, (rpn_cls_score_shape[0], \
|
||||
rpn_cls_score_shape[1], -1, self.cls_out_channels))
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
|
||||
rpn_cls_score = self.activation(rpn_cls_score)
|
||||
if self.use_sigmoid_cls:
|
||||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score), mstype.float16)
|
||||
else:
|
||||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 1]), mstype.float16)
|
||||
|
||||
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
|
||||
|
||||
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.num_pre)
|
||||
|
||||
topk_inds = self.reshape(topk_inds, self.topK_shape)
|
||||
|
||||
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
|
||||
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
|
||||
|
||||
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
|
||||
|
||||
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape)))
|
||||
proposals, _, mask_valid = self.nms(proposals_decode)
|
||||
|
||||
mlvl_proposals = mlvl_proposals + (proposals,)
|
||||
mlvl_mask = mlvl_mask + (mask_valid,)
|
||||
|
||||
proposals = self.concat_axis0(mlvl_proposals)
|
||||
masks = self.concat_axis0(mlvl_mask)
|
||||
|
||||
_, _, _, _, scores = self.split(proposals)
|
||||
scores = self.squeeze(scores)
|
||||
topk_mask = self.cast(self.topK_mask, mstype.float16)
|
||||
scores_using = self.select(masks, scores, topk_mask)
|
||||
|
||||
_, topk_inds = self.topKv2(scores_using, self.max_num)
|
||||
|
||||
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
|
||||
proposals = self.gatherND(proposals, topk_inds)
|
||||
masks = self.gatherND(masks, topk_inds)
|
||||
return proposals, masks
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""RPN for fasterRCNN"""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from src.CTPN.bbox_assign_sample import BboxAssignSample
|
||||
|
||||
class RpnRegClsBlock(nn.Cell):
|
||||
"""
|
||||
Rpn reg cls block for rpn layer
|
||||
|
||||
Args:
|
||||
config(EasyDict) - Network construction config.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
num_anchors,
|
||||
cls_out_channels):
|
||||
super(RpnRegClsBlock, self).__init__()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, 2*config.hidden_size)
|
||||
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
|
||||
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
|
||||
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
|
||||
self.shape1 = (config.num_step, config.rnn_batch_size, -1)
|
||||
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step)
|
||||
self.transpose = P.Transpose()
|
||||
self.print = P.Print()
|
||||
self.dropout = nn.Dropout(0.8)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.reshape(x, self.shape)
|
||||
x = self.lstm_fc(x)
|
||||
x1 = self.rpn_cls(x)
|
||||
x1 = self.reshape(x1, self.shape1)
|
||||
x1 = self.transpose(x1, (2, 1, 0))
|
||||
x1 = self.reshape(x1, self.shape2)
|
||||
x1 = self.transpose(x1, (1, 0, 2, 3))
|
||||
x2 = self.rpn_reg(x)
|
||||
x2 = self.reshape(x2, self.shape1)
|
||||
x2 = self.transpose(x2, (2, 1, 0))
|
||||
x2 = self.reshape(x2, self.shape2)
|
||||
x2 = self.transpose(x2, (1, 0, 2, 3))
|
||||
return x1, x2
|
||||
|
||||
class RPN(nn.Cell):
|
||||
"""
|
||||
ROI proposal network..
|
||||
|
||||
Args:
|
||||
config (dict) - Config.
|
||||
batch_size (int) - Batchsize.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor.
|
||||
|
||||
Examples:
|
||||
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
|
||||
num_anchors=3, cls_out_channels=512)
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
batch_size,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
num_anchors,
|
||||
cls_out_channels):
|
||||
super(RPN, self).__init__()
|
||||
cfg_rpn = config
|
||||
self.cfg = config
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.feature_anchor_shape = cfg_rpn.feature_shapes
|
||||
self.feature_anchor_shape = self.feature_anchor_shape[0] * \
|
||||
self.feature_anchor_shape[1] * num_anchors * batch_size
|
||||
self.num_anchors = num_anchors
|
||||
self.batch_size = batch_size
|
||||
self.test_batch_size = cfg_rpn.test_batch_size
|
||||
self.num_layers = 1
|
||||
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
|
||||
self.use_sigmoid_cls = config.use_sigmoid_cls
|
||||
if config.use_sigmoid_cls:
|
||||
self.reshape_shape_cls = (-1,)
|
||||
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
|
||||
cls_out_channels = 1
|
||||
else:
|
||||
self.reshape_shape_cls = (-1, cls_out_channels)
|
||||
self.loss_cls = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none")
|
||||
self.rpn_convs_list = self._make_rpn_layer(self.num_layers, in_channels, feat_channels,\
|
||||
num_anchors, cls_out_channels)
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.fill = P.Fill()
|
||||
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
|
||||
|
||||
self.trans_shape = (0, 2, 3, 1)
|
||||
|
||||
self.reshape_shape_reg = (-1, 4)
|
||||
self.softmax = nn.Softmax()
|
||||
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
|
||||
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
|
||||
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
|
||||
self.CheckValid = P.CheckValid()
|
||||
self.sum_loss = P.ReduceSum()
|
||||
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.print = P.Print()
|
||||
|
||||
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
|
||||
"""
|
||||
make rpn layer for rpn proposal network
|
||||
|
||||
Args:
|
||||
num_layers (int) - layer num.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
List, list of RpnRegClsBlock cells.
|
||||
"""
|
||||
rpn_layer = RpnRegClsBlock(self.cfg, in_channels, feat_channels, num_anchors, cls_out_channels)
|
||||
return rpn_layer
|
||||
|
||||
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
|
||||
'''
|
||||
inputs(Tensor): Inputs tensor from lstm.
|
||||
img_metas(Tensor): Image shape.
|
||||
anchor_list(Tensor): Total anchor list.
|
||||
gt_labels(Tensor): Ground truth labels.
|
||||
gt_valids(Tensor): Whether ground truth is valid.
|
||||
'''
|
||||
rpn_cls_score_ori, rpn_bbox_pred_ori = self.rpn_convs_list(inputs)
|
||||
rpn_cls_score = self.transpose(rpn_cls_score_ori, self.trans_shape)
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape_cls)
|
||||
rpn_bbox_pred = self.transpose(rpn_bbox_pred_ori, self.trans_shape)
|
||||
rpn_bbox_pred = self.reshape(rpn_bbox_pred, self.reshape_shape_reg)
|
||||
output = ()
|
||||
bbox_targets = ()
|
||||
bbox_weights = ()
|
||||
labels = ()
|
||||
label_weights = ()
|
||||
if self.training:
|
||||
for i in range(self.batch_size):
|
||||
valid_flag_list = self.cast(self.CheckValid(anchor_list, self.squeeze(img_metas[i:i + 1:1, ::])),\
|
||||
mstype.int32)
|
||||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
|
||||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
|
||||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
|
||||
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
|
||||
gt_labels_i,
|
||||
self.cast(valid_flag_list,
|
||||
mstype.bool_),
|
||||
anchor_list, gt_valids_i)
|
||||
bbox_weight = self.cast(bbox_weight, mstype.float16)
|
||||
label_weight = self.cast(label_weight, mstype.float16)
|
||||
bbox_targets += (bbox_target,)
|
||||
bbox_weights += (bbox_weight,)
|
||||
labels += (label,)
|
||||
label_weights += (label_weight,)
|
||||
bbox_target_with_batchsize = self.concat(bbox_targets)
|
||||
bbox_weight_with_batchsize = self.concat(bbox_weights)
|
||||
label_with_batchsize = self.concat(labels)
|
||||
label_weight_with_batchsize = self.concat(label_weights)
|
||||
|
||||
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
|
||||
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
|
||||
label_ = F.stop_gradient(label_with_batchsize)
|
||||
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
|
||||
rpn_cls_score = self.cast(rpn_cls_score, mstype.float32)
|
||||
if self.use_sigmoid_cls:
|
||||
label_ = self.cast(label_, mstype.float32)
|
||||
loss_cls = self.loss_cls(rpn_cls_score, label_)
|
||||
loss_cls = loss_cls * label_weight_
|
||||
loss_cls = self.sum_loss(loss_cls, (0,)) / self.num_expected_total
|
||||
rpn_bbox_pred = self.cast(rpn_bbox_pred, mstype.float32)
|
||||
bbox_target_ = self.cast(bbox_target_, mstype.float32)
|
||||
loss_reg = self.loss_bbox(rpn_bbox_pred, bbox_target_)
|
||||
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape, 1)), (1, 4))
|
||||
loss_reg = loss_reg * bbox_weight_
|
||||
loss_reg = self.sum_loss(loss_reg, (1,))
|
||||
loss_reg = self.sum_loss(loss_reg, (0,)) / self.num_expected_total
|
||||
loss_total = self.rpn_loss_cls_weight * loss_cls + self.rpn_loss_reg_weight * loss_reg
|
||||
output = (loss_total, rpn_cls_score_ori, rpn_bbox_pred_ori, loss_cls, loss_reg)
|
||||
else:
|
||||
output = (self.placeh1, rpn_cls_score_ori, rpn_bbox_pred_ori, self.placeh1, self.placeh1)
|
||||
return output
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
''''weight initialize'''
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
|
||||
"""Batchnorm2D wrapper."""
|
||||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
||||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
||||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
||||
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
||||
|
||||
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
|
||||
beta_init=beta_init, moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
|
||||
"""Conv2D wrapper."""
|
||||
weights = 'ones'
|
||||
layers = []
|
||||
conv = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=False)
|
||||
if not weights_update:
|
||||
conv.weight.requires_grad = False
|
||||
layers += [conv]
|
||||
layers += [_BatchNorm2dInit(out_channels)]
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
def _fc(in_channels, out_channels):
|
||||
'''full connection layer'''
|
||||
weight = _weight_variable((out_channels, in_channels))
|
||||
bias = _weight_variable((out_channels,))
|
||||
return nn.Dense(in_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
class VGG16FeatureExtraction(nn.Cell):
|
||||
def __init__(self, weights_update=False):
|
||||
"""
|
||||
VGG16 feature extraction
|
||||
|
||||
Args:
|
||||
weights_updata(bool): whether update weights for top two layers, default is False.
|
||||
"""
|
||||
super(VGG16FeatureExtraction, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv1_1 = _conv(in_channels=3, out_channels=64, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
self.conv1_2 = _conv(in_channels=64, out_channels=64, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
|
||||
self.conv2_1 = _conv(in_channels=64, out_channels=128, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
self.conv2_2 = _conv(in_channels=128, out_channels=128, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
|
||||
self.conv3_1 = _conv(in_channels=128, out_channels=256, kernel_size=3, padding=1)
|
||||
self.conv3_2 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
|
||||
self.conv3_3 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
|
||||
|
||||
self.conv4_1 = _conv(in_channels=256, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv4_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv4_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
|
||||
self.conv5_1 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv5_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv5_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 3, 224, 224)
|
||||
:return:
|
||||
"""
|
||||
x = self.cast(x, mstype.float32)
|
||||
x = self.conv1_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv2_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv3_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv3_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv3_3(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv4_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_3(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv5_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv5_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv5_3(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class VGG16Classfier(nn.Cell):
|
||||
def __init__(self):
|
||||
"""VGG16 classfier structure"""
|
||||
super(VGG16Classfier, self).__init__()
|
||||
self.flatten = P.Flatten()
|
||||
self.relu = nn.ReLU()
|
||||
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
|
||||
self.fc2 = _fc(in_channels=4096, out_channels=4096)
|
||||
self.batch_size = 32
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 512, 7, 7)
|
||||
:return:
|
||||
"""
|
||||
x = self.reshape(x, (self.batch_size, 7*7*512))
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class VGG16(nn.Cell):
|
||||
def __init__(self):
|
||||
"""VGG16 construct for training backbone"""
|
||||
super(VGG16, self).__init__()
|
||||
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.classifier = VGG16Classfier()
|
||||
self.fc3 = _fc(in_channels=4096, out_channels=1000)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 3, 224, 224)
|
||||
:return: logits, shape=(B, 1000)
|
||||
"""
|
||||
feature_maps = self.feature_extraction(x)
|
||||
x = self.max_pool(feature_maps)
|
||||
x = self.classifier(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network parameters."""
|
||||
from easydict import EasyDict
|
||||
pretrain_config = EasyDict({
|
||||
# LR
|
||||
"base_lr": 0.0009,
|
||||
"warmup_step": 30000,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"total_epoch": 100,
|
||||
})
|
||||
finetune_config = EasyDict({
|
||||
# LR
|
||||
"base_lr": 0.0005,
|
||||
"warmup_step": 300,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"total_epoch": 50,
|
||||
})
|
||||
|
||||
# use for low case number
|
||||
config = EasyDict({
|
||||
"img_width": 960,
|
||||
"img_height": 576,
|
||||
"keep_ratio": False,
|
||||
"flip_ratio": 0.0,
|
||||
"photo_ratio": 0.0,
|
||||
"expand_ratio": 1.0,
|
||||
|
||||
# anchor
|
||||
"feature_shapes": (36, 60),
|
||||
"num_anchors": 14,
|
||||
"anchor_base": 16,
|
||||
"anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406],
|
||||
"anchor_width": [16],
|
||||
|
||||
# rpn
|
||||
"rpn_in_channels": 256,
|
||||
"rpn_feat_channels": 512,
|
||||
"rpn_loss_cls_weight": 1.0,
|
||||
"rpn_loss_reg_weight": 3.0,
|
||||
"rpn_cls_out_channels": 2,
|
||||
|
||||
# bbox_assign_sampler
|
||||
"neg_iou_thr": 0.5,
|
||||
"pos_iou_thr": 0.7,
|
||||
"min_pos_iou": 0.001,
|
||||
"num_bboxes": 30240,
|
||||
"num_gts": 256,
|
||||
"num_expected_neg": 512,
|
||||
"num_expected_pos": 256,
|
||||
|
||||
#proposal
|
||||
"activate_num_classes": 2,
|
||||
"use_sigmoid_cls": False,
|
||||
|
||||
# train proposal
|
||||
"rpn_proposal_nms_across_levels": False,
|
||||
"rpn_proposal_nms_pre": 2000,
|
||||
"rpn_proposal_nms_post": 1000,
|
||||
"rpn_proposal_max_num": 1000,
|
||||
"rpn_proposal_nms_thr": 0.7,
|
||||
"rpn_proposal_min_bbox_size": 8,
|
||||
|
||||
# rnn structure
|
||||
"input_size": 512,
|
||||
"num_step": 60,
|
||||
"rnn_batch_size": 36,
|
||||
"hidden_size": 128,
|
||||
|
||||
# training
|
||||
"warmup_mode": "linear",
|
||||
"batch_size": 1,
|
||||
"momentum": 0.9,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 10,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./",
|
||||
"use_dropout": False,
|
||||
"loss_scale": 1,
|
||||
"weight_decay": 1e-4,
|
||||
|
||||
# test proposal
|
||||
"rpn_nms_pre": 2000,
|
||||
"rpn_nms_post": 1000,
|
||||
"rpn_max_num": 1000,
|
||||
"rpn_nms_thr": 0.7,
|
||||
"rpn_min_bbox_min_size": 8,
|
||||
"test_iou_thr": 0.7,
|
||||
"test_max_per_img": 100,
|
||||
"test_batch_size": 1,
|
||||
"use_python_proposal": False,
|
||||
|
||||
# text proposal connection
|
||||
"max_horizontal_gap": 60,
|
||||
"text_proposals_min_scores": 0.7,
|
||||
"text_proposals_nms_thresh": 0.2,
|
||||
"min_v_overlaps": 0.7,
|
||||
"min_size_sim": 0.7,
|
||||
"min_ratio": 0.5,
|
||||
"line_min_score": 0.9,
|
||||
"text_proposals_width": 16,
|
||||
"min_num_proposals": 2,
|
||||
|
||||
# create dataset
|
||||
"coco_root": "",
|
||||
"coco_train_data_type": "",
|
||||
"cocotext_json": "",
|
||||
"icdar11_train_path": [],
|
||||
"icdar13_train_path": [],
|
||||
"icdar15_train_path": [],
|
||||
"icdar13_test_path": [],
|
||||
"flick_train_path": [],
|
||||
"svt_train_path": [],
|
||||
"pretrain_dataset_path": "",
|
||||
"finetune_dataset_path": "",
|
||||
"test_dataset_path": "",
|
||||
|
||||
# training dataset
|
||||
"pretraining_dataset_file": "",
|
||||
"finetune_dataset_file": ""
|
||||
})
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert icdar2015 dataset label"""
|
||||
import os
|
||||
import argparse
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-s', '--src_label_path', type=str, default='./',
|
||||
help='Directory containing icdar2015 train label')
|
||||
parser.add_argument('-t', '--target_label_path', type=str, default='test.xml',
|
||||
help='Directory where save the icdar2015 label after convert')
|
||||
return parser.parse_args()
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
anno_file = os.listdir(args.src_label_path)
|
||||
annos = {}
|
||||
# read
|
||||
for file in anno_file:
|
||||
gt = open(os.path.join(args.src_label_path, file), 'r', encoding='UTF-8-sig').read().splitlines()
|
||||
label_list = []
|
||||
label_name = os.path.basename(file)
|
||||
for each_label in gt:
|
||||
print(file)
|
||||
spt = each_label.split(',')
|
||||
print(spt)
|
||||
if "###" in spt[8]:
|
||||
continue
|
||||
else:
|
||||
x1 = min(int(spt[0]), int(spt[6]))
|
||||
y1 = min(int(spt[1]), int(spt[3]))
|
||||
x2 = max(int(spt[2]), int(spt[4]))
|
||||
y2 = max(int(spt[5]), int(spt[7]))
|
||||
label_list.append([x1, y1, x2, y2])
|
||||
annos[label_name] = label_list
|
||||
# write
|
||||
if not os.path.exists(args.target_label_path):
|
||||
os.makedirs(args.target_label_path)
|
||||
for label_file, pos in annos.items():
|
||||
tgt_anno_file = os.path.join(args.target_label_path, label_file)
|
||||
f = open(tgt_anno_file, 'w', encoding='UTF-8-sig')
|
||||
for tgt_label in pos:
|
||||
str_pos = str(tgt_label[0]) + ',' + str(tgt_label[1]) + ',' + str(tgt_label[2]) + ',' + str(tgt_label[3])
|
||||
f.write(str_pos)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert svt dataset label"""
|
||||
import os
|
||||
import argparse
|
||||
from xml.etree import ElementTree as ET
|
||||
import numpy as np
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||
help='Directory containing images')
|
||||
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--location_dir', type=str, default='./location',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
return parser.parse_args()
|
||||
|
||||
def xml_to_dict(xml_file, save_file=False):
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
imgs_labels = []
|
||||
|
||||
for ch in root:
|
||||
im_label = {}
|
||||
for ch01 in ch:
|
||||
if ch01.tag in "address":
|
||||
continue
|
||||
elif ch01.tag in 'taggedRectangles':
|
||||
# multiple children
|
||||
rect_list = []
|
||||
for ch02 in ch01:
|
||||
rect = {}
|
||||
rect['location'] = ch02.attrib
|
||||
rect['label'] = ch02[0].text
|
||||
rect_list.append(rect)
|
||||
im_label['rect'] = rect_list
|
||||
else:
|
||||
im_label[ch01.tag] = ch01.text
|
||||
imgs_labels.append(im_label)
|
||||
|
||||
if save_file:
|
||||
np.save("annotation_train.npy", imgs_labels)
|
||||
|
||||
return imgs_labels
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
if not os.path.exists(args.dataset_dir):
|
||||
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||
|
||||
if not os.path.exists(args.xml_file):
|
||||
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||
|
||||
if not os.path.exists(args.location_dir):
|
||||
os.makedirs(args.location_dir)
|
||||
|
||||
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||
num_images = len(ims_labels_dict)
|
||||
print("Converting annotation, {} images in total ".format(num_images))
|
||||
for i in range(num_images):
|
||||
img_label = ims_labels_dict[i]
|
||||
image_name = img_label['imageName']
|
||||
rects = img_label['rect']
|
||||
print("processing image: {}".format(image_name))
|
||||
location_file_name = os.path.join(args.location_dir, os.path.basename(image_name).replace(".jpg", ".txt"))
|
||||
f = open(location_file_name, 'w')
|
||||
for j, rect in enumerate(rects):
|
||||
rect = rects[j]
|
||||
location = rect['location']
|
||||
h = int(location['height'])
|
||||
w = int(location['width'])
|
||||
x = int(location['x'])
|
||||
y = int(location['y'])
|
||||
pos = [x, y, x+w, y+h]
|
||||
str_pos = str(pos[0]) + "," + str(pos[1]) + "," + str(pos[2]) + "," + str(pos[3])
|
||||
f.write(str_pos)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import division
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.config import config
|
||||
|
||||
def create_coco_label():
|
||||
"""Create image label."""
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
coco_root = config.coco_root
|
||||
data_type = config.coco_train_data_type
|
||||
from src.coco_text import COCO_Text
|
||||
anno_json = config.cocotext_json
|
||||
ct = COCO_Text(anno_json)
|
||||
image_ids = ct.getImgIds(imgIds=ct.train,
|
||||
catIds=[('legibility', 'legible')])
|
||||
for img_id in image_ids:
|
||||
image_info = ct.loadImgs(img_id)[0]
|
||||
file_name = image_info['file_name'][15:]
|
||||
anno_ids = ct.getAnnIds(imgIds=img_id)
|
||||
anno = ct.loadAnns(anno_ids)
|
||||
image_path = os.path.join(coco_root, data_type, file_name)
|
||||
annos = []
|
||||
im = Image.open(image_path)
|
||||
width, _ = im.size
|
||||
for label in anno:
|
||||
bbox = label["bbox"]
|
||||
bbox_width = int(bbox[2])
|
||||
if 60 * bbox_width < width:
|
||||
continue
|
||||
x1, x2 = int(bbox[0]), int(bbox[0] + bbox[2])
|
||||
y1, y2 = int(bbox[1]), int(bbox[1] + bbox[3])
|
||||
annos.append([x1, y1, x2, y2] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_anno_dataset_label(train_img_dirs, train_txt_dirs):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
# read
|
||||
img_basenames = []
|
||||
for file in os.listdir(train_img_dirs):
|
||||
# Filter git file.
|
||||
if 'gif' not in file:
|
||||
img_basenames.append(os.path.basename(file))
|
||||
img_names = []
|
||||
for item in img_basenames:
|
||||
temp1, _ = os.path.splitext(item)
|
||||
img_names.append((temp1, item))
|
||||
for img, img_basename in img_names:
|
||||
image_path = train_img_dirs + '/' + img_basename
|
||||
annos = []
|
||||
if len(img) == 6 and '_' not in img_basename:
|
||||
gt = open(train_txt_dirs + '/' + img + '.txt').read().splitlines()
|
||||
if img.isdigit() and int(img) > 1200:
|
||||
continue
|
||||
for img_each_label in gt:
|
||||
spt = img_each_label.replace(',', '').split(' ')
|
||||
if ' ' not in img_each_label:
|
||||
spt = img_each_label.split(',')
|
||||
annos.append([spt[0], spt[1], str(int(spt[0]) + int(spt[2])), str(int(spt[1]) + int(spt[3]))] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_icdar_svt_label(train_img_dir, train_txt_dir, prefix):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
img_basenames = []
|
||||
for file_name in os.listdir(train_img_dir):
|
||||
if 'gif' not in file_name:
|
||||
img_basenames.append(os.path.basename(file_name))
|
||||
img_names = []
|
||||
for item in img_basenames:
|
||||
temp1, _ = os.path.splitext(item)
|
||||
img_names.append((temp1, item))
|
||||
for img, img_basename in img_names:
|
||||
image_path = train_img_dir + '/' + img_basename
|
||||
annos = []
|
||||
file_name = prefix + img + ".txt"
|
||||
file_path = os.path.join(train_txt_dir, file_name)
|
||||
gt = open(file_path, 'r', encoding='UTF-8-sig').read().splitlines()
|
||||
if not gt:
|
||||
continue
|
||||
for img_each_label in gt:
|
||||
spt = img_each_label.replace(',', '').split(' ')
|
||||
if ' ' not in img_each_label:
|
||||
spt = img_each_label.split(',')
|
||||
annos.append([spt[0], spt[1], spt[2], spt[3]] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_train_dataset(dataset_type):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
if dataset_type == "pretraining":
|
||||
# pretrianing: coco, flick, icdar2013 train, icdar2015, svt
|
||||
coco_image_files, coco_anno_dict = create_coco_label()
|
||||
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
|
||||
config.flick_train_path[1])
|
||||
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
|
||||
config.icdar13_train_path[1], "gt_img_")
|
||||
icdar15_image_files, icdar15_anno_dict = create_icdar_svt_label(config.icdar15_train_path[0],
|
||||
config.icdar15_train_path[1], "gt_")
|
||||
svt_image_files, svt_anno_dict = create_icdar_svt_label(config.svt_train_path[0], config.svt_train_path[1], "")
|
||||
image_files = coco_image_files + flick_image_files + icdar13_image_files + icdar15_image_files + svt_image_files
|
||||
image_anno_dict = {**coco_anno_dict, **flick_anno_dict, \
|
||||
**icdar13_anno_dict, **icdar15_anno_dict, **svt_anno_dict}
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.pretrain_dataset_path, \
|
||||
prefix="ctpn_pretrain.mindrecord", file_num=8)
|
||||
elif dataset_type == "finetune":
|
||||
# finetune: icdar2011, icdar2013 train, flick
|
||||
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
|
||||
config.flick_train_path[1])
|
||||
icdar11_image_files, icdar11_anno_dict = create_icdar_svt_label(config.icdar11_train_path[0],
|
||||
config.icdar11_train_path[1], "gt_")
|
||||
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
|
||||
config.icdar13_train_path[1], "gt_img_")
|
||||
image_files = flick_image_files + icdar11_image_files + icdar13_image_files
|
||||
image_anno_dict = {**flick_anno_dict, **icdar11_anno_dict, **icdar13_anno_dict}
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.finetune_dataset_path, \
|
||||
prefix="ctpn_finetune.mindrecord", file_num=8)
|
||||
elif dataset_type == "test":
|
||||
# test: icdar2013 test
|
||||
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
|
||||
config.icdar13_test_path[1], "")
|
||||
image_files = icdar_test_image_files
|
||||
image_anno_dict = icdar_test_anno_dict
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
|
||||
prefix="ctpn_test.mindrecord", file_num=1)
|
||||
else:
|
||||
print("dataset_type should be pretraining, finetune, test")
|
||||
|
||||
def data_to_mindrecord_byte_image(image_files, image_anno_dict, dst_dir, prefix="cptn_mlt.mindrecord", file_num=1):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_path = os.path.join(dst_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
|
||||
ctpn_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(ctpn_json, "ctpn_json")
|
||||
for image_name in image_files:
|
||||
with open(image_name, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
||||
print("img name is {}, anno is {}".format(image_name, annos))
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_train_dataset("pretraining")
|
||||
create_train_dataset("finetune")
|
||||
create_train_dataset("test")
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""CPTN network definition."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from src.CTPN.rpn import RPN
|
||||
from src.CTPN.anchor_generator import AnchorGenerator
|
||||
from src.CTPN.proposal_generator import Proposal
|
||||
from src.CTPN.vgg16 import VGG16FeatureExtraction
|
||||
|
||||
class BiLSTM(nn.Cell):
|
||||
"""
|
||||
Define a BiLSTM network which contains two LSTM layers
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, config, is_training=True):
|
||||
super(BiLSTM, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.batch_size = config.batch_size * config.rnn_batch_size
|
||||
print("batch size is {} ".format(self.batch_size))
|
||||
self.input_size = config.input_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_step = config.num_step
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
k = (1 / self.hidden_size) ** 0.5
|
||||
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.w1 = Parameter(np.random.uniform(-k, k, \
|
||||
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1")
|
||||
self.w1_bw = Parameter(np.random.uniform(-k, k, \
|
||||
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw")
|
||||
|
||||
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
|
||||
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
|
||||
|
||||
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
|
||||
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.reverse_seq = P.ReverseV2(axis=[0])
|
||||
self.concat = P.Concat()
|
||||
self.transpose = P.Transpose()
|
||||
self.concat1 = P.Concat(axis=2)
|
||||
self.dropout = nn.Dropout(0.7)
|
||||
self.use_dropout = config.use_dropout
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
def construct(self, x):
|
||||
if self.use_dropout:
|
||||
x = self.dropout(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
bw_x = self.reverse_seq(x)
|
||||
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
|
||||
y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
|
||||
y1_bw = self.reverse_seq(y1_bw)
|
||||
output = self.concat1((y1, y1_bw))
|
||||
return output
|
||||
|
||||
class CTPN(nn.Cell):
|
||||
"""
|
||||
Define CTPN network
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, config, is_training=True):
|
||||
super(CTPN, self).__init__()
|
||||
self.config = config
|
||||
self.is_training = is_training
|
||||
self.num_step = config.num_step
|
||||
self.input_size = config.input_size
|
||||
self.batch_size = config.batch_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vgg16_feature_extractor = VGG16FeatureExtraction()
|
||||
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
|
||||
self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16)
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
|
||||
# rpn block
|
||||
self.rpn_with_loss = RPN(config,
|
||||
self.batch_size,
|
||||
config.rpn_in_channels,
|
||||
config.rpn_feat_channels,
|
||||
config.num_anchors,
|
||||
config.rpn_cls_out_channels)
|
||||
self.anchor_generator = AnchorGenerator(config)
|
||||
self.featmap_size = config.feature_shapes
|
||||
self.anchor_list = self.get_anchors(self.featmap_size)
|
||||
self.proposal_generator_test = Proposal(config,
|
||||
config.test_batch_size,
|
||||
config.activate_num_classes,
|
||||
config.use_sigmoid_cls)
|
||||
self.proposal_generator_test.set_train_local(config, False)
|
||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
|
||||
# (1,3,600,900)
|
||||
x = self.vgg16_feature_extractor(img_data)
|
||||
x = self.conv(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
# (1, 512, 38, 57)
|
||||
x = self.transpose(x, (0, 2, 1, 3))
|
||||
x = self.reshape(x, (-1, self.input_size, self.num_step))
|
||||
x = self.transpose(x, (2, 0, 1))
|
||||
# (57, 38, 512)
|
||||
x = self.rnn(x)
|
||||
# (57, 38, 256)
|
||||
#x = self.cast(x, mstype.float32)
|
||||
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x,
|
||||
img_metas,
|
||||
self.anchor_list,
|
||||
gt_bboxes,
|
||||
gt_labels,
|
||||
gt_valids)
|
||||
if self.training:
|
||||
return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss
|
||||
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
|
||||
return proposal, proposal_mask
|
||||
|
||||
def get_anchors(self, featmap_size):
|
||||
anchors = self.anchor_generator.grid_anchors(featmap_size)
|
||||
return Tensor(anchors, mstype.float16)
|
|
@ -0,0 +1,342 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FasterRcnn dataset"""
|
||||
from __future__ import division
|
||||
import os
|
||||
import numpy as np
|
||||
from numpy import random
|
||||
import mmcv
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as CC
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.config import config
|
||||
|
||||
class PhotoMetricDistortion:
|
||||
"""Photo Metric Distortion"""
|
||||
|
||||
def __init__(self,
|
||||
brightness_delta=32,
|
||||
contrast_range=(0.5, 1.5),
|
||||
saturation_range=(0.5, 1.5),
|
||||
hue_delta=18):
|
||||
self.brightness_delta = brightness_delta
|
||||
self.contrast_lower, self.contrast_upper = contrast_range
|
||||
self.saturation_lower, self.saturation_upper = saturation_range
|
||||
self.hue_delta = hue_delta
|
||||
|
||||
def __call__(self, img, boxes, labels):
|
||||
img = img.astype('float32')
|
||||
if random.randint(2):
|
||||
delta = random.uniform(-self.brightness_delta, self.brightness_delta)
|
||||
img += delta
|
||||
mode = random.randint(2)
|
||||
if mode == 1:
|
||||
if random.randint(2):
|
||||
alpha = random.uniform(self.contrast_lower,
|
||||
self.contrast_upper)
|
||||
img *= alpha
|
||||
# convert color from BGR to HSV
|
||||
img = mmcv.bgr2hsv(img)
|
||||
# random saturation
|
||||
if random.randint(2):
|
||||
img[..., 1] *= random.uniform(self.saturation_lower,
|
||||
self.saturation_upper)
|
||||
# random hue
|
||||
if random.randint(2):
|
||||
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
|
||||
img[..., 0][img[..., 0] > 360] -= 360
|
||||
img[..., 0][img[..., 0] < 0] += 360
|
||||
# convert color from HSV to BGR
|
||||
img = mmcv.hsv2bgr(img)
|
||||
# random contrast
|
||||
if mode == 0:
|
||||
if random.randint(2):
|
||||
alpha = random.uniform(self.contrast_lower,
|
||||
self.contrast_upper)
|
||||
img *= alpha
|
||||
# randomly swap channels
|
||||
if random.randint(2):
|
||||
img = img[..., random.permutation(3)]
|
||||
return img, boxes, labels
|
||||
|
||||
class Expand:
|
||||
"""expand image"""
|
||||
|
||||
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
|
||||
if to_rgb:
|
||||
self.mean = mean[::-1]
|
||||
else:
|
||||
self.mean = mean
|
||||
self.min_ratio, self.max_ratio = ratio_range
|
||||
|
||||
def __call__(self, img, boxes, labels):
|
||||
if random.randint(2):
|
||||
return img, boxes, labels
|
||||
h, w, c = img.shape
|
||||
ratio = random.uniform(self.min_ratio, self.max_ratio)
|
||||
expand_img = np.full((int(h * ratio), int(w * ratio), c),
|
||||
self.mean).astype(img.dtype)
|
||||
left = int(random.uniform(0, w * ratio - w))
|
||||
top = int(random.uniform(0, h * ratio - h))
|
||||
expand_img[top:top + h, left:left + w] = img
|
||||
img = expand_img
|
||||
boxes += np.tile((left, top), 2)
|
||||
return img, boxes, labels
|
||||
|
||||
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""rescale operation for image"""
|
||||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
||||
if img_data.shape[0] > config.img_height:
|
||||
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True)
|
||||
scale_factor = scale_factor * scale_factor2
|
||||
img_shape = np.append(img_shape, scale_factor)
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
gt_bboxes = split_gtbox_label(gt_bboxes)
|
||||
if gt_bboxes.shape[0] != 0:
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""resize operation for image"""
|
||||
img_data = img
|
||||
img_data, w_scale, h_scale = mmcv.imresize(
|
||||
img_data, (config.img_width, config.img_height), return_scale=True)
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
img_shape = (config.img_height, config.img_width, 1.0)
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
gt_bboxes = split_gtbox_label(gt_bboxes)
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""resize operation for image of eval"""
|
||||
img_data = img
|
||||
img_data, w_scale, h_scale = mmcv.imresize(
|
||||
img_data, (config.img_width, config.img_height), return_scale=True)
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
img_shape = (config.img_height, config.img_width)
|
||||
img_shape = np.append(img_shape, (h_scale, w_scale))
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
shape = gt_bboxes.shape
|
||||
label_column = np.ones((shape[0], 1), dtype=int)
|
||||
gt_bboxes = np.concatenate((gt_bboxes, label_column), axis=1)
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""flipped generation"""
|
||||
img_data = img
|
||||
flipped = gt_bboxes.copy()
|
||||
_, w, _ = img_data.shape
|
||||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
|
||||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
|
||||
return (img_data, img_shape, flipped, gt_label, gt_num)
|
||||
|
||||
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
img_data = img[:, :, ::-1]
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""photo crop operation for image"""
|
||||
random_photo = PhotoMetricDistortion()
|
||||
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""expand operation for image"""
|
||||
expand = Expand()
|
||||
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
|
||||
|
||||
return (img, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def split_gtbox_label(gt_bbox_total):
|
||||
"""split ground truth box label"""
|
||||
gtbox_list = []
|
||||
box_num, _ = gt_bbox_total.shape
|
||||
for i in range(box_num):
|
||||
gt_bbox = gt_bbox_total[i]
|
||||
if gt_bbox[0] % 16 != 0:
|
||||
gt_bbox[0] = (gt_bbox[0] // 16) * 16
|
||||
if gt_bbox[2] % 16 != 0:
|
||||
gt_bbox[2] = (gt_bbox[2] // 16 + 1) * 16
|
||||
x0_array = np.arange(gt_bbox[0], gt_bbox[2], 16)
|
||||
for x0 in x0_array:
|
||||
gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1])
|
||||
return np.array(gtbox_list)
|
||||
|
||||
def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid):
|
||||
"""pad ground truth label"""
|
||||
pad_max_number = 256
|
||||
gt_label = gt_bboxes[:, 4]
|
||||
gt_valid = gt_bboxes[:, 4]
|
||||
if gt_bboxes.shape[0] < 256:
|
||||
gt_box = np.pad(gt_bboxes, ((0, pad_max_number - gt_bboxes.shape[0]), (0, 0)), \
|
||||
mode="constant", constant_values=0)
|
||||
gt_label = np.pad(gt_label, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=-1)
|
||||
gt_valid = np.pad(gt_valid, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=0)
|
||||
else:
|
||||
print("WARNING label num is high than 256")
|
||||
gt_box = gt_bboxes[0:pad_max_number]
|
||||
gt_label = gt_label[0:pad_max_number]
|
||||
gt_valid = gt_valid[0:pad_max_number]
|
||||
return (img, img_shape, gt_box[:, :4], gt_label, gt_valid)
|
||||
|
||||
def preprocess_fn(image, box, is_training):
|
||||
"""Preprocess function for dataset."""
|
||||
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid):
|
||||
image_shape = image_shape[:2]
|
||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid
|
||||
if config.keep_ratio:
|
||||
input_data = rescale_column(*input_data)
|
||||
else:
|
||||
input_data = resize_column_test(*input_data)
|
||||
input_data = pad_label(*input_data)
|
||||
input_data = image_bgr_rgb(*input_data)
|
||||
output_data = input_data
|
||||
return output_data
|
||||
|
||||
def _data_aug(image, box, is_training):
|
||||
"""Data augmentation function."""
|
||||
image_bgr = image.copy()
|
||||
image_bgr[:, :, 0] = image[:, :, 2]
|
||||
image_bgr[:, :, 1] = image[:, :, 1]
|
||||
image_bgr[:, :, 2] = image[:, :, 0]
|
||||
image_shape = image_bgr.shape[:2]
|
||||
gt_box = box[:, :4]
|
||||
gt_label = box[:, 4]
|
||||
gt_valid = box[:, 4]
|
||||
input_data = image_bgr, image_shape, gt_box, gt_label, gt_valid
|
||||
if not is_training:
|
||||
return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_valid)
|
||||
expand = (np.random.rand() < config.expand_ratio)
|
||||
if expand:
|
||||
input_data = expand_column(*input_data)
|
||||
input_data = photo_crop_column(*input_data)
|
||||
if config.keep_ratio:
|
||||
input_data = rescale_column(*input_data)
|
||||
else:
|
||||
input_data = resize_column(*input_data)
|
||||
input_data = pad_label(*input_data)
|
||||
input_data = image_bgr_rgb(*input_data)
|
||||
output_data = input_data
|
||||
return output_data
|
||||
|
||||
return _data_aug(image, box, is_training)
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
annos.append(anno)
|
||||
return annos
|
||||
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
if not os.path.isdir(image_dir):
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
if not os.path.isfile(anno_path):
|
||||
raise RuntimeError("Annotation file is not valid.")
|
||||
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line_str = line.decode("utf-8").strip()
|
||||
line_split = str(line_str).split(' ')
|
||||
file_name = line_split[0]
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
if os.path.isfile(image_path):
|
||||
image_anno_dict[image_path] = anno_parser(line_split[1:])
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord", file_num=8):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
image_files, image_anno_dict = create_icdar_test_label()
|
||||
ctpn_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 6]},
|
||||
}
|
||||
writer.add_schema(ctpn_json, "ctpn_json")
|
||||
for image_name in image_files:
|
||||
with open(image_name, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0,
|
||||
is_training=True, num_parallel_workers=4):
|
||||
"""Creatr deeptext dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\
|
||||
num_parallel_workers=8, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
|
||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
||||
hwc_to_chw = C.HWC2CHW()
|
||||
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
|
||||
type_cast0 = CC.TypeCast(mstype.float32)
|
||||
type_cast1 = CC.TypeCast(mstype.float16)
|
||||
type_cast2 = CC.TypeCast(mstype.int32)
|
||||
type_cast3 = CC.TypeCast(mstype.bool_)
|
||||
if is_training:
|
||||
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
|
||||
num_parallel_workers=12)
|
||||
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
|
||||
num_parallel_workers=12)
|
||||
else:
|
||||
ds = ds.map(operations=compose_map_func,
|
||||
input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
|
||||
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
|
||||
num_parallel_workers=24)
|
||||
# transpose_column from python to c
|
||||
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
|
||||
ds = ds.map(operations=[type_cast1], input_columns=["box"])
|
||||
ds = ds.map(operations=[type_cast2], input_columns=["label"])
|
||||
ds = ds.map(operations=[type_cast3], input_columns=["valid_num"])
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for deeptext"""
|
||||
import math
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(config, base_step):
|
||||
"""dynamic learning rate generator"""
|
||||
base_lr = config.base_lr
|
||||
total_steps = int(base_step * config.total_epoch)
|
||||
warmup_steps = config.warmup_step
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
return lr
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn training network wrapper."""
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1, rank_id=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.count = 0
|
||||
self.rpn_loss_sum = 0
|
||||
self.rpn_cls_loss_sum = 0
|
||||
self.rpn_reg_loss_sum = 0
|
||||
self.rank_id = rank_id
|
||||
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = time.time()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
rpn_loss = cb_params.net_outputs[0].asnumpy()
|
||||
rpn_cls_loss = cb_params.net_outputs[1].asnumpy()
|
||||
rpn_reg_loss = cb_params.net_outputs[2].asnumpy()
|
||||
|
||||
self.count += 1
|
||||
self.rpn_loss_sum += float(rpn_loss)
|
||||
self.rpn_cls_loss_sum += float(rpn_cls_loss)
|
||||
self.rpn_reg_loss_sum += float(rpn_reg_loss)
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if self.count >= 1:
|
||||
global time_stamp_first
|
||||
time_stamp_current = time.time()
|
||||
rpn_loss = self.rpn_loss_sum / self.count
|
||||
rpn_cls_loss = self.rpn_cls_loss_sum / self.count
|
||||
rpn_reg_loss = self.rpn_reg_loss_sum / self.count
|
||||
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
|
||||
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"%
|
||||
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
rpn_loss, rpn_cls_loss, rpn_reg_loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
class LossNet(nn.Cell):
|
||||
"""FasterRcnn loss method"""
|
||||
def construct(self, x1, x2, x3):
|
||||
return x1
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to compute loss.
|
||||
|
||||
Args:
|
||||
backbone (Cell): The target network to wrap.
|
||||
loss_fn (Cell): The loss function used to compute loss.
|
||||
"""
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
||||
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
||||
return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""
|
||||
Get the backbone network.
|
||||
|
||||
Returns:
|
||||
Cell, return backbone network.
|
||||
"""
|
||||
return self._backbone
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Network training package class.
|
||||
|
||||
Append an optimizer to the training network after that the construct function
|
||||
can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
network_backbone (Cell): The forward network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default value is 1.0.
|
||||
reduce_flag (bool): The reduce flag. Default value is False.
|
||||
mean (bool): Allreduce method. Default value is False.
|
||||
degree (int): Device number. Default value is None.
|
||||
"""
|
||||
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.backbone = network_backbone
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
|
||||
self.reduce_flag = reduce_flag
|
||||
if reduce_flag:
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
||||
weights = self.weights
|
||||
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
||||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
|
||||
if self.reduce_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss
|
|
@ -0,0 +1,65 @@
|
|||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================import numpy as np
|
||||
import numpy as np
|
||||
from src.text_connector.utils import clip_boxes, fit_y
|
||||
from src.text_connector.get_successions import get_successions
|
||||
|
||||
def connect_text_lines(text_proposals, scores, size):
|
||||
"""
|
||||
Connect text lines
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Predict text proposals.
|
||||
scores(numpy.array): Bbox predicts scores.
|
||||
size(numpy.array): Image size.
|
||||
Returns:
|
||||
text_recs(numpy.array): Text boxes after connect.
|
||||
"""
|
||||
graph = get_successions(text_proposals, scores, size)
|
||||
text_lines = np.zeros((len(graph), 5), np.float32)
|
||||
for index, indices in enumerate(graph):
|
||||
text_line_boxes = text_proposals[list(indices)]
|
||||
x0 = np.min(text_line_boxes[:, 0])
|
||||
x1 = np.max(text_line_boxes[:, 2])
|
||||
|
||||
offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5
|
||||
|
||||
lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
|
||||
lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)
|
||||
|
||||
# the score of a text line is the average score of the scores
|
||||
# of all text proposals contained in the text line
|
||||
score = scores[list(indices)].sum() / float(len(indices))
|
||||
|
||||
text_lines[index, 0] = x0
|
||||
text_lines[index, 1] = min(lt_y, rt_y)
|
||||
text_lines[index, 2] = x1
|
||||
text_lines[index, 3] = max(lb_y, rb_y)
|
||||
text_lines[index, 4] = score
|
||||
|
||||
text_lines = clip_boxes(text_lines, size)
|
||||
|
||||
text_recs = np.zeros((len(text_lines), 9), np.float)
|
||||
index = 0
|
||||
for line in text_lines:
|
||||
xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3]
|
||||
text_recs[index, 0] = xmin
|
||||
text_recs[index, 1] = ymin
|
||||
text_recs[index, 2] = xmax
|
||||
text_recs[index, 3] = ymax
|
||||
text_recs[index, 4] = line[4]
|
||||
index = index + 1
|
||||
return text_recs
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from src.text_connector.utils import nms
|
||||
from src.text_connector.connect_text_lines import connect_text_lines
|
||||
|
||||
def filter_proposal(proposals, scores):
|
||||
"""
|
||||
Filter text proposals
|
||||
|
||||
Args:
|
||||
proposals(numpy.array): Text proposals.
|
||||
Returns:
|
||||
proposals(numpy.array): Text proposals after filter.
|
||||
"""
|
||||
inds = np.where(scores > config.text_proposals_min_scores)[0]
|
||||
keep_proposals = proposals[inds]
|
||||
keep_scores = scores[inds]
|
||||
sorted_inds = np.argsort(keep_scores.ravel())[::-1]
|
||||
keep_proposals, keep_scores = keep_proposals[sorted_inds], keep_scores[sorted_inds]
|
||||
nms_inds = nms(np.hstack((keep_proposals, keep_scores)), config.text_proposals_nms_thresh)
|
||||
keep_proposals, keep_scores = keep_proposals[nms_inds], keep_scores[nms_inds]
|
||||
return keep_proposals, keep_scores
|
||||
|
||||
def filter_boxes(boxes):
|
||||
"""
|
||||
Filter text boxes
|
||||
|
||||
Args:
|
||||
boxes(numpy.array): Text boxes.
|
||||
Returns:
|
||||
boxes(numpy.array): Text boxes after filter.
|
||||
"""
|
||||
heights = np.zeros((len(boxes), 1), np.float)
|
||||
widths = np.zeros((len(boxes), 1), np.float)
|
||||
scores = np.zeros((len(boxes), 1), np.float)
|
||||
index = 0
|
||||
for box in boxes:
|
||||
widths[index] = abs(box[2] - box[0])
|
||||
heights[index] = abs(box[3] - box[1])
|
||||
scores[index] = abs(box[4])
|
||||
index += 1
|
||||
return np.where((widths / heights > config.min_ratio) & (scores > config.line_min_score) &\
|
||||
(widths > (config.text_proposals_width * config.min_num_proposals)))[0]
|
||||
|
||||
def detect(text_proposals, scores, size):
|
||||
"""
|
||||
Detect text boxes
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Predict text proposals.
|
||||
scores(numpy.array): Bbox predicts scores.
|
||||
size(numpy.array): Image size.
|
||||
Returns:
|
||||
boxes(numpy.array): Text boxes after connect.
|
||||
"""
|
||||
keep_proposals, keep_scores = filter_proposal(text_proposals, scores)
|
||||
connect_boxes = connect_text_lines(keep_proposals, keep_scores, size)
|
||||
boxes = connect_boxes[filter_boxes(connect_boxes)]
|
||||
return boxes
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from src.text_connector.utils import overlaps_v, size_similarity
|
||||
|
||||
def get_successions(text_proposals, scores, im_size):
|
||||
"""
|
||||
Get successions text boxes.
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Predict text proposals.
|
||||
scores(numpy.array): Bbox predicts scores.
|
||||
size(numpy.array): Image size.
|
||||
Returns:
|
||||
sub_graph(list): Proposals graph.
|
||||
"""
|
||||
bboxes_table = [[] for _ in range(int(im_size[1]))]
|
||||
for index, box in enumerate(text_proposals):
|
||||
bboxes_table[int(box[0])].append(index)
|
||||
graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool)
|
||||
for index, box in enumerate(text_proposals):
|
||||
successions_left = []
|
||||
for left in range(int(box[0]) + 1, min(int(box[0]) + config.max_horizontal_gap + 1, im_size[1])):
|
||||
adj_box_indices = bboxes_table[left]
|
||||
for adj_box_index in adj_box_indices:
|
||||
if meet_v_iou(text_proposals, adj_box_index, index):
|
||||
successions_left.append(adj_box_index)
|
||||
if successions_left:
|
||||
break
|
||||
if not successions_left:
|
||||
continue
|
||||
succession_index = successions_left[np.argmax(scores[successions_left])]
|
||||
box_right = text_proposals[succession_index]
|
||||
succession_right = []
|
||||
for right in range(int(box_right[0]) - 1, max(int(box_right[0] - config.max_horizontal_gap), 0) - 1, -1):
|
||||
adj_box_indices = bboxes_table[right]
|
||||
for adj_box_index in adj_box_indices:
|
||||
if meet_v_iou(text_proposals, adj_box_index, index):
|
||||
succession_right.append(adj_box_index)
|
||||
if succession_right:
|
||||
break
|
||||
if scores[index] >= np.max(scores[succession_right]):
|
||||
graph[index, succession_index] = True
|
||||
sub_graph = get_sub_graph(graph)
|
||||
return sub_graph
|
||||
|
||||
def get_sub_graph(graph):
|
||||
"""
|
||||
Get successions text boxes.
|
||||
|
||||
Args:
|
||||
graph(numpy.array): proposal graph
|
||||
Returns:
|
||||
sub_graph(list): Proposals graph after connect.
|
||||
"""
|
||||
sub_graphs = []
|
||||
for index in range(graph.shape[0]):
|
||||
if not graph[:, index].any() and graph[index, :].any():
|
||||
v = index
|
||||
sub_graphs.append([v])
|
||||
while graph[v, :].any():
|
||||
v = np.where(graph[v, :])[0][0]
|
||||
sub_graphs[-1].append(v)
|
||||
return sub_graphs
|
||||
|
||||
def meet_v_iou(text_proposals, index1, index2):
|
||||
"""
|
||||
Calculate vertical iou.
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): tex proposals
|
||||
index1(int): text_proposal index
|
||||
tindex2(int): text proposal index
|
||||
Returns:
|
||||
sub_graph(list): Proposals graph after connect.
|
||||
"""
|
||||
heights = text_proposals[:, 3] - text_proposals[:, 1] + 1
|
||||
return overlaps_v(text_proposals, index1, index2) >= config.min_v_overlaps and \
|
||||
size_similarity(heights, index1, index2) >= config.min_size_sim
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
def threshold(coords, min_, max_):
|
||||
return np.maximum(np.minimum(coords, max_), min_)
|
||||
|
||||
def clip_boxes(boxes, im_shape):
|
||||
"""
|
||||
Clip boxes to image boundaries.
|
||||
|
||||
Args:
|
||||
boxes(numpy.array):bounding box.
|
||||
im_shape(numpy.array): image shape.
|
||||
|
||||
Return:
|
||||
boxes(numpy.array):boundding box after clip.
|
||||
"""
|
||||
boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1)
|
||||
boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1)
|
||||
return boxes
|
||||
|
||||
def overlaps_v(text_proposals, index1, index2):
|
||||
"""
|
||||
Calculate vertical overlap ratio.
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Text proposlas.
|
||||
index1(int): First text proposal.
|
||||
index2(int): Second text proposal.
|
||||
|
||||
Return:
|
||||
overlap(float32): vertical overlap.
|
||||
"""
|
||||
h1 = text_proposals[index1][3] - text_proposals[index1][1] + 1
|
||||
h2 = text_proposals[index2][3] - text_proposals[index2][1] + 1
|
||||
y0 = max(text_proposals[index2][1], text_proposals[index1][1])
|
||||
y1 = min(text_proposals[index2][3], text_proposals[index1][3])
|
||||
return max(0, y1 - y0 + 1) / min(h1, h2)
|
||||
|
||||
def size_similarity(heights, index1, index2):
|
||||
"""
|
||||
Calculate vertical size similarity ratio.
|
||||
|
||||
Args:
|
||||
heights(numpy.array): Text proposlas heights.
|
||||
index1(int): First text proposal.
|
||||
index2(int): Second text proposal.
|
||||
|
||||
Return:
|
||||
overlap(float32): vertical overlap.
|
||||
"""
|
||||
h1 = heights[index1]
|
||||
h2 = heights[index2]
|
||||
return min(h1, h2) / max(h1, h2)
|
||||
|
||||
def fit_y(X, Y, x1, x2):
|
||||
if np.sum(X == X[0]) == len(X):
|
||||
return Y[0], Y[0]
|
||||
p = np.poly1d(np.polyfit(X, Y, 1))
|
||||
return p(x1), p(x2)
|
||||
|
||||
def nms(bboxs, thresh):
|
||||
"""
|
||||
Args:
|
||||
text_proposals(numpy.array): tex proposals
|
||||
index1(int): text_proposal index
|
||||
tindex2(int): text proposal index
|
||||
"""
|
||||
x1, y1, x2, y2, scores = np.split(bboxs, 5, axis=1)
|
||||
x1 = bboxs[:, 0]
|
||||
y1 = bboxs[:, 1]
|
||||
x2 = bboxs[:, 2]
|
||||
y2 = bboxs[:, 3]
|
||||
scores = bboxs[:, 4]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
num_dets = bboxs.shape[0]
|
||||
suppressed = np.zeros(num_dets, dtype=np.int32)
|
||||
keep = []
|
||||
for _i in range(num_dets):
|
||||
i = order[_i]
|
||||
if suppressed[i] == 1:
|
||||
continue
|
||||
keep.append(i)
|
||||
x1_i = x1[i]
|
||||
y1_i = y1[i]
|
||||
x2_i = x2[i]
|
||||
y2_i = y2[i]
|
||||
area_i = areas[i]
|
||||
for _j in range(_i + 1, num_dets):
|
||||
j = order[_j]
|
||||
if suppressed[j] == 1:
|
||||
continue
|
||||
x1_j = max(x1_i, x1[j])
|
||||
y1_j = max(y1_i, y1[j])
|
||||
x2_j = min(x2_i, x2[j])
|
||||
y2_j = min(y2_i, y2[j])
|
||||
w = max(0.0, x2_j - x1_j + 1)
|
||||
h = max(0.0, y2_j - y1_j + 1)
|
||||
inter = w*h
|
||||
overlap = inter / (area_i+areas[j]-inter)
|
||||
if overlap >= thresh:
|
||||
suppressed[j] = 1
|
||||
return keep
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""train CTPN and get checkpoint files."""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.common import set_seed
|
||||
from src.ctpn import CTPN
|
||||
from src.config import config, pretrain_config, finetune_config
|
||||
from src.dataset import create_ctpn_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CTPN training")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
|
||||
parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
|
||||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
parser.add_argument("--task_type", type=str, default="Pretraining",\
|
||||
choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.run_distribute:
|
||||
rank = args_opt.rank_id
|
||||
device_num = args_opt.device_num
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
if args_opt.task_type == "Pretraining":
|
||||
print("Start to do pretraining")
|
||||
mindrecord_file = config.pretraining_dataset_file
|
||||
training_cfg = pretrain_config
|
||||
else:
|
||||
print("Start to do finetune")
|
||||
mindrecord_file = config.finetune_dataset_file
|
||||
training_cfg = finetune_config
|
||||
|
||||
print("CHECKING MINDRECORD FILES ...")
|
||||
while not os.path.exists(mindrecord_file + ".db"):
|
||||
time.sleep(5)
|
||||
|
||||
print("CHECKING MINDRECORD FILES DONE!")
|
||||
|
||||
loss_scale = float(config.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as ctpn_pretrain.mindrecord0.
|
||||
dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\
|
||||
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
net = CTPN(config=config, is_training=True)
|
||||
net = net.set_train()
|
||||
|
||||
load_path = args_opt.pre_trained
|
||||
if args_opt.task_type == "Pretraining":
|
||||
print("load backbone vgg16 ckpt {}".format(args_opt.pre_trained))
|
||||
param_dict = load_checkpoint(load_path)
|
||||
for item in list(param_dict.keys()):
|
||||
if not item.startswith('vgg16_feature_extractor'):
|
||||
param_dict.pop(item)
|
||||
load_param_into_net(net, param_dict)
|
||||
else:
|
||||
if load_path != "":
|
||||
print("load pretrain ckpt {}".format(args_opt.pre_trained))
|
||||
param_dict = load_checkpoint(load_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
loss = LossNet()
|
||||
lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32)
|
||||
opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
if args_opt.run_distribute:
|
||||
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
|
||||
mean=True, degree=device_num)
|
||||
else:
|
||||
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
|
||||
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossCallBack(rank_id=rank)
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
|
||||
ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig)
|
||||
cb += [ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True)
|
Loading…
Reference in New Issue