forked from mindspore-Ecosystem/mindspore
init for ctpn
add for connect table fix some bug fix pylint fix for create dataset fix dataset bug fix for create dataset problem add for svt icdar2015 convert script fix for ctpn problem fix for vgg16
This commit is contained in:
parent
78c733ffbe
commit
ae27c383fa
|
@ -0,0 +1,293 @@
|
|||
![](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
# CTPN for Ascend
|
||||
|
||||
- [CTPN Description](#CTPN-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [CTPN Description](#contents)
|
||||
|
||||
CTPN is a text detection model based on object detection method. It improves Faster R-CNN and combines with bidirectional LSTM, so ctpn is very effective for horizontal text detection. Another highlight of ctpn is to transform the text detection task into a series of small-scale text box detection.This idea was proposed in the paper "Detecting Text in Natural Image with Connectionist Text Proposal Network".
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1609.03605.pdf) Zhi Tian, Weilin Huang, Tong He, Pan He, Yu Qiao, "Detecting Text in Natural Image with Connectionist Text Proposal Network", ArXiv, vol. abs/1609.03605, 2016.
|
||||
|
||||
# [Model architecture](#contents)
|
||||
|
||||
The overall network architecture contains a VGG16 as backbone, and use bidirection lstm to extract context feature of the small-scale text box, then it used the RPN(RegionProposal Network) to predict the boundding box and probability.
|
||||
|
||||
[Link](https://arxiv.org/pdf/1605.07314v1.pdf)
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Here we used 6 datasets for training, and 1 datasets for Evaluation.
|
||||
|
||||
- Dataset1: ICDAR 2013: Focused Scene Text
|
||||
- Train: 142MB, 229 images
|
||||
- Test: 110MB, 233 images
|
||||
- Dataset2: ICDAR 2011: Born-Digital Images
|
||||
- Train: 27.7MB, 410 images
|
||||
- Dataset3: ICDAR 2015:
|
||||
- Train:89MB, 1000 images
|
||||
- Dataset4: SCUT-FORU: Flickr OCR Universal Database
|
||||
- Train: 388MB, 1715 images
|
||||
- Dataset5: CocoText v2(Subset of MSCOCO2017):
|
||||
- Train: 13GB, 63686 images
|
||||
- Dataset6: SVT(The Street View Dataset)
|
||||
- Train: 115MB, 349 images
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─ctpn
|
||||
├── README.md # network readme
|
||||
├── eval.py # eval net
|
||||
├── scripts
|
||||
│ ├── eval_res.sh # calculate precision and recall
|
||||
│ ├── run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
|
||||
│ ├── run_eval_ascend.sh # launch evaluating with ascend platform
|
||||
│ └── run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
|
||||
├── src
|
||||
│ ├── CTPN
|
||||
│ │ ├── BoundingBoxDecode.py # bounding box decode
|
||||
│ │ ├── BoundingBoxEncode.py # bounding box encode
|
||||
│ │ ├── __init__.py # package init file
|
||||
│ │ ├── anchor_generator.py # anchor generator
|
||||
│ │ ├── bbox_assign_sample.py # proposal layer
|
||||
│ │ ├── proposal_generator.py # proposla generator
|
||||
│ │ ├── rpn.py # region-proposal network
|
||||
│ │ └── vgg16.py # backbone
|
||||
│ ├── config.py # training configuration
|
||||
│ ├── convert_icdar2015.py # convert icdar2015 dataset label
|
||||
│ ├── convert_svt.py # convert svt label
|
||||
│ ├── create_dataset.py # create mindrecord dataset
|
||||
│ ├── ctpn.py # ctpn network definition
|
||||
│ ├── dataset.py # data proprocessing
|
||||
│ ├── lr_schedule.py # learning rate scheduler
|
||||
│ ├── network_define.py # network definition
|
||||
│ └── text_connector
|
||||
│ ├── __init__.py # package init file
|
||||
│ ├── connect_text_lines.py # connect text lines
|
||||
│ ├── detector.py # detect box
|
||||
│ ├── get_successions.py # get succession proposal
|
||||
│ └── utils.py # some functions which is commonly used
|
||||
└── train.py # train net
|
||||
|
||||
```
|
||||
|
||||
## [Training process](#contents)
|
||||
|
||||
### Dataset
|
||||
|
||||
To create dataset, download the dataset first and deal with it.We provided src/convert_svt.py and src/convert_icdar2015.py to deal with svt and icdar2015 dataset label.For svt dataset, you can deal with it as below:
|
||||
|
||||
```shell
|
||||
python convert_svt.py --dataset_path=/path/img --xml_file=/path/train.xml --location_dir=/path/location
|
||||
```
|
||||
|
||||
For ICDAR2015 dataset, you can deal with it
|
||||
|
||||
```shell
|
||||
python convert_icdar2015.py --src_label_path=/path/train_label --target_label_path=/path/label
|
||||
```
|
||||
|
||||
Then modify the src/config.py and add the dataset path.For each path, add IMAGE_PATH and LABEL_PATH into a list in config.An example is show as blow:
|
||||
|
||||
```python
|
||||
# create dataset
|
||||
"coco_root": "/path/coco",
|
||||
"coco_train_data_type": "train2017",
|
||||
"cocotext_json": "/path/cocotext.v2.json",
|
||||
"icdar11_train_path": ["/path/image/", "/path/label"],
|
||||
"icdar13_train_path": ["/path/image/", "/path/label"],
|
||||
"icdar15_train_path": ["/path/image/", "/path/label"],
|
||||
"icdar13_test_path": ["/path/image/", "/path/label"],
|
||||
"flick_train_path": ["/path/image/", "/path/label"],
|
||||
"svt_train_path": ["/path/image/", "/path/label"],
|
||||
"pretrain_dataset_path": "",
|
||||
"finetune_dataset_path": "",
|
||||
"test_dataset_path": "",
|
||||
```
|
||||
|
||||
Then you can create dataset with src/create_dataset.py with the command as below:
|
||||
|
||||
```shell
|
||||
python src/create_dataset.py
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
- Ascend:
|
||||
|
||||
```bash
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]
|
||||
# standalone training
|
||||
sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]
|
||||
# evaluation:
|
||||
sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).To get the vgg16 backbone, you can use the network structure defined in src/CTPN/vgg16.py.To train the backbone, copy the src/CTPN/vgg16.py under modelzoo/official/cv/vgg16/src/, and modify the vgg16/train.py to suit the new construction.You can fix it as below:
|
||||
|
||||
```python
|
||||
...
|
||||
from src.vgg16 import VGG16
|
||||
...
|
||||
network = VGG16()
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
Then you can train it with ImageNet2012.
|
||||
> Notes:
|
||||
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
>
|
||||
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
|
||||
>
|
||||
> TASK_TYPE contains Pretraining and Finetune. For Pretraining, we use ICDAR2013, ICDAR2015, SVT, SCUT-FORU, CocoText v2. For Finetune, we use ICDAR2011,
|
||||
ICDAR2013, SCUT-FORU to improve precision and recall, and when doing Finetune, we use the checkpoint training in Pretrain as our PRETRAINED_PATH.
|
||||
> COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).
|
||||
>
|
||||
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
# training example
|
||||
shell:
|
||||
Ascend:
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]
|
||||
# standalone training
|
||||
sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log`, also the loss will be redirected to `./loss_0.log` like followings.
|
||||
|
||||
```python
|
||||
377 epoch: 1 step: 229 ,rpn_loss: 0.00355, rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00103,
|
||||
399 epoch: 2 step: 229 ,rpn_loss: 0.00327,rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00093,
|
||||
424 epoch: 3 step: 229 ,rpn_loss: 0.00910, rpn_cls_loss: 0.00385, rpn_reg_loss: 0.00175,
|
||||
```
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- Ascend:
|
||||
|
||||
```bash
|
||||
sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
After eval, you can get serval archive file named submit_ctpn-xx_xxxx.zip, which contains the name of your checkpoint file.To evalulate it, you can use the scripts provided by the ICDAR2013 network, you can download the Deteval scripts from the [link](https://rrc.cvc.uab.es/?com=downloads&action=download&ch=2&f=aHR0cHM6Ly9ycmMuY3ZjLnVhYi5lcy9zdGFuZGFsb25lcy9zY3JpcHRfdGVzdF9jaDJfdDFfZTItMTU3Nzk4MzA2Ny56aXA=)
|
||||
After download the scripts, unzip it and put it under ctpn/scripts and use eval_res.sh to get the result.You will get files as below:
|
||||
|
||||
```text
|
||||
gt.zip
|
||||
readme.txt
|
||||
rrc_evalulation_funcs_1_1.py
|
||||
script.py
|
||||
```
|
||||
|
||||
Then you can run the scripts/eval_res.sh to calculate the evalulation result.
|
||||
|
||||
```base
|
||||
bash eval_res.sh
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Evaluation result will be stored in the example path, you can find result like the followings in `log`.
|
||||
|
||||
```text
|
||||
{"precision": 0.90791, "recall": 0.86118, "hmean": 0.88393}
|
||||
```
|
||||
|
||||
# [Model description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------ |
|
||||
| Model Version | CTPN |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||
| uploaded Date | 02/06/2021 |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | 16930 images |
|
||||
| Batch_size | 2 |
|
||||
| Training Parameters | src/config.py |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits for classification, SmoothL2Loss for bbox regression|
|
||||
| Loss | ~0.04 |
|
||||
| Total time (8p) | 6h |
|
||||
| Scripts | [ctpn script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ctpn) |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | CTPN |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
|
||||
| Uploaded Date | 02/06/2020 |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | 229 images |
|
||||
| Batch_size | 1 |
|
||||
| Accuracy | precision=0.9079, recall=0.8611 F-measure:0.8839 |
|
||||
| Total time | 1 min |
|
||||
| Model for inference | 135M (.ckpt file) |
|
||||
|
||||
#### Training performance results
|
||||
|
||||
| **Ascend** | train performance |
|
||||
| :--------: | :---------------: |
|
||||
| 1p | 10 img/s |
|
||||
|
||||
| **Ascend** | train performance |
|
||||
| :--------: | :---------------: |
|
||||
| 8p | 84 img/s |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We set seed to 1 in train.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for CTPN"""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.ctpn import CTPN
|
||||
from src.config import config
|
||||
from src.dataset import create_ctpn_dataset
|
||||
from src.text_connector.detector import detect
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CTPN evaluation")
|
||||
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path.")
|
||||
parser.add_argument("--image_path", type=str, default="", help="Image path.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
|
||||
def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
|
||||
"""ctpn infer."""
|
||||
print("ckpt path is {}".format(ckpt_path))
|
||||
ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False)
|
||||
config.batch_size = config.test_batch_size
|
||||
total = ds.get_dataset_size()
|
||||
print("*************total dataset size is {}".format(total))
|
||||
net = CTPN(config, is_training=False)
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
eval_iter = 0
|
||||
|
||||
print("\n========================================\n")
|
||||
print("Processing, please wait a moment.")
|
||||
img_basenames = []
|
||||
output_dir = os.path.join(os.getcwd(), "submit")
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
for file in os.listdir(img_dir):
|
||||
img_basenames.append(os.path.basename(file))
|
||||
for data in ds.create_dict_iterator():
|
||||
img_data = data['image']
|
||||
img_metas = data['image_shape']
|
||||
gt_bboxes = data['box']
|
||||
gt_labels = data['label']
|
||||
gt_num = data['valid_num']
|
||||
|
||||
start = time.time()
|
||||
# run net
|
||||
output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
|
||||
gt_bboxes = gt_bboxes.asnumpy()
|
||||
gt_labels = gt_labels.asnumpy()
|
||||
gt_num = gt_num.asnumpy().astype(bool)
|
||||
end = time.time()
|
||||
proposal = output[0]
|
||||
proposal_mask = output[1]
|
||||
print("start to draw pic")
|
||||
for j in range(config.test_batch_size):
|
||||
img = img_basenames[config.test_batch_size * eval_iter + j]
|
||||
all_box_tmp = proposal[j].asnumpy()
|
||||
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
|
||||
using_boxes_mask = all_box_tmp * all_mask_tmp
|
||||
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
|
||||
scores = using_boxes_mask[:, 4].astype(np.float32)
|
||||
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
|
||||
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
|
||||
from PIL import Image, ImageDraw
|
||||
im = Image.open(img_dir + '/' + img)
|
||||
draw = ImageDraw.Draw(im)
|
||||
image_h = img_metas.asnumpy()[j][2]
|
||||
image_w = img_metas.asnumpy()[j][3]
|
||||
gt_boxs = gt_bboxes[j][gt_num[j], :]
|
||||
for gt_box in gt_boxs:
|
||||
gt_x1 = gt_box[0] / image_w
|
||||
gt_y1 = gt_box[1] / image_h
|
||||
gt_x2 = gt_box[2] / image_w
|
||||
gt_y2 = gt_box[3] / image_h
|
||||
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
|
||||
fill='green', width=2)
|
||||
file_name = "res_" + img.replace("jpg", "txt")
|
||||
output_file = os.path.join(output_dir, file_name)
|
||||
f = open(output_file, 'w')
|
||||
for bbox in bboxes:
|
||||
x1 = bbox[0] / image_w
|
||||
y1 = bbox[1] / image_h
|
||||
x2 = bbox[2] / image_w
|
||||
y2 = bbox[3] / image_h
|
||||
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
|
||||
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
|
||||
f.write(str_tmp)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
im.save(img)
|
||||
percent = round(eval_iter / total * 100, 2)
|
||||
eval_iter = eval_iter + 1
|
||||
print("Iter {} cost time {}".format(eval_iter, end - start))
|
||||
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
|
||||
|
||||
if __name__ == '__main__':
|
||||
ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path)
|
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
for submit_file in "submit"*.zip
|
||||
do
|
||||
echo "eval result for ${submit_file}"
|
||||
python script.py –g=gt.zip –s=${submit_file} –o=./
|
||||
echo -e ".\n"
|
||||
done
|
|
@ -0,0 +1,67 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
TASK_TYPE=$2
|
||||
PATH2=$(get_real_path $3)
|
||||
echo $PATH2
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --task_type=$TASK_TYPE --pre_trained=$PATH2 &> log &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,80 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
IMAGE_PATH=$(get_real_path $1)
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
CHECKPOINT_PATH=$(get_real_path $3)
|
||||
echo $IMAGE_PATH
|
||||
echo $DATASET_PATH
|
||||
echo $CHECKPOINT_PATH
|
||||
|
||||
if [ ! -d $IMAGE_PATH ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$DATASET_PATH is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $CHECKPOINT_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$CHECKPOINT_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
for file in "${CHECKPOINT_PATH}"/*.ckpt
|
||||
do
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval
|
||||
env > env.log
|
||||
CHECKPOINT_FILE_PATH=$file
|
||||
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
python eval.py --device_id=$DEVICE_ID --image_path=$IMAGE_PATH --dataset_path=$DATASET_PATH --checkpoint_path=$CHECKPOINT_FILE_PATH &> log
|
||||
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
cd ./submit
|
||||
file_base_name=$(basename $file)
|
||||
zip -r ../../submit_${file_base_name%.*}.zip *.txt
|
||||
cd ../../
|
||||
done
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
TASK_TYPE=$1
|
||||
PRETRAINED_PATH=$(get_real_path $2)
|
||||
echo $PRETRAINED_PATH
|
||||
if [ ! -f $PRETRAINED_PATH ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH &> log &
|
||||
cd ..
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class BoundingBoxDecode(nn.Cell):
|
||||
"""
|
||||
BoundintBox Decoder.
|
||||
|
||||
Returns:
|
||||
pred_box(Tensor): decoder bounding boxes.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BoundingBoxDecode, self).__init__()
|
||||
self.split = P.Split(axis=1, output_num=4)
|
||||
self.ones = 1.0
|
||||
self.half = 0.5
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, bboxes, deltas):
|
||||
"""
|
||||
boxes(Tensor): boundingbox.
|
||||
deltas(Tensor): delta between boundingboxs and anchors.
|
||||
"""
|
||||
x1, y1, x2, y2 = self.split(bboxes)
|
||||
width = x2 - x1 + self.ones
|
||||
height = y2 - y1 + self.ones
|
||||
ctr_x = x1 + self.half * width
|
||||
ctr_y = y1 + self.half * height
|
||||
_, dy, _, dh = self.split(deltas)
|
||||
pred_ctr_x = ctr_x
|
||||
pred_ctr_y = dy * height + ctr_y
|
||||
pred_w = width
|
||||
pred_h = self.exp(dh) * height
|
||||
|
||||
x1 = pred_ctr_x - self.half * pred_w
|
||||
y1 = pred_ctr_y - self.half * pred_h
|
||||
x2 = pred_ctr_x + self.half * pred_w
|
||||
y2 = pred_ctr_y + self.half * pred_h
|
||||
pred_box = self.concat((x1, y1, x2, y2))
|
||||
return pred_box
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class BoundingBoxEncode(nn.Cell):
|
||||
"""
|
||||
BoundintBox Decoder.
|
||||
|
||||
Returns:
|
||||
pred_box(Tensor): decoder bounding boxes.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BoundingBoxEncode, self).__init__()
|
||||
self.split = P.Split(axis=1, output_num=4)
|
||||
self.ones = 1.0
|
||||
self.half = 0.5
|
||||
self.log = P.Log()
|
||||
self.concat = P.Concat(axis=1)
|
||||
def construct(self, anchor_box, gt_box):
|
||||
"""
|
||||
boxes(Tensor): boundingbox.
|
||||
deltas(Tensor): delta between boundingboxs and anchors.
|
||||
"""
|
||||
x1, y1, x2, y2 = self.split(anchor_box)
|
||||
width = x2 - x1 + self.ones
|
||||
height = y2 - y1 + self.ones
|
||||
ctr_x = x1 + self.half * width
|
||||
ctr_y = y1 + self.half * height
|
||||
gt_x1, gt_y1, gt_x2, gt_y2 = self.split(gt_box)
|
||||
gt_width = gt_x2 - gt_x1 + self.ones
|
||||
gt_height = gt_y2 - gt_y1 + self.ones
|
||||
ctr_gt_x = gt_x1 + self.half * gt_width
|
||||
ctr_gt_y = gt_y1 + self.half * gt_height
|
||||
|
||||
target_dx = (ctr_gt_x - ctr_x) / width
|
||||
target_dy = (ctr_gt_y - ctr_y) / height
|
||||
dw = gt_width / width
|
||||
dh = gt_height / height
|
||||
target_dw = self.log(dw)
|
||||
target_dh = self.log(dh)
|
||||
deltas = self.concat((target_dx, target_dy, target_dw, target_dh))
|
||||
return deltas
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn anchor generator."""
|
||||
import numpy as np
|
||||
class AnchorGenerator():
|
||||
"""Anchor generator for FasterRcnn."""
|
||||
def __init__(self, config):
|
||||
"""Anchor generator init method."""
|
||||
self.base_size = config.anchor_base
|
||||
self.num_anchor = config.num_anchors
|
||||
self.anchor_height = config.anchor_height
|
||||
self.anchor_width = config.anchor_width
|
||||
self.size = self.gen_anchor_size()
|
||||
self.base_anchors = self.gen_base_anchors()
|
||||
|
||||
def gen_base_anchors(self):
|
||||
"""Generate a single anchor."""
|
||||
base_anchor = np.array([0, 0, self.base_size - 1, self.base_size - 1], np.int32)
|
||||
anchors = np.zeros((len(self.size), 4), np.int32)
|
||||
index = 0
|
||||
for h, w in self.size:
|
||||
anchors[index] = self.scale_anchor(base_anchor, h, w)
|
||||
index += 1
|
||||
return anchors
|
||||
|
||||
def gen_anchor_size(self):
|
||||
"""Generate a list of anchor size"""
|
||||
size = []
|
||||
for width in self.anchor_width:
|
||||
for height in self.anchor_height:
|
||||
size.append((height, width))
|
||||
return size
|
||||
|
||||
def scale_anchor(self, anchor, h, w):
|
||||
x_ctr = (anchor[0] + anchor[2]) * 0.5
|
||||
y_ctr = (anchor[1] + anchor[3]) * 0.5
|
||||
scaled_anchor = anchor.copy()
|
||||
scaled_anchor[0] = x_ctr - w / 2 # xmin
|
||||
scaled_anchor[2] = x_ctr + w / 2 # xmax
|
||||
scaled_anchor[1] = y_ctr - h / 2 # ymin
|
||||
scaled_anchor[3] = y_ctr + h / 2 # ymax
|
||||
return scaled_anchor
|
||||
|
||||
def _meshgrid(self, x, y):
|
||||
"""Generate grid."""
|
||||
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
|
||||
yy = np.repeat(y, len(x))
|
||||
return xx, yy
|
||||
|
||||
def grid_anchors(self, featmap_size, stride=16):
|
||||
"""Generate anchor list."""
|
||||
base_anchors = self.base_anchors
|
||||
feat_h, feat_w = featmap_size
|
||||
shift_x = np.arange(0, feat_w) * stride
|
||||
shift_y = np.arange(0, feat_h) * stride
|
||||
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
||||
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
|
||||
shifts = shifts.astype(base_anchors.dtype)
|
||||
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
|
||||
all_anchors = all_anchors.reshape(-1, 4)
|
||||
return all_anchors
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn positive and negative sample screening for RPN."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.CTPN.BoundingBoxEncode import BoundingBoxEncode
|
||||
|
||||
|
||||
class BboxAssignSample(nn.Cell):
|
||||
"""
|
||||
Bbox assigner and sampler definition.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_bboxes (int): The anchor nums.
|
||||
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
|
||||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
|
||||
labels: label for every bboxes, (batch_size, num_bboxes, 1)
|
||||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
|
||||
|
||||
Examples:
|
||||
BboxAssignSample(config, 2, 1024, True)
|
||||
"""
|
||||
|
||||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
|
||||
super(BboxAssignSample, self).__init__()
|
||||
cfg = config
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
|
||||
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
|
||||
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
|
||||
self.zero_thr = Tensor(0.0, mstype.float16)
|
||||
|
||||
self.num_bboxes = num_bboxes
|
||||
self.num_gts = cfg.num_gts
|
||||
self.num_expected_pos = cfg.num_expected_pos
|
||||
self.num_expected_neg = cfg.num_expected_neg
|
||||
self.add_gt_as_proposals = add_gt_as_proposals
|
||||
|
||||
if self.add_gt_as_proposals:
|
||||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
|
||||
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.max_gt = P.ArgMaxWithValue(axis=0)
|
||||
self.max_anchor = P.ArgMaxWithValue(axis=1)
|
||||
self.sum_inds = P.ReduceSum()
|
||||
self.iou = P.IOU()
|
||||
self.greaterequal = P.GreaterEqual()
|
||||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.less = P.Less()
|
||||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
|
||||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
|
||||
self.reshape = P.Reshape()
|
||||
self.equal = P.Equal()
|
||||
self.bounding_box_encode = BoundingBoxEncode()
|
||||
self.scatterNdUpdate = P.ScatterNdUpdate()
|
||||
self.scatterNd = P.ScatterNd()
|
||||
self.logicalnot = P.LogicalNot()
|
||||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
|
||||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
|
||||
|
||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
|
||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
|
||||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
|
||||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
|
||||
self.print = P.Print()
|
||||
|
||||
|
||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
|
||||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
|
||||
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
|
||||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
|
||||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
|
||||
overlaps = self.iou(bboxes, gt_bboxes_i)
|
||||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
|
||||
_, max_overlaps_w_ac = self.max_anchor(overlaps)
|
||||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
|
||||
self.less(max_overlaps_w_gt, self.neg_iou_thr))
|
||||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
|
||||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
|
||||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
|
||||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
|
||||
assigned_gt_inds4 = assigned_gt_inds3
|
||||
for j in range(self.num_gts):
|
||||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
|
||||
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
|
||||
|
||||
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
|
||||
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
|
||||
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
|
||||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
|
||||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
|
||||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
|
||||
pos_check_valid = self.sum_inds(pos_check_valid, -1)
|
||||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
|
||||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
|
||||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
|
||||
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
|
||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
|
||||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
|
||||
|
||||
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
|
||||
num_pos = self.sum_inds(num_pos, -1)
|
||||
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
|
||||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
|
||||
pos_bboxes_ = self.gatherND(bboxes, pos_index)
|
||||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
|
||||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
|
||||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
|
||||
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
|
||||
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
|
||||
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
|
||||
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
|
||||
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
|
||||
total_index = self.concat((pos_index, neg_index))
|
||||
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
|
||||
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
|
||||
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
|
||||
labels_total, self.cast(label_weights_total, mstype.bool_)
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn proposal generator."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from src.CTPN.BoundingBoxDecode import BoundingBoxDecode
|
||||
|
||||
class Proposal(nn.Cell):
|
||||
"""
|
||||
Proposal subnet.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_classes (int) - Class number.
|
||||
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
|
||||
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
|
||||
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor,(proposal, mask).
|
||||
|
||||
Examples:
|
||||
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
|
||||
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
batch_size,
|
||||
num_classes,
|
||||
use_sigmoid_cls,
|
||||
target_means=(.0, .0, .0, .0),
|
||||
target_stds=(1.0, 1.0, 1.0, 1.0)
|
||||
):
|
||||
super(Proposal, self).__init__()
|
||||
cfg = config
|
||||
self.batch_size = batch_size
|
||||
self.num_classes = num_classes
|
||||
self.target_means = target_means
|
||||
self.target_stds = target_stds
|
||||
self.use_sigmoid_cls = config.use_sigmoid_cls
|
||||
|
||||
if self.use_sigmoid_cls:
|
||||
self.cls_out_channels = 1
|
||||
self.activation = P.Sigmoid()
|
||||
self.reshape_shape = (-1, 1)
|
||||
else:
|
||||
self.cls_out_channels = num_classes
|
||||
self.activation = P.Softmax(axis=1)
|
||||
self.reshape_shape = (-1, 2)
|
||||
|
||||
if self.cls_out_channels <= 0:
|
||||
raise ValueError('num_classes={} is too small'.format(num_classes))
|
||||
|
||||
self.num_pre = cfg.rpn_proposal_nms_pre
|
||||
self.min_box_size = cfg.rpn_proposal_min_bbox_size
|
||||
self.nms_thr = cfg.rpn_proposal_nms_thr
|
||||
self.nms_post = cfg.rpn_proposal_nms_post
|
||||
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
|
||||
self.max_num = cfg.rpn_proposal_max_num
|
||||
|
||||
# Op Define
|
||||
self.squeeze = P.Squeeze()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.feature_shapes = cfg.feature_shapes
|
||||
|
||||
self.transpose_shape = (1, 2, 0)
|
||||
|
||||
self.decode = BoundingBoxDecode()
|
||||
|
||||
self.nms = P.NMSWithMask(self.nms_thr)
|
||||
self.concat_axis0 = P.Concat(axis=0)
|
||||
self.concat_axis1 = P.Concat(axis=1)
|
||||
self.split = P.Split(axis=1, output_num=5)
|
||||
self.min = P.Minimum()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.slice = P.Slice()
|
||||
self.select = P.Select()
|
||||
self.greater = P.Greater()
|
||||
self.transpose = P.Transpose()
|
||||
self.tile = P.Tile()
|
||||
self.set_train_local(config, training=True)
|
||||
|
||||
self.multi_10 = Tensor(10.0, mstype.float16)
|
||||
|
||||
def set_train_local(self, config, training=False):
|
||||
"""Set training flag."""
|
||||
self.training_local = training
|
||||
cfg = config
|
||||
self.topK_stage1 = ()
|
||||
self.topK_shape = ()
|
||||
total_max_topk_input = 0
|
||||
if not self.training_local:
|
||||
self.num_pre = cfg.rpn_nms_pre
|
||||
self.min_box_size = cfg.rpn_min_bbox_min_size
|
||||
self.nms_thr = cfg.rpn_nms_thr
|
||||
self.nms_post = cfg.rpn_nms_post
|
||||
self.max_num = cfg.rpn_max_num
|
||||
k_num = self.num_pre
|
||||
total_max_topk_input = k_num
|
||||
self.topK_stage1 = k_num
|
||||
self.topK_shape = (k_num, 1)
|
||||
|
||||
self.topKv2 = P.TopK(sorted=True)
|
||||
self.topK_shape_stage2 = (self.max_num, 1)
|
||||
self.min_float_num = -65536.0
|
||||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
|
||||
self.shape = P.Shape()
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
|
||||
proposals_tuple = ()
|
||||
masks_tuple = ()
|
||||
for img_id in range(self.batch_size):
|
||||
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[img_id:img_id+1:1, ::, ::, ::])
|
||||
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[img_id:img_id+1:1, ::, ::, ::])
|
||||
proposals, masks = self.get_bboxes_single(rpn_cls_score_i, rpn_bbox_pred_i, anchor_list)
|
||||
proposals_tuple += (proposals,)
|
||||
masks_tuple += (masks,)
|
||||
return proposals_tuple, masks_tuple
|
||||
|
||||
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
|
||||
"""Get proposal boundingbox."""
|
||||
mlvl_proposals = ()
|
||||
mlvl_mask = ()
|
||||
|
||||
rpn_cls_score = self.transpose(cls_scores, self.transpose_shape)
|
||||
rpn_bbox_pred = self.transpose(bbox_preds, self.transpose_shape)
|
||||
anchors = mlvl_anchors
|
||||
|
||||
# (H, W, A*2)
|
||||
rpn_cls_score_shape = self.shape(rpn_cls_score)
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, (rpn_cls_score_shape[0], \
|
||||
rpn_cls_score_shape[1], -1, self.cls_out_channels))
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
|
||||
rpn_cls_score = self.activation(rpn_cls_score)
|
||||
if self.use_sigmoid_cls:
|
||||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score), mstype.float16)
|
||||
else:
|
||||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 1]), mstype.float16)
|
||||
|
||||
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
|
||||
|
||||
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.num_pre)
|
||||
|
||||
topk_inds = self.reshape(topk_inds, self.topK_shape)
|
||||
|
||||
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
|
||||
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
|
||||
|
||||
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
|
||||
|
||||
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape)))
|
||||
proposals, _, mask_valid = self.nms(proposals_decode)
|
||||
|
||||
mlvl_proposals = mlvl_proposals + (proposals,)
|
||||
mlvl_mask = mlvl_mask + (mask_valid,)
|
||||
|
||||
proposals = self.concat_axis0(mlvl_proposals)
|
||||
masks = self.concat_axis0(mlvl_mask)
|
||||
|
||||
_, _, _, _, scores = self.split(proposals)
|
||||
scores = self.squeeze(scores)
|
||||
topk_mask = self.cast(self.topK_mask, mstype.float16)
|
||||
scores_using = self.select(masks, scores, topk_mask)
|
||||
|
||||
_, topk_inds = self.topKv2(scores_using, self.max_num)
|
||||
|
||||
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
|
||||
proposals = self.gatherND(proposals, topk_inds)
|
||||
masks = self.gatherND(masks, topk_inds)
|
||||
return proposals, masks
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""RPN for fasterRCNN"""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from src.CTPN.bbox_assign_sample import BboxAssignSample
|
||||
|
||||
class RpnRegClsBlock(nn.Cell):
|
||||
"""
|
||||
Rpn reg cls block for rpn layer
|
||||
|
||||
Args:
|
||||
config(EasyDict) - Network construction config.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
num_anchors,
|
||||
cls_out_channels):
|
||||
super(RpnRegClsBlock, self).__init__()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, 2*config.hidden_size)
|
||||
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
|
||||
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
|
||||
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
|
||||
self.shape1 = (config.num_step, config.rnn_batch_size, -1)
|
||||
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step)
|
||||
self.transpose = P.Transpose()
|
||||
self.print = P.Print()
|
||||
self.dropout = nn.Dropout(0.8)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.reshape(x, self.shape)
|
||||
x = self.lstm_fc(x)
|
||||
x1 = self.rpn_cls(x)
|
||||
x1 = self.reshape(x1, self.shape1)
|
||||
x1 = self.transpose(x1, (2, 1, 0))
|
||||
x1 = self.reshape(x1, self.shape2)
|
||||
x1 = self.transpose(x1, (1, 0, 2, 3))
|
||||
x2 = self.rpn_reg(x)
|
||||
x2 = self.reshape(x2, self.shape1)
|
||||
x2 = self.transpose(x2, (2, 1, 0))
|
||||
x2 = self.reshape(x2, self.shape2)
|
||||
x2 = self.transpose(x2, (1, 0, 2, 3))
|
||||
return x1, x2
|
||||
|
||||
class RPN(nn.Cell):
|
||||
"""
|
||||
ROI proposal network..
|
||||
|
||||
Args:
|
||||
config (dict) - Config.
|
||||
batch_size (int) - Batchsize.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor.
|
||||
|
||||
Examples:
|
||||
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
|
||||
num_anchors=3, cls_out_channels=512)
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
batch_size,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
num_anchors,
|
||||
cls_out_channels):
|
||||
super(RPN, self).__init__()
|
||||
cfg_rpn = config
|
||||
self.cfg = config
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.feature_anchor_shape = cfg_rpn.feature_shapes
|
||||
self.feature_anchor_shape = self.feature_anchor_shape[0] * \
|
||||
self.feature_anchor_shape[1] * num_anchors * batch_size
|
||||
self.num_anchors = num_anchors
|
||||
self.batch_size = batch_size
|
||||
self.test_batch_size = cfg_rpn.test_batch_size
|
||||
self.num_layers = 1
|
||||
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
|
||||
self.use_sigmoid_cls = config.use_sigmoid_cls
|
||||
if config.use_sigmoid_cls:
|
||||
self.reshape_shape_cls = (-1,)
|
||||
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
|
||||
cls_out_channels = 1
|
||||
else:
|
||||
self.reshape_shape_cls = (-1, cls_out_channels)
|
||||
self.loss_cls = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none")
|
||||
self.rpn_convs_list = self._make_rpn_layer(self.num_layers, in_channels, feat_channels,\
|
||||
num_anchors, cls_out_channels)
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.fill = P.Fill()
|
||||
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
|
||||
|
||||
self.trans_shape = (0, 2, 3, 1)
|
||||
|
||||
self.reshape_shape_reg = (-1, 4)
|
||||
self.softmax = nn.Softmax()
|
||||
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
|
||||
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
|
||||
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
|
||||
self.CheckValid = P.CheckValid()
|
||||
self.sum_loss = P.ReduceSum()
|
||||
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.print = P.Print()
|
||||
|
||||
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
|
||||
"""
|
||||
make rpn layer for rpn proposal network
|
||||
|
||||
Args:
|
||||
num_layers (int) - layer num.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
List, list of RpnRegClsBlock cells.
|
||||
"""
|
||||
rpn_layer = RpnRegClsBlock(self.cfg, in_channels, feat_channels, num_anchors, cls_out_channels)
|
||||
return rpn_layer
|
||||
|
||||
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
|
||||
'''
|
||||
inputs(Tensor): Inputs tensor from lstm.
|
||||
img_metas(Tensor): Image shape.
|
||||
anchor_list(Tensor): Total anchor list.
|
||||
gt_labels(Tensor): Ground truth labels.
|
||||
gt_valids(Tensor): Whether ground truth is valid.
|
||||
'''
|
||||
rpn_cls_score_ori, rpn_bbox_pred_ori = self.rpn_convs_list(inputs)
|
||||
rpn_cls_score = self.transpose(rpn_cls_score_ori, self.trans_shape)
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape_cls)
|
||||
rpn_bbox_pred = self.transpose(rpn_bbox_pred_ori, self.trans_shape)
|
||||
rpn_bbox_pred = self.reshape(rpn_bbox_pred, self.reshape_shape_reg)
|
||||
output = ()
|
||||
bbox_targets = ()
|
||||
bbox_weights = ()
|
||||
labels = ()
|
||||
label_weights = ()
|
||||
if self.training:
|
||||
for i in range(self.batch_size):
|
||||
valid_flag_list = self.cast(self.CheckValid(anchor_list, self.squeeze(img_metas[i:i + 1:1, ::])),\
|
||||
mstype.int32)
|
||||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
|
||||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
|
||||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
|
||||
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
|
||||
gt_labels_i,
|
||||
self.cast(valid_flag_list,
|
||||
mstype.bool_),
|
||||
anchor_list, gt_valids_i)
|
||||
bbox_weight = self.cast(bbox_weight, mstype.float16)
|
||||
label_weight = self.cast(label_weight, mstype.float16)
|
||||
bbox_targets += (bbox_target,)
|
||||
bbox_weights += (bbox_weight,)
|
||||
labels += (label,)
|
||||
label_weights += (label_weight,)
|
||||
bbox_target_with_batchsize = self.concat(bbox_targets)
|
||||
bbox_weight_with_batchsize = self.concat(bbox_weights)
|
||||
label_with_batchsize = self.concat(labels)
|
||||
label_weight_with_batchsize = self.concat(label_weights)
|
||||
|
||||
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
|
||||
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
|
||||
label_ = F.stop_gradient(label_with_batchsize)
|
||||
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
|
||||
rpn_cls_score = self.cast(rpn_cls_score, mstype.float32)
|
||||
if self.use_sigmoid_cls:
|
||||
label_ = self.cast(label_, mstype.float32)
|
||||
loss_cls = self.loss_cls(rpn_cls_score, label_)
|
||||
loss_cls = loss_cls * label_weight_
|
||||
loss_cls = self.sum_loss(loss_cls, (0,)) / self.num_expected_total
|
||||
rpn_bbox_pred = self.cast(rpn_bbox_pred, mstype.float32)
|
||||
bbox_target_ = self.cast(bbox_target_, mstype.float32)
|
||||
loss_reg = self.loss_bbox(rpn_bbox_pred, bbox_target_)
|
||||
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape, 1)), (1, 4))
|
||||
loss_reg = loss_reg * bbox_weight_
|
||||
loss_reg = self.sum_loss(loss_reg, (1,))
|
||||
loss_reg = self.sum_loss(loss_reg, (0,)) / self.num_expected_total
|
||||
loss_total = self.rpn_loss_cls_weight * loss_cls + self.rpn_loss_reg_weight * loss_reg
|
||||
output = (loss_total, rpn_cls_score_ori, rpn_bbox_pred_ori, loss_cls, loss_reg)
|
||||
else:
|
||||
output = (self.placeh1, rpn_cls_score_ori, rpn_bbox_pred_ori, self.placeh1, self.placeh1)
|
||||
return output
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
''''weight initialize'''
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
|
||||
"""Batchnorm2D wrapper."""
|
||||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
||||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
||||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
||||
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
||||
|
||||
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
|
||||
beta_init=beta_init, moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
|
||||
"""Conv2D wrapper."""
|
||||
weights = 'ones'
|
||||
layers = []
|
||||
conv = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=False)
|
||||
if not weights_update:
|
||||
conv.weight.requires_grad = False
|
||||
layers += [conv]
|
||||
layers += [_BatchNorm2dInit(out_channels)]
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
def _fc(in_channels, out_channels):
|
||||
'''full connection layer'''
|
||||
weight = _weight_variable((out_channels, in_channels))
|
||||
bias = _weight_variable((out_channels,))
|
||||
return nn.Dense(in_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
class VGG16FeatureExtraction(nn.Cell):
|
||||
def __init__(self, weights_update=False):
|
||||
"""
|
||||
VGG16 feature extraction
|
||||
|
||||
Args:
|
||||
weights_updata(bool): whether update weights for top two layers, default is False.
|
||||
"""
|
||||
super(VGG16FeatureExtraction, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv1_1 = _conv(in_channels=3, out_channels=64, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
self.conv1_2 = _conv(in_channels=64, out_channels=64, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
|
||||
self.conv2_1 = _conv(in_channels=64, out_channels=128, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
self.conv2_2 = _conv(in_channels=128, out_channels=128, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
|
||||
self.conv3_1 = _conv(in_channels=128, out_channels=256, kernel_size=3, padding=1)
|
||||
self.conv3_2 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
|
||||
self.conv3_3 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
|
||||
|
||||
self.conv4_1 = _conv(in_channels=256, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv4_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv4_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
|
||||
self.conv5_1 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv5_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv5_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 3, 224, 224)
|
||||
:return:
|
||||
"""
|
||||
x = self.cast(x, mstype.float32)
|
||||
x = self.conv1_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv2_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv3_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv3_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv3_3(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv4_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_3(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv5_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv5_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv5_3(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class VGG16Classfier(nn.Cell):
|
||||
def __init__(self):
|
||||
"""VGG16 classfier structure"""
|
||||
super(VGG16Classfier, self).__init__()
|
||||
self.flatten = P.Flatten()
|
||||
self.relu = nn.ReLU()
|
||||
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
|
||||
self.fc2 = _fc(in_channels=4096, out_channels=4096)
|
||||
self.batch_size = 32
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 512, 7, 7)
|
||||
:return:
|
||||
"""
|
||||
x = self.reshape(x, (self.batch_size, 7*7*512))
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class VGG16(nn.Cell):
|
||||
def __init__(self):
|
||||
"""VGG16 construct for training backbone"""
|
||||
super(VGG16, self).__init__()
|
||||
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.classifier = VGG16Classfier()
|
||||
self.fc3 = _fc(in_channels=4096, out_channels=1000)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 3, 224, 224)
|
||||
:return: logits, shape=(B, 1000)
|
||||
"""
|
||||
feature_maps = self.feature_extraction(x)
|
||||
x = self.max_pool(feature_maps)
|
||||
x = self.classifier(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network parameters."""
|
||||
from easydict import EasyDict
|
||||
pretrain_config = EasyDict({
|
||||
# LR
|
||||
"base_lr": 0.0009,
|
||||
"warmup_step": 30000,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"total_epoch": 100,
|
||||
})
|
||||
finetune_config = EasyDict({
|
||||
# LR
|
||||
"base_lr": 0.0005,
|
||||
"warmup_step": 300,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"total_epoch": 50,
|
||||
})
|
||||
|
||||
# use for low case number
|
||||
config = EasyDict({
|
||||
"img_width": 960,
|
||||
"img_height": 576,
|
||||
"keep_ratio": False,
|
||||
"flip_ratio": 0.0,
|
||||
"photo_ratio": 0.0,
|
||||
"expand_ratio": 1.0,
|
||||
|
||||
# anchor
|
||||
"feature_shapes": (36, 60),
|
||||
"num_anchors": 14,
|
||||
"anchor_base": 16,
|
||||
"anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406],
|
||||
"anchor_width": [16],
|
||||
|
||||
# rpn
|
||||
"rpn_in_channels": 256,
|
||||
"rpn_feat_channels": 512,
|
||||
"rpn_loss_cls_weight": 1.0,
|
||||
"rpn_loss_reg_weight": 3.0,
|
||||
"rpn_cls_out_channels": 2,
|
||||
|
||||
# bbox_assign_sampler
|
||||
"neg_iou_thr": 0.5,
|
||||
"pos_iou_thr": 0.7,
|
||||
"min_pos_iou": 0.001,
|
||||
"num_bboxes": 30240,
|
||||
"num_gts": 256,
|
||||
"num_expected_neg": 512,
|
||||
"num_expected_pos": 256,
|
||||
|
||||
#proposal
|
||||
"activate_num_classes": 2,
|
||||
"use_sigmoid_cls": False,
|
||||
|
||||
# train proposal
|
||||
"rpn_proposal_nms_across_levels": False,
|
||||
"rpn_proposal_nms_pre": 2000,
|
||||
"rpn_proposal_nms_post": 1000,
|
||||
"rpn_proposal_max_num": 1000,
|
||||
"rpn_proposal_nms_thr": 0.7,
|
||||
"rpn_proposal_min_bbox_size": 8,
|
||||
|
||||
# rnn structure
|
||||
"input_size": 512,
|
||||
"num_step": 60,
|
||||
"rnn_batch_size": 36,
|
||||
"hidden_size": 128,
|
||||
|
||||
# training
|
||||
"warmup_mode": "linear",
|
||||
"batch_size": 1,
|
||||
"momentum": 0.9,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 10,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./",
|
||||
"use_dropout": False,
|
||||
"loss_scale": 1,
|
||||
"weight_decay": 1e-4,
|
||||
|
||||
# test proposal
|
||||
"rpn_nms_pre": 2000,
|
||||
"rpn_nms_post": 1000,
|
||||
"rpn_max_num": 1000,
|
||||
"rpn_nms_thr": 0.7,
|
||||
"rpn_min_bbox_min_size": 8,
|
||||
"test_iou_thr": 0.7,
|
||||
"test_max_per_img": 100,
|
||||
"test_batch_size": 1,
|
||||
"use_python_proposal": False,
|
||||
|
||||
# text proposal connection
|
||||
"max_horizontal_gap": 60,
|
||||
"text_proposals_min_scores": 0.7,
|
||||
"text_proposals_nms_thresh": 0.2,
|
||||
"min_v_overlaps": 0.7,
|
||||
"min_size_sim": 0.7,
|
||||
"min_ratio": 0.5,
|
||||
"line_min_score": 0.9,
|
||||
"text_proposals_width": 16,
|
||||
"min_num_proposals": 2,
|
||||
|
||||
# create dataset
|
||||
"coco_root": "",
|
||||
"coco_train_data_type": "",
|
||||
"cocotext_json": "",
|
||||
"icdar11_train_path": [],
|
||||
"icdar13_train_path": [],
|
||||
"icdar15_train_path": [],
|
||||
"icdar13_test_path": [],
|
||||
"flick_train_path": [],
|
||||
"svt_train_path": [],
|
||||
"pretrain_dataset_path": "",
|
||||
"finetune_dataset_path": "",
|
||||
"test_dataset_path": "",
|
||||
|
||||
# training dataset
|
||||
"pretraining_dataset_file": "",
|
||||
"finetune_dataset_file": ""
|
||||
})
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert icdar2015 dataset label"""
|
||||
import os
|
||||
import argparse
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-s', '--src_label_path', type=str, default='./',
|
||||
help='Directory containing icdar2015 train label')
|
||||
parser.add_argument('-t', '--target_label_path', type=str, default='test.xml',
|
||||
help='Directory where save the icdar2015 label after convert')
|
||||
return parser.parse_args()
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
anno_file = os.listdir(args.src_label_path)
|
||||
annos = {}
|
||||
# read
|
||||
for file in anno_file:
|
||||
gt = open(os.path.join(args.src_label_path, file), 'r', encoding='UTF-8-sig').read().splitlines()
|
||||
label_list = []
|
||||
label_name = os.path.basename(file)
|
||||
for each_label in gt:
|
||||
print(file)
|
||||
spt = each_label.split(',')
|
||||
print(spt)
|
||||
if "###" in spt[8]:
|
||||
continue
|
||||
else:
|
||||
x1 = min(int(spt[0]), int(spt[6]))
|
||||
y1 = min(int(spt[1]), int(spt[3]))
|
||||
x2 = max(int(spt[2]), int(spt[4]))
|
||||
y2 = max(int(spt[5]), int(spt[7]))
|
||||
label_list.append([x1, y1, x2, y2])
|
||||
annos[label_name] = label_list
|
||||
# write
|
||||
if not os.path.exists(args.target_label_path):
|
||||
os.makedirs(args.target_label_path)
|
||||
for label_file, pos in annos.items():
|
||||
tgt_anno_file = os.path.join(args.target_label_path, label_file)
|
||||
f = open(tgt_anno_file, 'w', encoding='UTF-8-sig')
|
||||
for tgt_label in pos:
|
||||
str_pos = str(tgt_label[0]) + ',' + str(tgt_label[1]) + ',' + str(tgt_label[2]) + ',' + str(tgt_label[3])
|
||||
f.write(str_pos)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert svt dataset label"""
|
||||
import os
|
||||
import argparse
|
||||
from xml.etree import ElementTree as ET
|
||||
import numpy as np
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||
help='Directory containing images')
|
||||
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--location_dir', type=str, default='./location',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
return parser.parse_args()
|
||||
|
||||
def xml_to_dict(xml_file, save_file=False):
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
imgs_labels = []
|
||||
|
||||
for ch in root:
|
||||
im_label = {}
|
||||
for ch01 in ch:
|
||||
if ch01.tag in "address":
|
||||
continue
|
||||
elif ch01.tag in 'taggedRectangles':
|
||||
# multiple children
|
||||
rect_list = []
|
||||
for ch02 in ch01:
|
||||
rect = {}
|
||||
rect['location'] = ch02.attrib
|
||||
rect['label'] = ch02[0].text
|
||||
rect_list.append(rect)
|
||||
im_label['rect'] = rect_list
|
||||
else:
|
||||
im_label[ch01.tag] = ch01.text
|
||||
imgs_labels.append(im_label)
|
||||
|
||||
if save_file:
|
||||
np.save("annotation_train.npy", imgs_labels)
|
||||
|
||||
return imgs_labels
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
if not os.path.exists(args.dataset_dir):
|
||||
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||
|
||||
if not os.path.exists(args.xml_file):
|
||||
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||
|
||||
if not os.path.exists(args.location_dir):
|
||||
os.makedirs(args.location_dir)
|
||||
|
||||
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||
num_images = len(ims_labels_dict)
|
||||
print("Converting annotation, {} images in total ".format(num_images))
|
||||
for i in range(num_images):
|
||||
img_label = ims_labels_dict[i]
|
||||
image_name = img_label['imageName']
|
||||
rects = img_label['rect']
|
||||
print("processing image: {}".format(image_name))
|
||||
location_file_name = os.path.join(args.location_dir, os.path.basename(image_name).replace(".jpg", ".txt"))
|
||||
f = open(location_file_name, 'w')
|
||||
for j, rect in enumerate(rects):
|
||||
rect = rects[j]
|
||||
location = rect['location']
|
||||
h = int(location['height'])
|
||||
w = int(location['width'])
|
||||
x = int(location['x'])
|
||||
y = int(location['y'])
|
||||
pos = [x, y, x+w, y+h]
|
||||
str_pos = str(pos[0]) + "," + str(pos[1]) + "," + str(pos[2]) + "," + str(pos[3])
|
||||
f.write(str_pos)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import division
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.config import config
|
||||
|
||||
def create_coco_label():
|
||||
"""Create image label."""
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
coco_root = config.coco_root
|
||||
data_type = config.coco_train_data_type
|
||||
from src.coco_text import COCO_Text
|
||||
anno_json = config.cocotext_json
|
||||
ct = COCO_Text(anno_json)
|
||||
image_ids = ct.getImgIds(imgIds=ct.train,
|
||||
catIds=[('legibility', 'legible')])
|
||||
for img_id in image_ids:
|
||||
image_info = ct.loadImgs(img_id)[0]
|
||||
file_name = image_info['file_name'][15:]
|
||||
anno_ids = ct.getAnnIds(imgIds=img_id)
|
||||
anno = ct.loadAnns(anno_ids)
|
||||
image_path = os.path.join(coco_root, data_type, file_name)
|
||||
annos = []
|
||||
im = Image.open(image_path)
|
||||
width, _ = im.size
|
||||
for label in anno:
|
||||
bbox = label["bbox"]
|
||||
bbox_width = int(bbox[2])
|
||||
if 60 * bbox_width < width:
|
||||
continue
|
||||
x1, x2 = int(bbox[0]), int(bbox[0] + bbox[2])
|
||||
y1, y2 = int(bbox[1]), int(bbox[1] + bbox[3])
|
||||
annos.append([x1, y1, x2, y2] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_anno_dataset_label(train_img_dirs, train_txt_dirs):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
# read
|
||||
img_basenames = []
|
||||
for file in os.listdir(train_img_dirs):
|
||||
# Filter git file.
|
||||
if 'gif' not in file:
|
||||
img_basenames.append(os.path.basename(file))
|
||||
img_names = []
|
||||
for item in img_basenames:
|
||||
temp1, _ = os.path.splitext(item)
|
||||
img_names.append((temp1, item))
|
||||
for img, img_basename in img_names:
|
||||
image_path = train_img_dirs + '/' + img_basename
|
||||
annos = []
|
||||
if len(img) == 6 and '_' not in img_basename:
|
||||
gt = open(train_txt_dirs + '/' + img + '.txt').read().splitlines()
|
||||
if img.isdigit() and int(img) > 1200:
|
||||
continue
|
||||
for img_each_label in gt:
|
||||
spt = img_each_label.replace(',', '').split(' ')
|
||||
if ' ' not in img_each_label:
|
||||
spt = img_each_label.split(',')
|
||||
annos.append([spt[0], spt[1], str(int(spt[0]) + int(spt[2])), str(int(spt[1]) + int(spt[3]))] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_icdar_svt_label(train_img_dir, train_txt_dir, prefix):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
img_basenames = []
|
||||
for file_name in os.listdir(train_img_dir):
|
||||
if 'gif' not in file_name:
|
||||
img_basenames.append(os.path.basename(file_name))
|
||||
img_names = []
|
||||
for item in img_basenames:
|
||||
temp1, _ = os.path.splitext(item)
|
||||
img_names.append((temp1, item))
|
||||
for img, img_basename in img_names:
|
||||
image_path = train_img_dir + '/' + img_basename
|
||||
annos = []
|
||||
file_name = prefix + img + ".txt"
|
||||
file_path = os.path.join(train_txt_dir, file_name)
|
||||
gt = open(file_path, 'r', encoding='UTF-8-sig').read().splitlines()
|
||||
if not gt:
|
||||
continue
|
||||
for img_each_label in gt:
|
||||
spt = img_each_label.replace(',', '').split(' ')
|
||||
if ' ' not in img_each_label:
|
||||
spt = img_each_label.split(',')
|
||||
annos.append([spt[0], spt[1], spt[2], spt[3]] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_train_dataset(dataset_type):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
if dataset_type == "pretraining":
|
||||
# pretrianing: coco, flick, icdar2013 train, icdar2015, svt
|
||||
coco_image_files, coco_anno_dict = create_coco_label()
|
||||
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
|
||||
config.flick_train_path[1])
|
||||
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
|
||||
config.icdar13_train_path[1], "gt_img_")
|
||||
icdar15_image_files, icdar15_anno_dict = create_icdar_svt_label(config.icdar15_train_path[0],
|
||||
config.icdar15_train_path[1], "gt_")
|
||||
svt_image_files, svt_anno_dict = create_icdar_svt_label(config.svt_train_path[0], config.svt_train_path[1], "")
|
||||
image_files = coco_image_files + flick_image_files + icdar13_image_files + icdar15_image_files + svt_image_files
|
||||
image_anno_dict = {**coco_anno_dict, **flick_anno_dict, \
|
||||
**icdar13_anno_dict, **icdar15_anno_dict, **svt_anno_dict}
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.pretrain_dataset_path, \
|
||||
prefix="ctpn_pretrain.mindrecord", file_num=8)
|
||||
elif dataset_type == "finetune":
|
||||
# finetune: icdar2011, icdar2013 train, flick
|
||||
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
|
||||
config.flick_train_path[1])
|
||||
icdar11_image_files, icdar11_anno_dict = create_icdar_svt_label(config.icdar11_train_path[0],
|
||||
config.icdar11_train_path[1], "gt_")
|
||||
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
|
||||
config.icdar13_train_path[1], "gt_img_")
|
||||
image_files = flick_image_files + icdar11_image_files + icdar13_image_files
|
||||
image_anno_dict = {**flick_anno_dict, **icdar11_anno_dict, **icdar13_anno_dict}
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.finetune_dataset_path, \
|
||||
prefix="ctpn_finetune.mindrecord", file_num=8)
|
||||
elif dataset_type == "test":
|
||||
# test: icdar2013 test
|
||||
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
|
||||
config.icdar13_test_path[1], "")
|
||||
image_files = icdar_test_image_files
|
||||
image_anno_dict = icdar_test_anno_dict
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
|
||||
prefix="ctpn_test.mindrecord", file_num=1)
|
||||
else:
|
||||
print("dataset_type should be pretraining, finetune, test")
|
||||
|
||||
def data_to_mindrecord_byte_image(image_files, image_anno_dict, dst_dir, prefix="cptn_mlt.mindrecord", file_num=1):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_path = os.path.join(dst_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
|
||||
ctpn_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(ctpn_json, "ctpn_json")
|
||||
for image_name in image_files:
|
||||
with open(image_name, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
||||
print("img name is {}, anno is {}".format(image_name, annos))
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_train_dataset("pretraining")
|
||||
create_train_dataset("finetune")
|
||||
create_train_dataset("test")
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""CPTN network definition."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from src.CTPN.rpn import RPN
|
||||
from src.CTPN.anchor_generator import AnchorGenerator
|
||||
from src.CTPN.proposal_generator import Proposal
|
||||
from src.CTPN.vgg16 import VGG16FeatureExtraction
|
||||
|
||||
class BiLSTM(nn.Cell):
|
||||
"""
|
||||
Define a BiLSTM network which contains two LSTM layers
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, config, is_training=True):
|
||||
super(BiLSTM, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.batch_size = config.batch_size * config.rnn_batch_size
|
||||
print("batch size is {} ".format(self.batch_size))
|
||||
self.input_size = config.input_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_step = config.num_step
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
k = (1 / self.hidden_size) ** 0.5
|
||||
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.w1 = Parameter(np.random.uniform(-k, k, \
|
||||
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1")
|
||||
self.w1_bw = Parameter(np.random.uniform(-k, k, \
|
||||
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw")
|
||||
|
||||
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
|
||||
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
|
||||
|
||||
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
|
||||
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.reverse_seq = P.ReverseV2(axis=[0])
|
||||
self.concat = P.Concat()
|
||||
self.transpose = P.Transpose()
|
||||
self.concat1 = P.Concat(axis=2)
|
||||
self.dropout = nn.Dropout(0.7)
|
||||
self.use_dropout = config.use_dropout
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
def construct(self, x):
|
||||
if self.use_dropout:
|
||||
x = self.dropout(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
bw_x = self.reverse_seq(x)
|
||||
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
|
||||
y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
|
||||
y1_bw = self.reverse_seq(y1_bw)
|
||||
output = self.concat1((y1, y1_bw))
|
||||
return output
|
||||
|
||||
class CTPN(nn.Cell):
|
||||
"""
|
||||
Define CTPN network
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, config, is_training=True):
|
||||
super(CTPN, self).__init__()
|
||||
self.config = config
|
||||
self.is_training = is_training
|
||||
self.num_step = config.num_step
|
||||
self.input_size = config.input_size
|
||||
self.batch_size = config.batch_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vgg16_feature_extractor = VGG16FeatureExtraction()
|
||||
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
|
||||
self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16)
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
|
||||
# rpn block
|
||||
self.rpn_with_loss = RPN(config,
|
||||
self.batch_size,
|
||||
config.rpn_in_channels,
|
||||
config.rpn_feat_channels,
|
||||
config.num_anchors,
|
||||
config.rpn_cls_out_channels)
|
||||
self.anchor_generator = AnchorGenerator(config)
|
||||
self.featmap_size = config.feature_shapes
|
||||
self.anchor_list = self.get_anchors(self.featmap_size)
|
||||
self.proposal_generator_test = Proposal(config,
|
||||
config.test_batch_size,
|
||||
config.activate_num_classes,
|
||||
config.use_sigmoid_cls)
|
||||
self.proposal_generator_test.set_train_local(config, False)
|
||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
|
||||
# (1,3,600,900)
|
||||
x = self.vgg16_feature_extractor(img_data)
|
||||
x = self.conv(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
# (1, 512, 38, 57)
|
||||
x = self.transpose(x, (0, 2, 1, 3))
|
||||
x = self.reshape(x, (-1, self.input_size, self.num_step))
|
||||
x = self.transpose(x, (2, 0, 1))
|
||||
# (57, 38, 512)
|
||||
x = self.rnn(x)
|
||||
# (57, 38, 256)
|
||||
#x = self.cast(x, mstype.float32)
|
||||
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x,
|
||||
img_metas,
|
||||
self.anchor_list,
|
||||
gt_bboxes,
|
||||
gt_labels,
|
||||
gt_valids)
|
||||
if self.training:
|
||||
return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss
|
||||
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
|
||||
return proposal, proposal_mask
|
||||
|
||||
def get_anchors(self, featmap_size):
|
||||
anchors = self.anchor_generator.grid_anchors(featmap_size)
|
||||
return Tensor(anchors, mstype.float16)
|
|
@ -0,0 +1,342 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FasterRcnn dataset"""
|
||||
from __future__ import division
|
||||
import os
|
||||
import numpy as np
|
||||
from numpy import random
|
||||
import mmcv
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as CC
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.config import config
|
||||
|
||||
class PhotoMetricDistortion:
|
||||
"""Photo Metric Distortion"""
|
||||
|
||||
def __init__(self,
|
||||
brightness_delta=32,
|
||||
contrast_range=(0.5, 1.5),
|
||||
saturation_range=(0.5, 1.5),
|
||||
hue_delta=18):
|
||||
self.brightness_delta = brightness_delta
|
||||
self.contrast_lower, self.contrast_upper = contrast_range
|
||||
self.saturation_lower, self.saturation_upper = saturation_range
|
||||
self.hue_delta = hue_delta
|
||||
|
||||
def __call__(self, img, boxes, labels):
|
||||
img = img.astype('float32')
|
||||
if random.randint(2):
|
||||
delta = random.uniform(-self.brightness_delta, self.brightness_delta)
|
||||
img += delta
|
||||
mode = random.randint(2)
|
||||
if mode == 1:
|
||||
if random.randint(2):
|
||||
alpha = random.uniform(self.contrast_lower,
|
||||
self.contrast_upper)
|
||||
img *= alpha
|
||||
# convert color from BGR to HSV
|
||||
img = mmcv.bgr2hsv(img)
|
||||
# random saturation
|
||||
if random.randint(2):
|
||||
img[..., 1] *= random.uniform(self.saturation_lower,
|
||||
self.saturation_upper)
|
||||
# random hue
|
||||
if random.randint(2):
|
||||
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
|
||||
img[..., 0][img[..., 0] > 360] -= 360
|
||||
img[..., 0][img[..., 0] < 0] += 360
|
||||
# convert color from HSV to BGR
|
||||
img = mmcv.hsv2bgr(img)
|
||||
# random contrast
|
||||
if mode == 0:
|
||||
if random.randint(2):
|
||||
alpha = random.uniform(self.contrast_lower,
|
||||
self.contrast_upper)
|
||||
img *= alpha
|
||||
# randomly swap channels
|
||||
if random.randint(2):
|
||||
img = img[..., random.permutation(3)]
|
||||
return img, boxes, labels
|
||||
|
||||
class Expand:
|
||||
"""expand image"""
|
||||
|
||||
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
|
||||
if to_rgb:
|
||||
self.mean = mean[::-1]
|
||||
else:
|
||||
self.mean = mean
|
||||
self.min_ratio, self.max_ratio = ratio_range
|
||||
|
||||
def __call__(self, img, boxes, labels):
|
||||
if random.randint(2):
|
||||
return img, boxes, labels
|
||||
h, w, c = img.shape
|
||||
ratio = random.uniform(self.min_ratio, self.max_ratio)
|
||||
expand_img = np.full((int(h * ratio), int(w * ratio), c),
|
||||
self.mean).astype(img.dtype)
|
||||
left = int(random.uniform(0, w * ratio - w))
|
||||
top = int(random.uniform(0, h * ratio - h))
|
||||
expand_img[top:top + h, left:left + w] = img
|
||||
img = expand_img
|
||||
boxes += np.tile((left, top), 2)
|
||||
return img, boxes, labels
|
||||
|
||||
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""rescale operation for image"""
|
||||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
||||
if img_data.shape[0] > config.img_height:
|
||||
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True)
|
||||
scale_factor = scale_factor * scale_factor2
|
||||
img_shape = np.append(img_shape, scale_factor)
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
gt_bboxes = split_gtbox_label(gt_bboxes)
|
||||
if gt_bboxes.shape[0] != 0:
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""resize operation for image"""
|
||||
img_data = img
|
||||
img_data, w_scale, h_scale = mmcv.imresize(
|
||||
img_data, (config.img_width, config.img_height), return_scale=True)
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
img_shape = (config.img_height, config.img_width, 1.0)
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
gt_bboxes = split_gtbox_label(gt_bboxes)
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
|
||||
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""resize operation for image of eval"""
|
||||
img_data = img
|
||||
img_data, w_scale, h_scale = mmcv.imresize(
|
||||
img_data, (config.img_width, config.img_height), return_scale=True)
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
img_shape = (config.img_height, config.img_width)
|
||||
img_shape = np.append(img_shape, (h_scale, w_scale))
|
||||
img_shape = np.asarray(img_shape, dtype=np.float32)
|
||||
gt_bboxes = gt_bboxes * scale_factor
|
||||
shape = gt_bboxes.shape
|
||||
label_column = np.ones((shape[0], 1), dtype=int)
|
||||
gt_bboxes = np.concatenate((gt_bboxes, label_column), axis=1)
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""flipped generation"""
|
||||
img_data = img
|
||||
flipped = gt_bboxes.copy()
|
||||
_, w, _ = img_data.shape
|
||||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
|
||||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
|
||||
return (img_data, img_shape, flipped, gt_label, gt_num)
|
||||
|
||||
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
img_data = img[:, :, ::-1]
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""photo crop operation for image"""
|
||||
random_photo = PhotoMetricDistortion()
|
||||
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
|
||||
|
||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||
"""expand operation for image"""
|
||||
expand = Expand()
|
||||
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
|
||||
|
||||
return (img, img_shape, gt_bboxes, gt_label, gt_num)
|
||||
|
||||
def split_gtbox_label(gt_bbox_total):
|
||||
"""split ground truth box label"""
|
||||
gtbox_list = []
|
||||
box_num, _ = gt_bbox_total.shape
|
||||
for i in range(box_num):
|
||||
gt_bbox = gt_bbox_total[i]
|
||||
if gt_bbox[0] % 16 != 0:
|
||||
gt_bbox[0] = (gt_bbox[0] // 16) * 16
|
||||
if gt_bbox[2] % 16 != 0:
|
||||
gt_bbox[2] = (gt_bbox[2] // 16 + 1) * 16
|
||||
x0_array = np.arange(gt_bbox[0], gt_bbox[2], 16)
|
||||
for x0 in x0_array:
|
||||
gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1])
|
||||
return np.array(gtbox_list)
|
||||
|
||||
def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid):
|
||||
"""pad ground truth label"""
|
||||
pad_max_number = 256
|
||||
gt_label = gt_bboxes[:, 4]
|
||||
gt_valid = gt_bboxes[:, 4]
|
||||
if gt_bboxes.shape[0] < 256:
|
||||
gt_box = np.pad(gt_bboxes, ((0, pad_max_number - gt_bboxes.shape[0]), (0, 0)), \
|
||||
mode="constant", constant_values=0)
|
||||
gt_label = np.pad(gt_label, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=-1)
|
||||
gt_valid = np.pad(gt_valid, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=0)
|
||||
else:
|
||||
print("WARNING label num is high than 256")
|
||||
gt_box = gt_bboxes[0:pad_max_number]
|
||||
gt_label = gt_label[0:pad_max_number]
|
||||
gt_valid = gt_valid[0:pad_max_number]
|
||||
return (img, img_shape, gt_box[:, :4], gt_label, gt_valid)
|
||||
|
||||
def preprocess_fn(image, box, is_training):
|
||||
"""Preprocess function for dataset."""
|
||||
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid):
|
||||
image_shape = image_shape[:2]
|
||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid
|
||||
if config.keep_ratio:
|
||||
input_data = rescale_column(*input_data)
|
||||
else:
|
||||
input_data = resize_column_test(*input_data)
|
||||
input_data = pad_label(*input_data)
|
||||
input_data = image_bgr_rgb(*input_data)
|
||||
output_data = input_data
|
||||
return output_data
|
||||
|
||||
def _data_aug(image, box, is_training):
|
||||
"""Data augmentation function."""
|
||||
image_bgr = image.copy()
|
||||
image_bgr[:, :, 0] = image[:, :, 2]
|
||||
image_bgr[:, :, 1] = image[:, :, 1]
|
||||
image_bgr[:, :, 2] = image[:, :, 0]
|
||||
image_shape = image_bgr.shape[:2]
|
||||
gt_box = box[:, :4]
|
||||
gt_label = box[:, 4]
|
||||
gt_valid = box[:, 4]
|
||||
input_data = image_bgr, image_shape, gt_box, gt_label, gt_valid
|
||||
if not is_training:
|
||||
return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_valid)
|
||||
expand = (np.random.rand() < config.expand_ratio)
|
||||
if expand:
|
||||
input_data = expand_column(*input_data)
|
||||
input_data = photo_crop_column(*input_data)
|
||||
if config.keep_ratio:
|
||||
input_data = rescale_column(*input_data)
|
||||
else:
|
||||
input_data = resize_column(*input_data)
|
||||
input_data = pad_label(*input_data)
|
||||
input_data = image_bgr_rgb(*input_data)
|
||||
output_data = input_data
|
||||
return output_data
|
||||
|
||||
return _data_aug(image, box, is_training)
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
annos.append(anno)
|
||||
return annos
|
||||
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
if not os.path.isdir(image_dir):
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
if not os.path.isfile(anno_path):
|
||||
raise RuntimeError("Annotation file is not valid.")
|
||||
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line_str = line.decode("utf-8").strip()
|
||||
line_split = str(line_str).split(' ')
|
||||
file_name = line_split[0]
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
if os.path.isfile(image_path):
|
||||
image_anno_dict[image_path] = anno_parser(line_split[1:])
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord", file_num=8):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
image_files, image_anno_dict = create_icdar_test_label()
|
||||
ctpn_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 6]},
|
||||
}
|
||||
writer.add_schema(ctpn_json, "ctpn_json")
|
||||
for image_name in image_files:
|
||||
with open(image_name, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0,
|
||||
is_training=True, num_parallel_workers=4):
|
||||
"""Creatr deeptext dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\
|
||||
num_parallel_workers=8, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
|
||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
||||
hwc_to_chw = C.HWC2CHW()
|
||||
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
|
||||
type_cast0 = CC.TypeCast(mstype.float32)
|
||||
type_cast1 = CC.TypeCast(mstype.float16)
|
||||
type_cast2 = CC.TypeCast(mstype.int32)
|
||||
type_cast3 = CC.TypeCast(mstype.bool_)
|
||||
if is_training:
|
||||
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
|
||||
num_parallel_workers=12)
|
||||
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
|
||||
num_parallel_workers=12)
|
||||
else:
|
||||
ds = ds.map(operations=compose_map_func,
|
||||
input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
|
||||
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
|
||||
num_parallel_workers=24)
|
||||
# transpose_column from python to c
|
||||
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
|
||||
ds = ds.map(operations=[type_cast1], input_columns=["box"])
|
||||
ds = ds.map(operations=[type_cast2], input_columns=["label"])
|
||||
ds = ds.map(operations=[type_cast3], input_columns=["valid_num"])
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for deeptext"""
|
||||
import math
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(config, base_step):
|
||||
"""dynamic learning rate generator"""
|
||||
base_lr = config.base_lr
|
||||
total_steps = int(base_step * config.total_epoch)
|
||||
warmup_steps = config.warmup_step
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
return lr
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn training network wrapper."""
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1, rank_id=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.count = 0
|
||||
self.rpn_loss_sum = 0
|
||||
self.rpn_cls_loss_sum = 0
|
||||
self.rpn_reg_loss_sum = 0
|
||||
self.rank_id = rank_id
|
||||
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = time.time()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
rpn_loss = cb_params.net_outputs[0].asnumpy()
|
||||
rpn_cls_loss = cb_params.net_outputs[1].asnumpy()
|
||||
rpn_reg_loss = cb_params.net_outputs[2].asnumpy()
|
||||
|
||||
self.count += 1
|
||||
self.rpn_loss_sum += float(rpn_loss)
|
||||
self.rpn_cls_loss_sum += float(rpn_cls_loss)
|
||||
self.rpn_reg_loss_sum += float(rpn_reg_loss)
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if self.count >= 1:
|
||||
global time_stamp_first
|
||||
time_stamp_current = time.time()
|
||||
rpn_loss = self.rpn_loss_sum / self.count
|
||||
rpn_cls_loss = self.rpn_cls_loss_sum / self.count
|
||||
rpn_reg_loss = self.rpn_reg_loss_sum / self.count
|
||||
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
|
||||
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"%
|
||||
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
rpn_loss, rpn_cls_loss, rpn_reg_loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
class LossNet(nn.Cell):
|
||||
"""FasterRcnn loss method"""
|
||||
def construct(self, x1, x2, x3):
|
||||
return x1
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to compute loss.
|
||||
|
||||
Args:
|
||||
backbone (Cell): The target network to wrap.
|
||||
loss_fn (Cell): The loss function used to compute loss.
|
||||
"""
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
||||
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
||||
return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""
|
||||
Get the backbone network.
|
||||
|
||||
Returns:
|
||||
Cell, return backbone network.
|
||||
"""
|
||||
return self._backbone
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Network training package class.
|
||||
|
||||
Append an optimizer to the training network after that the construct function
|
||||
can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
network_backbone (Cell): The forward network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default value is 1.0.
|
||||
reduce_flag (bool): The reduce flag. Default value is False.
|
||||
mean (bool): Allreduce method. Default value is False.
|
||||
degree (int): Device number. Default value is None.
|
||||
"""
|
||||
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.backbone = network_backbone
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
|
||||
self.reduce_flag = reduce_flag
|
||||
if reduce_flag:
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
||||
weights = self.weights
|
||||
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
||||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
|
||||
if self.reduce_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss
|
|
@ -0,0 +1,65 @@
|
|||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================import numpy as np
|
||||
import numpy as np
|
||||
from src.text_connector.utils import clip_boxes, fit_y
|
||||
from src.text_connector.get_successions import get_successions
|
||||
|
||||
def connect_text_lines(text_proposals, scores, size):
|
||||
"""
|
||||
Connect text lines
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Predict text proposals.
|
||||
scores(numpy.array): Bbox predicts scores.
|
||||
size(numpy.array): Image size.
|
||||
Returns:
|
||||
text_recs(numpy.array): Text boxes after connect.
|
||||
"""
|
||||
graph = get_successions(text_proposals, scores, size)
|
||||
text_lines = np.zeros((len(graph), 5), np.float32)
|
||||
for index, indices in enumerate(graph):
|
||||
text_line_boxes = text_proposals[list(indices)]
|
||||
x0 = np.min(text_line_boxes[:, 0])
|
||||
x1 = np.max(text_line_boxes[:, 2])
|
||||
|
||||
offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5
|
||||
|
||||
lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
|
||||
lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)
|
||||
|
||||
# the score of a text line is the average score of the scores
|
||||
# of all text proposals contained in the text line
|
||||
score = scores[list(indices)].sum() / float(len(indices))
|
||||
|
||||
text_lines[index, 0] = x0
|
||||
text_lines[index, 1] = min(lt_y, rt_y)
|
||||
text_lines[index, 2] = x1
|
||||
text_lines[index, 3] = max(lb_y, rb_y)
|
||||
text_lines[index, 4] = score
|
||||
|
||||
text_lines = clip_boxes(text_lines, size)
|
||||
|
||||
text_recs = np.zeros((len(text_lines), 9), np.float)
|
||||
index = 0
|
||||
for line in text_lines:
|
||||
xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3]
|
||||
text_recs[index, 0] = xmin
|
||||
text_recs[index, 1] = ymin
|
||||
text_recs[index, 2] = xmax
|
||||
text_recs[index, 3] = ymax
|
||||
text_recs[index, 4] = line[4]
|
||||
index = index + 1
|
||||
return text_recs
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from src.text_connector.utils import nms
|
||||
from src.text_connector.connect_text_lines import connect_text_lines
|
||||
|
||||
def filter_proposal(proposals, scores):
|
||||
"""
|
||||
Filter text proposals
|
||||
|
||||
Args:
|
||||
proposals(numpy.array): Text proposals.
|
||||
Returns:
|
||||
proposals(numpy.array): Text proposals after filter.
|
||||
"""
|
||||
inds = np.where(scores > config.text_proposals_min_scores)[0]
|
||||
keep_proposals = proposals[inds]
|
||||
keep_scores = scores[inds]
|
||||
sorted_inds = np.argsort(keep_scores.ravel())[::-1]
|
||||
keep_proposals, keep_scores = keep_proposals[sorted_inds], keep_scores[sorted_inds]
|
||||
nms_inds = nms(np.hstack((keep_proposals, keep_scores)), config.text_proposals_nms_thresh)
|
||||
keep_proposals, keep_scores = keep_proposals[nms_inds], keep_scores[nms_inds]
|
||||
return keep_proposals, keep_scores
|
||||
|
||||
def filter_boxes(boxes):
|
||||
"""
|
||||
Filter text boxes
|
||||
|
||||
Args:
|
||||
boxes(numpy.array): Text boxes.
|
||||
Returns:
|
||||
boxes(numpy.array): Text boxes after filter.
|
||||
"""
|
||||
heights = np.zeros((len(boxes), 1), np.float)
|
||||
widths = np.zeros((len(boxes), 1), np.float)
|
||||
scores = np.zeros((len(boxes), 1), np.float)
|
||||
index = 0
|
||||
for box in boxes:
|
||||
widths[index] = abs(box[2] - box[0])
|
||||
heights[index] = abs(box[3] - box[1])
|
||||
scores[index] = abs(box[4])
|
||||
index += 1
|
||||
return np.where((widths / heights > config.min_ratio) & (scores > config.line_min_score) &\
|
||||
(widths > (config.text_proposals_width * config.min_num_proposals)))[0]
|
||||
|
||||
def detect(text_proposals, scores, size):
|
||||
"""
|
||||
Detect text boxes
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Predict text proposals.
|
||||
scores(numpy.array): Bbox predicts scores.
|
||||
size(numpy.array): Image size.
|
||||
Returns:
|
||||
boxes(numpy.array): Text boxes after connect.
|
||||
"""
|
||||
keep_proposals, keep_scores = filter_proposal(text_proposals, scores)
|
||||
connect_boxes = connect_text_lines(keep_proposals, keep_scores, size)
|
||||
boxes = connect_boxes[filter_boxes(connect_boxes)]
|
||||
return boxes
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from src.text_connector.utils import overlaps_v, size_similarity
|
||||
|
||||
def get_successions(text_proposals, scores, im_size):
|
||||
"""
|
||||
Get successions text boxes.
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Predict text proposals.
|
||||
scores(numpy.array): Bbox predicts scores.
|
||||
size(numpy.array): Image size.
|
||||
Returns:
|
||||
sub_graph(list): Proposals graph.
|
||||
"""
|
||||
bboxes_table = [[] for _ in range(int(im_size[1]))]
|
||||
for index, box in enumerate(text_proposals):
|
||||
bboxes_table[int(box[0])].append(index)
|
||||
graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool)
|
||||
for index, box in enumerate(text_proposals):
|
||||
successions_left = []
|
||||
for left in range(int(box[0]) + 1, min(int(box[0]) + config.max_horizontal_gap + 1, im_size[1])):
|
||||
adj_box_indices = bboxes_table[left]
|
||||
for adj_box_index in adj_box_indices:
|
||||
if meet_v_iou(text_proposals, adj_box_index, index):
|
||||
successions_left.append(adj_box_index)
|
||||
if successions_left:
|
||||
break
|
||||
if not successions_left:
|
||||
continue
|
||||
succession_index = successions_left[np.argmax(scores[successions_left])]
|
||||
box_right = text_proposals[succession_index]
|
||||
succession_right = []
|
||||
for right in range(int(box_right[0]) - 1, max(int(box_right[0] - config.max_horizontal_gap), 0) - 1, -1):
|
||||
adj_box_indices = bboxes_table[right]
|
||||
for adj_box_index in adj_box_indices:
|
||||
if meet_v_iou(text_proposals, adj_box_index, index):
|
||||
succession_right.append(adj_box_index)
|
||||
if succession_right:
|
||||
break
|
||||
if scores[index] >= np.max(scores[succession_right]):
|
||||
graph[index, succession_index] = True
|
||||
sub_graph = get_sub_graph(graph)
|
||||
return sub_graph
|
||||
|
||||
def get_sub_graph(graph):
|
||||
"""
|
||||
Get successions text boxes.
|
||||
|
||||
Args:
|
||||
graph(numpy.array): proposal graph
|
||||
Returns:
|
||||
sub_graph(list): Proposals graph after connect.
|
||||
"""
|
||||
sub_graphs = []
|
||||
for index in range(graph.shape[0]):
|
||||
if not graph[:, index].any() and graph[index, :].any():
|
||||
v = index
|
||||
sub_graphs.append([v])
|
||||
while graph[v, :].any():
|
||||
v = np.where(graph[v, :])[0][0]
|
||||
sub_graphs[-1].append(v)
|
||||
return sub_graphs
|
||||
|
||||
def meet_v_iou(text_proposals, index1, index2):
|
||||
"""
|
||||
Calculate vertical iou.
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): tex proposals
|
||||
index1(int): text_proposal index
|
||||
tindex2(int): text proposal index
|
||||
Returns:
|
||||
sub_graph(list): Proposals graph after connect.
|
||||
"""
|
||||
heights = text_proposals[:, 3] - text_proposals[:, 1] + 1
|
||||
return overlaps_v(text_proposals, index1, index2) >= config.min_v_overlaps and \
|
||||
size_similarity(heights, index1, index2) >= config.min_size_sim
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
def threshold(coords, min_, max_):
|
||||
return np.maximum(np.minimum(coords, max_), min_)
|
||||
|
||||
def clip_boxes(boxes, im_shape):
|
||||
"""
|
||||
Clip boxes to image boundaries.
|
||||
|
||||
Args:
|
||||
boxes(numpy.array):bounding box.
|
||||
im_shape(numpy.array): image shape.
|
||||
|
||||
Return:
|
||||
boxes(numpy.array):boundding box after clip.
|
||||
"""
|
||||
boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1)
|
||||
boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1)
|
||||
return boxes
|
||||
|
||||
def overlaps_v(text_proposals, index1, index2):
|
||||
"""
|
||||
Calculate vertical overlap ratio.
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Text proposlas.
|
||||
index1(int): First text proposal.
|
||||
index2(int): Second text proposal.
|
||||
|
||||
Return:
|
||||
overlap(float32): vertical overlap.
|
||||
"""
|
||||
h1 = text_proposals[index1][3] - text_proposals[index1][1] + 1
|
||||
h2 = text_proposals[index2][3] - text_proposals[index2][1] + 1
|
||||
y0 = max(text_proposals[index2][1], text_proposals[index1][1])
|
||||
y1 = min(text_proposals[index2][3], text_proposals[index1][3])
|
||||
return max(0, y1 - y0 + 1) / min(h1, h2)
|
||||
|
||||
def size_similarity(heights, index1, index2):
|
||||
"""
|
||||
Calculate vertical size similarity ratio.
|
||||
|
||||
Args:
|
||||
heights(numpy.array): Text proposlas heights.
|
||||
index1(int): First text proposal.
|
||||
index2(int): Second text proposal.
|
||||
|
||||
Return:
|
||||
overlap(float32): vertical overlap.
|
||||
"""
|
||||
h1 = heights[index1]
|
||||
h2 = heights[index2]
|
||||
return min(h1, h2) / max(h1, h2)
|
||||
|
||||
def fit_y(X, Y, x1, x2):
|
||||
if np.sum(X == X[0]) == len(X):
|
||||
return Y[0], Y[0]
|
||||
p = np.poly1d(np.polyfit(X, Y, 1))
|
||||
return p(x1), p(x2)
|
||||
|
||||
def nms(bboxs, thresh):
|
||||
"""
|
||||
Args:
|
||||
text_proposals(numpy.array): tex proposals
|
||||
index1(int): text_proposal index
|
||||
tindex2(int): text proposal index
|
||||
"""
|
||||
x1, y1, x2, y2, scores = np.split(bboxs, 5, axis=1)
|
||||
x1 = bboxs[:, 0]
|
||||
y1 = bboxs[:, 1]
|
||||
x2 = bboxs[:, 2]
|
||||
y2 = bboxs[:, 3]
|
||||
scores = bboxs[:, 4]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
num_dets = bboxs.shape[0]
|
||||
suppressed = np.zeros(num_dets, dtype=np.int32)
|
||||
keep = []
|
||||
for _i in range(num_dets):
|
||||
i = order[_i]
|
||||
if suppressed[i] == 1:
|
||||
continue
|
||||
keep.append(i)
|
||||
x1_i = x1[i]
|
||||
y1_i = y1[i]
|
||||
x2_i = x2[i]
|
||||
y2_i = y2[i]
|
||||
area_i = areas[i]
|
||||
for _j in range(_i + 1, num_dets):
|
||||
j = order[_j]
|
||||
if suppressed[j] == 1:
|
||||
continue
|
||||
x1_j = max(x1_i, x1[j])
|
||||
y1_j = max(y1_i, y1[j])
|
||||
x2_j = min(x2_i, x2[j])
|
||||
y2_j = min(y2_i, y2[j])
|
||||
w = max(0.0, x2_j - x1_j + 1)
|
||||
h = max(0.0, y2_j - y1_j + 1)
|
||||
inter = w*h
|
||||
overlap = inter / (area_i+areas[j]-inter)
|
||||
if overlap >= thresh:
|
||||
suppressed[j] = 1
|
||||
return keep
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""train CTPN and get checkpoint files."""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.common import set_seed
|
||||
from src.ctpn import CTPN
|
||||
from src.config import config, pretrain_config, finetune_config
|
||||
from src.dataset import create_ctpn_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CTPN training")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
|
||||
parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
|
||||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
parser.add_argument("--task_type", type=str, default="Pretraining",\
|
||||
choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.run_distribute:
|
||||
rank = args_opt.rank_id
|
||||
device_num = args_opt.device_num
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
if args_opt.task_type == "Pretraining":
|
||||
print("Start to do pretraining")
|
||||
mindrecord_file = config.pretraining_dataset_file
|
||||
training_cfg = pretrain_config
|
||||
else:
|
||||
print("Start to do finetune")
|
||||
mindrecord_file = config.finetune_dataset_file
|
||||
training_cfg = finetune_config
|
||||
|
||||
print("CHECKING MINDRECORD FILES ...")
|
||||
while not os.path.exists(mindrecord_file + ".db"):
|
||||
time.sleep(5)
|
||||
|
||||
print("CHECKING MINDRECORD FILES DONE!")
|
||||
|
||||
loss_scale = float(config.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as ctpn_pretrain.mindrecord0.
|
||||
dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\
|
||||
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
net = CTPN(config=config, is_training=True)
|
||||
net = net.set_train()
|
||||
|
||||
load_path = args_opt.pre_trained
|
||||
if args_opt.task_type == "Pretraining":
|
||||
print("load backbone vgg16 ckpt {}".format(args_opt.pre_trained))
|
||||
param_dict = load_checkpoint(load_path)
|
||||
for item in list(param_dict.keys()):
|
||||
if not item.startswith('vgg16_feature_extractor'):
|
||||
param_dict.pop(item)
|
||||
load_param_into_net(net, param_dict)
|
||||
else:
|
||||
if load_path != "":
|
||||
print("load pretrain ckpt {}".format(args_opt.pre_trained))
|
||||
param_dict = load_checkpoint(load_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
loss = LossNet()
|
||||
lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32)
|
||||
opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
if args_opt.run_distribute:
|
||||
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
|
||||
mean=True, degree=device_num)
|
||||
else:
|
||||
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
|
||||
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossCallBack(rank_id=rank)
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
|
||||
ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig)
|
||||
cb += [ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True)
|
Loading…
Reference in New Issue