forked from mindspore-Ecosystem/mindspore
commit
d26df7cdcb
|
@ -0,0 +1,236 @@
|
|||
# Contents
|
||||
|
||||
- [PSENet Description](#PSENet-description)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [How to use](#how-to-use)
|
||||
- [Inference](#inference)
|
||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||
- [Transfer Learning](#transfer-learning)
|
||||
|
||||
|
||||
# [PSENet Description](#contents)
|
||||
With the development of convolutional neural network, scene text detection technology has been developed rapidly. However, there are still two problems in this algorithm, which hinders its application in industry. On the one hand, most of the existing algorithms require quadrilateral bounding boxes to accurately locate arbitrary shape text. On the other hand, two adjacent instances of text can cause error detection overwriting both instances. Traditionally, a segmentation-based approach can solve the first problem, but usually not the second. To solve these two problems, a new PSENet (PSENet) is proposed, which can accurately detect arbitrary shape text instances. More specifically, PSENet generates different scale kernels for each text instance and gradually expands the minimum scale kernel to a text instance with full shape. Because of the large geometric margins between the minimum scale kernels, our method can effectively segment closed text instances, making it easier to detect arbitrary shape text instances. The effectiveness of PSENet has been verified by numerous experiments on CTW1500, full text, ICDAR 2015, and ICDAR 2017 MLT.
|
||||
|
||||
[Paper](https://openaccess.thecvf.com/content_CVPR_2019/html/Wang_Shape_Robust_Text_Detection_With_Progressive_Scale_Expansion_Network_CVPR_2019_paper.html): Wenhai Wang, Enze Xie, Xiang Li, Wenbo Hou, Tong Lu, Gang Yu, Shuai Shao; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 9336-9345
|
||||
|
||||
|
||||
# PSENet Example
|
||||
## Description
|
||||
Progressive Scale Expansion Network (PSENet) is a text detector which is able to well detect the arbitrary-shape text in natural scene.
|
||||
|
||||
# [Dataset](#contents)
|
||||
Dataset used: [ICDAR2015](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)
|
||||
A training set of 1000 images containing about 4500 readable words
|
||||
A testing set containing about 2000 readable words
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
- install Mindspore
|
||||
- install [pyblind11](https://github.com/pybind/pybind11)
|
||||
- install [Opencv3.4](https://docs.opencv.org/3.4.9/d7/d9f/tutorial_linux_install.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
```python
|
||||
# run distributed training example
|
||||
sh scripts/run_distribute_train.sh pretrained_model.ckpt
|
||||
|
||||
#setup opencv library
|
||||
download pyblind11, opencv3.4,setup opencv3.4
|
||||
|
||||
#make so file
|
||||
run src/ETSNET/pse/Makefile; make libadaptor.so
|
||||
|
||||
#run test.py
|
||||
python test.py --ckpt=pretrained_model.ckpt
|
||||
|
||||
#download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization).
|
||||
#click "My Methods" button,then download Evaluation Scripts
|
||||
download script.py
|
||||
# run evaluation example
|
||||
sh scripts/run_eval_ascend.sh
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
```
|
||||
└── PSENet
|
||||
├── README.md // descriptions about PSENet
|
||||
├── scripts
|
||||
├── run_distribute_train.sh // shell script for distributed
|
||||
└── eval_ic15.sh // shell script for evaluation
|
||||
├── src
|
||||
├── __init__.py
|
||||
├── generate_hccn_file.py // creating rank.json
|
||||
├── ETSNET
|
||||
├── __init__.py
|
||||
├── base.py // convolution and BN operator
|
||||
├── dice_loss.py // calculate PSENet loss value
|
||||
├── etsnet.py // Subnet in PSENet
|
||||
├── fpn.py // Subnet in PSENet
|
||||
├── resnet50.py // Subnet in PSENet
|
||||
├── pse // Subnet in PSENet
|
||||
├── __init__.py
|
||||
├── adaptor.cpp
|
||||
├── adaptor.h
|
||||
├── Makefile
|
||||
├── config.py // parameter configuration
|
||||
├── dataset.py // creating dataset
|
||||
└── network_define.py // PSENet architecture
|
||||
├── test.py // test script
|
||||
└── train.py // training script
|
||||
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```python
|
||||
Major parameters in train.py and config.py are:
|
||||
|
||||
--pre_trained: Whether training from scratch or training based on the
|
||||
pre-trained model.Optional values are True, False.
|
||||
--device_id: Device ID used to train or evaluate the dataset. Ignore it
|
||||
when you use train.sh for distributed training.
|
||||
--device_num: devices used when you use train.sh for distributed training.
|
||||
|
||||
```
|
||||
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Distributed Training
|
||||
```
|
||||
sh scripts/run_distribute_train.sh pretrained_model.ckpt
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file
|
||||
`device[X]/log`. The loss value will be achieved as follows:
|
||||
|
||||
```
|
||||
# grep "epoch: " device_*/loss.log
|
||||
device_0/log:epoch: 1, step: 20, loss is 0.80383
|
||||
device_0/log:epcoh: 2, step: 40, loss is 0.77951
|
||||
...
|
||||
device_1/log:epoch: 1, step: 20, loss is 0.78026
|
||||
device_1/log:epcoh: 2, step: 40, loss is 0.76629
|
||||
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Eval Script for ICDAR2015
|
||||
#### Usage
|
||||
+ step 1: download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization).
|
||||
+ step 2: click "My Methods" button,then download Evaluation Scripts.
|
||||
+ step 3: it is recommended to symlink the eval method root to $MINDSPORE/model_zoo/psenet/eval_ic15/. if your folder structure is different,you may need to change the corresponding paths in eval script files.
|
||||
```
|
||||
sh ./script/run_eval_ascend.sh.sh
|
||||
```
|
||||
#### Result
|
||||
Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean": 0.8076736279747451, "AP": 0}
|
||||
|
||||
|
||||
# [Model Description](#contents)
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | PSENet |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | Inception V1 |
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
|
||||
| uploaded Date | 09/15/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0-alpha |
|
||||
| Dataset | ICDAR2015 |
|
||||
| Training Parameters | start_lr=0.1; lr_scale=0.1 |
|
||||
| Optimizer | SGD |
|
||||
| Loss Function | LossCallBack |
|
||||
| outputs | probability |
|
||||
| Loss | 0.35 |
|
||||
| Speed | 1pc: 444 ms/step; 4pcs: 446 ms/step |
|
||||
| Total time | 1pc: 75.48 h; 4pcs: 18.87 h |
|
||||
| Parameters (M) | 27.36 |
|
||||
| Checkpoint for Fine tuning | 109.44M (.ckpt file) |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/psenet |
|
||||
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | PSENet |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | Inception V1 |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/15/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0-alpha |
|
||||
| Dataset | ICDAR2015 |
|
||||
| outputs | probability |
|
||||
| Accuracy | 1pc: 81%; 8pcs: 81% |
|
||||
|
||||
## [How to use](#contents)
|
||||
|
||||
### Inference
|
||||
|
||||
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
|
||||
|
||||
```
|
||||
# Load unseen dataset for inference
|
||||
dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
||||
|
||||
# Define model
|
||||
config.INFERENCE = False
|
||||
net = ETSNet(config)
|
||||
net = net.set_train()
|
||||
param_dict = load_checkpoint(args.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('Load Pretrained parameters done!')
|
||||
|
||||
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
|
||||
|
||||
lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
|
||||
|
||||
# warp model
|
||||
net = WithLossCell(net, criterion)
|
||||
net = TrainOneStepCell(net, opt)
|
||||
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossCallBack(per_print_times=20)
|
||||
# set and apply parameters of check point
|
||||
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
||||
|
||||
model = Model(net)
|
||||
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
||||
|
||||
# Load pre-trained model
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# Make predictions on the unseen dataset
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
```
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,77 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo 'current_exec_path: '${current_exec_path}
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [PRETRAINED_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python ${current_exec_path}/src/generate_hccn_file.py
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=4
|
||||
export RANK_SIZE=4
|
||||
export RANK_TABLE_FILE=${current_exec_path}/rank_table_4p.json
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
if [ -d ${current_exec_path}/device_$i/ ]
|
||||
then
|
||||
if [ -d ${current_exec_path}/device_$i/checkpoints/ ]
|
||||
then
|
||||
rm ${current_exec_path}/device_$i/checkpoints/ -rf
|
||||
fi
|
||||
|
||||
if [ -f ${current_exec_path}/device_$i/loss.log ]
|
||||
then
|
||||
rm ${current_exec_path}/device_$i/loss.log
|
||||
fi
|
||||
|
||||
if [ -f ${current_exec_path}/device_$i/test_deep$i.log ]
|
||||
then
|
||||
rm ${current_exec_path}/device_$i/test_deep$i.log
|
||||
fi
|
||||
else
|
||||
mkdir ${current_exec_path}/device_$i
|
||||
fi
|
||||
|
||||
cd ${current_exec_path}/device_$i || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python ${current_exec_path}/train.py --run_distribute --device_id $i --pre_trained $PATH1 --device_num ${DEVICE_NUM} >test_deep$i.log 2>&1 &
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
res_path=${current_exec_path}/res/submit_ic15/
|
||||
eval_tool_path=${current_exec_path}/eval_ic15/
|
||||
|
||||
cd ${res_path} || exit
|
||||
zip ${eval_tool_path}/submit.zip ./*
|
||||
cd ${eval_tool_path} || exit
|
||||
python ./script.py -s=submit.zip -g=gt.zip
|
||||
cd ${current_exec_path} || exit
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same', has_bias=False):
|
||||
init_value = TruncatedNormal(0.02)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=init_value, has_bias=has_bias)
|
||||
|
||||
def _bn(channels, momentum=0.1):
|
||||
return nn.BatchNorm2d(channels, momentum=momentum)
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
|
||||
class DiceLoss(_Loss):
|
||||
def __init__(self, batch_size=4):
|
||||
super(DiceLoss, self).__init__()
|
||||
|
||||
self.threshold0 = Tensor(0.5, mstype.float32)
|
||||
self.zero_float32 = Tensor(0.0, mstype.float32)
|
||||
self.k = int(640 * 640)
|
||||
self.negative_one_int32 = Tensor(-1, mstype.int32)
|
||||
self.batch_size = batch_size
|
||||
self.concat = P.Concat()
|
||||
self.less_equal = P.LessEqual()
|
||||
self.greater = P.Greater()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_sum_keep_dims = P.ReduceSum(keep_dims=True)
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reduce_min = P.ReduceMin()
|
||||
self.cast = P.Cast()
|
||||
self.minimum = P.Minimum()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.select = P.Select()
|
||||
self.fill = P.Fill()
|
||||
self.topk = P.TopK(sorted=True)
|
||||
self.shape = P.Shape()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.reshape = P.Reshape()
|
||||
self.slice = P.Slice()
|
||||
self.logical_and = P.LogicalAnd()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.equal = P.Equal()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
self.add = P.TensorAdd()
|
||||
self.gather = P.GatherV2()
|
||||
|
||||
def ohem_batch(self, scores, gt_texts, training_masks):
|
||||
'''
|
||||
|
||||
:param scores: [N * H * W]
|
||||
:param gt_texts: [N * H * W]
|
||||
:param training_masks: [N * H * W]
|
||||
:return: [N * H * W]
|
||||
'''
|
||||
selected_masks = ()
|
||||
for i in range(self.batch_size):
|
||||
score = self.slice(scores, (i, 0, 0), (1, 640, 640))
|
||||
score = self.reshape(score, (640, 640))
|
||||
|
||||
gt_text = self.slice(gt_texts, (i, 0, 0), (1, 640, 640))
|
||||
gt_text = self.reshape(gt_text, (640, 640))
|
||||
|
||||
training_mask = self.slice(training_masks, (i, 0, 0), (1, 640, 640))
|
||||
training_mask = self.reshape(training_mask, (640, 640))
|
||||
|
||||
selected_mask = self.ohem_single(score, gt_text, training_mask)
|
||||
selected_masks = selected_masks + (selected_mask,)
|
||||
|
||||
selected_masks = self.concat(selected_masks)
|
||||
return selected_masks
|
||||
|
||||
def ohem_single(self, score, gt_text, training_mask):
|
||||
pos_num = self.logical_and(self.greater(gt_text, self.threshold0),
|
||||
self.greater(training_mask, self.threshold0))
|
||||
pos_num = self.reduce_sum(self.cast(pos_num, mstype.float32))
|
||||
|
||||
neg_num = self.less_equal(gt_text, self.threshold0)
|
||||
neg_num = self.reduce_sum(self.cast(neg_num, mstype.float32))
|
||||
neg_num = self.minimum(3 * pos_num, neg_num)
|
||||
neg_num = self.cast(neg_num, mstype.int32)
|
||||
|
||||
neg_num = self.add(neg_num, self.negative_one_int32)
|
||||
neg_mask = self.less_equal(gt_text, self.threshold0)
|
||||
ignore_score = self.fill(mstype.float32, (640, 640), -1e3)
|
||||
neg_score = self.select(neg_mask, score, ignore_score)
|
||||
neg_score = self.reshape(neg_score, (640 * 640,))
|
||||
|
||||
topk_values, _ = self.topk(neg_score, self.k)
|
||||
threshold = self.gather(topk_values, neg_num, 0)
|
||||
|
||||
selected_mask = self.logical_and(
|
||||
self.logical_or(self.greater(score, threshold),
|
||||
self.greater(gt_text, self.threshold0)),
|
||||
self.greater(training_mask, self.threshold0))
|
||||
|
||||
selected_mask = self.cast(selected_mask, mstype.float32)
|
||||
selected_mask = self.expand_dims(selected_mask, 0)
|
||||
|
||||
return selected_mask
|
||||
|
||||
def dice_loss(self, input_params, target, mask):
|
||||
'''
|
||||
|
||||
:param input: [N, H, W]
|
||||
:param target: [N, H, W]
|
||||
:param mask: [N, H, W]
|
||||
:return:
|
||||
'''
|
||||
|
||||
input_sigmoid = self.sigmoid(input_params)
|
||||
|
||||
input_reshape = self.reshape(input_sigmoid, (self.batch_size, 640 * 640))
|
||||
target = self.reshape(target, (self.batch_size, 640 * 640))
|
||||
mask = self.reshape(mask, (self.batch_size, 640 * 640))
|
||||
|
||||
input_mask = input_reshape * mask
|
||||
target = target * mask
|
||||
|
||||
a = self.reduce_sum(input_mask * target, 1)
|
||||
b = self.reduce_sum(input_mask * input_mask, 1) + 0.001
|
||||
c = self.reduce_sum(target * target, 1) + 0.001
|
||||
d = (2 * a) / (b + c)
|
||||
dice_loss = self.reduce_mean(d)
|
||||
return 1 - dice_loss
|
||||
|
||||
def avg_losses(self, loss_list):
|
||||
loss_kernel = loss_list[0]
|
||||
for i in range(1, len(loss_list)):
|
||||
loss_kernel += loss_list[i]
|
||||
loss_kernel = loss_kernel / len(loss_list)
|
||||
return loss_kernel
|
||||
|
||||
def construct(self, model_predict, gt_texts, gt_kernels, training_masks):
|
||||
'''
|
||||
|
||||
:param model_predict: [N * 7 * H * W]
|
||||
:param gt_texts: [N * H * W]
|
||||
:param gt_kernels:[N * 6 * H * W]
|
||||
:param training_masks:[N * H * W]
|
||||
:return:
|
||||
'''
|
||||
texts = self.slice(model_predict, (0, 0, 0, 0), (self.batch_size, 1, 640, 640))
|
||||
texts = self.reshape(texts, (self.batch_size, 640, 640))
|
||||
selected_masks_text = self.ohem_batch(texts, gt_texts, training_masks)
|
||||
loss_text = self.dice_loss(texts, gt_texts, selected_masks_text)
|
||||
|
||||
kernels = []
|
||||
loss_kernels = []
|
||||
for i in range(1, 7):
|
||||
kernel = self.slice(model_predict, (0, i, 0, 0), (self.batch_size, 1, 640, 640))
|
||||
kernel = self.reshape(kernel, (self.batch_size, 640, 640))
|
||||
kernels.append(kernel)
|
||||
|
||||
mask0 = self.sigmoid(texts)
|
||||
selected_masks_kernels = self.logical_and(self.greater(mask0, self.threshold0),
|
||||
self.greater(training_masks, self.threshold0))
|
||||
selected_masks_kernels = self.cast(selected_masks_kernels, mstype.float32)
|
||||
|
||||
for i in range(6):
|
||||
gt_kernel = self.slice(gt_kernels, (0, i, 0, 0), (self.batch_size, 1, 640, 640))
|
||||
gt_kernel = self.reshape(gt_kernel, (self.batch_size, 640, 640))
|
||||
loss_kernel_i = self.dice_loss(kernels[i], gt_kernel, selected_masks_kernels)
|
||||
loss_kernels.append(loss_kernel_i)
|
||||
loss_kernel = self.avg_losses(loss_kernels)
|
||||
|
||||
loss = 0.7 * loss_text + 0.3 * loss_kernel
|
||||
return loss
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from .base import _conv, _bn
|
||||
from .resnet50 import ResNet, ResidualBlock
|
||||
from .fpn import FPN
|
||||
|
||||
|
||||
class ETSNet(nn.Cell):
|
||||
def __init__(self, config):
|
||||
super(ETSNet, self).__init__()
|
||||
self.kernel_num = config.KERNEL_NUM
|
||||
self.inference = config.INFERENCE
|
||||
if config.INFERENCE:
|
||||
self.long_size = config.INFER_LONG_SIZE
|
||||
else:
|
||||
self.long_size = config.TRAIN_LONG_SIZE
|
||||
|
||||
# backbone
|
||||
self.feature_extractor = ResNet(ResidualBlock,
|
||||
config.BACKBONE_LAYER_NUMS,
|
||||
config.BACKBONE_IN_CHANNELS,
|
||||
config.BACKBONE_OUT_CHANNELS)
|
||||
|
||||
# neck
|
||||
self.feature_fusion = FPN(config.BACKBONE_OUT_CHANNELS,
|
||||
config.NECK_OUT_CHANNEL,
|
||||
self.long_size)
|
||||
|
||||
# head
|
||||
self.conv1 = _conv(4 * config.NECK_OUT_CHANNEL,
|
||||
config.NECK_OUT_CHANNEL,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
has_bias=True)
|
||||
self.bn1 = _bn(config.NECK_OUT_CHANNEL)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv2 = _conv(config.NECK_OUT_CHANNEL,
|
||||
config.KERNEL_NUM,
|
||||
kernel_size=1,
|
||||
has_bias=True)
|
||||
self._upsample = P.ResizeBilinear((self.long_size, self.long_size), align_corners=True)
|
||||
|
||||
if self.inference:
|
||||
self.one_float32 = Tensor(1.0, mstype.float32)
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.greater = P.Greater()
|
||||
self.logic_and = P.LogicalAnd()
|
||||
|
||||
print('ETSNet initialized!')
|
||||
|
||||
def construct(self, x):
|
||||
c2, c3, c4, c5 = self.feature_extractor(x)
|
||||
|
||||
feature = self.feature_fusion(c2, c3, c4, c5)
|
||||
|
||||
output = self.conv1(feature)
|
||||
output = self.relu1(self.bn1(output))
|
||||
output = self.conv2(output)
|
||||
output = self._upsample(output)
|
||||
|
||||
if self.inference:
|
||||
text = output[::, 0:1:1, ::, ::]
|
||||
kernels = output[::, 0:7:1, ::, ::]
|
||||
score = self.sigmoid(text)
|
||||
kernels = self.logic_and(self.greater(kernels, self.one_float32),
|
||||
self.greater(text, self.one_float32))
|
||||
return score, kernels
|
||||
return output
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
from .base import _conv, _bn
|
||||
|
||||
class FPN(nn.Cell):
|
||||
def __init__(self, in_channels, out_channel, long_size):
|
||||
super(FPN, self).__init__()
|
||||
|
||||
self.long_size = long_size
|
||||
|
||||
# reduce layers
|
||||
self.reduce_conv_c2 = _conv(in_channels[0], out_channel, kernel_size=1, has_bias=True)
|
||||
self.reduce_bn_c2 = _bn(out_channel)
|
||||
self.reduce_relu_c2 = nn.ReLU()
|
||||
|
||||
self.reduce_conv_c3 = _conv(in_channels[1], out_channel, kernel_size=1, has_bias=True)
|
||||
self.reduce_bn_c3 = _bn(out_channel)
|
||||
self.reduce_relu_c3 = nn.ReLU()
|
||||
|
||||
self.reduce_conv_c4 = _conv(in_channels[2], out_channel, kernel_size=1, has_bias=True)
|
||||
self.reduce_bn_c4 = _bn(out_channel)
|
||||
self.reduce_relu_c4 = nn.ReLU()
|
||||
|
||||
self.reduce_conv_c5 = _conv(in_channels[3], out_channel, kernel_size=1, has_bias=True)
|
||||
self.reduce_bn_c5 = _bn(out_channel)
|
||||
self.reduce_relu_c5 = nn.ReLU()
|
||||
|
||||
# smooth layers
|
||||
self.smooth_conv_p4 = _conv(out_channel, out_channel, kernel_size=3, has_bias=True)
|
||||
self.smooth_bn_p4 = _bn(out_channel)
|
||||
self.smooth_relu_p4 = nn.ReLU()
|
||||
|
||||
self.smooth_conv_p3 = _conv(out_channel, out_channel, kernel_size=3, has_bias=True)
|
||||
self.smooth_bn_p3 = _bn(out_channel)
|
||||
self.smooth_relu_p3 = nn.ReLU()
|
||||
|
||||
self.smooth_conv_p2 = _conv(out_channel, out_channel, kernel_size=3, has_bias=True)
|
||||
self.smooth_bn_p2 = _bn(out_channel)
|
||||
self.smooth_relu_p2 = nn.ReLU()
|
||||
|
||||
self._upsample_p4 = P.ResizeBilinear((long_size // 16, long_size // 16), align_corners=True)
|
||||
self._upsample_p3 = P.ResizeBilinear((long_size // 8, long_size // 8), align_corners=True)
|
||||
self._upsample_p2 = P.ResizeBilinear((long_size // 4, long_size // 4), align_corners=True)
|
||||
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, c2, c3, c4, c5):
|
||||
p5 = self.reduce_conv_c5(c5)
|
||||
p5 = self.reduce_relu_c5(self.reduce_bn_c5(p5))
|
||||
|
||||
c4 = self.reduce_conv_c4(c4)
|
||||
c4 = self.reduce_relu_c4(self.reduce_bn_c4(c4))
|
||||
p4 = self._upsample_p4(p5) + c4
|
||||
p4 = self.smooth_conv_p4(p4)
|
||||
p4 = self.smooth_relu_p4(self.smooth_bn_p4(p4))
|
||||
|
||||
c3 = self.reduce_conv_c3(c3)
|
||||
c3 = self.reduce_relu_c3(self.reduce_bn_c3(c3))
|
||||
p3 = self._upsample_p3(p4) + c3
|
||||
p3 = self.smooth_conv_p3(p3)
|
||||
p3 = self.smooth_relu_p3(self.smooth_bn_p3(p3))
|
||||
|
||||
c2 = self.reduce_conv_c2(c2)
|
||||
c2 = self.reduce_relu_c2(self.reduce_bn_c2(c2))
|
||||
p2 = self._upsample_p2(p3) + c2
|
||||
p2 = self.smooth_conv_p2(p2)
|
||||
p2 = self.smooth_relu_p2(self.smooth_bn_p2(p2))
|
||||
|
||||
p3 = self._upsample_p2(p3)
|
||||
p4 = self._upsample_p2(p4)
|
||||
p5 = self._upsample_p2(p5)
|
||||
|
||||
out = self.concat((p2, p3, p4, p5))
|
||||
|
||||
return out
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
CXXFLAGS = -I include -std=c++11 -O3
|
||||
CXX_SOURCES = adaptor.cpp
|
||||
OPENCV = `pkg-config --cflags --libs opencv`
|
||||
|
||||
LIB_SO = adaptor.so
|
||||
|
||||
$(LIB_SO): $(CXX_SOURCES) $(DEPS)
|
||||
$(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV)
|
||||
|
||||
clean:
|
||||
rm -rf $(LIB_SO)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import numpy as np
|
||||
from .adaptor import pse as cpse
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
def pse(polys, min_area):
|
||||
ret = np.array(cpse(polys, min_area), dtype='int32')
|
||||
return ret
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ETSNET/pse/adaptor.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <opencv2/core/core.hpp>
|
||||
#include <opencv2/highgui/highgui.hpp>
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
|
||||
using std::vector;
|
||||
using cv::vector;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace pse_adaptor {
|
||||
void get_kernals(const int *data, vector<int> data_shape, const vector<Mat> &kernals) {
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
Mat kernal = Mat::zeros(data_shape[1], data_shape[2], CV_8UC1);
|
||||
for (int x = 0; x < kernal.rows; ++x) {
|
||||
for (int y = 0; y < kernal.cols; ++y) {
|
||||
kernal.at<char>(x, y) = data[i * data_shape[1] * data_shape[2] + x * data_shape[2] + y];
|
||||
}
|
||||
}
|
||||
kernals.emplace_back(kernal);
|
||||
}
|
||||
}
|
||||
|
||||
void growing_text_line(const vector, const vector<vector<>> &text_line, float min_area) {
|
||||
Mat label_mat;
|
||||
int label_num = connectedComponents(kernals[kernals.size() - 1], label_mat, 4);
|
||||
vector<int> area(label_num + 1, 0)
|
||||
memset(area, 0, sizeof(area));
|
||||
for (int x = 0; x < label_mat.rows; ++x) {
|
||||
for (int y = 0; y < label_mat.cols; ++y) {
|
||||
int label = label_mat.at<int>(x, y);
|
||||
if (label == 0) continue;
|
||||
area[label] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
queue<Point> queue, next_queue;
|
||||
for (int x = 0; x < label_mat.rows; ++x) {
|
||||
vector<int> row(label_mat.cols);
|
||||
for (int y = 0; y < label_mat.cols; ++y) {
|
||||
int label = label_mat.at<int>(x, y);
|
||||
|
||||
if (label == 0) continue;
|
||||
if (area[label] < min_area) continue;
|
||||
|
||||
Point point(x, y);
|
||||
queue.push(point);
|
||||
row[y] = label;
|
||||
}
|
||||
text_line.emplace_back(row);
|
||||
}
|
||||
|
||||
int dx[] = {-1, 1, 0, 0};
|
||||
int dy[] = {0, 0, -1, 1};
|
||||
|
||||
for (int kernal_id = kernals.size() - 2; kernal_id >= 0; --kernal_id) {
|
||||
while (!queue.empty()) {
|
||||
Point point = queue.front();
|
||||
queue.pop();
|
||||
int x = point.x;
|
||||
int y = point.y;
|
||||
int label = text_line[x][y];
|
||||
bool is_edge = true;
|
||||
for (int d = 0; d < 4; ++d) {
|
||||
int tmp_x = x + dx[d];
|
||||
int tmp_y = y + dy[d];
|
||||
|
||||
if (tmp_x < 0 || tmp_x >= (static_cast)<int>text_line.size()) continue;
|
||||
if (tmp_y < 0 || tmp_y >= (static_cast)<int>text_line[1].size()) continue;
|
||||
if (kernals[kernal_id].at<char>(tmp_x, tmp_y) == 0) continue;
|
||||
if (text_line[tmp_x][tmp_y] > 0) continue;
|
||||
|
||||
Point point_tmp(tmp_x, tmp_y);
|
||||
queue.push(point_tmp);
|
||||
text_line[tmp_x][tmp_y] = label;
|
||||
is_edge = false;
|
||||
}
|
||||
|
||||
if (is_edge) {
|
||||
next_queue.push(point);
|
||||
}
|
||||
}
|
||||
swap(queue, next_queue);
|
||||
}
|
||||
}
|
||||
|
||||
vector<vector<int>> pse(py::array_t<int, py::array::c_style | py::array::forcecast> quad_n9, float min_area) {
|
||||
auto buf = quad_n9.request();
|
||||
auto data = static_cast<int *>(buf.ptr);
|
||||
vector<Mat> kernals;
|
||||
get_kernals(data, buf.shape, kernals);
|
||||
vector<vector<int>> text_line;
|
||||
growing_text_line(kernals, text_line, min_area);
|
||||
|
||||
return text_line;
|
||||
}
|
||||
} // namespace pse_adaptor
|
||||
|
||||
PYBIND11_PLUGIN(adaptor) {
|
||||
py::module m("adaptor", "pse");
|
||||
m.def("pse", &pse_adaptor::pse, "pse");
|
||||
return m.ptr();
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from .base import _conv, _bn
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channels: Integer. Input channel.
|
||||
out_channels: Integer. Output channel.
|
||||
stride: Integer. Stride size for the initial convolutional layer. Default:1.
|
||||
momentum: Float. Momentum for batchnorm layer. Default:0.1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
ResidualBlock(3,256,stride=2,down_sample=True)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
momentum=0.1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
out_chls = out_channels // self.expansion
|
||||
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1)
|
||||
self.bn1 = _bn(out_chls, momentum=momentum)
|
||||
|
||||
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1, pad_mode='pad')
|
||||
self.bn2 = _bn(out_chls, momentum=momentum)
|
||||
|
||||
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1)
|
||||
self.bn3 = _bn(out_channels, momentum=momentum)
|
||||
|
||||
self.relu = P.ReLU()
|
||||
self.downsample = (in_channels != out_channels)
|
||||
self.stride = stride
|
||||
if self.downsample or self.stride != 1:
|
||||
self.conv_down_sample = _conv(in_channels, out_channels,
|
||||
kernel_size=1, stride=stride)
|
||||
self.bn_down_sample = _bn(out_channels, momentum=momentum)
|
||||
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample or self.stride != 1:
|
||||
identity = self.conv_down_sample(identity)
|
||||
identity = self.bn_down_sample(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
"""
|
||||
ResNet V1 network.
|
||||
|
||||
Args:
|
||||
block: Cell. Block for network.
|
||||
layer_nums: List. Numbers of different layers.
|
||||
in_channels: Integer. Input channel.
|
||||
out_channels: Integer. Output channel.
|
||||
num_classes: Integer. Class number. Default:100.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
100)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of "
|
||||
"layer_num, inchannel, outchannel list must be 4!")
|
||||
|
||||
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=1)
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=2)
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=2)
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=2)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
|
||||
"""
|
||||
Make Layer for ResNet.
|
||||
|
||||
Args:
|
||||
block: Cell. Resnet block.
|
||||
layer_num: Integer. Layer number.
|
||||
in_channel: Integer. Input channel.
|
||||
out_channel: Integer. Output channel.
|
||||
stride:Integer. Stride size for the initial convolutional layer.
|
||||
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
_make_layer(BasicBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
layers = []
|
||||
|
||||
resblk = block(in_channel, out_channel, stride=stride)
|
||||
layers.append(resblk)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resblk = block(out_channel, out_channel, stride=1)
|
||||
layers.append(resblk)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.pad(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
return c2, c3, c4, c5
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
'INFER_LONG_SIZE': 1920,
|
||||
'KERNEL_NUM': 7,
|
||||
'INFERENCE': True, # INFER MODE\TRAIN MODE
|
||||
|
||||
# backbone
|
||||
'BACKBONE_LAYER_NUMS': [3, 4, 6, 3],
|
||||
'BACKBONE_IN_CHANNELS': [64, 256, 512, 1024],
|
||||
'BACKBONE_OUT_CHANNELS': [256, 512, 1024, 2048],
|
||||
|
||||
# neck
|
||||
'NECK_OUT_CHANNEL': 256,
|
||||
|
||||
# dataset for train
|
||||
"TRAIN_ROOT_DIR": '/autotest/lqk/modelzoo/psenet/ic15/',
|
||||
"TRAIN_IS_TRANSFORM": True,
|
||||
"TRAIN_LONG_SIZE": 640,
|
||||
"TRAIN_DATASET_SIZE": 1000,
|
||||
"TRAIN_MIN_SCALE": 0.4,
|
||||
"TRAIN_BUFFER_SIZE": 8,
|
||||
"TRAIN_BATCH_SIZE": 4,
|
||||
"TRAIN_REPEAT_NUM": 608*4,
|
||||
"TRAIN_DROP_REMAINDER": True,
|
||||
"TRAIN_TOTAL_ITER": 152000,
|
||||
"TRAIN_MODEL_SAVE_PATH": './checkpoints/',
|
||||
|
||||
# dataset for test
|
||||
"TEST_ROOT_DIR": '/autotest/lqk/modelzoo/psenet/ic15/',
|
||||
"TEST_DATASET_SIZE": 500,
|
||||
"TEST_BUFFER_SIZE": 4,
|
||||
"TEST_DROP_REMAINDER": False,
|
||||
})
|
|
@ -0,0 +1,314 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import pyclipper
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import Polygon as plg
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.py_transforms as py_transforms
|
||||
|
||||
from src.config import config
|
||||
|
||||
__all__ = ['train_dataset_creator', 'test_dataset_creator']
|
||||
|
||||
def get_img(img_path):
|
||||
img = cv2.imread(img_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
def get_imgs_names(root_dir):
|
||||
img_paths = [i for i in os.listdir(root_dir)
|
||||
if os.path.splitext(i)[-1].lower() in ['.jpg', '.jpeg', '.png']]
|
||||
return img_paths
|
||||
|
||||
def get_bboxes(img, gt_path):
|
||||
h, w = img.shape[0:2]
|
||||
with open(gt_path, 'r', encoding='utf-8-sig') as f:
|
||||
lines = f.readlines()
|
||||
bboxes = []
|
||||
tags = []
|
||||
for line in lines:
|
||||
line = line.replace('\xef\xbb\xbf', '')
|
||||
line = line.replace('\ufeff', '')
|
||||
line = line.replace('\n', '')
|
||||
gt = line.split(",", 8)
|
||||
tag = gt[-1][0] != '#'
|
||||
box = [int(gt[i]) for i in range(8)]
|
||||
box = np.asarray(box) / ([w * 1.0, h * 1.0] * 4)
|
||||
bboxes.append(box)
|
||||
tags.append(tag)
|
||||
return np.array(bboxes), tags
|
||||
|
||||
def random_scale(img, min_size):
|
||||
h, w = img.shape[0:2]
|
||||
if max(h, w) > 1280:
|
||||
scale1 = 1280.0 / max(h, w)
|
||||
img = cv2.resize(img, dsize=None, fx=scale1, fy=scale1)
|
||||
|
||||
h, w = img.shape[0:2]
|
||||
random_scale1 = np.array([0.5, 1.0, 2.0, 3.0])
|
||||
scale2 = np.random.choice(random_scale1)
|
||||
if min(h, w) * scale2 <= min_size:
|
||||
scale3 = (min_size + 10) * 1.0 / min(h, w)
|
||||
img = cv2.resize(img, dsize=None, fx=scale3, fy=scale3)
|
||||
else:
|
||||
img = cv2.resize(img, dsize=None, fx=scale2, fy=scale2)
|
||||
return img
|
||||
|
||||
def random_horizontal_flip(imgs):
|
||||
if random.random() < 0.5:
|
||||
for i, _ in enumerate(imgs):
|
||||
imgs[i] = np.flip(imgs[i], axis=1).copy()
|
||||
return imgs
|
||||
|
||||
def random_rotate(imgs):
|
||||
max_angle = 10
|
||||
angle = random.random() * 2 * max_angle - max_angle
|
||||
for i, _ in enumerate(imgs):
|
||||
img = imgs[i]
|
||||
w, h = img.shape[:2]
|
||||
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
|
||||
img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
|
||||
imgs[i] = img_rotation
|
||||
return imgs
|
||||
|
||||
def random_crop(imgs, img_size):
|
||||
h, w = imgs[0].shape[0:2]
|
||||
th, tw = img_size
|
||||
if w == tw and h == th:
|
||||
return imgs
|
||||
|
||||
if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
|
||||
tl = np.min(np.where(imgs[1] > 0), axis=1) - img_size
|
||||
tl[tl < 0] = 0
|
||||
br = np.max(np.where(imgs[1] > 0), axis=1) - img_size
|
||||
br[br < 0] = 0
|
||||
br[0] = min(br[0], h - th)
|
||||
br[1] = min(br[1], w - tw)
|
||||
|
||||
i = random.randint(tl[0], br[0])
|
||||
j = random.randint(tl[1], br[1])
|
||||
else:
|
||||
i = random.randint(0, h - th)
|
||||
j = random.randint(0, w - tw)
|
||||
|
||||
for idx, _ in enumerate(imgs):
|
||||
if len(imgs[idx].shape) == 3:
|
||||
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
|
||||
else:
|
||||
imgs[idx] = imgs[idx][i:i + th, j:j + tw]
|
||||
return imgs
|
||||
|
||||
def scale(img, long_size=2240):
|
||||
h, w = img.shape[0:2]
|
||||
scale_long = long_size * 1.0 / max(h, w)
|
||||
img = cv2.resize(img, dsize=None, fx=scale_long, fy=scale_long)
|
||||
return img
|
||||
|
||||
def dist(a, b):
|
||||
return np.sqrt(np.sum((a - b) ** 2))
|
||||
|
||||
def perimeter(bbox):
|
||||
peri = 0.0
|
||||
for i in range(bbox.shape[0]):
|
||||
peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
|
||||
return peri
|
||||
|
||||
def shrink(bboxes, rate, max_shr=20):
|
||||
rate = rate * rate
|
||||
shrinked_bboxes = []
|
||||
for bbox in bboxes:
|
||||
area = plg.Polygon(bbox).area()
|
||||
peri = perimeter(bbox)
|
||||
|
||||
pco = pyclipper.PyclipperOffset()
|
||||
pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
|
||||
|
||||
shrinked_bbox = pco.Execute(-offset)
|
||||
if not shrinked_bbox:
|
||||
shrinked_bboxes.append(bbox)
|
||||
continue
|
||||
|
||||
shrinked_bbox = np.array(shrinked_bbox)[0]
|
||||
if shrinked_bbox.shape[0] <= 2:
|
||||
shrinked_bboxes.append(bbox)
|
||||
continue
|
||||
|
||||
shrinked_bboxes.append(shrinked_bbox)
|
||||
|
||||
return np.array(shrinked_bboxes)
|
||||
|
||||
class TrainDataset:
|
||||
def __init__(self):
|
||||
self.is_transform = config.TRAIN_IS_TRANSFORM
|
||||
self.img_size = config.TRAIN_LONG_SIZE
|
||||
self.kernel_num = config.KERNEL_NUM
|
||||
self.min_scale = config.TRAIN_MIN_SCALE
|
||||
|
||||
root_dir = os.path.join(os.path.join(os.path.dirname(__file__), '..'), config.TRAIN_ROOT_DIR)
|
||||
ic15_train_data_dir = root_dir + 'ch4_training_images/'
|
||||
ic15_train_gt_dir = root_dir + 'ch4_training_localization_transcription_gt/'
|
||||
|
||||
self.img_size = self.img_size if \
|
||||
(self.img_size is None or isinstance(self.img_size, tuple)) \
|
||||
else (self.img_size, self.img_size)
|
||||
|
||||
data_dirs = [ic15_train_data_dir]
|
||||
gt_dirs = [ic15_train_gt_dir]
|
||||
|
||||
self.all_img_paths = []
|
||||
self.all_gt_paths = []
|
||||
|
||||
for data_dir, gt_dir in zip(data_dirs, gt_dirs):
|
||||
img_names = [i for i in os.listdir(data_dir)
|
||||
if os.path.splitext(i)[-1].lower()
|
||||
in ['.jpg', '.jpeg', '.png']]
|
||||
|
||||
img_paths = []
|
||||
gt_paths = []
|
||||
for _, img_name in enumerate(img_names):
|
||||
img_path = os.path.join(data_dir, img_name)
|
||||
gt_name = 'gt_' + img_name.split('.')[0] + '.txt'
|
||||
gt_path = os.path.join(gt_dir, gt_name)
|
||||
img_paths.append(img_path)
|
||||
gt_paths.append(gt_path)
|
||||
|
||||
self.all_img_paths.extend(img_paths)
|
||||
self.all_gt_paths.extend(gt_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.all_img_paths[index]
|
||||
gt_path = self.all_gt_paths[index]
|
||||
|
||||
# start0 = time.time()
|
||||
img = get_img(img_path)
|
||||
bboxes, tags = get_bboxes(img, gt_path)
|
||||
# end0 = time.time()
|
||||
|
||||
# multi-scale training
|
||||
if self.is_transform:
|
||||
img = random_scale(img, min_size=self.img_size[0])
|
||||
|
||||
# get gt_text and training_mask
|
||||
img_h, img_w = img.shape[0: 2]
|
||||
gt_text = np.zeros((img_h, img_w), dtype=np.float32)
|
||||
training_mask = np.ones((img_h, img_w), dtype=np.float32)
|
||||
if bboxes.shape[0] > 0:
|
||||
bboxes = np.reshape(bboxes * ([img_w, img_h] * 4), (bboxes.shape[0], -1, 2)).astype('int32')
|
||||
for i in range(bboxes.shape[0]):
|
||||
cv2.drawContours(gt_text, [bboxes[i]], 0, i + 1, -1)
|
||||
if not tags[i]:
|
||||
cv2.drawContours(training_mask, [bboxes[i]], 0, 0, -1)
|
||||
|
||||
# get gt_kernels
|
||||
gt_kernels = []
|
||||
for i in range(1, self.kernel_num):
|
||||
rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i
|
||||
gt_kernel = np.zeros(img.shape[0:2], dtype=np.float32)
|
||||
kernel_bboxes = shrink(bboxes, rate)
|
||||
for j in range(kernel_bboxes.shape[0]):
|
||||
cv2.drawContours(gt_kernel, [kernel_bboxes[j]], 0, 1, -1)
|
||||
gt_kernels.append(gt_kernel)
|
||||
|
||||
# data augmentation
|
||||
if self.is_transform:
|
||||
imgs = [img, gt_text, training_mask]
|
||||
imgs.extend(gt_kernels)
|
||||
imgs = random_horizontal_flip(imgs)
|
||||
imgs = random_rotate(imgs)
|
||||
imgs = random_crop(imgs, self.img_size)
|
||||
img, gt_text, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:]
|
||||
|
||||
gt_text[gt_text > 0] = 1
|
||||
gt_kernels = np.array(gt_kernels)
|
||||
|
||||
if self.is_transform:
|
||||
img = Image.fromarray(img)
|
||||
img = img.convert('RGB')
|
||||
img = py_transforms.RandomColorAdjust(brightness=32.0 / 255, saturation=0.5)(img)
|
||||
else:
|
||||
img = Image.fromarray(img)
|
||||
img = img.convert('RGB')
|
||||
|
||||
img = py_transforms.ToTensor()(img)
|
||||
img = py_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
|
||||
|
||||
gt_text = gt_text.astype(np.float32)
|
||||
gt_kernels = gt_kernels.astype(np.float32)
|
||||
training_mask = training_mask.astype(np.float32)
|
||||
|
||||
return img, gt_text, gt_kernels, training_mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_img_paths)
|
||||
|
||||
def IC15_TEST_Generator():
|
||||
ic15_test_data_dir = config.TEST_ROOT_DIR + 'ch4_test_images/'
|
||||
img_size = config.INFER_LONG_SIZE
|
||||
|
||||
img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size)
|
||||
|
||||
data_dirs = [ic15_test_data_dir]
|
||||
all_img_paths = []
|
||||
|
||||
for data_dir in data_dirs:
|
||||
img_names = [i for i in os.listdir(data_dir) if os.path.splitext(i)[-1].lower() in ['.jpg', '.jpeg', '.png']]
|
||||
|
||||
img_paths = []
|
||||
for _, img_name in enumerate(img_names):
|
||||
img_path = data_dir + img_name
|
||||
img_paths.append(img_path)
|
||||
|
||||
all_img_paths.extend(img_paths)
|
||||
|
||||
dataset_length = len(all_img_paths)
|
||||
|
||||
for index in range(dataset_length):
|
||||
img_path = all_img_paths[index]
|
||||
img_name = np.array(os.path.split(img_path)[-1])
|
||||
img = get_img(img_path)
|
||||
|
||||
long_size = max(img.shape[:2])
|
||||
img_resized = np.zeros((long_size, long_size, 3), np.uint8)
|
||||
img_resized[:img.shape[0], :img.shape[1], :] = img
|
||||
img_resized = cv2.resize(img_resized, dsize=img_size)
|
||||
|
||||
img_resized = Image.fromarray(img_resized)
|
||||
img_resized = img_resized.convert('RGB')
|
||||
img_resized = py_transforms.ToTensor()(img_resized)
|
||||
img_resized = py_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_resized)
|
||||
|
||||
yield img, img_resized, img_name
|
||||
|
||||
def train_dataset_creator():
|
||||
cv2.setNumThreads(0)
|
||||
dataset = TrainDataset()
|
||||
ds = de.GeneratorDataset(dataset, ['img', 'gt_text', 'gt_kernels', 'training_mask'], num_parallel_workers=8)
|
||||
#ds = ds.repeat(config.TRAIN_REPEAT_NUM)
|
||||
ds = ds.batch(config.TRAIN_BATCH_SIZE, drop_remainder=config.TRAIN_DROP_REMAINDER)
|
||||
ds = ds.shuffle(buffer_size=config.TRAIN_BUFFER_SIZE)
|
||||
return ds
|
||||
|
||||
def test_dataset_creator():
|
||||
ds = de.GeneratorDataset(IC15_TEST_Generator, ['img', 'img_resized', 'img_name'])
|
||||
ds = ds.shuffle(config.TEST_BUFFER_SIZE)
|
||||
ds = ds.batch(1, drop_remainder=config.TEST_DROP_REMAINDER)
|
||||
return ds
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import os
|
||||
import socket
|
||||
|
||||
RANK_TABLE_SAVE_PATH = './rank_table_4p.json'
|
||||
|
||||
|
||||
def main():
|
||||
nproc_per_node = 4
|
||||
|
||||
visible_devices = ['0', '1', '2', '3']
|
||||
|
||||
server_id = socket.gethostbyname(socket.gethostname())
|
||||
|
||||
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
|
||||
device_ips = {}
|
||||
for hccn_item in hccn_configs:
|
||||
hccn_item = hccn_item.strip()
|
||||
if hccn_item.startswith('address_'):
|
||||
device_id, device_ip = hccn_item.split('=')
|
||||
device_id = device_id.split('_')[1]
|
||||
device_ips[device_id] = device_ip
|
||||
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
|
||||
|
||||
hccn_table = {}
|
||||
hccn_table['board_id'] = '0x002f' # A+K
|
||||
# hccn_table['board_id'] = '0x0000' # A+X
|
||||
|
||||
hccn_table['chip_info'] = '910'
|
||||
hccn_table['deploy_mode'] = 'lab'
|
||||
hccn_table['group_count'] = '1'
|
||||
hccn_table['group_list'] = []
|
||||
instance_list = []
|
||||
for instance_id in range(nproc_per_node):
|
||||
instance = {}
|
||||
instance['devices'] = []
|
||||
device_id = visible_devices[instance_id]
|
||||
device_ip = device_ips[device_id]
|
||||
instance['devices'].append({
|
||||
'device_id': device_id,
|
||||
'device_ip': device_ip,
|
||||
})
|
||||
instance['rank_id'] = str(instance_id)
|
||||
instance['server_id'] = server_id
|
||||
instance_list.append(instance)
|
||||
hccn_table['group_list'].append({
|
||||
'device_num': str(nproc_per_node),
|
||||
'server_num': '1',
|
||||
'group_name': '',
|
||||
'instance_count': str(nproc_per_node),
|
||||
'instance_list': instance_list,
|
||||
})
|
||||
hccn_table['para_plane_nic_location'] = 'device'
|
||||
hccn_table['para_plane_nic_name'] = []
|
||||
for instance_id in range(nproc_per_node):
|
||||
eth_id = visible_devices[instance_id]
|
||||
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
|
||||
hccn_table['para_plane_nic_num'] = str(nproc_per_node)
|
||||
hccn_table['status'] = 'completed'
|
||||
import json
|
||||
with open(RANK_TABLE_SAVE_PATH, 'w') as table_fp:
|
||||
json.dump(hccn_table, table_fp, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if os.path.exists(RANK_TABLE_SAVE_PATH):
|
||||
print('Rank table file exists.')
|
||||
else:
|
||||
print('Generating rank table file.')
|
||||
main()
|
||||
print('Rank table file generated')
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import time
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['LossCallBack', 'WithLossCell', 'TrainOneStepCell']
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.loss_avg = AverageMeter()
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
|
||||
if cur_step_in_epoch == 1:
|
||||
self.loss_avg = AverageMeter()
|
||||
|
||||
self.loss_avg.update(loss)
|
||||
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||
loss_log = "time: %s, epoch: %s, step: %s, loss is %s" % (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())),
|
||||
cb_params.cur_epoch_num,
|
||||
cur_step_in_epoch,
|
||||
self.loss_avg.avg)
|
||||
print(loss_log)
|
||||
loss_file = open("./loss.log", "a+")
|
||||
loss_file.write(loss_log)
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to compute loss.
|
||||
|
||||
Args:
|
||||
backbone (Cell): The target network to wrap.
|
||||
loss_fn (Cell): The loss function used to compute loss.
|
||||
"""
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, img, gt_text, gt_kernels, training_mask):
|
||||
model_predict = self._backbone(img)
|
||||
return self._loss_fn(model_predict, gt_text, gt_kernels, training_mask)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""
|
||||
Get the backbone network.
|
||||
|
||||
Returns:
|
||||
Cell, return backbone network.
|
||||
"""
|
||||
return self._backbone
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Network training package class.
|
||||
|
||||
Append an optimizer to the training network after that the construct function
|
||||
can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default value is 1.0.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
# self.backbone = network._backbone
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.sens = Tensor((np.ones(1, dtype=np.float32)) * sens)
|
||||
self.reducer_flag = reduce_flag
|
||||
if self.reducer_flag:
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, img, gt_text, gt_kernels, training_mask):
|
||||
weights = self.weights
|
||||
loss = self.network(img, gt_text, gt_kernels, training_mask)
|
||||
grads = self.grad(self.network, weights)(img, gt_text, gt_kernels, training_mask, self.sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,156 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
import cv2
|
||||
from mindspore import Tensor, context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.dataset import test_dataset_creator
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.ETSNET.pse import pse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Hyperparams')
|
||||
parser.add_argument("--ckpt", type=str, default=0, help='trained model path.')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||
save_graphs_path=".")
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def sort_to_clockwise(points):
|
||||
center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2))
|
||||
clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees(
|
||||
math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)
|
||||
return clockwise_points
|
||||
|
||||
def write_result_as_txt(img_name, bboxes, path):
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path)
|
||||
filename = os.path.join(path, 'res_{}.txt'.format(os.path.splitext(img_name)[0]))
|
||||
lines = []
|
||||
for _, bbox in enumerate(bboxes):
|
||||
bbox = bbox.reshape(-1, 2)
|
||||
bbox = np.array(list(sort_to_clockwise(bbox)))[[3, 0, 1, 2]].copy().reshape(-1)
|
||||
values = [int(v) for v in bbox]
|
||||
line = "%d,%d,%d,%d,%d,%d,%d,%d\n" % tuple(values)
|
||||
lines.append(line)
|
||||
with open(filename, 'w') as f:
|
||||
for line in lines:
|
||||
f.write(line)
|
||||
|
||||
def test():
|
||||
if not os.path.isdir('./res/submit_ic15/'):
|
||||
os.makedirs('./res/submit_ic15/')
|
||||
if not os.path.isdir('./res/vis_ic15/'):
|
||||
os.makedirs('./res/vis_ic15/')
|
||||
ds = test_dataset_creator()
|
||||
|
||||
config.INFERENCE = True
|
||||
net = ETSNet(config)
|
||||
print(args.ckpt)
|
||||
param_dict = load_checkpoint(args.ckpt)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('parameters loaded!')
|
||||
|
||||
get_data_time = AverageMeter()
|
||||
model_run_time = AverageMeter()
|
||||
post_process_time = AverageMeter()
|
||||
|
||||
end_pts = time.time()
|
||||
iters = ds.create_tuple_iterator(output_numpy=True)
|
||||
count = 0
|
||||
for data in iters:
|
||||
count += 1
|
||||
# get data
|
||||
img, img_resized, img_name = data
|
||||
img = img[0].astype(np.uint8).copy()
|
||||
img_name = img_name[0].decode('utf-8')
|
||||
|
||||
get_data_pts = time.time()
|
||||
get_data_time.update(get_data_pts - end_pts)
|
||||
|
||||
# model run
|
||||
img_tensor = Tensor(img_resized, mstype.float32)
|
||||
score, kernels = net(img_tensor)
|
||||
score = np.squeeze(score.asnumpy())
|
||||
kernels = np.squeeze(kernels.asnumpy())
|
||||
|
||||
model_run_pts = time.time()
|
||||
model_run_time.update(model_run_pts - get_data_pts)
|
||||
|
||||
# post-process
|
||||
pred = pse(kernels, 5.0)
|
||||
scale = max(img.shape[:2]) * 1.0 / config.INFER_LONG_SIZE
|
||||
label = pred
|
||||
label_num = np.max(label) + 1
|
||||
bboxes = []
|
||||
|
||||
for i in range(1, label_num):
|
||||
points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]
|
||||
if points.shape[0] < 600:
|
||||
continue
|
||||
|
||||
score_i = np.mean(score[label == i])
|
||||
if score_i < 0.93:
|
||||
continue
|
||||
|
||||
rect = cv2.minAreaRect(points)
|
||||
bbox = cv2.boxPoints(rect) * scale
|
||||
bbox = bbox.astype('int32')
|
||||
cv2.drawContours(img, [bbox], 0, (0, 255, 0), 3)
|
||||
bboxes.append(bbox)
|
||||
|
||||
post_process_pts = time.time()
|
||||
post_process_time.update(post_process_pts - model_run_pts)
|
||||
|
||||
if count == 1:
|
||||
get_data_time.reset()
|
||||
model_run_time.reset()
|
||||
post_process_time.reset()
|
||||
|
||||
end_pts = time.time()
|
||||
|
||||
# save res
|
||||
cv2.imwrite('./res/vis_ic15/{}'.format(img_name), img[:, :, [2, 1, 0]].copy())
|
||||
write_result_as_txt(img_name, bboxes, './res/submit_ic15/')
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import train_dataset_creator
|
||||
from src.config import config
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.ETSNET.dice_loss import DiceLoss
|
||||
from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack
|
||||
|
||||
parser = argparse.ArgumentParser(description='Hyperparams')
|
||||
parser.add_argument('--run_distribute', default=False, action='store_true',
|
||||
help='Run distribute, default is false.')
|
||||
parser.add_argument('--pre_trained', type=str, default='', help='Pretrain file path.')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1.')
|
||||
args = parser.parse_args()
|
||||
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
def lr_generator(start_lr, lr_scale, total_iters):
|
||||
lrs = [start_lr * (lr_scale ** math.floor(cur_iter * 1.0 / (total_iters / 3))) for cur_iter in range(total_iters)]
|
||||
return lrs
|
||||
|
||||
def train():
|
||||
rank_id = 0
|
||||
if args.run_distribute:
|
||||
context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, parameter_broadcast=True)
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
|
||||
# dataset/network/criterion/optim
|
||||
ds = train_dataset_creator()
|
||||
step_size = ds.get_dataset_size()
|
||||
print('Create dataset done!')
|
||||
|
||||
config.INFERENCE = False
|
||||
net = ETSNet(config)
|
||||
net = net.set_train()
|
||||
param_dict = load_checkpoint(args.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('Load Pretrained parameters done!')
|
||||
|
||||
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
|
||||
|
||||
lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
|
||||
|
||||
# warp model
|
||||
net = WithLossCell(net, criterion)
|
||||
if args.run_distribute:
|
||||
net = TrainOneStepCell(net, opt, reduce_flag=True, mean=True, degree=args.device_num)
|
||||
else:
|
||||
net = TrainOneStepCell(net, opt)
|
||||
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossCallBack(per_print_times=10)
|
||||
# set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH
|
||||
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf,
|
||||
directory="./ckpt_{}".format(rank_id))
|
||||
|
||||
model = Model(net)
|
||||
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=True, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
Loading…
Reference in New Issue