init for ctpn

add for connect table

fix some bug

fix pylint

fix for create dataset

fix dataset bug

fix for create dataset problem

add for svt icdar2015 convert script

fix for ctpn problem

fix for vgg16
This commit is contained in:
qujianwei 2021-01-28 21:46:12 +08:00
parent 78c733ffbe
commit ae27c383fa
28 changed files with 3177 additions and 0 deletions

View File

@ -0,0 +1,293 @@
![](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
<!-- TOC -->
# CTPN for Ascend
- [CTPN Description](#CTPN-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [CTPN Description](#contents)
CTPN is a text detection model based on object detection method. It improves Faster R-CNN and combines with bidirectional LSTM, so ctpn is very effective for horizontal text detection. Another highlight of ctpn is to transform the text detection task into a series of small-scale text box detection.This idea was proposed in the paper "Detecting Text in Natural Image with Connectionist Text Proposal Network".
[Paper](https://arxiv.org/pdf/1609.03605.pdf) Zhi Tian, Weilin Huang, Tong He, Pan He, Yu Qiao, "Detecting Text in Natural Image with Connectionist Text Proposal Network", ArXiv, vol. abs/1609.03605, 2016.
# [Model architecture](#contents)
The overall network architecture contains a VGG16 as backbone, and use bidirection lstm to extract context feature of the small-scale text box, then it used the RPN(RegionProposal Network) to predict the boundding box and probability.
[Link](https://arxiv.org/pdf/1605.07314v1.pdf)
# [Dataset](#contents)
Here we used 6 datasets for training, and 1 datasets for Evaluation.
- Dataset1: ICDAR 2013: Focused Scene Text
- Train: 142MB, 229 images
- Test: 110MB, 233 images
- Dataset2: ICDAR 2011: Born-Digital Images
- Train: 27.7MB, 410 images
- Dataset3: ICDAR 2015:
- Train89MB, 1000 images
- Dataset4: SCUT-FORU: Flickr OCR Universal Database
- Train: 388MB, 1715 images
- Dataset5: CocoText v2(Subset of MSCOCO2017):
- Train: 13GB, 63686 images
- Dataset6: SVT(The Street View Dataset)
- Train: 115MB, 349 images
# [Features](#contents)
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```shell
.
└─ctpn
├── README.md # network readme
├── eval.py # eval net
├── scripts
│   ├── eval_res.sh # calculate precision and recall
│   ├── run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
│   ├── run_eval_ascend.sh # launch evaluating with ascend platform
│   └── run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
├── src
│   ├── CTPN
│   │   ├── BoundingBoxDecode.py # bounding box decode
│   │   ├── BoundingBoxEncode.py # bounding box encode
│   │   ├── __init__.py # package init file
│   │   ├── anchor_generator.py # anchor generator
│   │   ├── bbox_assign_sample.py # proposal layer
│   │   ├── proposal_generator.py # proposla generator
│   │   ├── rpn.py # region-proposal network
│   │   └── vgg16.py # backbone
│   ├── config.py # training configuration
│   ├── convert_icdar2015.py # convert icdar2015 dataset label
│   ├── convert_svt.py # convert svt label
│   ├── create_dataset.py # create mindrecord dataset
│   ├── ctpn.py # ctpn network definition
│   ├── dataset.py # data proprocessing
│   ├── lr_schedule.py # learning rate scheduler
│   ├── network_define.py # network definition
│   └── text_connector
│   ├── __init__.py # package init file
│   ├── connect_text_lines.py # connect text lines
│   ├── detector.py # detect box
│   ├── get_successions.py # get succession proposal
│   └── utils.py # some functions which is commonly used
└── train.py # train net
```
## [Training process](#contents)
### Dataset
To create dataset, download the dataset first and deal with it.We provided src/convert_svt.py and src/convert_icdar2015.py to deal with svt and icdar2015 dataset label.For svt dataset, you can deal with it as below:
```shell
python convert_svt.py --dataset_path=/path/img --xml_file=/path/train.xml --location_dir=/path/location
```
For ICDAR2015 dataset, you can deal with it
```shell
python convert_icdar2015.py --src_label_path=/path/train_label --target_label_path=/path/label
```
Then modify the src/config.py and add the dataset path.For each path, add IMAGE_PATH and LABEL_PATH into a list in config.An example is show as blow:
```python
# create dataset
"coco_root": "/path/coco",
"coco_train_data_type": "train2017",
"cocotext_json": "/path/cocotext.v2.json",
"icdar11_train_path": ["/path/image/", "/path/label"],
"icdar13_train_path": ["/path/image/", "/path/label"],
"icdar15_train_path": ["/path/image/", "/path/label"],
"icdar13_test_path": ["/path/image/", "/path/label"],
"flick_train_path": ["/path/image/", "/path/label"],
"svt_train_path": ["/path/image/", "/path/label"],
"pretrain_dataset_path": "",
"finetune_dataset_path": "",
"test_dataset_path": "",
```
Then you can create dataset with src/create_dataset.py with the command as below:
```shell
python src/create_dataset.py
```
### Usage
- Ascend:
```bash
# distribute training example(8p)
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]
# standalone training
sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]
# evaluation:
sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]
```
The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).To get the vgg16 backbone, you can use the network structure defined in src/CTPN/vgg16.py.To train the backbone, copy the src/CTPN/vgg16.py under modelzoo/official/cv/vgg16/src/, and modify the vgg16/train.py to suit the new construction.You can fix it as below:
```python
...
from src.vgg16 import VGG16
...
network = VGG16()
...
```
Then you can train it with ImageNet2012.
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
>
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
>
> TASK_TYPE contains Pretraining and Finetune. For Pretraining, we use ICDAR2013, ICDAR2015, SVT, SCUT-FORU, CocoText v2. For Finetune, we use ICDAR2011,
ICDAR2013, SCUT-FORU to improve precision and recall, and when doing Finetune, we use the checkpoint training in Pretrain as our PRETRAINED_PATH.
> COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).
>
### Launch
```bash
# training example
shell:
Ascend:
# distribute training example(8p)
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]
# standalone training
sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]
```
### Result
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log`, also the loss will be redirected to `./loss_0.log` like followings.
```python
377 epoch: 1 step: 229 ,rpn_loss: 0.00355, rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00103,
399 epoch: 2 step: 229 ,rpn_loss: 0.00327,rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00093,
424 epoch: 3 step: 229 ,rpn_loss: 0.00910, rpn_cls_loss: 0.00385, rpn_reg_loss: 0.00175,
```
## [Eval process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend:
```bash
sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]
```
After eval, you can get serval archive file named submit_ctpn-xx_xxxx.zip, which contains the name of your checkpoint file.To evalulate it, you can use the scripts provided by the ICDAR2013 network, you can download the Deteval scripts from the [link](https://rrc.cvc.uab.es/?com=downloads&action=download&ch=2&f=aHR0cHM6Ly9ycmMuY3ZjLnVhYi5lcy9zdGFuZGFsb25lcy9zY3JpcHRfdGVzdF9jaDJfdDFfZTItMTU3Nzk4MzA2Ny56aXA=)
After download the scripts, unzip it and put it under ctpn/scripts and use eval_res.sh to get the result.You will get files as below:
```text
gt.zip
readme.txt
rrc_evalulation_funcs_1_1.py
script.py
```
Then you can run the scripts/eval_res.sh to calculate the evalulation result.
```base
bash eval_res.sh
```
### Result
Evaluation result will be stored in the example path, you can find result like the followings in `log`.
```text
{"precision": 0.90791, "recall": 0.86118, "hmean": 0.88393}
```
# [Model description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | CTPN |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
| uploaded Date | 02/06/2021 |
| MindSpore Version | 1.1.1 |
| Dataset | 16930 images |
| Batch_size | 2 |
| Training Parameters | src/config.py |
| Optimizer | Momentum |
| Loss Function | SoftmaxCrossEntropyWithLogits for classification, SmoothL2Loss for bbox regression|
| Loss | ~0.04 |
| Total time (8p) | 6h |
| Scripts | [ctpn script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ctpn) |
#### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | CTPN |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G |
| Uploaded Date | 02/06/2020 |
| MindSpore Version | 1.1.1 |
| Dataset | 229 images |
| Batch_size | 1 |
| Accuracy | precision=0.9079, recall=0.8611 F-measure:0.8839 |
| Total time | 1 min |
| Model for inference | 135M (.ckpt file) |
#### Training performance results
| **Ascend** | train performance |
| :--------: | :---------------: |
| 1p | 10 img/s |
| **Ascend** | train performance |
| :--------: | :---------------: |
| 8p | 84 img/s |
# [Description of Random Situation](#contents)
We set seed to 1 in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for CTPN"""
import os
import argparse
import time
import numpy as np
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.ctpn import CTPN
from src.config import config
from src.dataset import create_ctpn_dataset
from src.text_connector.detector import detect
set_seed(1)
parser = argparse.ArgumentParser(description="CTPN evaluation")
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path.")
parser.add_argument("--image_path", type=str, default="", help="Image path.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
"""ctpn infer."""
print("ckpt path is {}".format(ckpt_path))
ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False)
config.batch_size = config.test_batch_size
total = ds.get_dataset_size()
print("*************total dataset size is {}".format(total))
net = CTPN(config, is_training=False)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
eval_iter = 0
print("\n========================================\n")
print("Processing, please wait a moment.")
img_basenames = []
output_dir = os.path.join(os.getcwd(), "submit")
if not os.path.exists(output_dir):
os.mkdir(output_dir)
for file in os.listdir(img_dir):
img_basenames.append(os.path.basename(file))
for data in ds.create_dict_iterator():
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
start = time.time()
# run net
output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
gt_bboxes = gt_bboxes.asnumpy()
gt_labels = gt_labels.asnumpy()
gt_num = gt_num.asnumpy().astype(bool)
end = time.time()
proposal = output[0]
proposal_mask = output[1]
print("start to draw pic")
for j in range(config.test_batch_size):
img = img_basenames[config.test_batch_size * eval_iter + j]
all_box_tmp = proposal[j].asnumpy()
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
using_boxes_mask = all_box_tmp * all_mask_tmp
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
scores = using_boxes_mask[:, 4].astype(np.float32)
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
from PIL import Image, ImageDraw
im = Image.open(img_dir + '/' + img)
draw = ImageDraw.Draw(im)
image_h = img_metas.asnumpy()[j][2]
image_w = img_metas.asnumpy()[j][3]
gt_boxs = gt_bboxes[j][gt_num[j], :]
for gt_box in gt_boxs:
gt_x1 = gt_box[0] / image_w
gt_y1 = gt_box[1] / image_h
gt_x2 = gt_box[2] / image_w
gt_y2 = gt_box[3] / image_h
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
fill='green', width=2)
file_name = "res_" + img.replace("jpg", "txt")
output_file = os.path.join(output_dir, file_name)
f = open(output_file, 'w')
for bbox in bboxes:
x1 = bbox[0] / image_w
y1 = bbox[1] / image_h
x2 = bbox[2] / image_w
y2 = bbox[3] / image_h
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
f.write(str_tmp)
f.write("\n")
f.close()
im.save(img)
percent = round(eval_iter / total * 100, 2)
eval_iter = eval_iter + 1
print("Iter {} cost time {}".format(eval_iter, end - start))
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
if __name__ == '__main__':
ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path)

View File

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
for submit_file in "submit"*.zip
do
echo "eval result for ${submit_file}"
python script.py g=gt.zip s=${submit_file} o=./
echo -e ".\n"
done

View File

@ -0,0 +1,67 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 3 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
TASK_TYPE=$2
PATH2=$(get_real_path $3)
echo $PATH2
if [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --task_type=$TASK_TYPE --pre_trained=$PATH2 &> log &
cd ..
done

View File

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
IMAGE_PATH=$(get_real_path $1)
DATASET_PATH=$(get_real_path $2)
CHECKPOINT_PATH=$(get_real_path $3)
echo $IMAGE_PATH
echo $DATASET_PATH
echo $CHECKPOINT_PATH
if [ ! -d $IMAGE_PATH ]
then
echo "error: IMAGE_PATH=$PATH1 is not a path"
exit 1
fi
if [ ! -f $DATASET_PATH ]
then
echo "error: CHECKPOINT_PATH=$DATASET_PATH is not a path"
exit 1
fi
if [ ! -d $CHECKPOINT_PATH ]
then
echo "error: CHECKPOINT_PATH=$CHECKPOINT_PATH is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
for file in "${CHECKPOINT_PATH}"/*.ckpt
do
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval
env > env.log
CHECKPOINT_FILE_PATH=$file
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
python eval.py --device_id=$DEVICE_ID --image_path=$IMAGE_PATH --dataset_path=$DATASET_PATH --checkpoint_path=$CHECKPOINT_FILE_PATH &> log
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
cd ./submit
file_base_name=$(basename $file)
zip -r ../../submit_${file_base_name%.*}.zip *.txt
cd ../../
done

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
TASK_TYPE=$1
PRETRAINED_PATH=$(get_real_path $2)
echo $PRETRAINED_PATH
if [ ! -f $PRETRAINED_PATH ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_PATH is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
rm -rf ./train
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH &> log &
cd ..

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
class BoundingBoxDecode(nn.Cell):
"""
BoundintBox Decoder.
Returns:
pred_box(Tensor): decoder bounding boxes.
"""
def __init__(self):
super(BoundingBoxDecode, self).__init__()
self.split = P.Split(axis=1, output_num=4)
self.ones = 1.0
self.half = 0.5
self.log = P.Log()
self.exp = P.Exp()
self.concat = P.Concat(axis=1)
def construct(self, bboxes, deltas):
"""
boxes(Tensor): boundingbox.
deltas(Tensor): delta between boundingboxs and anchors.
"""
x1, y1, x2, y2 = self.split(bboxes)
width = x2 - x1 + self.ones
height = y2 - y1 + self.ones
ctr_x = x1 + self.half * width
ctr_y = y1 + self.half * height
_, dy, _, dh = self.split(deltas)
pred_ctr_x = ctr_x
pred_ctr_y = dy * height + ctr_y
pred_w = width
pred_h = self.exp(dh) * height
x1 = pred_ctr_x - self.half * pred_w
y1 = pred_ctr_y - self.half * pred_h
x2 = pred_ctr_x + self.half * pred_w
y2 = pred_ctr_y + self.half * pred_h
pred_box = self.concat((x1, y1, x2, y2))
return pred_box

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
class BoundingBoxEncode(nn.Cell):
"""
BoundintBox Decoder.
Returns:
pred_box(Tensor): decoder bounding boxes.
"""
def __init__(self):
super(BoundingBoxEncode, self).__init__()
self.split = P.Split(axis=1, output_num=4)
self.ones = 1.0
self.half = 0.5
self.log = P.Log()
self.concat = P.Concat(axis=1)
def construct(self, anchor_box, gt_box):
"""
boxes(Tensor): boundingbox.
deltas(Tensor): delta between boundingboxs and anchors.
"""
x1, y1, x2, y2 = self.split(anchor_box)
width = x2 - x1 + self.ones
height = y2 - y1 + self.ones
ctr_x = x1 + self.half * width
ctr_y = y1 + self.half * height
gt_x1, gt_y1, gt_x2, gt_y2 = self.split(gt_box)
gt_width = gt_x2 - gt_x1 + self.ones
gt_height = gt_y2 - gt_y1 + self.ones
ctr_gt_x = gt_x1 + self.half * gt_width
ctr_gt_y = gt_y1 + self.half * gt_height
target_dx = (ctr_gt_x - ctr_x) / width
target_dy = (ctr_gt_y - ctr_y) / height
dw = gt_width / width
dh = gt_height / height
target_dw = self.log(dw)
target_dh = self.log(dh)
deltas = self.concat((target_dx, target_dy, target_dw, target_dh))
return deltas

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn anchor generator."""
import numpy as np
class AnchorGenerator():
"""Anchor generator for FasterRcnn."""
def __init__(self, config):
"""Anchor generator init method."""
self.base_size = config.anchor_base
self.num_anchor = config.num_anchors
self.anchor_height = config.anchor_height
self.anchor_width = config.anchor_width
self.size = self.gen_anchor_size()
self.base_anchors = self.gen_base_anchors()
def gen_base_anchors(self):
"""Generate a single anchor."""
base_anchor = np.array([0, 0, self.base_size - 1, self.base_size - 1], np.int32)
anchors = np.zeros((len(self.size), 4), np.int32)
index = 0
for h, w in self.size:
anchors[index] = self.scale_anchor(base_anchor, h, w)
index += 1
return anchors
def gen_anchor_size(self):
"""Generate a list of anchor size"""
size = []
for width in self.anchor_width:
for height in self.anchor_height:
size.append((height, width))
return size
def scale_anchor(self, anchor, h, w):
x_ctr = (anchor[0] + anchor[2]) * 0.5
y_ctr = (anchor[1] + anchor[3]) * 0.5
scaled_anchor = anchor.copy()
scaled_anchor[0] = x_ctr - w / 2 # xmin
scaled_anchor[2] = x_ctr + w / 2 # xmax
scaled_anchor[1] = y_ctr - h / 2 # ymin
scaled_anchor[3] = y_ctr + h / 2 # ymax
return scaled_anchor
def _meshgrid(self, x, y):
"""Generate grid."""
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
yy = np.repeat(y, len(x))
return xx, yy
def grid_anchors(self, featmap_size, stride=16):
"""Generate anchor list."""
base_anchors = self.base_anchors
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w) * stride
shift_y = np.arange(0, feat_h) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
shifts = shifts.astype(base_anchors.dtype)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.reshape(-1, 4)
return all_anchors

View File

@ -0,0 +1,152 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn positive and negative sample screening for RPN."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from src.CTPN.BoundingBoxEncode import BoundingBoxEncode
class BboxAssignSample(nn.Cell):
"""
Bbox assigner and sampler definition.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
Examples:
BboxAssignSample(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSample, self).__init__()
cfg = config
self.batch_size = batch_size
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
self.zero_thr = Tensor(0.0, mstype.float16)
self.num_bboxes = num_bboxes
self.num_gts = cfg.num_gts
self.num_expected_pos = cfg.num_expected_pos
self.num_expected_neg = cfg.num_expected_neg
self.add_gt_as_proposals = add_gt_as_proposals
if self.add_gt_as_proposals:
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = BoundingBoxEncode()
self.scatterNdUpdate = P.ScatterNdUpdate()
self.scatterNd = P.ScatterNd()
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
self.print = P.Print()
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
self.less(max_overlaps_w_gt, self.neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
assigned_gt_inds4 = assigned_gt_inds3
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
num_pos = self.sum_inds(num_pos, -1)
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
pos_bboxes_ = self.gatherND(bboxes, pos_index)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
total_index = self.concat((pos_index, neg_index))
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
labels_total, self.cast(label_weights_total, mstype.bool_)

View File

@ -0,0 +1,190 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn proposal generator."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from src.CTPN.BoundingBoxDecode import BoundingBoxDecode
class Proposal(nn.Cell):
"""
Proposal subnet.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_classes (int) - Class number.
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
Returns:
Tuple, tuple of output tensor,(proposal, mask).
Examples:
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
"""
def __init__(self,
config,
batch_size,
num_classes,
use_sigmoid_cls,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)
):
super(Proposal, self).__init__()
cfg = config
self.batch_size = batch_size
self.num_classes = num_classes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = config.use_sigmoid_cls
if self.use_sigmoid_cls:
self.cls_out_channels = 1
self.activation = P.Sigmoid()
self.reshape_shape = (-1, 1)
else:
self.cls_out_channels = num_classes
self.activation = P.Softmax(axis=1)
self.reshape_shape = (-1, 2)
if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes))
self.num_pre = cfg.rpn_proposal_nms_pre
self.min_box_size = cfg.rpn_proposal_min_bbox_size
self.nms_thr = cfg.rpn_proposal_nms_thr
self.nms_post = cfg.rpn_proposal_nms_post
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
self.max_num = cfg.rpn_proposal_max_num
# Op Define
self.squeeze = P.Squeeze()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.feature_shapes = cfg.feature_shapes
self.transpose_shape = (1, 2, 0)
self.decode = BoundingBoxDecode()
self.nms = P.NMSWithMask(self.nms_thr)
self.concat_axis0 = P.Concat(axis=0)
self.concat_axis1 = P.Concat(axis=1)
self.split = P.Split(axis=1, output_num=5)
self.min = P.Minimum()
self.gatherND = P.GatherNd()
self.slice = P.Slice()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
self.tile = P.Tile()
self.set_train_local(config, training=True)
self.multi_10 = Tensor(10.0, mstype.float16)
def set_train_local(self, config, training=False):
"""Set training flag."""
self.training_local = training
cfg = config
self.topK_stage1 = ()
self.topK_shape = ()
total_max_topk_input = 0
if not self.training_local:
self.num_pre = cfg.rpn_nms_pre
self.min_box_size = cfg.rpn_min_bbox_min_size
self.nms_thr = cfg.rpn_nms_thr
self.nms_post = cfg.rpn_nms_post
self.max_num = cfg.rpn_max_num
k_num = self.num_pre
total_max_topk_input = k_num
self.topK_stage1 = k_num
self.topK_shape = (k_num, 1)
self.topKv2 = P.TopK(sorted=True)
self.topK_shape_stage2 = (self.max_num, 1)
self.min_float_num = -65536.0
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
self.shape = P.Shape()
self.print = P.Print()
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
proposals_tuple = ()
masks_tuple = ()
for img_id in range(self.batch_size):
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[img_id:img_id+1:1, ::, ::, ::])
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[img_id:img_id+1:1, ::, ::, ::])
proposals, masks = self.get_bboxes_single(rpn_cls_score_i, rpn_bbox_pred_i, anchor_list)
proposals_tuple += (proposals,)
masks_tuple += (masks,)
return proposals_tuple, masks_tuple
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
"""Get proposal boundingbox."""
mlvl_proposals = ()
mlvl_mask = ()
rpn_cls_score = self.transpose(cls_scores, self.transpose_shape)
rpn_bbox_pred = self.transpose(bbox_preds, self.transpose_shape)
anchors = mlvl_anchors
# (H, W, A*2)
rpn_cls_score_shape = self.shape(rpn_cls_score)
rpn_cls_score = self.reshape(rpn_cls_score, (rpn_cls_score_shape[0], \
rpn_cls_score_shape[1], -1, self.cls_out_channels))
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
rpn_cls_score = self.activation(rpn_cls_score)
if self.use_sigmoid_cls:
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score), mstype.float16)
else:
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 1]), mstype.float16)
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.num_pre)
topk_inds = self.reshape(topk_inds, self.topK_shape)
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape)))
proposals, _, mask_valid = self.nms(proposals_decode)
mlvl_proposals = mlvl_proposals + (proposals,)
mlvl_mask = mlvl_mask + (mask_valid,)
proposals = self.concat_axis0(mlvl_proposals)
masks = self.concat_axis0(mlvl_mask)
_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, mstype.float16)
scores_using = self.select(masks, scores, topk_mask)
_, topk_inds = self.topKv2(scores_using, self.max_num)
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
proposals = self.gatherND(proposals, topk_inds)
masks = self.gatherND(masks, topk_inds)
return proposals, masks

View File

@ -0,0 +1,228 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RPN for fasterRCNN"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops import functional as F
from src.CTPN.bbox_assign_sample import BboxAssignSample
class RpnRegClsBlock(nn.Cell):
"""
Rpn reg cls block for rpn layer
Args:
config(EasyDict) - Network construction config.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
Tensor, output tensor.
"""
def __init__(self,
config,
in_channels,
feat_channels,
num_anchors,
cls_out_channels):
super(RpnRegClsBlock, self).__init__()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.shape = (-1, 2*config.hidden_size)
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
self.shape1 = (config.num_step, config.rnn_batch_size, -1)
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step)
self.transpose = P.Transpose()
self.print = P.Print()
self.dropout = nn.Dropout(0.8)
def construct(self, x):
x = self.reshape(x, self.shape)
x = self.lstm_fc(x)
x1 = self.rpn_cls(x)
x1 = self.reshape(x1, self.shape1)
x1 = self.transpose(x1, (2, 1, 0))
x1 = self.reshape(x1, self.shape2)
x1 = self.transpose(x1, (1, 0, 2, 3))
x2 = self.rpn_reg(x)
x2 = self.reshape(x2, self.shape1)
x2 = self.transpose(x2, (2, 1, 0))
x2 = self.reshape(x2, self.shape2)
x2 = self.transpose(x2, (1, 0, 2, 3))
return x1, x2
class RPN(nn.Cell):
"""
ROI proposal network..
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
Tuple, tuple of output tensor.
Examples:
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
num_anchors=3, cls_out_channels=512)
"""
def __init__(self,
config,
batch_size,
in_channels,
feat_channels,
num_anchors,
cls_out_channels):
super(RPN, self).__init__()
cfg_rpn = config
self.cfg = config
self.num_bboxes = cfg_rpn.num_bboxes
self.feature_anchor_shape = cfg_rpn.feature_shapes
self.feature_anchor_shape = self.feature_anchor_shape[0] * \
self.feature_anchor_shape[1] * num_anchors * batch_size
self.num_anchors = num_anchors
self.batch_size = batch_size
self.test_batch_size = cfg_rpn.test_batch_size
self.num_layers = 1
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
self.use_sigmoid_cls = config.use_sigmoid_cls
if config.use_sigmoid_cls:
self.reshape_shape_cls = (-1,)
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
cls_out_channels = 1
else:
self.reshape_shape_cls = (-1, cls_out_channels)
self.loss_cls = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none")
self.rpn_convs_list = self._make_rpn_layer(self.num_layers, in_channels, feat_channels,\
num_anchors, cls_out_channels)
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=0)
self.fill = P.Fill()
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
self.trans_shape = (0, 2, 3, 1)
self.reshape_shape_reg = (-1, 4)
self.softmax = nn.Softmax()
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
self.num_bboxes = cfg_rpn.num_bboxes
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
self.CheckValid = P.CheckValid()
self.sum_loss = P.ReduceSum()
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
self.print = P.Print()
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
"""
make rpn layer for rpn proposal network
Args:
num_layers (int) - layer num.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
List, list of RpnRegClsBlock cells.
"""
rpn_layer = RpnRegClsBlock(self.cfg, in_channels, feat_channels, num_anchors, cls_out_channels)
return rpn_layer
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
'''
inputs(Tensor): Inputs tensor from lstm.
img_metas(Tensor): Image shape.
anchor_list(Tensor): Total anchor list.
gt_labels(Tensor): Ground truth labels.
gt_valids(Tensor): Whether ground truth is valid.
'''
rpn_cls_score_ori, rpn_bbox_pred_ori = self.rpn_convs_list(inputs)
rpn_cls_score = self.transpose(rpn_cls_score_ori, self.trans_shape)
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape_cls)
rpn_bbox_pred = self.transpose(rpn_bbox_pred_ori, self.trans_shape)
rpn_bbox_pred = self.reshape(rpn_bbox_pred, self.reshape_shape_reg)
output = ()
bbox_targets = ()
bbox_weights = ()
labels = ()
label_weights = ()
if self.training:
for i in range(self.batch_size):
valid_flag_list = self.cast(self.CheckValid(anchor_list, self.squeeze(img_metas[i:i + 1:1, ::])),\
mstype.int32)
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
gt_labels_i,
self.cast(valid_flag_list,
mstype.bool_),
anchor_list, gt_valids_i)
bbox_weight = self.cast(bbox_weight, mstype.float16)
label_weight = self.cast(label_weight, mstype.float16)
bbox_targets += (bbox_target,)
bbox_weights += (bbox_weight,)
labels += (label,)
label_weights += (label_weight,)
bbox_target_with_batchsize = self.concat(bbox_targets)
bbox_weight_with_batchsize = self.concat(bbox_weights)
label_with_batchsize = self.concat(labels)
label_weight_with_batchsize = self.concat(label_weights)
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
rpn_cls_score = self.cast(rpn_cls_score, mstype.float32)
if self.use_sigmoid_cls:
label_ = self.cast(label_, mstype.float32)
loss_cls = self.loss_cls(rpn_cls_score, label_)
loss_cls = loss_cls * label_weight_
loss_cls = self.sum_loss(loss_cls, (0,)) / self.num_expected_total
rpn_bbox_pred = self.cast(rpn_bbox_pred, mstype.float32)
bbox_target_ = self.cast(bbox_target_, mstype.float32)
loss_reg = self.loss_bbox(rpn_bbox_pred, bbox_target_)
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape, 1)), (1, 4))
loss_reg = loss_reg * bbox_weight_
loss_reg = self.sum_loss(loss_reg, (1,))
loss_reg = self.sum_loss(loss_reg, (0,)) / self.num_expected_total
loss_total = self.rpn_loss_cls_weight * loss_cls + self.rpn_loss_reg_weight * loss_reg
output = (loss_total, rpn_cls_score_ori, rpn_bbox_pred_ori, loss_cls, loss_reg)
else:
output = (self.placeh1, rpn_cls_score_ori, rpn_bbox_pred_ori, self.placeh1, self.placeh1)
return output

View File

@ -0,0 +1,177 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
def _weight_variable(shape, factor=0.01):
''''weight initialize'''
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
"""Conv2D wrapper."""
weights = 'ones'
layers = []
conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
if not weights_update:
conv.weight.requires_grad = False
layers += [conv]
layers += [_BatchNorm2dInit(out_channels)]
return nn.SequentialCell(layers)
def _fc(in_channels, out_channels):
'''full connection layer'''
weight = _weight_variable((out_channels, in_channels))
bias = _weight_variable((out_channels,))
return nn.Dense(in_channels, out_channels, weight, bias)
class VGG16FeatureExtraction(nn.Cell):
def __init__(self, weights_update=False):
"""
VGG16 feature extraction
Args:
weights_updata(bool): whether update weights for top two layers, default is False.
"""
super(VGG16FeatureExtraction, self).__init__()
self.relu = nn.ReLU()
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv1_1 = _conv(in_channels=3, out_channels=64, kernel_size=3,\
padding=1, weights_update=weights_update)
self.conv1_2 = _conv(in_channels=64, out_channels=64, kernel_size=3,\
padding=1, weights_update=weights_update)
self.conv2_1 = _conv(in_channels=64, out_channels=128, kernel_size=3,\
padding=1, weights_update=weights_update)
self.conv2_2 = _conv(in_channels=128, out_channels=128, kernel_size=3,\
padding=1, weights_update=weights_update)
self.conv3_1 = _conv(in_channels=128, out_channels=256, kernel_size=3, padding=1)
self.conv3_2 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv3_3 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv4_1 = _conv(in_channels=256, out_channels=512, kernel_size=3, padding=1)
self.conv4_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv4_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_1 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.cast = P.Cast()
def construct(self, x):
"""
:param x: shape=(B, 3, 224, 224)
:return:
"""
x = self.cast(x, mstype.float32)
x = self.conv1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.conv2_1(x)
x = self.relu(x)
x = self.conv2_2(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.conv3_1(x)
x = self.relu(x)
x = self.conv3_2(x)
x = self.relu(x)
x = self.conv3_3(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.conv4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.relu(x)
x = self.conv4_3(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.conv5_1(x)
x = self.relu(x)
x = self.conv5_2(x)
x = self.relu(x)
x = self.conv5_3(x)
x = self.relu(x)
return x
class VGG16Classfier(nn.Cell):
def __init__(self):
"""VGG16 classfier structure"""
super(VGG16Classfier, self).__init__()
self.flatten = P.Flatten()
self.relu = nn.ReLU()
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
self.fc2 = _fc(in_channels=4096, out_channels=4096)
self.batch_size = 32
self.reshape = P.Reshape()
def construct(self, x):
"""
:param x: shape=(B, 512, 7, 7)
:return:
"""
x = self.reshape(x, (self.batch_size, 7*7*512))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
return x
class VGG16(nn.Cell):
def __init__(self):
"""VGG16 construct for training backbone"""
super(VGG16, self).__init__()
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = VGG16Classfier()
self.fc3 = _fc(in_channels=4096, out_channels=1000)
def construct(self, x):
"""
:param x: shape=(B, 3, 224, 224)
:return: logits, shape=(B, 1000)
"""
feature_maps = self.feature_extraction(x)
x = self.max_pool(feature_maps)
x = self.classifier(x)
x = self.fc3(x)
return x

View File

@ -0,0 +1,133 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network parameters."""
from easydict import EasyDict
pretrain_config = EasyDict({
# LR
"base_lr": 0.0009,
"warmup_step": 30000,
"warmup_ratio": 1/3.0,
"total_epoch": 100,
})
finetune_config = EasyDict({
# LR
"base_lr": 0.0005,
"warmup_step": 300,
"warmup_ratio": 1/3.0,
"total_epoch": 50,
})
# use for low case number
config = EasyDict({
"img_width": 960,
"img_height": 576,
"keep_ratio": False,
"flip_ratio": 0.0,
"photo_ratio": 0.0,
"expand_ratio": 1.0,
# anchor
"feature_shapes": (36, 60),
"num_anchors": 14,
"anchor_base": 16,
"anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406],
"anchor_width": [16],
# rpn
"rpn_in_channels": 256,
"rpn_feat_channels": 512,
"rpn_loss_cls_weight": 1.0,
"rpn_loss_reg_weight": 3.0,
"rpn_cls_out_channels": 2,
# bbox_assign_sampler
"neg_iou_thr": 0.5,
"pos_iou_thr": 0.7,
"min_pos_iou": 0.001,
"num_bboxes": 30240,
"num_gts": 256,
"num_expected_neg": 512,
"num_expected_pos": 256,
#proposal
"activate_num_classes": 2,
"use_sigmoid_cls": False,
# train proposal
"rpn_proposal_nms_across_levels": False,
"rpn_proposal_nms_pre": 2000,
"rpn_proposal_nms_post": 1000,
"rpn_proposal_max_num": 1000,
"rpn_proposal_nms_thr": 0.7,
"rpn_proposal_min_bbox_size": 8,
# rnn structure
"input_size": 512,
"num_step": 60,
"rnn_batch_size": 36,
"hidden_size": 128,
# training
"warmup_mode": "linear",
"batch_size": 1,
"momentum": 0.9,
"save_checkpoint": True,
"save_checkpoint_epochs": 10,
"keep_checkpoint_max": 5,
"save_checkpoint_path": "./",
"use_dropout": False,
"loss_scale": 1,
"weight_decay": 1e-4,
# test proposal
"rpn_nms_pre": 2000,
"rpn_nms_post": 1000,
"rpn_max_num": 1000,
"rpn_nms_thr": 0.7,
"rpn_min_bbox_min_size": 8,
"test_iou_thr": 0.7,
"test_max_per_img": 100,
"test_batch_size": 1,
"use_python_proposal": False,
# text proposal connection
"max_horizontal_gap": 60,
"text_proposals_min_scores": 0.7,
"text_proposals_nms_thresh": 0.2,
"min_v_overlaps": 0.7,
"min_size_sim": 0.7,
"min_ratio": 0.5,
"line_min_score": 0.9,
"text_proposals_width": 16,
"min_num_proposals": 2,
# create dataset
"coco_root": "",
"coco_train_data_type": "",
"cocotext_json": "",
"icdar11_train_path": [],
"icdar13_train_path": [],
"icdar15_train_path": [],
"icdar13_test_path": [],
"flick_train_path": [],
"svt_train_path": [],
"pretrain_dataset_path": "",
"finetune_dataset_path": "",
"test_dataset_path": "",
# training dataset
"pretraining_dataset_file": "",
"finetune_dataset_file": ""
})

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""convert icdar2015 dataset label"""
import os
import argparse
def init_args():
parser = argparse.ArgumentParser('')
parser.add_argument('-s', '--src_label_path', type=str, default='./',
help='Directory containing icdar2015 train label')
parser.add_argument('-t', '--target_label_path', type=str, default='test.xml',
help='Directory where save the icdar2015 label after convert')
return parser.parse_args()
def convert():
args = init_args()
anno_file = os.listdir(args.src_label_path)
annos = {}
# read
for file in anno_file:
gt = open(os.path.join(args.src_label_path, file), 'r', encoding='UTF-8-sig').read().splitlines()
label_list = []
label_name = os.path.basename(file)
for each_label in gt:
print(file)
spt = each_label.split(',')
print(spt)
if "###" in spt[8]:
continue
else:
x1 = min(int(spt[0]), int(spt[6]))
y1 = min(int(spt[1]), int(spt[3]))
x2 = max(int(spt[2]), int(spt[4]))
y2 = max(int(spt[5]), int(spt[7]))
label_list.append([x1, y1, x2, y2])
annos[label_name] = label_list
# write
if not os.path.exists(args.target_label_path):
os.makedirs(args.target_label_path)
for label_file, pos in annos.items():
tgt_anno_file = os.path.join(args.target_label_path, label_file)
f = open(tgt_anno_file, 'w', encoding='UTF-8-sig')
for tgt_label in pos:
str_pos = str(tgt_label[0]) + ',' + str(tgt_label[1]) + ',' + str(tgt_label[2]) + ',' + str(tgt_label[3])
f.write(str_pos)
f.write("\n")
f.close()
if __name__ == "__main__":
convert()

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""convert svt dataset label"""
import os
import argparse
from xml.etree import ElementTree as ET
import numpy as np
def init_args():
parser = argparse.ArgumentParser('')
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
help='Directory containing images')
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
help='Directory where character dictionaries for the dataset were stored')
parser.add_argument('-o', '--location_dir', type=str, default='./location',
help='Directory where ord map dictionaries for the dataset were stored')
return parser.parse_args()
def xml_to_dict(xml_file, save_file=False):
tree = ET.parse(xml_file)
root = tree.getroot()
imgs_labels = []
for ch in root:
im_label = {}
for ch01 in ch:
if ch01.tag in "address":
continue
elif ch01.tag in 'taggedRectangles':
# multiple children
rect_list = []
for ch02 in ch01:
rect = {}
rect['location'] = ch02.attrib
rect['label'] = ch02[0].text
rect_list.append(rect)
im_label['rect'] = rect_list
else:
im_label[ch01.tag] = ch01.text
imgs_labels.append(im_label)
if save_file:
np.save("annotation_train.npy", imgs_labels)
return imgs_labels
def convert():
args = init_args()
if not os.path.exists(args.dataset_dir):
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
if not os.path.exists(args.xml_file):
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
if not os.path.exists(args.location_dir):
os.makedirs(args.location_dir)
ims_labels_dict = xml_to_dict(args.xml_file, True)
num_images = len(ims_labels_dict)
print("Converting annotation, {} images in total ".format(num_images))
for i in range(num_images):
img_label = ims_labels_dict[i]
image_name = img_label['imageName']
rects = img_label['rect']
print("processing image: {}".format(image_name))
location_file_name = os.path.join(args.location_dir, os.path.basename(image_name).replace(".jpg", ".txt"))
f = open(location_file_name, 'w')
for j, rect in enumerate(rects):
rect = rects[j]
location = rect['location']
h = int(location['height'])
w = int(location['width'])
x = int(location['x'])
y = int(location['y'])
pos = [x, y, x+w, y+h]
str_pos = str(pos[0]) + "," + str(pos[1]) + "," + str(pos[2]) + "," + str(pos[3])
f.write(str_pos)
f.write("\n")
f.close()
if __name__ == "__main__":
convert()

View File

@ -0,0 +1,177 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from __future__ import division
import os
import numpy as np
from PIL import Image
from mindspore.mindrecord import FileWriter
from src.config import config
def create_coco_label():
"""Create image label."""
image_files = []
image_anno_dict = {}
coco_root = config.coco_root
data_type = config.coco_train_data_type
from src.coco_text import COCO_Text
anno_json = config.cocotext_json
ct = COCO_Text(anno_json)
image_ids = ct.getImgIds(imgIds=ct.train,
catIds=[('legibility', 'legible')])
for img_id in image_ids:
image_info = ct.loadImgs(img_id)[0]
file_name = image_info['file_name'][15:]
anno_ids = ct.getAnnIds(imgIds=img_id)
anno = ct.loadAnns(anno_ids)
image_path = os.path.join(coco_root, data_type, file_name)
annos = []
im = Image.open(image_path)
width, _ = im.size
for label in anno:
bbox = label["bbox"]
bbox_width = int(bbox[2])
if 60 * bbox_width < width:
continue
x1, x2 = int(bbox[0]), int(bbox[0] + bbox[2])
y1, y2 = int(bbox[1]), int(bbox[1] + bbox[3])
annos.append([x1, y1, x2, y2] + [1])
if annos:
image_anno_dict[image_path] = np.array(annos)
image_files.append(image_path)
return image_files, image_anno_dict
def create_anno_dataset_label(train_img_dirs, train_txt_dirs):
image_files = []
image_anno_dict = {}
# read
img_basenames = []
for file in os.listdir(train_img_dirs):
# Filter git file.
if 'gif' not in file:
img_basenames.append(os.path.basename(file))
img_names = []
for item in img_basenames:
temp1, _ = os.path.splitext(item)
img_names.append((temp1, item))
for img, img_basename in img_names:
image_path = train_img_dirs + '/' + img_basename
annos = []
if len(img) == 6 and '_' not in img_basename:
gt = open(train_txt_dirs + '/' + img + '.txt').read().splitlines()
if img.isdigit() and int(img) > 1200:
continue
for img_each_label in gt:
spt = img_each_label.replace(',', '').split(' ')
if ' ' not in img_each_label:
spt = img_each_label.split(',')
annos.append([spt[0], spt[1], str(int(spt[0]) + int(spt[2])), str(int(spt[1]) + int(spt[3]))] + [1])
if annos:
image_anno_dict[image_path] = np.array(annos)
image_files.append(image_path)
return image_files, image_anno_dict
def create_icdar_svt_label(train_img_dir, train_txt_dir, prefix):
image_files = []
image_anno_dict = {}
img_basenames = []
for file_name in os.listdir(train_img_dir):
if 'gif' not in file_name:
img_basenames.append(os.path.basename(file_name))
img_names = []
for item in img_basenames:
temp1, _ = os.path.splitext(item)
img_names.append((temp1, item))
for img, img_basename in img_names:
image_path = train_img_dir + '/' + img_basename
annos = []
file_name = prefix + img + ".txt"
file_path = os.path.join(train_txt_dir, file_name)
gt = open(file_path, 'r', encoding='UTF-8-sig').read().splitlines()
if not gt:
continue
for img_each_label in gt:
spt = img_each_label.replace(',', '').split(' ')
if ' ' not in img_each_label:
spt = img_each_label.split(',')
annos.append([spt[0], spt[1], spt[2], spt[3]] + [1])
if annos:
image_anno_dict[image_path] = np.array(annos)
image_files.append(image_path)
return image_files, image_anno_dict
def create_train_dataset(dataset_type):
image_files = []
image_anno_dict = {}
if dataset_type == "pretraining":
# pretrianing: coco, flick, icdar2013 train, icdar2015, svt
coco_image_files, coco_anno_dict = create_coco_label()
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
config.flick_train_path[1])
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
config.icdar13_train_path[1], "gt_img_")
icdar15_image_files, icdar15_anno_dict = create_icdar_svt_label(config.icdar15_train_path[0],
config.icdar15_train_path[1], "gt_")
svt_image_files, svt_anno_dict = create_icdar_svt_label(config.svt_train_path[0], config.svt_train_path[1], "")
image_files = coco_image_files + flick_image_files + icdar13_image_files + icdar15_image_files + svt_image_files
image_anno_dict = {**coco_anno_dict, **flick_anno_dict, \
**icdar13_anno_dict, **icdar15_anno_dict, **svt_anno_dict}
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.pretrain_dataset_path, \
prefix="ctpn_pretrain.mindrecord", file_num=8)
elif dataset_type == "finetune":
# finetune: icdar2011, icdar2013 train, flick
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
config.flick_train_path[1])
icdar11_image_files, icdar11_anno_dict = create_icdar_svt_label(config.icdar11_train_path[0],
config.icdar11_train_path[1], "gt_")
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
config.icdar13_train_path[1], "gt_img_")
image_files = flick_image_files + icdar11_image_files + icdar13_image_files
image_anno_dict = {**flick_anno_dict, **icdar11_anno_dict, **icdar13_anno_dict}
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.finetune_dataset_path, \
prefix="ctpn_finetune.mindrecord", file_num=8)
elif dataset_type == "test":
# test: icdar2013 test
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
config.icdar13_test_path[1], "")
image_files = icdar_test_image_files
image_anno_dict = icdar_test_anno_dict
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
prefix="ctpn_test.mindrecord", file_num=1)
else:
print("dataset_type should be pretraining, finetune, test")
def data_to_mindrecord_byte_image(image_files, image_anno_dict, dst_dir, prefix="cptn_mlt.mindrecord", file_num=1):
"""Create MindRecord file."""
mindrecord_path = os.path.join(dst_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
ctpn_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 5]},
}
writer.add_schema(ctpn_json, "ctpn_json")
for image_name in image_files:
with open(image_name, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
print("img name is {}, anno is {}".format(image_name, annos))
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
if __name__ == "__main__":
create_train_dataset("pretraining")
create_train_dataset("finetune")
create_train_dataset("test")

View File

@ -0,0 +1,148 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CPTN network definition."""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from src.CTPN.rpn import RPN
from src.CTPN.anchor_generator import AnchorGenerator
from src.CTPN.proposal_generator import Proposal
from src.CTPN.vgg16 import VGG16FeatureExtraction
class BiLSTM(nn.Cell):
"""
Define a BiLSTM network which contains two LSTM layers
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, config, is_training=True):
super(BiLSTM, self).__init__()
self.is_training = is_training
self.batch_size = config.batch_size * config.rnn_batch_size
print("batch size is {} ".format(self.batch_size))
self.input_size = config.input_size
self.hidden_size = config.hidden_size
self.num_step = config.num_step
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / self.hidden_size) ** 0.5
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
self.rnn_bw = P.DynamicRNN(forget_bias=0.0)
self.w1 = Parameter(np.random.uniform(-k, k, \
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1")
self.w1_bw = Parameter(np.random.uniform(-k, k, \
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.reverse_seq = P.ReverseV2(axis=[0])
self.concat = P.Concat()
self.transpose = P.Transpose()
self.concat1 = P.Concat(axis=2)
self.dropout = nn.Dropout(0.7)
self.use_dropout = config.use_dropout
self.reshape = P.Reshape()
self.transpose = P.Transpose()
def construct(self, x):
if self.use_dropout:
x = self.dropout(x)
x = self.cast(x, mstype.float16)
bw_x = self.reverse_seq(x)
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
y1_bw = self.reverse_seq(y1_bw)
output = self.concat1((y1, y1_bw))
return output
class CTPN(nn.Cell):
"""
Define CTPN network
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, config, is_training=True):
super(CTPN, self).__init__()
self.config = config
self.is_training = is_training
self.num_step = config.num_step
self.input_size = config.input_size
self.batch_size = config.batch_size
self.hidden_size = config.hidden_size
self.vgg16_feature_extractor = VGG16FeatureExtraction()
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16)
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cast = P.Cast()
# rpn block
self.rpn_with_loss = RPN(config,
self.batch_size,
config.rpn_in_channels,
config.rpn_feat_channels,
config.num_anchors,
config.rpn_cls_out_channels)
self.anchor_generator = AnchorGenerator(config)
self.featmap_size = config.feature_shapes
self.anchor_list = self.get_anchors(self.featmap_size)
self.proposal_generator_test = Proposal(config,
config.test_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False)
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
# (1,3,600,900)
x = self.vgg16_feature_extractor(img_data)
x = self.conv(x)
x = self.cast(x, mstype.float16)
# (1, 512, 38, 57)
x = self.transpose(x, (0, 2, 1, 3))
x = self.reshape(x, (-1, self.input_size, self.num_step))
x = self.transpose(x, (2, 0, 1))
# (57, 38, 512)
x = self.rnn(x)
# (57, 38, 256)
#x = self.cast(x, mstype.float32)
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x,
img_metas,
self.anchor_list,
gt_bboxes,
gt_labels,
gt_valids)
if self.training:
return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
return proposal, proposal_mask
def get_anchors(self, featmap_size):
anchors = self.anchor_generator.grid_anchors(featmap_size)
return Tensor(anchors, mstype.float16)

View File

@ -0,0 +1,342 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn dataset"""
from __future__ import division
import os
import numpy as np
from numpy import random
import mmcv
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as CC
import mindspore.common.dtype as mstype
from mindspore.mindrecord import FileWriter
from src.config import config
class PhotoMetricDistortion:
"""Photo Metric Distortion"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, img, boxes, labels):
img = img.astype('float32')
if random.randint(2):
delta = random.uniform(-self.brightness_delta, self.brightness_delta)
img += delta
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
return img, boxes, labels
class Expand:
"""expand image"""
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
if to_rgb:
self.mean = mean[::-1]
else:
self.mean = mean
self.min_ratio, self.max_ratio = ratio_range
def __call__(self, img, boxes, labels):
if random.randint(2):
return img, boxes, labels
h, w, c = img.shape
ratio = random.uniform(self.min_ratio, self.max_ratio)
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean).astype(img.dtype)
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
expand_img[top:top + h, left:left + w] = img
img = expand_img
boxes += np.tile((left, top), 2)
return img, boxes, labels
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""rescale operation for image"""
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
if img_data.shape[0] > config.img_height:
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True)
scale_factor = scale_factor * scale_factor2
img_shape = np.append(img_shape, scale_factor)
img_shape = np.asarray(img_shape, dtype=np.float32)
gt_bboxes = gt_bboxes * scale_factor
gt_bboxes = split_gtbox_label(gt_bboxes)
if gt_bboxes.shape[0] != 0:
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""resize operation for image"""
img_data = img
img_data, w_scale, h_scale = mmcv.imresize(
img_data, (config.img_width, config.img_height), return_scale=True)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
img_shape = (config.img_height, config.img_width, 1.0)
img_shape = np.asarray(img_shape, dtype=np.float32)
gt_bboxes = gt_bboxes * scale_factor
gt_bboxes = split_gtbox_label(gt_bboxes)
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
"""resize operation for image of eval"""
img_data = img
img_data, w_scale, h_scale = mmcv.imresize(
img_data, (config.img_width, config.img_height), return_scale=True)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
img_shape = (config.img_height, config.img_width)
img_shape = np.append(img_shape, (h_scale, w_scale))
img_shape = np.asarray(img_shape, dtype=np.float32)
gt_bboxes = gt_bboxes * scale_factor
shape = gt_bboxes.shape
label_column = np.ones((shape[0], 1), dtype=int)
gt_bboxes = np.concatenate((gt_bboxes, label_column), axis=1)
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num):
"""flipped generation"""
img_data = img
flipped = gt_bboxes.copy()
_, w, _ = img_data.shape
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
return (img_data, img_shape, flipped, gt_label, gt_num)
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num):
img_data = img[:, :, ::-1]
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""photo crop operation for image"""
random_photo = PhotoMetricDistortion()
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""expand operation for image"""
expand = Expand()
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
return (img, img_shape, gt_bboxes, gt_label, gt_num)
def split_gtbox_label(gt_bbox_total):
"""split ground truth box label"""
gtbox_list = []
box_num, _ = gt_bbox_total.shape
for i in range(box_num):
gt_bbox = gt_bbox_total[i]
if gt_bbox[0] % 16 != 0:
gt_bbox[0] = (gt_bbox[0] // 16) * 16
if gt_bbox[2] % 16 != 0:
gt_bbox[2] = (gt_bbox[2] // 16 + 1) * 16
x0_array = np.arange(gt_bbox[0], gt_bbox[2], 16)
for x0 in x0_array:
gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1])
return np.array(gtbox_list)
def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid):
"""pad ground truth label"""
pad_max_number = 256
gt_label = gt_bboxes[:, 4]
gt_valid = gt_bboxes[:, 4]
if gt_bboxes.shape[0] < 256:
gt_box = np.pad(gt_bboxes, ((0, pad_max_number - gt_bboxes.shape[0]), (0, 0)), \
mode="constant", constant_values=0)
gt_label = np.pad(gt_label, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=-1)
gt_valid = np.pad(gt_valid, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=0)
else:
print("WARNING label num is high than 256")
gt_box = gt_bboxes[0:pad_max_number]
gt_label = gt_label[0:pad_max_number]
gt_valid = gt_valid[0:pad_max_number]
return (img, img_shape, gt_box[:, :4], gt_label, gt_valid)
def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid):
image_shape = image_shape[:2]
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid
if config.keep_ratio:
input_data = rescale_column(*input_data)
else:
input_data = resize_column_test(*input_data)
input_data = pad_label(*input_data)
input_data = image_bgr_rgb(*input_data)
output_data = input_data
return output_data
def _data_aug(image, box, is_training):
"""Data augmentation function."""
image_bgr = image.copy()
image_bgr[:, :, 0] = image[:, :, 2]
image_bgr[:, :, 1] = image[:, :, 1]
image_bgr[:, :, 2] = image[:, :, 0]
image_shape = image_bgr.shape[:2]
gt_box = box[:, :4]
gt_label = box[:, 4]
gt_valid = box[:, 4]
input_data = image_bgr, image_shape, gt_box, gt_label, gt_valid
if not is_training:
return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_valid)
expand = (np.random.rand() < config.expand_ratio)
if expand:
input_data = expand_column(*input_data)
input_data = photo_crop_column(*input_data)
if config.keep_ratio:
input_data = rescale_column(*input_data)
else:
input_data = resize_column(*input_data)
input_data = pad_label(*input_data)
input_data = image_bgr_rgb(*input_data)
output_data = input_data
return output_data
return _data_aug(image, box, is_training)
def anno_parser(annos_str):
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos
def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path."""
image_files = []
image_anno_dict = {}
if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.")
if not os.path.isfile(anno_path):
raise RuntimeError("Annotation file is not valid.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ')
file_name = line_split[0]
image_path = os.path.join(image_dir, file_name)
if os.path.isfile(image_path):
image_anno_dict[image_path] = anno_parser(line_split[1:])
image_files.append(image_path)
return image_files, image_anno_dict
def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord", file_num=8):
"""Create MindRecord file."""
mindrecord_dir = config.mindrecord_dir
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
image_files, image_anno_dict = create_icdar_test_label()
ctpn_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 6]},
}
writer.add_schema(ctpn_json, "ctpn_json")
for image_name in image_files:
with open(image_name, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=4):
"""Creatr deeptext dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\
num_parallel_workers=8, shuffle=is_training)
decode = C.Decode()
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
type_cast0 = CC.TypeCast(mstype.float32)
type_cast1 = CC.TypeCast(mstype.float16)
type_cast2 = CC.TypeCast(mstype.int32)
type_cast3 = CC.TypeCast(mstype.bool_)
if is_training:
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
num_parallel_workers=12)
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=12)
else:
ds = ds.map(operations=compose_map_func,
input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=24)
# transpose_column from python to c
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
ds = ds.map(operations=[type_cast1], input_columns=["box"])
ds = ds.map(operations=[type_cast2], input_columns=["label"])
ds = ds.map(operations=[type_cast3], input_columns=["valid_num"])
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for deeptext"""
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, base_step):
"""dynamic learning rate generator"""
base_lr = config.base_lr
total_steps = int(base_step * config.total_epoch)
warmup_steps = config.warmup_step
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

View File

@ -0,0 +1,153 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn training network wrapper."""
import time
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = time.time()
time_stamp_init = True
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rpn_cls_loss = cb_params.net_outputs[1].asnumpy()
rpn_reg_loss = cb_params.net_outputs[2].asnumpy()
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum / self.count
rpn_cls_loss = self.rpn_cls_loss_sum / self.count
rpn_reg_loss = self.rpn_reg_loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"%
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rpn_cls_loss, rpn_reg_loss))
loss_file.write("\n")
loss_file.close()
class LossNet(nn.Cell):
"""FasterRcnn loss method"""
def construct(self, x1, x2, x3):
return x1
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone
class TrainOneStepCell(nn.Cell):
"""
Network training package class.
Append an optimizer to the training network after that the construct function
can be called to create the backward graph.
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
self.reduce_flag = reduce_flag
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
weights = self.weights
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss

View File

@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================import numpy as np
import numpy as np
from src.text_connector.utils import clip_boxes, fit_y
from src.text_connector.get_successions import get_successions
def connect_text_lines(text_proposals, scores, size):
"""
Connect text lines
Args:
text_proposals(numpy.array): Predict text proposals.
scores(numpy.array): Bbox predicts scores.
size(numpy.array): Image size.
Returns:
text_recs(numpy.array): Text boxes after connect.
"""
graph = get_successions(text_proposals, scores, size)
text_lines = np.zeros((len(graph), 5), np.float32)
for index, indices in enumerate(graph):
text_line_boxes = text_proposals[list(indices)]
x0 = np.min(text_line_boxes[:, 0])
x1 = np.max(text_line_boxes[:, 2])
offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5
lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)
# the score of a text line is the average score of the scores
# of all text proposals contained in the text line
score = scores[list(indices)].sum() / float(len(indices))
text_lines[index, 0] = x0
text_lines[index, 1] = min(lt_y, rt_y)
text_lines[index, 2] = x1
text_lines[index, 3] = max(lb_y, rb_y)
text_lines[index, 4] = score
text_lines = clip_boxes(text_lines, size)
text_recs = np.zeros((len(text_lines), 9), np.float)
index = 0
for line in text_lines:
xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3]
text_recs[index, 0] = xmin
text_recs[index, 1] = ymin
text_recs[index, 2] = xmax
text_recs[index, 3] = ymax
text_recs[index, 4] = line[4]
index = index + 1
return text_recs

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from src.config import config
from src.text_connector.utils import nms
from src.text_connector.connect_text_lines import connect_text_lines
def filter_proposal(proposals, scores):
"""
Filter text proposals
Args:
proposals(numpy.array): Text proposals.
Returns:
proposals(numpy.array): Text proposals after filter.
"""
inds = np.where(scores > config.text_proposals_min_scores)[0]
keep_proposals = proposals[inds]
keep_scores = scores[inds]
sorted_inds = np.argsort(keep_scores.ravel())[::-1]
keep_proposals, keep_scores = keep_proposals[sorted_inds], keep_scores[sorted_inds]
nms_inds = nms(np.hstack((keep_proposals, keep_scores)), config.text_proposals_nms_thresh)
keep_proposals, keep_scores = keep_proposals[nms_inds], keep_scores[nms_inds]
return keep_proposals, keep_scores
def filter_boxes(boxes):
"""
Filter text boxes
Args:
boxes(numpy.array): Text boxes.
Returns:
boxes(numpy.array): Text boxes after filter.
"""
heights = np.zeros((len(boxes), 1), np.float)
widths = np.zeros((len(boxes), 1), np.float)
scores = np.zeros((len(boxes), 1), np.float)
index = 0
for box in boxes:
widths[index] = abs(box[2] - box[0])
heights[index] = abs(box[3] - box[1])
scores[index] = abs(box[4])
index += 1
return np.where((widths / heights > config.min_ratio) & (scores > config.line_min_score) &\
(widths > (config.text_proposals_width * config.min_num_proposals)))[0]
def detect(text_proposals, scores, size):
"""
Detect text boxes
Args:
text_proposals(numpy.array): Predict text proposals.
scores(numpy.array): Bbox predicts scores.
size(numpy.array): Image size.
Returns:
boxes(numpy.array): Text boxes after connect.
"""
keep_proposals, keep_scores = filter_proposal(text_proposals, scores)
connect_boxes = connect_text_lines(keep_proposals, keep_scores, size)
boxes = connect_boxes[filter_boxes(connect_boxes)]
return boxes

View File

@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from src.config import config
from src.text_connector.utils import overlaps_v, size_similarity
def get_successions(text_proposals, scores, im_size):
"""
Get successions text boxes.
Args:
text_proposals(numpy.array): Predict text proposals.
scores(numpy.array): Bbox predicts scores.
size(numpy.array): Image size.
Returns:
sub_graph(list): Proposals graph.
"""
bboxes_table = [[] for _ in range(int(im_size[1]))]
for index, box in enumerate(text_proposals):
bboxes_table[int(box[0])].append(index)
graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool)
for index, box in enumerate(text_proposals):
successions_left = []
for left in range(int(box[0]) + 1, min(int(box[0]) + config.max_horizontal_gap + 1, im_size[1])):
adj_box_indices = bboxes_table[left]
for adj_box_index in adj_box_indices:
if meet_v_iou(text_proposals, adj_box_index, index):
successions_left.append(adj_box_index)
if successions_left:
break
if not successions_left:
continue
succession_index = successions_left[np.argmax(scores[successions_left])]
box_right = text_proposals[succession_index]
succession_right = []
for right in range(int(box_right[0]) - 1, max(int(box_right[0] - config.max_horizontal_gap), 0) - 1, -1):
adj_box_indices = bboxes_table[right]
for adj_box_index in adj_box_indices:
if meet_v_iou(text_proposals, adj_box_index, index):
succession_right.append(adj_box_index)
if succession_right:
break
if scores[index] >= np.max(scores[succession_right]):
graph[index, succession_index] = True
sub_graph = get_sub_graph(graph)
return sub_graph
def get_sub_graph(graph):
"""
Get successions text boxes.
Args:
graph(numpy.array): proposal graph
Returns:
sub_graph(list): Proposals graph after connect.
"""
sub_graphs = []
for index in range(graph.shape[0]):
if not graph[:, index].any() and graph[index, :].any():
v = index
sub_graphs.append([v])
while graph[v, :].any():
v = np.where(graph[v, :])[0][0]
sub_graphs[-1].append(v)
return sub_graphs
def meet_v_iou(text_proposals, index1, index2):
"""
Calculate vertical iou.
Args:
text_proposals(numpy.array): tex proposals
index1(int): text_proposal index
tindex2(int): text proposal index
Returns:
sub_graph(list): Proposals graph after connect.
"""
heights = text_proposals[:, 3] - text_proposals[:, 1] + 1
return overlaps_v(text_proposals, index1, index2) >= config.min_v_overlaps and \
size_similarity(heights, index1, index2) >= config.min_size_sim

View File

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
def threshold(coords, min_, max_):
return np.maximum(np.minimum(coords, max_), min_)
def clip_boxes(boxes, im_shape):
"""
Clip boxes to image boundaries.
Args:
boxes(numpy.array):bounding box.
im_shape(numpy.array): image shape.
Return:
boxes(numpy.array):boundding box after clip.
"""
boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1)
boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1)
return boxes
def overlaps_v(text_proposals, index1, index2):
"""
Calculate vertical overlap ratio.
Args:
text_proposals(numpy.array): Text proposlas.
index1(int): First text proposal.
index2(int): Second text proposal.
Return:
overlap(float32): vertical overlap.
"""
h1 = text_proposals[index1][3] - text_proposals[index1][1] + 1
h2 = text_proposals[index2][3] - text_proposals[index2][1] + 1
y0 = max(text_proposals[index2][1], text_proposals[index1][1])
y1 = min(text_proposals[index2][3], text_proposals[index1][3])
return max(0, y1 - y0 + 1) / min(h1, h2)
def size_similarity(heights, index1, index2):
"""
Calculate vertical size similarity ratio.
Args:
heights(numpy.array): Text proposlas heights.
index1(int): First text proposal.
index2(int): Second text proposal.
Return:
overlap(float32): vertical overlap.
"""
h1 = heights[index1]
h2 = heights[index2]
return min(h1, h2) / max(h1, h2)
def fit_y(X, Y, x1, x2):
if np.sum(X == X[0]) == len(X):
return Y[0], Y[0]
p = np.poly1d(np.polyfit(X, Y, 1))
return p(x1), p(x2)
def nms(bboxs, thresh):
"""
Args:
text_proposals(numpy.array): tex proposals
index1(int): text_proposal index
tindex2(int): text proposal index
"""
x1, y1, x2, y2, scores = np.split(bboxs, 5, axis=1)
x1 = bboxs[:, 0]
y1 = bboxs[:, 1]
x2 = bboxs[:, 2]
y2 = bboxs[:, 3]
scores = bboxs[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
num_dets = bboxs.shape[0]
suppressed = np.zeros(num_dets, dtype=np.int32)
keep = []
for _i in range(num_dets):
i = order[_i]
if suppressed[i] == 1:
continue
keep.append(i)
x1_i = x1[i]
y1_i = y1[i]
x2_i = x2[i]
y2_i = y2[i]
area_i = areas[i]
for _j in range(_i + 1, num_dets):
j = order[_j]
if suppressed[j] == 1:
continue
x1_j = max(x1_i, x1[j])
y1_j = max(y1_i, y1[j])
x2_j = min(x2_i, x2[j])
y2_j = min(y2_i, y2[j])
w = max(0.0, x2_j - x1_j + 1)
h = max(0.0, y2_j - y1_j + 1)
inter = w*h
overlap = inter / (area_i+areas[j]-inter)
if overlap >= thresh:
suppressed[j] = 1
return keep

View File

@ -0,0 +1,119 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train CTPN and get checkpoint files."""
import os
import time
import argparse
import ast
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Momentum
from mindspore.common import set_seed
from src.ctpn import CTPN
from src.config import config, pretrain_config, finetune_config
from src.dataset import create_ctpn_dataset
from src.lr_schedule import dynamic_lr
from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell
set_seed(1)
parser = argparse.ArgumentParser(description="CTPN training")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
parser.add_argument("--task_type", type=str, default="Pretraining",\
choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True)
if __name__ == '__main__':
if args_opt.run_distribute:
rank = args_opt.rank_id
device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1
if args_opt.task_type == "Pretraining":
print("Start to do pretraining")
mindrecord_file = config.pretraining_dataset_file
training_cfg = pretrain_config
else:
print("Start to do finetune")
mindrecord_file = config.finetune_dataset_file
training_cfg = finetune_config
print("CHECKING MINDRECORD FILES ...")
while not os.path.exists(mindrecord_file + ".db"):
time.sleep(5)
print("CHECKING MINDRECORD FILES DONE!")
loss_scale = float(config.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as ctpn_pretrain.mindrecord0.
dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size()
net = CTPN(config=config, is_training=True)
net = net.set_train()
load_path = args_opt.pre_trained
if args_opt.task_type == "Pretraining":
print("load backbone vgg16 ckpt {}".format(args_opt.pre_trained))
param_dict = load_checkpoint(load_path)
for item in list(param_dict.keys()):
if not item.startswith('vgg16_feature_extractor'):
param_dict.pop(item)
load_param_into_net(net, param_dict)
else:
if load_path != "":
print("load pretrain ckpt {}".format(args_opt.pre_trained))
param_dict = load_checkpoint(load_path)
load_param_into_net(net, param_dict)
loss = LossNet()
lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32)
opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb]
if config.save_checkpoint:
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size,
keep_checkpoint_max=config.keep_checkpoint_max)
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig)
cb += [ckpoint_cb]
model = Model(net)
model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True)