!6387 psenet open source

Merge pull request !6387 from 吴书全/wsqpse0917
This commit is contained in:
mindspore-ci-bot 2020-09-19 10:57:10 +08:00 committed by Gitee
commit d26df7cdcb
22 changed files with 2003 additions and 0 deletions

View File

@ -0,0 +1,236 @@
# Contents
- [PSENet Description](#PSENet-description)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [How to use](#how-to-use)
- [Inference](#inference)
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
- [Transfer Learning](#transfer-learning)
# [PSENet Description](#contents)
With the development of convolutional neural network, scene text detection technology has been developed rapidly. However, there are still two problems in this algorithm, which hinders its application in industry. On the one hand, most of the existing algorithms require quadrilateral bounding boxes to accurately locate arbitrary shape text. On the other hand, two adjacent instances of text can cause error detection overwriting both instances. Traditionally, a segmentation-based approach can solve the first problem, but usually not the second. To solve these two problems, a new PSENet (PSENet) is proposed, which can accurately detect arbitrary shape text instances. More specifically, PSENet generates different scale kernels for each text instance and gradually expands the minimum scale kernel to a text instance with full shape. Because of the large geometric margins between the minimum scale kernels, our method can effectively segment closed text instances, making it easier to detect arbitrary shape text instances. The effectiveness of PSENet has been verified by numerous experiments on CTW1500, full text, ICDAR 2015, and ICDAR 2017 MLT.
[Paper](https://openaccess.thecvf.com/content_CVPR_2019/html/Wang_Shape_Robust_Text_Detection_With_Progressive_Scale_Expansion_Network_CVPR_2019_paper.html): Wenhai Wang, Enze Xie, Xiang Li, Wenbo Hou, Tong Lu, Gang Yu, Shuai Shao; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 9336-9345
# PSENet Example
## Description
Progressive Scale Expansion Network (PSENet) is a text detector which is able to well detect the arbitrary-shape text in natural scene.
# [Dataset](#contents)
Dataset used: [ICDAR2015](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)
A training set of 1000 images containing about 4500 readable words
A testing set containing about 2000 readable words
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
- install Mindspore
- install [pyblind11](https://github.com/pybind/pybind11)
- install [Opencv3.4](https://docs.opencv.org/3.4.9/d7/d9f/tutorial_linux_install.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```python
# run distributed training example
sh scripts/run_distribute_train.sh pretrained_model.ckpt
#setup opencv library
download pyblind11, opencv3.4,setup opencv3.4
#make so file
run src/ETSNET/pse/Makefile; make libadaptor.so
#run test.py
python test.py --ckpt=pretrained_model.ckpt
#download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization).
#click "My Methods" button,then download Evaluation Scripts
download script.py
# run evaluation example
sh scripts/run_eval_ascend.sh
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```
└── PSENet
├── README.md // descriptions about PSENet
├── scripts
├── run_distribute_train.sh // shell script for distributed
└── eval_ic15.sh // shell script for evaluation
├── src
├── __init__.py
├── generate_hccn_file.py // creating rank.json
├── ETSNET
├── __init__.py
├── base.py // convolution and BN operator
├── dice_loss.py // calculate PSENet loss value
├── etsnet.py // Subnet in PSENet
├── fpn.py // Subnet in PSENet
├── resnet50.py // Subnet in PSENet
├── pse // Subnet in PSENet
├── __init__.py
├── adaptor.cpp
├── adaptor.h
├── Makefile
├── config.py // parameter configuration
├── dataset.py // creating dataset
└── network_define.py // PSENet architecture
├── test.py // test script
└── train.py // training script
```
## [Script Parameters](#contents)
```python
Major parameters in train.py and config.py are:
--pre_trained: Whether training from scratch or training based on the
pre-trained model.Optional values are True, False.
--device_id: Device ID used to train or evaluate the dataset. Ignore it
when you use train.sh for distributed training.
--device_num: devices used when you use train.sh for distributed training.
```
## [Training Process](#contents)
### Distributed Training
```
sh scripts/run_distribute_train.sh pretrained_model.ckpt
```
The above shell script will run distribute training in the background. You can view the results through the file
`device[X]/log`. The loss value will be achieved as follows:
```
# grep "epoch: " device_*/loss.log
device_0/log:epoch: 1, step: 20, loss is 0.80383
device_0/log:epcoh: 2, step: 40, loss is 0.77951
...
device_1/log:epoch: 1, step: 20, loss is 0.78026
device_1/log:epcoh: 2, step: 40, loss is 0.76629
```
## [Evaluation Process](#contents)
### Eval Script for ICDAR2015
#### Usage
+ step 1: download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization).
+ step 2: click "My Methods" button,then download Evaluation Scripts.
+ step 3: it is recommended to symlink the eval method root to $MINDSPORE/model_zoo/psenet/eval_ic15/. if your folder structure is different,you may need to change the corresponding paths in eval script files.
```
sh ./script/run_eval_ascend.sh.sh
```
#### Result
Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean": 0.8076736279747451, "AP": 0}
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | PSENet |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 CPU 2.60GHz56coresMemory314G |
| uploaded Date | 09/15/2020 (month/day/year) |
| MindSpore Version | 1.0-alpha |
| Dataset | ICDAR2015 |
| Training Parameters | start_lr=0.1; lr_scale=0.1 |
| Optimizer | SGD |
| Loss Function | LossCallBack |
| outputs | probability |
| Loss | 0.35 |
| Speed | 1pc: 444 ms/step; 4pcs: 446 ms/step |
| Total time | 1pc: 75.48 h; 4pcs: 18.87 h |
| Parameters (M) | 27.36 |
| Checkpoint for Fine tuning | 109.44M (.ckpt file) |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/psenet |
### Inference Performance
| Parameters | PSENet |
| ------------------- | --------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 |
| Uploaded Date | 09/15/2020 (month/day/year) |
| MindSpore Version | 1.0-alpha |
| Dataset | ICDAR2015 |
| outputs | probability |
| Accuracy | 1pc: 81%; 8pcs: 81% |
## [How to use](#contents)
### Inference
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
```
# Load unseen dataset for inference
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# Define model
config.INFERENCE = False
net = ETSNet(config)
net = net.set_train()
param_dict = load_checkpoint(args.pre_trained)
load_param_into_net(net, param_dict)
print('Load Pretrained parameters done!')
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
# warp model
net = WithLossCell(net, criterion)
net = TrainOneStepCell(net, opt)
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossCallBack(per_print_times=20)
# set and apply parameters of check point
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
model = Model(net)
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
# Load pre-trained model
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# Make predictions on the unseen dataset
acc = model.eval(dataset)
print("accuracy: ", acc)
```

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,77 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
current_exec_path=$(pwd)
echo 'current_exec_path: '${current_exec_path}
if [ $# != 1 ]
then
echo "Usage: sh run_distribute_train.sh [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ ! -f $PATH1 ]
then
echo "error: PRETRAINED_PATH=$PATH1 is not a file"
exit 1
fi
python ${current_exec_path}/src/generate_hccn_file.py
ulimit -u unlimited
export DEVICE_NUM=4
export RANK_SIZE=4
export RANK_TABLE_FILE=${current_exec_path}/rank_table_4p.json
for((i=0; i<${DEVICE_NUM}; i++))
do
if [ -d ${current_exec_path}/device_$i/ ]
then
if [ -d ${current_exec_path}/device_$i/checkpoints/ ]
then
rm ${current_exec_path}/device_$i/checkpoints/ -rf
fi
if [ -f ${current_exec_path}/device_$i/loss.log ]
then
rm ${current_exec_path}/device_$i/loss.log
fi
if [ -f ${current_exec_path}/device_$i/test_deep$i.log ]
then
rm ${current_exec_path}/device_$i/test_deep$i.log
fi
else
mkdir ${current_exec_path}/device_$i
fi
cd ${current_exec_path}/device_$i || exit
export RANK_ID=$i
export DEVICE_ID=$i
python ${current_exec_path}/train.py --run_distribute --device_id $i --pre_trained $PATH1 --device_num ${DEVICE_NUM} >test_deep$i.log 2>&1 &
cd ${current_exec_path} || exit
done

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
current_exec_path=$(pwd)
res_path=${current_exec_path}/res/submit_ic15/
eval_tool_path=${current_exec_path}/eval_ic15/
cd ${res_path} || exit
zip ${eval_tool_path}/submit.zip ./*
cd ${eval_tool_path} || exit
python ./script.py -s=submit.zip -g=gt.zip
cd ${current_exec_path} || exit

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same', has_bias=False):
init_value = TruncatedNormal(0.02)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=init_value, has_bias=has_bias)
def _bn(channels, momentum=0.1):
return nn.BatchNorm2d(channels, momentum=momentum)

View File

@ -0,0 +1,175 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.nn.loss.loss import _Loss
class DiceLoss(_Loss):
def __init__(self, batch_size=4):
super(DiceLoss, self).__init__()
self.threshold0 = Tensor(0.5, mstype.float32)
self.zero_float32 = Tensor(0.0, mstype.float32)
self.k = int(640 * 640)
self.negative_one_int32 = Tensor(-1, mstype.int32)
self.batch_size = batch_size
self.concat = P.Concat()
self.less_equal = P.LessEqual()
self.greater = P.Greater()
self.reduce_sum = P.ReduceSum()
self.reduce_sum_keep_dims = P.ReduceSum(keep_dims=True)
self.reduce_mean = P.ReduceMean()
self.reduce_min = P.ReduceMin()
self.cast = P.Cast()
self.minimum = P.Minimum()
self.expand_dims = P.ExpandDims()
self.select = P.Select()
self.fill = P.Fill()
self.topk = P.TopK(sorted=True)
self.shape = P.Shape()
self.sigmoid = P.Sigmoid()
self.reshape = P.Reshape()
self.slice = P.Slice()
self.logical_and = P.LogicalAnd()
self.logical_or = P.LogicalOr()
self.equal = P.Equal()
self.zeros_like = P.ZerosLike()
self.add = P.TensorAdd()
self.gather = P.GatherV2()
def ohem_batch(self, scores, gt_texts, training_masks):
'''
:param scores: [N * H * W]
:param gt_texts: [N * H * W]
:param training_masks: [N * H * W]
:return: [N * H * W]
'''
selected_masks = ()
for i in range(self.batch_size):
score = self.slice(scores, (i, 0, 0), (1, 640, 640))
score = self.reshape(score, (640, 640))
gt_text = self.slice(gt_texts, (i, 0, 0), (1, 640, 640))
gt_text = self.reshape(gt_text, (640, 640))
training_mask = self.slice(training_masks, (i, 0, 0), (1, 640, 640))
training_mask = self.reshape(training_mask, (640, 640))
selected_mask = self.ohem_single(score, gt_text, training_mask)
selected_masks = selected_masks + (selected_mask,)
selected_masks = self.concat(selected_masks)
return selected_masks
def ohem_single(self, score, gt_text, training_mask):
pos_num = self.logical_and(self.greater(gt_text, self.threshold0),
self.greater(training_mask, self.threshold0))
pos_num = self.reduce_sum(self.cast(pos_num, mstype.float32))
neg_num = self.less_equal(gt_text, self.threshold0)
neg_num = self.reduce_sum(self.cast(neg_num, mstype.float32))
neg_num = self.minimum(3 * pos_num, neg_num)
neg_num = self.cast(neg_num, mstype.int32)
neg_num = self.add(neg_num, self.negative_one_int32)
neg_mask = self.less_equal(gt_text, self.threshold0)
ignore_score = self.fill(mstype.float32, (640, 640), -1e3)
neg_score = self.select(neg_mask, score, ignore_score)
neg_score = self.reshape(neg_score, (640 * 640,))
topk_values, _ = self.topk(neg_score, self.k)
threshold = self.gather(topk_values, neg_num, 0)
selected_mask = self.logical_and(
self.logical_or(self.greater(score, threshold),
self.greater(gt_text, self.threshold0)),
self.greater(training_mask, self.threshold0))
selected_mask = self.cast(selected_mask, mstype.float32)
selected_mask = self.expand_dims(selected_mask, 0)
return selected_mask
def dice_loss(self, input_params, target, mask):
'''
:param input: [N, H, W]
:param target: [N, H, W]
:param mask: [N, H, W]
:return:
'''
input_sigmoid = self.sigmoid(input_params)
input_reshape = self.reshape(input_sigmoid, (self.batch_size, 640 * 640))
target = self.reshape(target, (self.batch_size, 640 * 640))
mask = self.reshape(mask, (self.batch_size, 640 * 640))
input_mask = input_reshape * mask
target = target * mask
a = self.reduce_sum(input_mask * target, 1)
b = self.reduce_sum(input_mask * input_mask, 1) + 0.001
c = self.reduce_sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
dice_loss = self.reduce_mean(d)
return 1 - dice_loss
def avg_losses(self, loss_list):
loss_kernel = loss_list[0]
for i in range(1, len(loss_list)):
loss_kernel += loss_list[i]
loss_kernel = loss_kernel / len(loss_list)
return loss_kernel
def construct(self, model_predict, gt_texts, gt_kernels, training_masks):
'''
:param model_predict: [N * 7 * H * W]
:param gt_texts: [N * H * W]
:param gt_kernels:[N * 6 * H * W]
:param training_masks:[N * H * W]
:return:
'''
texts = self.slice(model_predict, (0, 0, 0, 0), (self.batch_size, 1, 640, 640))
texts = self.reshape(texts, (self.batch_size, 640, 640))
selected_masks_text = self.ohem_batch(texts, gt_texts, training_masks)
loss_text = self.dice_loss(texts, gt_texts, selected_masks_text)
kernels = []
loss_kernels = []
for i in range(1, 7):
kernel = self.slice(model_predict, (0, i, 0, 0), (self.batch_size, 1, 640, 640))
kernel = self.reshape(kernel, (self.batch_size, 640, 640))
kernels.append(kernel)
mask0 = self.sigmoid(texts)
selected_masks_kernels = self.logical_and(self.greater(mask0, self.threshold0),
self.greater(training_masks, self.threshold0))
selected_masks_kernels = self.cast(selected_masks_kernels, mstype.float32)
for i in range(6):
gt_kernel = self.slice(gt_kernels, (0, i, 0, 0), (self.batch_size, 1, 640, 640))
gt_kernel = self.reshape(gt_kernel, (self.batch_size, 640, 640))
loss_kernel_i = self.dice_loss(kernels[i], gt_kernel, selected_masks_kernels)
loss_kernels.append(loss_kernel_i)
loss_kernel = self.avg_losses(loss_kernels)
loss = 0.7 * loss_text + 0.3 * loss_kernel
return loss

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor
import mindspore.common.dtype as mstype
from .base import _conv, _bn
from .resnet50 import ResNet, ResidualBlock
from .fpn import FPN
class ETSNet(nn.Cell):
def __init__(self, config):
super(ETSNet, self).__init__()
self.kernel_num = config.KERNEL_NUM
self.inference = config.INFERENCE
if config.INFERENCE:
self.long_size = config.INFER_LONG_SIZE
else:
self.long_size = config.TRAIN_LONG_SIZE
# backbone
self.feature_extractor = ResNet(ResidualBlock,
config.BACKBONE_LAYER_NUMS,
config.BACKBONE_IN_CHANNELS,
config.BACKBONE_OUT_CHANNELS)
# neck
self.feature_fusion = FPN(config.BACKBONE_OUT_CHANNELS,
config.NECK_OUT_CHANNEL,
self.long_size)
# head
self.conv1 = _conv(4 * config.NECK_OUT_CHANNEL,
config.NECK_OUT_CHANNEL,
kernel_size=3,
stride=1,
has_bias=True)
self.bn1 = _bn(config.NECK_OUT_CHANNEL)
self.relu1 = nn.ReLU()
self.conv2 = _conv(config.NECK_OUT_CHANNEL,
config.KERNEL_NUM,
kernel_size=1,
has_bias=True)
self._upsample = P.ResizeBilinear((self.long_size, self.long_size), align_corners=True)
if self.inference:
self.one_float32 = Tensor(1.0, mstype.float32)
self.sigmoid = P.Sigmoid()
self.greater = P.Greater()
self.logic_and = P.LogicalAnd()
print('ETSNet initialized!')
def construct(self, x):
c2, c3, c4, c5 = self.feature_extractor(x)
feature = self.feature_fusion(c2, c3, c4, c5)
output = self.conv1(feature)
output = self.relu1(self.bn1(output))
output = self.conv2(output)
output = self._upsample(output)
if self.inference:
text = output[::, 0:1:1, ::, ::]
kernels = output[::, 0:7:1, ::, ::]
score = self.sigmoid(text)
kernels = self.logic_and(self.greater(kernels, self.one_float32),
self.greater(text, self.one_float32))
return score, kernels
return output

View File

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
import mindspore.ops.operations as P
from .base import _conv, _bn
class FPN(nn.Cell):
def __init__(self, in_channels, out_channel, long_size):
super(FPN, self).__init__()
self.long_size = long_size
# reduce layers
self.reduce_conv_c2 = _conv(in_channels[0], out_channel, kernel_size=1, has_bias=True)
self.reduce_bn_c2 = _bn(out_channel)
self.reduce_relu_c2 = nn.ReLU()
self.reduce_conv_c3 = _conv(in_channels[1], out_channel, kernel_size=1, has_bias=True)
self.reduce_bn_c3 = _bn(out_channel)
self.reduce_relu_c3 = nn.ReLU()
self.reduce_conv_c4 = _conv(in_channels[2], out_channel, kernel_size=1, has_bias=True)
self.reduce_bn_c4 = _bn(out_channel)
self.reduce_relu_c4 = nn.ReLU()
self.reduce_conv_c5 = _conv(in_channels[3], out_channel, kernel_size=1, has_bias=True)
self.reduce_bn_c5 = _bn(out_channel)
self.reduce_relu_c5 = nn.ReLU()
# smooth layers
self.smooth_conv_p4 = _conv(out_channel, out_channel, kernel_size=3, has_bias=True)
self.smooth_bn_p4 = _bn(out_channel)
self.smooth_relu_p4 = nn.ReLU()
self.smooth_conv_p3 = _conv(out_channel, out_channel, kernel_size=3, has_bias=True)
self.smooth_bn_p3 = _bn(out_channel)
self.smooth_relu_p3 = nn.ReLU()
self.smooth_conv_p2 = _conv(out_channel, out_channel, kernel_size=3, has_bias=True)
self.smooth_bn_p2 = _bn(out_channel)
self.smooth_relu_p2 = nn.ReLU()
self._upsample_p4 = P.ResizeBilinear((long_size // 16, long_size // 16), align_corners=True)
self._upsample_p3 = P.ResizeBilinear((long_size // 8, long_size // 8), align_corners=True)
self._upsample_p2 = P.ResizeBilinear((long_size // 4, long_size // 4), align_corners=True)
self.concat = P.Concat(axis=1)
def construct(self, c2, c3, c4, c5):
p5 = self.reduce_conv_c5(c5)
p5 = self.reduce_relu_c5(self.reduce_bn_c5(p5))
c4 = self.reduce_conv_c4(c4)
c4 = self.reduce_relu_c4(self.reduce_bn_c4(c4))
p4 = self._upsample_p4(p5) + c4
p4 = self.smooth_conv_p4(p4)
p4 = self.smooth_relu_p4(self.smooth_bn_p4(p4))
c3 = self.reduce_conv_c3(c3)
c3 = self.reduce_relu_c3(self.reduce_bn_c3(c3))
p3 = self._upsample_p3(p4) + c3
p3 = self.smooth_conv_p3(p3)
p3 = self.smooth_relu_p3(self.smooth_bn_p3(p3))
c2 = self.reduce_conv_c2(c2)
c2 = self.reduce_relu_c2(self.reduce_bn_c2(c2))
p2 = self._upsample_p2(p3) + c2
p2 = self.smooth_conv_p2(p2)
p2 = self.smooth_relu_p2(self.smooth_bn_p2(p2))
p3 = self._upsample_p2(p3)
p4 = self._upsample_p2(p4)
p5 = self._upsample_p2(p5)
out = self.concat((p2, p3, p4, p5))
return out

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
CXXFLAGS = -I include -std=c++11 -O3
CXX_SOURCES = adaptor.cpp
OPENCV = `pkg-config --cflags --libs opencv`
LIB_SO = adaptor.so
$(LIB_SO): $(CXX_SOURCES) $(DEPS)
$(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV)
clean:
rm -rf $(LIB_SO)

View File

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import subprocess
import os
import numpy as np
from .adaptor import pse as cpse
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
def pse(polys, min_area):
ret = np.array(cpse(polys, min_area), dtype='int32')
return ret

View File

@ -0,0 +1,127 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ETSNET/pse/adaptor.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <iostream>
#include <queue>
#include <utility>
#include <vector>
#include <opencv2/opencv.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
using std::vector;
using cv::vector;
namespace py = pybind11;
namespace pse_adaptor {
void get_kernals(const int *data, vector<int> data_shape, const vector<Mat> &kernals) {
for (int i = 0; i < data_shape[0]; ++i) {
Mat kernal = Mat::zeros(data_shape[1], data_shape[2], CV_8UC1);
for (int x = 0; x < kernal.rows; ++x) {
for (int y = 0; y < kernal.cols; ++y) {
kernal.at<char>(x, y) = data[i * data_shape[1] * data_shape[2] + x * data_shape[2] + y];
}
}
kernals.emplace_back(kernal);
}
}
void growing_text_line(const vector, const vector<vector<>> &text_line, float min_area) {
Mat label_mat;
int label_num = connectedComponents(kernals[kernals.size() - 1], label_mat, 4);
vector<int> area(label_num + 1, 0)
memset(area, 0, sizeof(area));
for (int x = 0; x < label_mat.rows; ++x) {
for (int y = 0; y < label_mat.cols; ++y) {
int label = label_mat.at<int>(x, y);
if (label == 0) continue;
area[label] += 1;
}
}
queue<Point> queue, next_queue;
for (int x = 0; x < label_mat.rows; ++x) {
vector<int> row(label_mat.cols);
for (int y = 0; y < label_mat.cols; ++y) {
int label = label_mat.at<int>(x, y);
if (label == 0) continue;
if (area[label] < min_area) continue;
Point point(x, y);
queue.push(point);
row[y] = label;
}
text_line.emplace_back(row);
}
int dx[] = {-1, 1, 0, 0};
int dy[] = {0, 0, -1, 1};
for (int kernal_id = kernals.size() - 2; kernal_id >= 0; --kernal_id) {
while (!queue.empty()) {
Point point = queue.front();
queue.pop();
int x = point.x;
int y = point.y;
int label = text_line[x][y];
bool is_edge = true;
for (int d = 0; d < 4; ++d) {
int tmp_x = x + dx[d];
int tmp_y = y + dy[d];
if (tmp_x < 0 || tmp_x >= (static_cast)<int>text_line.size()) continue;
if (tmp_y < 0 || tmp_y >= (static_cast)<int>text_line[1].size()) continue;
if (kernals[kernal_id].at<char>(tmp_x, tmp_y) == 0) continue;
if (text_line[tmp_x][tmp_y] > 0) continue;
Point point_tmp(tmp_x, tmp_y);
queue.push(point_tmp);
text_line[tmp_x][tmp_y] = label;
is_edge = false;
}
if (is_edge) {
next_queue.push(point);
}
}
swap(queue, next_queue);
}
}
vector<vector<int>> pse(py::array_t<int, py::array::c_style | py::array::forcecast> quad_n9, float min_area) {
auto buf = quad_n9.request();
auto data = static_cast<int *>(buf.ptr);
vector<Mat> kernals;
get_kernals(data, buf.shape, kernals);
vector<vector<int>> text_line;
growing_text_line(kernals, text_line, min_area);
return text_line;
}
} // namespace pse_adaptor
PYBIND11_PLUGIN(adaptor) {
py::module m("adaptor", "pse");
m.def("pse", &pse_adaptor::pse, "pse");
return m.ptr();
}

View File

@ -0,0 +1,15 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

View File

@ -0,0 +1,190 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from .base import _conv, _bn
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channels: Integer. Input channel.
out_channels: Integer. Output channel.
stride: Integer. Stride size for the initial convolutional layer. Default:1.
momentum: Float. Momentum for batchnorm layer. Default:0.1.
Returns:
Tensor, output tensor.
Examples:
ResidualBlock(3,256,stride=2,down_sample=True)
"""
expansion = 4
def __init__(self,
in_channels,
out_channels,
stride=1,
momentum=0.1):
super(ResidualBlock, self).__init__()
out_chls = out_channels // self.expansion
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1)
self.bn1 = _bn(out_chls, momentum=momentum)
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1, pad_mode='pad')
self.bn2 = _bn(out_chls, momentum=momentum)
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1)
self.bn3 = _bn(out_channels, momentum=momentum)
self.relu = P.ReLU()
self.downsample = (in_channels != out_channels)
self.stride = stride
if self.downsample or self.stride != 1:
self.conv_down_sample = _conv(in_channels, out_channels,
kernel_size=1, stride=stride)
self.bn_down_sample = _bn(out_channels, momentum=momentum)
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample or self.stride != 1:
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet V1 network.
Args:
block: Cell. Block for network.
layer_nums: List. Numbers of different layers.
in_channels: Integer. Input channel.
out_channels: Integer. Output channel.
num_classes: Integer. Class number. Default:100.
Returns:
Tensor, output tensor.
Examples:
ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
100)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=1)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=2)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=2)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=2)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make Layer for ResNet.
Args:
block: Cell. Resnet block.
layer_num: Integer. Layer number.
in_channel: Integer. Input channel.
out_channel: Integer. Output channel.
stride:Integer. Stride size for the initial convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
_make_layer(BasicBlock, 3, 128, 256, 2)
"""
layers = []
resblk = block(in_channel, out_channel, stride=stride)
layers.append(resblk)
for _ in range(1, layer_num):
resblk = block(out_channel, out_channel, stride=1)
layers.append(resblk)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pad(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
return c2, c3, c4, c5

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,50 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from easydict import EasyDict as ed
config = ed({
'INFER_LONG_SIZE': 1920,
'KERNEL_NUM': 7,
'INFERENCE': True, # INFER MODE\TRAIN MODE
# backbone
'BACKBONE_LAYER_NUMS': [3, 4, 6, 3],
'BACKBONE_IN_CHANNELS': [64, 256, 512, 1024],
'BACKBONE_OUT_CHANNELS': [256, 512, 1024, 2048],
# neck
'NECK_OUT_CHANNEL': 256,
# dataset for train
"TRAIN_ROOT_DIR": '/autotest/lqk/modelzoo/psenet/ic15/',
"TRAIN_IS_TRANSFORM": True,
"TRAIN_LONG_SIZE": 640,
"TRAIN_DATASET_SIZE": 1000,
"TRAIN_MIN_SCALE": 0.4,
"TRAIN_BUFFER_SIZE": 8,
"TRAIN_BATCH_SIZE": 4,
"TRAIN_REPEAT_NUM": 608*4,
"TRAIN_DROP_REMAINDER": True,
"TRAIN_TOTAL_ITER": 152000,
"TRAIN_MODEL_SAVE_PATH": './checkpoints/',
# dataset for test
"TEST_ROOT_DIR": '/autotest/lqk/modelzoo/psenet/ic15/',
"TEST_DATASET_SIZE": 500,
"TEST_BUFFER_SIZE": 4,
"TEST_DROP_REMAINDER": False,
})

View File

@ -0,0 +1,314 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import random
import cv2
import pyclipper
import numpy as np
from PIL import Image
import Polygon as plg
import mindspore.dataset.engine as de
import mindspore.dataset.vision.py_transforms as py_transforms
from src.config import config
__all__ = ['train_dataset_creator', 'test_dataset_creator']
def get_img(img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def get_imgs_names(root_dir):
img_paths = [i for i in os.listdir(root_dir)
if os.path.splitext(i)[-1].lower() in ['.jpg', '.jpeg', '.png']]
return img_paths
def get_bboxes(img, gt_path):
h, w = img.shape[0:2]
with open(gt_path, 'r', encoding='utf-8-sig') as f:
lines = f.readlines()
bboxes = []
tags = []
for line in lines:
line = line.replace('\xef\xbb\xbf', '')
line = line.replace('\ufeff', '')
line = line.replace('\n', '')
gt = line.split(",", 8)
tag = gt[-1][0] != '#'
box = [int(gt[i]) for i in range(8)]
box = np.asarray(box) / ([w * 1.0, h * 1.0] * 4)
bboxes.append(box)
tags.append(tag)
return np.array(bboxes), tags
def random_scale(img, min_size):
h, w = img.shape[0:2]
if max(h, w) > 1280:
scale1 = 1280.0 / max(h, w)
img = cv2.resize(img, dsize=None, fx=scale1, fy=scale1)
h, w = img.shape[0:2]
random_scale1 = np.array([0.5, 1.0, 2.0, 3.0])
scale2 = np.random.choice(random_scale1)
if min(h, w) * scale2 <= min_size:
scale3 = (min_size + 10) * 1.0 / min(h, w)
img = cv2.resize(img, dsize=None, fx=scale3, fy=scale3)
else:
img = cv2.resize(img, dsize=None, fx=scale2, fy=scale2)
return img
def random_horizontal_flip(imgs):
if random.random() < 0.5:
for i, _ in enumerate(imgs):
imgs[i] = np.flip(imgs[i], axis=1).copy()
return imgs
def random_rotate(imgs):
max_angle = 10
angle = random.random() * 2 * max_angle - max_angle
for i, _ in enumerate(imgs):
img = imgs[i]
w, h = img.shape[:2]
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
imgs[i] = img_rotation
return imgs
def random_crop(imgs, img_size):
h, w = imgs[0].shape[0:2]
th, tw = img_size
if w == tw and h == th:
return imgs
if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
tl = np.min(np.where(imgs[1] > 0), axis=1) - img_size
tl[tl < 0] = 0
br = np.max(np.where(imgs[1] > 0), axis=1) - img_size
br[br < 0] = 0
br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw)
i = random.randint(tl[0], br[0])
j = random.randint(tl[1], br[1])
else:
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
for idx, _ in enumerate(imgs):
if len(imgs[idx].shape) == 3:
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
else:
imgs[idx] = imgs[idx][i:i + th, j:j + tw]
return imgs
def scale(img, long_size=2240):
h, w = img.shape[0:2]
scale_long = long_size * 1.0 / max(h, w)
img = cv2.resize(img, dsize=None, fx=scale_long, fy=scale_long)
return img
def dist(a, b):
return np.sqrt(np.sum((a - b) ** 2))
def perimeter(bbox):
peri = 0.0
for i in range(bbox.shape[0]):
peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
return peri
def shrink(bboxes, rate, max_shr=20):
rate = rate * rate
shrinked_bboxes = []
for bbox in bboxes:
area = plg.Polygon(bbox).area()
peri = perimeter(bbox)
pco = pyclipper.PyclipperOffset()
pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
shrinked_bbox = pco.Execute(-offset)
if not shrinked_bbox:
shrinked_bboxes.append(bbox)
continue
shrinked_bbox = np.array(shrinked_bbox)[0]
if shrinked_bbox.shape[0] <= 2:
shrinked_bboxes.append(bbox)
continue
shrinked_bboxes.append(shrinked_bbox)
return np.array(shrinked_bboxes)
class TrainDataset:
def __init__(self):
self.is_transform = config.TRAIN_IS_TRANSFORM
self.img_size = config.TRAIN_LONG_SIZE
self.kernel_num = config.KERNEL_NUM
self.min_scale = config.TRAIN_MIN_SCALE
root_dir = os.path.join(os.path.join(os.path.dirname(__file__), '..'), config.TRAIN_ROOT_DIR)
ic15_train_data_dir = root_dir + 'ch4_training_images/'
ic15_train_gt_dir = root_dir + 'ch4_training_localization_transcription_gt/'
self.img_size = self.img_size if \
(self.img_size is None or isinstance(self.img_size, tuple)) \
else (self.img_size, self.img_size)
data_dirs = [ic15_train_data_dir]
gt_dirs = [ic15_train_gt_dir]
self.all_img_paths = []
self.all_gt_paths = []
for data_dir, gt_dir in zip(data_dirs, gt_dirs):
img_names = [i for i in os.listdir(data_dir)
if os.path.splitext(i)[-1].lower()
in ['.jpg', '.jpeg', '.png']]
img_paths = []
gt_paths = []
for _, img_name in enumerate(img_names):
img_path = os.path.join(data_dir, img_name)
gt_name = 'gt_' + img_name.split('.')[0] + '.txt'
gt_path = os.path.join(gt_dir, gt_name)
img_paths.append(img_path)
gt_paths.append(gt_path)
self.all_img_paths.extend(img_paths)
self.all_gt_paths.extend(gt_paths)
def __getitem__(self, index):
img_path = self.all_img_paths[index]
gt_path = self.all_gt_paths[index]
# start0 = time.time()
img = get_img(img_path)
bboxes, tags = get_bboxes(img, gt_path)
# end0 = time.time()
# multi-scale training
if self.is_transform:
img = random_scale(img, min_size=self.img_size[0])
# get gt_text and training_mask
img_h, img_w = img.shape[0: 2]
gt_text = np.zeros((img_h, img_w), dtype=np.float32)
training_mask = np.ones((img_h, img_w), dtype=np.float32)
if bboxes.shape[0] > 0:
bboxes = np.reshape(bboxes * ([img_w, img_h] * 4), (bboxes.shape[0], -1, 2)).astype('int32')
for i in range(bboxes.shape[0]):
cv2.drawContours(gt_text, [bboxes[i]], 0, i + 1, -1)
if not tags[i]:
cv2.drawContours(training_mask, [bboxes[i]], 0, 0, -1)
# get gt_kernels
gt_kernels = []
for i in range(1, self.kernel_num):
rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i
gt_kernel = np.zeros(img.shape[0:2], dtype=np.float32)
kernel_bboxes = shrink(bboxes, rate)
for j in range(kernel_bboxes.shape[0]):
cv2.drawContours(gt_kernel, [kernel_bboxes[j]], 0, 1, -1)
gt_kernels.append(gt_kernel)
# data augmentation
if self.is_transform:
imgs = [img, gt_text, training_mask]
imgs.extend(gt_kernels)
imgs = random_horizontal_flip(imgs)
imgs = random_rotate(imgs)
imgs = random_crop(imgs, self.img_size)
img, gt_text, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:]
gt_text[gt_text > 0] = 1
gt_kernels = np.array(gt_kernels)
if self.is_transform:
img = Image.fromarray(img)
img = img.convert('RGB')
img = py_transforms.RandomColorAdjust(brightness=32.0 / 255, saturation=0.5)(img)
else:
img = Image.fromarray(img)
img = img.convert('RGB')
img = py_transforms.ToTensor()(img)
img = py_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
gt_text = gt_text.astype(np.float32)
gt_kernels = gt_kernels.astype(np.float32)
training_mask = training_mask.astype(np.float32)
return img, gt_text, gt_kernels, training_mask
def __len__(self):
return len(self.all_img_paths)
def IC15_TEST_Generator():
ic15_test_data_dir = config.TEST_ROOT_DIR + 'ch4_test_images/'
img_size = config.INFER_LONG_SIZE
img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size)
data_dirs = [ic15_test_data_dir]
all_img_paths = []
for data_dir in data_dirs:
img_names = [i for i in os.listdir(data_dir) if os.path.splitext(i)[-1].lower() in ['.jpg', '.jpeg', '.png']]
img_paths = []
for _, img_name in enumerate(img_names):
img_path = data_dir + img_name
img_paths.append(img_path)
all_img_paths.extend(img_paths)
dataset_length = len(all_img_paths)
for index in range(dataset_length):
img_path = all_img_paths[index]
img_name = np.array(os.path.split(img_path)[-1])
img = get_img(img_path)
long_size = max(img.shape[:2])
img_resized = np.zeros((long_size, long_size, 3), np.uint8)
img_resized[:img.shape[0], :img.shape[1], :] = img
img_resized = cv2.resize(img_resized, dsize=img_size)
img_resized = Image.fromarray(img_resized)
img_resized = img_resized.convert('RGB')
img_resized = py_transforms.ToTensor()(img_resized)
img_resized = py_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_resized)
yield img, img_resized, img_name
def train_dataset_creator():
cv2.setNumThreads(0)
dataset = TrainDataset()
ds = de.GeneratorDataset(dataset, ['img', 'gt_text', 'gt_kernels', 'training_mask'], num_parallel_workers=8)
#ds = ds.repeat(config.TRAIN_REPEAT_NUM)
ds = ds.batch(config.TRAIN_BATCH_SIZE, drop_remainder=config.TRAIN_DROP_REMAINDER)
ds = ds.shuffle(buffer_size=config.TRAIN_BUFFER_SIZE)
return ds
def test_dataset_creator():
ds = de.GeneratorDataset(IC15_TEST_Generator, ['img', 'img_resized', 'img_name'])
ds = ds.shuffle(config.TEST_BUFFER_SIZE)
ds = ds.batch(1, drop_remainder=config.TEST_DROP_REMAINDER)
return ds

View File

@ -0,0 +1,86 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import socket
RANK_TABLE_SAVE_PATH = './rank_table_4p.json'
def main():
nproc_per_node = 4
visible_devices = ['0', '1', '2', '3']
server_id = socket.gethostbyname(socket.gethostname())
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
device_ips = {}
for hccn_item in hccn_configs:
hccn_item = hccn_item.strip()
if hccn_item.startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {}
hccn_table['board_id'] = '0x002f' # A+K
# hccn_table['board_id'] = '0x0000' # A+X
hccn_table['chip_info'] = '910'
hccn_table['deploy_mode'] = 'lab'
hccn_table['group_count'] = '1'
hccn_table['group_list'] = []
instance_list = []
for instance_id in range(nproc_per_node):
instance = {}
instance['devices'] = []
device_id = visible_devices[instance_id]
device_ip = device_ips[device_id]
instance['devices'].append({
'device_id': device_id,
'device_ip': device_ip,
})
instance['rank_id'] = str(instance_id)
instance['server_id'] = server_id
instance_list.append(instance)
hccn_table['group_list'].append({
'device_num': str(nproc_per_node),
'server_num': '1',
'group_name': '',
'instance_count': str(nproc_per_node),
'instance_list': instance_list,
})
hccn_table['para_plane_nic_location'] = 'device'
hccn_table['para_plane_nic_name'] = []
for instance_id in range(nproc_per_node):
eth_id = visible_devices[instance_id]
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
hccn_table['para_plane_nic_num'] = str(nproc_per_node)
hccn_table['status'] = 'completed'
import json
with open(RANK_TABLE_SAVE_PATH, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
if __name__ == '__main__':
if os.path.exists(RANK_TABLE_SAVE_PATH):
print('Rank table file exists.')
else:
print('Generating rank table file.')
main()
print('Rank table file generated')

View File

@ -0,0 +1,146 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import mindspore.nn as nn
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.common.tensor import Tensor
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
import numpy as np
__all__ = ['LossCallBack', 'WithLossCell', 'TrainOneStepCell']
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.loss_avg = AverageMeter()
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
if cur_step_in_epoch == 1:
self.loss_avg = AverageMeter()
self.loss_avg.update(loss)
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
loss_log = "time: %s, epoch: %s, step: %s, loss is %s" % (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())),
cb_params.cur_epoch_num,
cur_step_in_epoch,
self.loss_avg.avg)
print(loss_log)
loss_file = open("./loss.log", "a+")
loss_file.write(loss_log)
loss_file.write("\n")
loss_file.close()
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, img, gt_text, gt_kernels, training_mask):
model_predict = self._backbone(img)
return self._loss_fn(model_predict, gt_text, gt_kernels, training_mask)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone
class TrainOneStepCell(nn.Cell):
"""
Network training package class.
Append an optimizer to the training network after that the construct function
can be called to create the backward graph.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
"""
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
# self.backbone = network._backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones(1, dtype=np.float32)) * sens)
self.reducer_flag = reduce_flag
if self.reducer_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, img, gt_text, gt_kernels, training_mask):
weights = self.weights
loss = self.network(img, gt_text, gt_kernels, training_mask)
grads = self.grad(self.network, weights)(img, gt_text, gt_kernels, training_mask, self.sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,156 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import math
import operator
from functools import reduce
import argparse
import time
import numpy as np
import cv2
from mindspore import Tensor, context
import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config
from src.dataset import test_dataset_creator
from src.ETSNET.etsnet import ETSNet
from src.ETSNET.pse import pse
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument("--ckpt", type=str, default=0, help='trained model path.')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
save_graphs_path=".")
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def sort_to_clockwise(points):
center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2))
clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees(
math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)
return clockwise_points
def write_result_as_txt(img_name, bboxes, path):
if not os.path.isdir(path):
os.makedirs(path)
filename = os.path.join(path, 'res_{}.txt'.format(os.path.splitext(img_name)[0]))
lines = []
for _, bbox in enumerate(bboxes):
bbox = bbox.reshape(-1, 2)
bbox = np.array(list(sort_to_clockwise(bbox)))[[3, 0, 1, 2]].copy().reshape(-1)
values = [int(v) for v in bbox]
line = "%d,%d,%d,%d,%d,%d,%d,%d\n" % tuple(values)
lines.append(line)
with open(filename, 'w') as f:
for line in lines:
f.write(line)
def test():
if not os.path.isdir('./res/submit_ic15/'):
os.makedirs('./res/submit_ic15/')
if not os.path.isdir('./res/vis_ic15/'):
os.makedirs('./res/vis_ic15/')
ds = test_dataset_creator()
config.INFERENCE = True
net = ETSNet(config)
print(args.ckpt)
param_dict = load_checkpoint(args.ckpt)
load_param_into_net(net, param_dict)
print('parameters loaded!')
get_data_time = AverageMeter()
model_run_time = AverageMeter()
post_process_time = AverageMeter()
end_pts = time.time()
iters = ds.create_tuple_iterator(output_numpy=True)
count = 0
for data in iters:
count += 1
# get data
img, img_resized, img_name = data
img = img[0].astype(np.uint8).copy()
img_name = img_name[0].decode('utf-8')
get_data_pts = time.time()
get_data_time.update(get_data_pts - end_pts)
# model run
img_tensor = Tensor(img_resized, mstype.float32)
score, kernels = net(img_tensor)
score = np.squeeze(score.asnumpy())
kernels = np.squeeze(kernels.asnumpy())
model_run_pts = time.time()
model_run_time.update(model_run_pts - get_data_pts)
# post-process
pred = pse(kernels, 5.0)
scale = max(img.shape[:2]) * 1.0 / config.INFER_LONG_SIZE
label = pred
label_num = np.max(label) + 1
bboxes = []
for i in range(1, label_num):
points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]
if points.shape[0] < 600:
continue
score_i = np.mean(score[label == i])
if score_i < 0.93:
continue
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect) * scale
bbox = bbox.astype('int32')
cv2.drawContours(img, [bbox], 0, (0, 255, 0), 3)
bboxes.append(bbox)
post_process_pts = time.time()
post_process_time.update(post_process_pts - model_run_pts)
if count == 1:
get_data_time.reset()
model_run_time.reset()
post_process_time.reset()
end_pts = time.time()
# save res
cv2.imwrite('./res/vis_ic15/{}'.format(img_name), img[:, :, [2, 1, 0]].copy())
write_result_as_txt(img_name, bboxes, './res/submit_ic15/')
if __name__ == "__main__":
test()

View File

@ -0,0 +1,91 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.communication.management import init, get_rank
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.dataset import train_dataset_creator
from src.config import config
from src.ETSNET.etsnet import ETSNet
from src.ETSNET.dice_loss import DiceLoss
from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--run_distribute', default=False, action='store_true',
help='Run distribute, default is false.')
parser.add_argument('--pre_trained', type=str, default='', help='Pretrain file path.')
parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.')
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1.')
args = parser.parse_args()
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
def lr_generator(start_lr, lr_scale, total_iters):
lrs = [start_lr * (lr_scale ** math.floor(cur_iter * 1.0 / (total_iters / 3))) for cur_iter in range(total_iters)]
return lrs
def train():
rank_id = 0
if args.run_distribute:
context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, parameter_broadcast=True)
init()
rank_id = get_rank()
# dataset/network/criterion/optim
ds = train_dataset_creator()
step_size = ds.get_dataset_size()
print('Create dataset done!')
config.INFERENCE = False
net = ETSNet(config)
net = net.set_train()
param_dict = load_checkpoint(args.pre_trained)
load_param_into_net(net, param_dict)
print('Load Pretrained parameters done!')
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
# warp model
net = WithLossCell(net, criterion)
if args.run_distribute:
net = TrainOneStepCell(net, opt, reduce_flag=True, mean=True, degree=args.device_num)
else:
net = TrainOneStepCell(net, opt)
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossCallBack(per_print_times=10)
# set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf,
directory="./ckpt_{}".format(rank_id))
model = Model(net)
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=True, callbacks=[time_cb, loss_cb, ckpoint_cb])
if __name__ == '__main__':
train()