!14008 Add AdvancedEast model.

From: @windaway
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-11 14:51:42 +08:00 committed by Gitee
commit 5820159fcc
31 changed files with 2509 additions and 0 deletions

View File

View File

View File

View File

View File

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

View File

@ -0,0 +1,194 @@
# AdvancedEAST
## introduction
AdvancedEAST is inspired by EAST [EAST:An Efficient and Accurate Scene Text Detector](https://arxiv.org/abs/1704.03155v2).
The architecture of AdvancedEAST is showed below![AdvancedEast network arch](AdvancedEast.network.png).
This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie/AdvancedEAST)(preprocess, network architecture, predict) and [BaoWentz/AdvancedEAST-PyTorch](https://github.com/BaoWentz/AdvancedEAST-PyTorch)(performance).
## Environment
* euleros v2r7 x86_64
* python 3.7.5
## Dependences
* mindspore==1.1
* shapely==1.7.1
* numpy==1.19.4
* tqdm==4.36.1
## Project files
* configuration of file
cfg.py, control parameters
* pre-process data:
preprocess.py, resize image
* label data:
label.py,produce label info
* define network
model.py and VGG.py
* define loss function
losses.py
* execute training
advanced_east.py and dataset.py
* predict
predict.py and nms.py
* scoring
score.py
* logging
logger.py
```shell
.
└──advanced_east
├── README.md
├── scripts
├── run_distribute_train_gpu.sh # launch ascend distributed training(8 pcs)
├── run_standalone_train_ascend.sh # launch ascend standalone training(1 pcs)
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
└── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
├── src
├── cfg.py # parameter configuration
├── dataset.py # data preprocessing
├── label.py # produce label info
├── logger.py # generate learning rate for each step
├── model.py # define network
├── nms.py # non-maximum suppression
├── predict.py # predict boxes
├── preprocess.py # pre-process data
└── score.py # scoring
├── export.py # export model for inference
├── prepare_data.py # exec data preprocessing
├── eval.py # eval net
├── train.py # train net
└── train_mindrecord.py # train net on user specified mindrecord
```
## dataset
ICPR MTWI 2018 challenge 2Text detection of network image[Link](https://tianchi.aliyun.com/competition/entrance/231651/introduction). It is not available to download dataset on the origin webpage,
the dataset is now provided by the author of the original project[Baiduyun link](https://pan.baidu.com/s/1NSyc-cHKV3IwDo6qojIrKA) password: ye9y. There are 10000 images and corresponding label
information in total in the dataset, which is divided into 2 directories with 9000 and 1000 samples respectively. In the origin training setting, training set and validation set are partitioned at the ratio
of 9:1. If you want to use your own dataset, please modify the configuration of dataset in /src/config.py. The organization of dataset file is listed as below
> ```bash
> .
> └─data_dir
> ├─images # dataset
> └─txt # vertex of text boxes
> ```
Some parameters in config.py
```python
'validation_split_ratio': 0.1, # ratio of validation dataset
'total_img': 10000, # total number of samples in dataset
'data_dir': './icpr/', # dir of dataset
'train_fname': 'train.txt', # the file which stores the images file name in training dataset
'val_fname': 'val.txt', # the file which stores the images file name in validation dataset
'mindsrecord_train_file': 'advanced-east.mindrecord', # mindsrecord of training dataset
'mindsrecord_test_file': 'advanced-east-val.mindrecord', # mindsrecord of validation dataset
'origin_image_dir_name': 'images_9000/', # dir which stores the original images.
'train_image_dir_name': 'images_train/', # dir which stores the preprocessed images.
'origin_txt_dir_name': 'txt_9000/', # dir which stores the original text verteices.
'train_label_dir_name': 'labels_train/', # dir which stores the preprocessed text verteices.
```
## Run the project
### Data preprocess
Resize all the images to fixed size, and convert the label information(the vertex of text box) into the format used in training and evaluation, then the Mindsrecord files are generated.
```bash
python preparedata.py
```
### Training
Prepare the VGG16 pre-training model. Due to copyright restrictions, please go to https://github.com/machrisaa/tensorflow-vgg to download the VGG16 pre-training model and place it in the src folder.
single 1p
```bash
python train.py --device_target="Ascend" --is_distributed=0 --device_id=0 > output.train.log 2>&1 &
```
single 1pspecific size
```bash
python train_mindrecord.py --device_target="Ascend" --is_distributed=0 --device_id=2 --size=256 > output.train.log 2>&1 &
```
multi Ascends
```bash
# running on distributed environment8p
bash scripts/run_distribute_train.sh
```
The detailed training parameters are in /src/config.py。
multi GPUs
```bash
# running on distributed environment8p
bash scripts/run_distribute_train_gpu.sh
```
The detailed training parameters are in /src/config.py。
config.py
```bash
'initial_epoch': 0, # epoch to init
'learning_rate': 1e-4, # learning rate when initialization
'decay': 5e-4, # weightdecay parameter
'epsilon': 1e-4, # the value of epsilon in loss computation
'batch_size': 8, # batch size
'lambda_inside_score_loss': 4.0, # coef of inside_score_loss
'lambda_side_vertex_code_loss': 1.0, # coef of vertex_code_loss
"lambda_side_vertex_coord_loss": 1.0, # coef of vertex_coord_loss
'max_train_img_size': 448, # max size of training images
'max_predict_img_size': 448, # max size of the images to predict
'ckpt_save_max': 10, # maximum of ckpt in dir
'saved_model_file_path': './saved_model/', # dir of saved model
```
## Evaluation
### Evaluate
The above python command will run in the background, you can view the results through the file output.eval.log. You will get the accuracy as following:
## performance
### Training performance
The performance listed below are acquired with the default configurations in /src/config.py
| Parameters | single Ascend |
| -------------------------- | ---------------------------------------------- |
| Model Version | AdvancedEAST |
| Resources | Ascend 910 |
| MindSpore Version | 1.1 |
| Dataset | MTWI-2018 |
| Training Parameters | epoch=18, batch_size = 8, lr=1e-3 |
| Optimizer | AdamWeightDecay |
| Loss Function | QuadLoss |
| Outputs | matrix with size of 3x64x64,3x96x96,3x112x112 |
| Loss | 0.1 |
| Total Time | 28mins, 60mins, 90mins |
| Checkpoints | 173MB.ckpt file |
### Evaluation Performance
On the default
| Parameters | single Ascend |
| ------------------- | --------------------------- |
| Model Version | AdvancedEAST |
| Resources | Ascend 910 |
| MindSpore Version | 1.1 |
| Dataset | 1000 images |
| batch_size | 8 |
| Outputs | precision, recall, F score |
| performance | 94.35, 55.45, 66.31 |

View File

@ -0,0 +1,180 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################eval advanced_east on dataset########################
"""
import argparse
import datetime
import os
import numpy as np
from PIL import Image
from mindspore import Tensor
from mindspore import context
from mindspore.common import set_seed
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from src.logger import get_logger
from src.predict import predict
from src.score import eval_pre_rec_f1
from src.config import config as cfg
from src.dataset import load_adEAST_dataset
from src.model import get_AdvancedEast_net, AdvancedEast
from src.preprocess import resize_image
set_seed(1)
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('adveast evaling')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
# logging and checkpoint related
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
parser.add_argument('--ckpt', type=str, default='Epoch_C0-6_4500.ckpt', help='ckpt to load')
parser.add_argument('--method', type=str, default='score', choices=['loss', 'score', 'pred'], help='evaluation')
parser.add_argument('--path', type=str, help='image path of prediction')
# distributed related
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args_opt = parser.parse_args()
args_opt.data_dir = cfg.data_dir
args_opt.batch_size = 8
args_opt.train_image_dir_name = cfg.data_dir + cfg.train_image_dir_name
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
args_opt.train_label_dir_name = cfg.data_dir + cfg.train_label_dir_name
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
args_opt.results_dir = cfg.results_dir
args_opt.val_fname = cfg.val_fname
args_opt.pixel_threshold = cfg.pixel_threshold
args_opt.max_predict_img_size = cfg.max_predict_img_size
args_opt.last_model_name = cfg.last_model_name
args_opt.saved_model_file_path = cfg.saved_model_file_path
return args_opt
def eval_loss(arg):
"""get network and init"""
loss_net, train_net = get_AdvancedEast_net()
print(os.path.join(arg.saved_model_file_path, arg.ckpt))
load_param_into_net(train_net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
train_net.set_train(False)
loss = 0
idx = 0
for item in dataset.create_tuple_iterator():
loss += loss_net(item[0], item[1])
idx += 1
print(loss / idx)
def eval_score(arg):
"""get network and init"""
net = AdvancedEast()
load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
net.set_train(False)
obj = eval_pre_rec_f1()
with open(os.path.join(arg.data_dir, arg.val_fname), 'r') as f_val:
f_list = f_val.readlines()
img_h, img_w = arg.max_predict_img_size, arg.max_predict_img_size
x = np.zeros((arg.batch_size, 3, img_h, img_w), dtype=np.float32)
batch_list = np.arange(0, len(f_list), arg.batch_size)
for idx in batch_list:
gt_list = []
for i in range(idx, min(idx + arg.batch_size, len(f_list))):
item = f_list[i]
img_filename = str(item).strip().split(',')[0][:-4]
img_path = os.path.join(arg.train_image_dir_name, img_filename) + '.jpg'
img = Image.open(img_path)
d_wight, d_height = resize_image(img, arg.max_predict_img_size)
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
img = np.asarray(img)
img = img / 1.
mean = np.array((123.68, 116.779, 103.939)).reshape([1, 1, 3])
img = ((img - mean)).astype(np.float32)
img = img.transpose((2, 0, 1))
x[i - idx] = img
# predict(east, img_path, threshold, cfg.pixel_threshold)
gt_list.append(np.load(os.path.join(arg.train_label_dir_name, img_filename) + '.npy'))
if idx + arg.batch_size >= len(f_list):
x = x[:len(f_list) - idx]
y = net(Tensor(x))
obj.add(y, gt_list)
print(obj.val())
def pred(arg):
"""pred"""
img_path = arg.path
net = AdvancedEast()
load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
predict(net, img_path, arg.pixel_threshold)
if __name__ == '__main__':
args = parse_args()
device_num = int(os.environ.get("DEVICE_NUM", 1))
if args.is_distributed:
if args.device_target == "Ascend":
init()
context.set_context(device_id=args.device_id)
elif args.device_target == "GPU":
init()
args.rank = get_rank()
args.group_size = get_group_size()
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
print(os.path.join(args.data_dir, args.mindsrecord_test_file))
dataset, batch_num = load_adEAST_dataset(os.path.join(args.data_dir,
args.mindsrecord_test_file),
batch_size=args.batch_size,
device_num=device_num, rank_id=args.rank, is_training=False,
num_parallel_workers=device_num)
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
method_dict = {'loss': eval_loss, 'score': eval_score, 'pred': pred}
method_dict[args.method](args)

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air and onnx models#################
python export.py
"""
import argparse
import numpy as np
from src.model import AdvancedEast
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
parser = argparse.ArgumentParser(description='adveast export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, default="./saved_model/Fin0_9-10_2250.ckpt", help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="./out.om", help="output file name.")
parser.add_argument('--width', type=int, default=256, help='input width')
parser.add_argument('--height', type=int, default=256, help='input height')
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
net = AdvancedEast()
assert args.ckpt_file is not None, "checkpoint_path is None."
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train advanced_east on dataset########################
"""
import argparse
import os
from mindspore.common import set_seed
from src.label import process_label, process_label_size
from src.config import config as cfg
from src.dataset import transImage2Mind, transImage2Mind_size
from src.preprocess import preprocess, preprocess_size
set_seed(1)
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore data prepare')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=3, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args()
return args_opt
if __name__ == '__main__':
args = parse_args()
# create data
preprocess()
process_label()
preprocess_size(256)
process_label_size(256)
preprocess_size(384)
process_label_size(384)
preprocess_size(448)
process_label_size(448)
# preprocess_size(512)
# process_label_size(512)
# preprocess_size(640)
# process_label_size(640)
# preprocess_size(736)
# process_label_size(736)
# preprocess_size(768)
# process_label_size(768)
mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file)
transImage2Mind(mindrecord_filename)
mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_test_file)
transImage2Mind(mindrecord_filename, True)
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
transImage2Mind_size(mindrecord_filename, 256)
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
transImage2Mind_size(mindrecord_filename, 256, True)
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
transImage2Mind_size(mindrecord_filename, 384)
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
transImage2Mind_size(mindrecord_filename, 384, True)
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
transImage2Mind_size(mindrecord_filename, 448)
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
transImage2Mind_size(mindrecord_filename, 448, True)

View File

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
PATH1=$1
ulimit -u unlimited
export DEVICE_NUM=2
export RANK_SIZE=2
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train_mindrecord.py --device_target Ascend --is_distributed 1 --device_id $i --data_path /disk1/adenew/icpr/advanced-east_448.mindrecord > log.txt 2>&1 &
cd ..
done

View File

@ -0,0 +1,24 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_standalone_train_ascend.sh"
echo "for example: bash run_standalone_train_ascend.sh"
echo "=============================================================================================================="
python train_mindrecord.py \
--device_target="Ascend" > output.train.log 2>&1 &

View File

@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
configs
"""
from easydict import EasyDict
config = EasyDict({
'initial_epoch': 2,
'epoch_num': 6,
'learning_rate': 1e-4,
'decay': 3e-5,
'epsilon': 1e-4,
'batch_size': 2,
'ckpt_interval': 2,
'lambda_inside_score_loss': 4.0,
'lambda_side_vertex_code_loss': 1.0,
"lambda_side_vertex_coord_loss": 1.0,
'max_train_img_size': 448,
'max_predict_img_size': 448,
'train_img_size': [256, 384, 448, 512, 640, 736, 768],
'predict_img_size': [256, 384, 448, 512, 640, 736, 768],
'ckpt_save_max': 500,
'validation_split_ratio': 0.1,
'total_img': 10000,
'data_dir': './icpr/',
'val_data_dir': './icpr/images/',
'train_fname': 'train.txt',
'train_fname_var': 'train_',
'val_fname': 'val.txt',
'val_fname_var': 'val_',
'mindsrecord_train_file': 'advanced-east.mindrecord',
'mindsrecord_test_file': 'advanced-east-val.mindrecord',
'results_dir': './results/',
'origin_image_dir_name': 'images/',
'train_image_dir_name': 'images_train/',
'origin_txt_dir_name': 'txt/',
'train_label_dir_name': 'labels_train/',
'train_image_dir_name_var': 'images_train_',
'mindsrecord_train_file_var': 'advanced-east_',
'train_label_dir_name_var': 'labels_train_',
'mindsrecord_val_file_var': 'advanced-east-val_',
'show_gt_image_dir_name': 'show_gt_images/',
'show_act_image_dir_name': 'show_act_images/',
'saved_model_file_path': './saved_model/',
'last_model_name': '_.ckpt',
'pixel_threshold': 0.8,
'iou_threshold': 0.1,
'feature_layers_range': range(5, 1, -1),
'feature_layers_num': len(range(5, 1, -1)),
'pixel_size': 2 ** range(5, 1, -1)[-1],
'quiet': True,
'side_vertex_pixel_threshold': 0.8,
'trunc_threshold': 0.1,
'predict_cut_text_line': False,
'predict_write2txt': True,
'shrink_ratio': 0.2,
'shrink_side_ratio': 0.6,
'gen_origin_img': True,
'draw_gt_quad': False,
'draw_act_quad': False,
'vgg_npy': '/disk1/ade/vgg16.npy',
})

View File

@ -0,0 +1,171 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset.
"""
import os
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as vision
from mindspore.mindrecord import FileWriter
import numpy as np
from PIL import Image, ImageFile
import src.config as cfg
ImageFile.LOAD_TRUNCATED_IMAGES = True
cfg = cfg.config
def gen(batch_size=cfg.batch_size, is_val=False):
"""generate label"""
img_h, img_w = cfg.max_train_img_size, cfg.max_train_img_size
x = np.zeros((batch_size, 3, img_h, img_w), dtype=np.float32)
pixel_num_h = img_h // 4
pixel_num_w = img_w // 4
y = np.zeros((batch_size, 7, pixel_num_h, pixel_num_w), dtype=np.float32)
if is_val:
with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:
f_list = f_val.readlines()
else:
with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
f_list = f_train.readlines()
while True:
batch_list = np.arange(0, len(f_list), batch_size)
for idx in batch_list:
for idx2 in range(idx, idx + batch_size):
# random gen an image name
if idx2 < len(f_list):
img = f_list[idx2]
else:
img = np.random.choice(f_list)
img_filename = str(img).strip().split(',')[0]
# load img and img anno
img_path = os.path.join(cfg.data_dir,
cfg.train_image_dir_name,
img_filename)
img = Image.open(img_path)
img = np.asarray(img)
img = img / 127.5 - 1
x[idx2 - idx] = img.transpose((2, 0, 1))
gt_file = os.path.join(cfg.data_dir,
cfg.train_label_dir_name,
img_filename[:-4] + '_gt.npy')
y[idx2 - idx] = np.load(gt_file).transpose((2, 0, 1))
yield x, y
def transImage2Mind(mindrecord_filename, is_val=False):
"""transfer the image to mindrecord"""
if os.path.exists(mindrecord_filename):
os.remove(mindrecord_filename)
os.remove(mindrecord_filename + ".db")
writer = FileWriter(file_name=mindrecord_filename, shard_num=1)
cv_schema = {"image": {"type": "bytes"}, "label": {"type": "float32", "shape": [7, 112, 112]}}
writer.add_schema(cv_schema, "advancedEast dataset")
if is_val:
with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:
f_list = f_val.readlines()
else:
with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
f_list = f_train.readlines()
data = []
for item in f_list:
img_filename = str(item).strip().split(',')[0]
img_path = os.path.join(cfg.data_dir,
cfg.train_image_dir_name,
img_filename)
print(img_path)
with open(img_path, 'rb') as f:
img = f.read()
gt_file = os.path.join(cfg.data_dir,
cfg.train_label_dir_name,
img_filename[:-4] + '_gt.npy')
labels = np.load(gt_file)
labels = np.transpose(labels, (2, 0, 1))
data.append({"image": img,
"label": np.array(labels, np.float32)})
if len(data) == 32:
writer.write_raw_data(data)
print('len(data):{}'.format(len(data)))
data = []
if data:
writer.write_raw_data(data)
writer.commit()
def transImage2Mind_size(mindrecord_filename, width=256, is_val=False):
"""transfer the image to mindrecord at specified size"""
mindrecord_filename = mindrecord_filename + str(width) + '.mindrecord'
if os.path.exists(mindrecord_filename):
os.remove(mindrecord_filename)
os.remove(mindrecord_filename + ".db")
writer = FileWriter(file_name=mindrecord_filename, shard_num=1)
cv_schema = {"image": {"type": "bytes"}, "label": {"type": "float32", "shape": [7, width // 4, width // 4]}}
writer.add_schema(cv_schema, "advancedEast dataset")
if is_val:
with open(os.path.join(cfg.data_dir, cfg.val_fname_var + str(width) + '.txt'), 'r') as f_val:
f_list = f_val.readlines()
else:
with open(os.path.join(cfg.data_dir, cfg.train_fname_var + str(width) + '.txt'), 'r') as f_train:
f_list = f_train.readlines()
data = []
for item in f_list:
img_filename = str(item).strip().split(',')[0]
img_path = os.path.join(cfg.data_dir,
cfg.train_image_dir_name_var + str(width),
img_filename)
with open(img_path, 'rb') as f:
img = f.read()
gt_file = os.path.join(cfg.data_dir,
cfg.train_label_dir_name_var + str(width),
img_filename[:-4] + '_gt.npy')
labels = np.load(gt_file)
labels = np.transpose(labels, (2, 0, 1))
data.append({"image": img,
"label": np.array(labels, np.float32)})
if len(data) == 32:
writer.write_raw_data(data)
print('len(data):{}'.format(len(data)))
data = []
if data:
writer.write_raw_data(data)
writer.commit()
def load_adEAST_dataset(mindrecord_file, batch_size=64, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=8):
"""load mindrecord"""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "label"], num_shards=device_num, shard_id=rank_id,
num_parallel_workers=8, shuffle=is_training)
hwc_to_chw = vision.HWC2CHW()
cd = vision.Decode()
ds = ds.map(operations=cd, input_columns=["image"])
rc = vision.RandomColorAdjust(brightness=0.1, contrast=0.2, saturation=0.2)
vn = vision.Normalize(mean=(123.68, 116.779, 103.939), std=(1., 1., 1.))
ds = ds.map(operations=[rc, vn, hwc_to_chw], input_columns=["image"], num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
batch_num = ds.get_dataset_size()
ds = ds.shuffle(batch_num)
return ds, batch_num

View File

@ -0,0 +1,242 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
labeling
"""
import os
import numpy as np
from PIL import Image, ImageDraw
from tqdm import tqdm
from src.config import config as cfg
def point_inside_of_quad(px, py, quad_xy_list, p_min, p_max):
"""test box in or not"""
if (p_min[0] <= px <= p_max[0]) and (p_min[1] <= py <= p_max[1]):
xy_list = np.zeros((4, 2))
xy_list[:3, :] = quad_xy_list[1:4, :] - quad_xy_list[:3, :]
xy_list[3] = quad_xy_list[0, :] - quad_xy_list[3, :]
yx_list = np.zeros((4, 2))
yx_list[:, :] = quad_xy_list[:, -1:-3:-1]
a = xy_list * ([py, px] - yx_list)
b = a[:, 0] - a[:, 1]
if np.amin(b) >= 0 or np.amax(b) <= 0:
return True
return False
def point_inside_of_nth_quad(px, py, xy_list, shrink_1, long_edge):
"""test point in which box"""
nth = -1
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
for ith in range(2):
quad_xy_list = np.concatenate((
np.reshape(xy_list[vs[long_edge][ith][0]], (1, 2)),
np.reshape(shrink_1[vs[long_edge][ith][1]], (1, 2)),
np.reshape(shrink_1[vs[long_edge][ith][2]], (1, 2)),
np.reshape(xy_list[vs[long_edge][ith][3]], (1, 2))), axis=0)
p_min = np.amin(quad_xy_list, axis=0)
p_max = np.amax(quad_xy_list, axis=0)
if point_inside_of_quad(px, py, quad_xy_list, p_min, p_max):
if nth == -1:
nth = ith
else:
nth = -1
break
return nth
def shrink(xy_list, ratio=cfg.shrink_ratio):
"""shrink"""
if ratio == 0.0:
return xy_list, xy_list
diff_1to3 = xy_list[:3, :] - xy_list[1:4, :]
diff_4 = xy_list[3:4, :] - xy_list[0:1, :]
diff = np.concatenate((diff_1to3, diff_4), axis=0)
dis = np.sqrt(np.sum(np.square(diff), axis=-1))
# determine which are long or short edges
long_edge = int(np.argmax(np.sum(np.reshape(dis, (2, 2)), axis=0)))
short_edge = 1 - long_edge
# cal r length array
r = [np.minimum(dis[i], dis[(i + 1) % 4]) for i in range(4)]
# cal theta array
diff_abs = np.abs(diff)
diff_abs[:, 0] += cfg.epsilon
theta = np.arctan(diff_abs[:, 1] / diff_abs[:, 0])
# shrink two long edges
temp_new_xy_list = np.copy(xy_list)
shrink_edge(xy_list, temp_new_xy_list, long_edge, r, theta, ratio)
shrink_edge(xy_list, temp_new_xy_list, long_edge + 2, r, theta, ratio)
# shrink two short edges
new_xy_list = np.copy(temp_new_xy_list)
shrink_edge(temp_new_xy_list, new_xy_list, short_edge, r, theta, ratio)
shrink_edge(temp_new_xy_list, new_xy_list, short_edge + 2, r, theta, ratio)
return temp_new_xy_list, new_xy_list, long_edge # 缩短后的长边,缩短后的短边,长边下标
def shrink_edge(xy_list, new_xy_list, edge, r, theta, ratio=cfg.shrink_ratio):
"""shrink edge"""
if ratio == 0.0:
return
start_point = edge # 边的起始点下标0或1
end_point = (edge + 1) % 4
long_start_sign_x = np.sign(
xy_list[end_point, 0] - xy_list[start_point, 0])
new_xy_list[start_point, 0] = \
xy_list[start_point, 0] + \
long_start_sign_x * ratio * r[start_point] * np.cos(theta[start_point])
long_start_sign_y = np.sign(
xy_list[end_point, 1] - xy_list[start_point, 1])
new_xy_list[start_point, 1] = \
xy_list[start_point, 1] + \
long_start_sign_y * ratio * r[start_point] * np.sin(theta[start_point])
# long edge one, end point
long_end_sign_x = -1 * long_start_sign_x
new_xy_list[end_point, 0] = \
xy_list[end_point, 0] + \
long_end_sign_x * ratio * r[end_point] * np.cos(theta[start_point])
long_end_sign_y = -1 * long_start_sign_y
new_xy_list[end_point, 1] = \
xy_list[end_point, 1] + \
long_end_sign_y * ratio * r[end_point] * np.sin(theta[start_point])
def precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax,
jmin, jmax, p_min, p_max, gt, long_edge, draw):
"""precess list"""
for i in range(imin, imax):
for j in range(jmin, jmax):
px = (j + 0.5) * cfg.pixel_size
py = (i + 0.5) * cfg.pixel_size
if point_inside_of_quad(px, py,
shrink_xy_list, p_min, p_max):
gt[i, j, 0] = 1
line_width, line_color = 1, 'red'
ith = point_inside_of_nth_quad(px, py,
xy_list,
shrink_1,
long_edge)
vs = [[[3, 0], [1, 2]], [[0, 1], [2, 3]]]
if ith in range(2):
gt[i, j, 1] = 1
if ith == 0:
line_width, line_color = 2, 'yellow'
else:
line_width, line_color = 2, 'green'
gt[i, j, 2:3] = ith
gt[i, j, 3:5] = \
xy_list[vs[long_edge][ith][0]] - [px, py]
gt[i, j, 5:] = \
xy_list[vs[long_edge][ith][1]] - [px, py]
draw.line([(px - 0.5 * cfg.pixel_size,
py - 0.5 * cfg.pixel_size),
(px + 0.5 * cfg.pixel_size,
py - 0.5 * cfg.pixel_size),
(px + 0.5 * cfg.pixel_size,
py + 0.5 * cfg.pixel_size),
(px - 0.5 * cfg.pixel_size,
py + 0.5 * cfg.pixel_size),
(px - 0.5 * cfg.pixel_size,
py - 0.5 * cfg.pixel_size)],
width=line_width, fill=line_color)
return gt
def process_label(data_dir=cfg.data_dir):
"""process label"""
with open(os.path.join(data_dir, cfg.val_fname), 'r') as f_val:
f_list = f_val.readlines()
with open(os.path.join(data_dir, cfg.train_fname), 'r') as f_train:
f_list.extend(f_train.readlines())
for line, _ in zip(f_list, tqdm(range(len(f_list)))):
line_cols = str(line).strip().split(',')
img_name, width, height = \
line_cols[0].strip(), int(line_cols[1].strip()), \
int(line_cols[2].strip())
gt = np.zeros((height // cfg.pixel_size, width // cfg.pixel_size, 7))
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
xy_list_array = np.load(os.path.join(train_label_dir,
img_name[:-4] + '.npy'))
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
with Image.open(os.path.join(train_image_dir, img_name)) as im:
draw = ImageDraw.Draw(im)
for xy_list in xy_list_array:
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
p_min = np.amin(shrink_xy_list, axis=0)
p_max = np.amax(shrink_xy_list, axis=0)
# floor of the float
ji_min = (p_min / cfg.pixel_size - 0.5).astype(int) - 1
# +1 for ceil of the float and +1 for include the end
ji_max = (p_max / cfg.pixel_size - 0.5).astype(int) + 3
imin = np.maximum(0, ji_min[1])
imax = np.minimum(height // cfg.pixel_size, ji_max[1])
jmin = np.maximum(0, ji_min[0])
jmax = np.minimum(width // cfg.pixel_size, ji_max[0])
gt = precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax, jmin, jmax,
p_min, p_max, gt, long_edge, draw)
act_image_dir = os.path.join(cfg.data_dir,
cfg.show_act_image_dir_name)
if cfg.draw_act_quad:
im.save(os.path.join(act_image_dir, img_name))
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
np.save(os.path.join(train_label_dir,
img_name[:-4] + '_gt.npy'), gt)
def process_label_size(width=256, data_dir=cfg.data_dir):
"""process label at specific size"""
with open(os.path.join(data_dir, cfg.val_fname_var + str(width) + '.txt'), 'r') as f_val:
f_list = f_val.readlines()
with open(os.path.join(data_dir, cfg.train_fname_var + str(width) + '.txt'), 'r') as f_train:
f_list.extend(f_train.readlines())
for line, _ in zip(f_list, tqdm(range(len(f_list)))):
line_cols = str(line).strip().split(',')
img_name, width, height = \
line_cols[0].strip(), int(line_cols[1].strip()), \
int(line_cols[2].strip())
gt = np.zeros((height // cfg.pixel_size, width // cfg.pixel_size, 7))
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
xy_list_array = np.load(os.path.join(train_label_dir,
img_name[:-4] + '.npy'))
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
with Image.open(os.path.join(train_image_dir, img_name)) as im:
draw = ImageDraw.Draw(im)
for xy_list in xy_list_array:
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
p_min = np.amin(shrink_xy_list, axis=0)
p_max = np.amax(shrink_xy_list, axis=0)
# floor of the float
ji_min = (p_min / cfg.pixel_size - 0.5).astype(int) - 1
# +1 for ceil of the float and +1 for include the end
ji_max = (p_max / cfg.pixel_size - 0.5).astype(int) + 3
imin = np.maximum(0, ji_min[1])
imax = np.minimum(height // cfg.pixel_size, ji_max[1])
jmin = np.maximum(0, ji_min[0])
jmax = np.minimum(width // cfg.pixel_size, ji_max[0])
gt = precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax, jmin, jmax, p_min, p_max, gt,
long_edge,
draw)
act_image_dir = os.path.join(cfg.data_dir,
cfg.show_act_image_dir_name)
if cfg.draw_act_quad:
im.save(os.path.join(act_image_dir, img_name))
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
np.save(os.path.join(train_label_dir,
img_name[:-4] + '_gt.npy'), gt)

View File

@ -0,0 +1,85 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get logger.
"""
import logging
import os
import sys
from datetime import datetime
class LOGGER(logging.Logger):
"""
set up logging file.
Args:
logger_name (string): logger name.
log_dir (string): path of logger.
Returns:
string, logger path
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""set up log file"""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*' * 70 + '\n') * line_width
important_msg += ('*' * line_width + '\n') * 2
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
important_msg += ('*' * line_width + '\n') * 2
important_msg += ('*' * 70 + '\n') * line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
""" get logger"""
logger = LOGGER("mindversion", rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,273 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset processing.
"""
import mindspore
import mindspore.nn as nn
from mindspore import Tensor, ParameterTuple, Parameter
from mindspore.common.initializer import initializer
from mindspore.nn.optim import AdamWeightDecay
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
import numpy as np
from src.config import config as cfg
class AdvancedEast(nn.Cell):
"""
EAST network definition.
"""
def __init__(self):
super(AdvancedEast, self).__init__()
vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item()
def get_var(name, idx):
value = vgg_dict[name][idx]
if idx == 0:
value = np.transpose(value, [3, 2, 0, 1])
var = Tensor(value)
return var
def get_conv_var(name):
filters = get_var(name, 0)
biases = get_var(name, 1)
return filters, biases
class VGG_Conv(nn.Cell):
"""
VGG16 network definition.
"""
def __init__(self, name):
super(VGG_Conv, self).__init__()
filters, conv_biases = get_conv_var(name)
out_channels, in_channels, filter_size, _ = filters.shape
self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1)
self.bias_add = P.BiasAdd()
self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]),
name='weight')
self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias')
self.relu = P.ReLU()
self.gn = nn.GroupNorm(32, out_channels)
def construct(self, x):
output = self.conv2d(x, self.weight)
output = self.bias_add(output, self.bias)
output = self.gn(output)
output = self.relu(output)
return output
self.conv1_1 = VGG_Conv('conv1_1')
self.conv1_2 = VGG_Conv('conv1_2')
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2_1 = VGG_Conv('conv2_1')
self.conv2_2 = VGG_Conv('conv2_2')
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3_1 = VGG_Conv('conv3_1')
self.conv3_2 = VGG_Conv('conv3_2')
self.conv3_3 = VGG_Conv('conv3_3')
self.pool3 = nn.MaxPool2d(2, 2)
self.conv4_1 = VGG_Conv('conv4_1')
self.conv4_2 = VGG_Conv('conv4_2')
self.conv4_3 = VGG_Conv('conv4_3')
self.pool4 = nn.MaxPool2d(2, 2)
self.conv5_1 = VGG_Conv('conv5_1')
self.conv5_2 = VGG_Conv('conv5_2')
self.conv5_3 = VGG_Conv('conv5_3')
self.pool5 = nn.MaxPool2d(2, 2)
self.merging1 = self.merging(i=2)
self.merging2 = self.merging(i=3)
self.merging3 = self.merging(i=4)
self.last_bn = nn.GroupNorm(16, 32)
self.conv_last = nn.Conv2d(32, 32, kernel_size=3, stride=1, has_bias=True, weight_init='XavierUniform')
self.inside_score_conv = nn.Conv2d(32, 1, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
self.side_v_angle_conv = nn.Conv2d(32, 2, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
self.side_v_coord_conv = nn.Conv2d(32, 4, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
self.op_concat = P.Concat(axis=1)
self.relu = P.ReLU()
def merging(self, i=2):
"""
def merge layer
"""
in_size = {'2': 1024, '3': 384, '4': 192}
layers = [
nn.Conv2d(in_size[str(i)], 128 // 2 ** (i - 2), kernel_size=1, stride=1, has_bias=True,
weight_init='XavierUniform'),
nn.GroupNorm(16, 128 // 2 ** (i - 2)),
nn.ReLU(),
nn.Conv2d(128 // 2 ** (i - 2), 128 // 2 ** (i - 2), kernel_size=3, stride=1, has_bias=True,
weight_init='XavierUniform'),
nn.GroupNorm(16, 128 // 2 ** (i - 2)),
nn.ReLU()]
return nn.SequentialCell(layers)
def construct(self, x):
"""
forward func
"""
f4 = self.conv1_1(x)
f4 = self.conv1_2(f4)
f4 = self.pool1(f4)
f4 = self.conv2_1(f4)
f4 = self.conv2_2(f4)
f4 = self.pool2(f4)
f3 = self.conv3_1(f4)
f3 = self.conv3_2(f3)
f3 = self.conv3_3(f3)
f3 = self.pool3(f3)
f2 = self.conv4_1(f3)
f2 = self.conv4_2(f2)
f2 = self.conv4_3(f2)
f2 = self.pool4(f2)
f1 = self.conv5_1(f2)
f1 = self.conv5_2(f1)
f1 = self.conv5_3(f1)
f1 = self.pool5(f1)
h1 = f1
_, _, h_, w_ = P.Shape()(h1)
H1 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h1)
concat1 = self.op_concat((H1, f2))
h2 = self.merging1(concat1)
_, _, h_, w_ = P.Shape()(h2)
H2 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h2)
concat2 = self.op_concat((H2, f3))
h3 = self.merging2(concat2)
_, _, h_, w_ = P.Shape()(h3)
H3 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h3)
concat3 = self.op_concat((H3, f4))
h4 = self.merging3(concat3)
before_output = self.relu(self.last_bn(self.conv_last(h4)))
inside_score = self.inside_score_conv(before_output)
side_v_angle = self.side_v_angle_conv(before_output)
side_v_coord = self.side_v_coord_conv(before_output)
east_detect = self.op_concat((inside_score, side_v_coord, side_v_angle))
return east_detect
def dice_loss(gt_score, pred_score):
"""dice_loss1"""
inter = P.ReduceSum()(gt_score * pred_score)
union = P.ReduceSum()(gt_score) + P.ReduceSum()(pred_score) + 1e-5
return 1. - (2 * (inter / union))
def dice_loss2(gt_score, pred_score, mask):
"""dice_loss2"""
inter = P.ReduceSum()(gt_score * pred_score * mask)
union = P.ReduceSum()(gt_score * mask) + P.ReduceSum()(pred_score * mask) + 1e-5
return 1. - (2 * (inter / union))
def quad_loss(y_true, y_pred,
lambda_inside_score_loss=0.2,
lambda_side_vertex_code_loss=0.1,
lambda_side_vertex_coord_loss=1.0,
epsilon=1e-4):
"""quad loss"""
y_true = P.Transpose()(y_true, (0, 2, 3, 1))
y_pred = P.Transpose()(y_pred, (0, 2, 3, 1))
logits = y_pred[:, :, :, :1]
labels = y_true[:, :, :, :1]
predicts = P.Sigmoid()(logits)
inside_score_loss = dice_loss(labels, predicts)
inside_score_loss = inside_score_loss * lambda_inside_score_loss
# loss for side_vertex_code
vertex_logitsp = P.Sigmoid()(y_pred[:, :, :, 1:2])
vertex_labelsp = y_true[:, :, :, 1:2]
vertex_logitsn = P.Sigmoid()(y_pred[:, :, :, 2:3])
vertex_labelsn = y_true[:, :, :, 2:3]
labels2 = y_true[:, :, :, 1:2]
side_vertex_code_lossp = dice_loss2(vertex_labelsp, vertex_logitsp, labels)
side_vertex_code_lossn = dice_loss2(vertex_labelsn, vertex_logitsn, labels2)
side_vertex_code_loss = (side_vertex_code_lossp + side_vertex_code_lossn) * lambda_side_vertex_code_loss
# loss for side_vertex_coord delta
g_hat = y_pred[:, :, :, 3:] # N*W*H*8
g_true = y_true[:, :, :, 3:]
vertex_weights = P.Cast()(P.Equal()(y_true[:, :, :, 1], 1), mindspore.float32)
# vertex_weights=y_true[:, :, :,1]
pixel_wise_smooth_l1norm = smooth_l1_loss(g_hat, g_true, vertex_weights)
side_vertex_coord_loss = P.ReduceSum()(pixel_wise_smooth_l1norm) / (
P.ReduceSum()(vertex_weights) + epsilon)
side_vertex_coord_loss = side_vertex_coord_loss * lambda_side_vertex_coord_loss
# print(inside_score_loss, side_vertex_code_loss, side_vertex_coord_loss)
return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss
def smooth_l1_loss(prediction_tensor, target_tensor, weights):
"""smooth l1 loss"""
n_q = P.Reshape()(quad_norm(target_tensor), weights.shape)
diff = P.SmoothL1Loss()(prediction_tensor, target_tensor)
pixel_wise_smooth_l1norm = P.ReduceSum()(diff, -1) / n_q * weights
return pixel_wise_smooth_l1norm
def quad_norm(g_true, epsilon=1e-4):
""" quad norm"""
shape = g_true.shape
delta_xy_matrix = P.Reshape()(g_true, (shape[0] * shape[1] * shape[2], 2, 2))
diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :]
square = diff * diff
distance = P.Sqrt()(P.ReduceSum()(square, -1))
distance = distance * 4.0
distance = distance + epsilon
return P.Reshape()(distance, (shape[0], shape[1], shape[2]))
class EastWithLossCell(nn.Cell):
"""get loss cell"""
def __init__(self, network, config=None):
super(EastWithLossCell, self).__init__()
self.East_network = network
self.cat = P.Concat(axis=1)
def construct(self, image, label):
y_pred = self.East_network(image)
loss = quad_loss(label, y_pred)
return loss
class TrainStepWrap(nn.Cell):
"""train net warper"""
def __init__(self, network, steps, config=None):
super(TrainStepWrap, self).__init__()
self.network = network
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = AdamWeightDecay(self.weights, learning_rate=cfg.learning_rate
, eps=1e-7, weight_decay=cfg.decay)
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = 3.0
def construct(self, image, label):
weights = self.weights
loss = self.network(image, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(image, label, sens)
return F.depend(loss, self.optimizer(grads))
def get_AdvancedEast_net(configure=None, steps=1, mode=True):
"""
Get network of wide&deep model.
"""
AdvancedEast_net = AdvancedEast()
loss_net = EastWithLossCell(AdvancedEast_net, configure)
train_net = TrainStepWrap(loss_net, steps, configure)
return loss_net, train_net

View File

@ -0,0 +1,114 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
NMS module
"""
import numpy as np
from src.config import config as cfg
def should_merge(region, i, j):
"""
test merge
"""
neighbor = {(i, j - 1)}
return not region.isdisjoint(neighbor)
def region_neighbor(region_set):
"""
cal the neighbor of the region
"""
region_pixels = np.array(list(region_set))
j_min = np.amin(region_pixels, axis=0)[1] - 1
j_max = np.amax(region_pixels, axis=0)[1] + 1
i_m = np.amin(region_pixels, axis=0)[0] + 1
region_pixels[:, 0] += 1
neighbor = {(region_pixels[n, 0], region_pixels[n, 1]) for n in
range(len(region_pixels))}
neighbor.add((i_m, j_min))
neighbor.add((i_m, j_max))
return neighbor
def region_group(region_list):
"""
group regions
"""
S = [i for i in range(len(region_list))]
D = []
while S:
m = S.pop(0)
if not S:
# S has only one element, put it to D
D.append([m])
else:
D.append(rec_region_merge(region_list, m, S))
return D
def rec_region_merge(region_list, m, S):
"""
merge regions
"""
rows = [m]
tmp = []
for n in S:
if not region_neighbor(region_list[m]).isdisjoint(region_list[n]) or \
not region_neighbor(region_list[n]).isdisjoint(region_list[m]):
tmp.append(n)
for d in tmp:
S.remove(d)
for e in tmp:
rows.extend(rec_region_merge(region_list, e, S))
return rows
def nms(predict, activation_pixels, threshold=cfg.side_vertex_pixel_threshold):
"""
perform nms on results
"""
region_list = []
for i, j in zip(activation_pixels[0], activation_pixels[1]):
merge = False
for k, value in enumerate(region_list):
if should_merge(value, i, j):
region_list[k].add((i, j))
merge = True
if not merge:
region_list.append({(i, j)})
D = region_group(region_list)
quad_list = np.zeros((len(D), 4, 2))
score_list = np.zeros((len(D), 4))
for group, g_th in zip(D, range(len(D))):
total_score = np.zeros((4, 2))
for row in group:
for ij in region_list[row]:
score = predict[ij[0], ij[1], 1]
if score >= threshold:
ith_score = predict[ij[0], ij[1], 2:3]
if not (cfg.trunc_threshold <= ith_score < 1 -
cfg.trunc_threshold):
ith = int(np.around(ith_score))
total_score[ith * 2:(ith + 1) * 2] += score
px = (ij[1] + 0.5) * cfg.pixel_size
py = (ij[0] + 0.5) * cfg.pixel_size
p_v = [px, py] + np.reshape(predict[ij[0], ij[1], 3:7],
(2, 2))
quad_list[g_th, ith * 2:(ith + 1) * 2] += score * p_v
score_list[g_th] = total_score[:, 0]
quad_list[g_th] /= (total_score + cfg.epsilon)
return score_list, quad_list

View File

@ -0,0 +1,176 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################predict the qual line of images ########################
"""
import argparse
import numpy as np
from PIL import Image, ImageDraw
from mindspore import Tensor
from src.label import point_inside_of_quad
from src.config import config as cfg
from src.nms import nms
from src.preprocess import resize_image
def sigmoid(x):
"""`y = 1 / (1 + exp(-x))`"""
return 1 / (1 + np.exp(-x))
def cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array, img_path, s):
"""
cut text line
"""
geo /= [scale_ratio_w, scale_ratio_h]
p_min = np.amin(geo, axis=0)
p_max = np.amax(geo, axis=0)
min_xy = p_min.astype(int)
max_xy = p_max.astype(int) + 2
sub_im_arr = im_array[min_xy[1]:max_xy[1], min_xy[0]:max_xy[0], :].copy()
for m in range(min_xy[1], max_xy[1]):
for n in range(min_xy[0], max_xy[0]):
if not point_inside_of_quad(n, m, geo, p_min, p_max):
sub_im_arr[m - min_xy[1], n - min_xy[0], :] = 255
sub_im = Image.fromarray(sub_im_arr)
sub_im.save(img_path + '_subim%d.jpg' % s)
def predict(east_detect, img_path, pixel_threshold, quiet=True):
"""
predict to txt and image
"""
img = Image.open(img_path)
d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
d_wight = max(d_wight, d_height)
d_height = max(d_wight, d_height)
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
img = np.asarray(img)
img = img / 127.5 - 1
img = img.transpose((2, 0, 1))
x = Tensor(np.expand_dims(img, axis=0), "float32")
y = east_detect(x).asnumpy()
y = np.squeeze(y, axis=0)
if y.shape[0] == 7:
y = y.transpose((1, 2, 0)) # CHW->HWC
y[:, :, :3] = sigmoid(y[:, :, :3])
cond = np.greater_equal(y[:, :, 0], pixel_threshold)
activation_pixels = np.where(cond)
quad_scores, quad_after_nms = nms(y, activation_pixels)
with Image.open(img_path) as im:
im_array = np.asarray(im.convert('RGB'))
d_wight, d_height = resize_image(im, cfg.max_predict_img_size)
scale_ratio_w = d_wight / im.width
scale_ratio_h = d_height / im.height
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
quad_im = im.copy()
draw = ImageDraw.Draw(im)
for i, j in zip(activation_pixels[0], activation_pixels[1]):
px = (j + 0.5) * cfg.pixel_size
py = (i + 0.5) * cfg.pixel_size
line_width, line_color = 1, 'red'
if y[i, j, 1] >= cfg.side_vertex_pixel_threshold:
if y[i, j, 2] < cfg.trunc_threshold:
line_width, line_color = 2, 'yellow'
elif y[i, j, 2] >= 1 - cfg.trunc_threshold:
line_width, line_color = 2, 'green'
draw.line([(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
(px + 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
(px + 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
(px - 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size)],
width=line_width, fill=line_color)
im.save(img_path + '_act.jpg')
quad_draw = ImageDraw.Draw(quad_im)
txt_items = []
for score, geo, s in zip(quad_scores, quad_after_nms,
range(len(quad_scores))):
print(np.amin(score))
if np.amin(score) > 0:
quad_draw.line([tuple(geo[0]),
tuple(geo[1]),
tuple(geo[2]),
tuple(geo[3]),
tuple(geo[0])], width=2, fill='red')
if cfg.predict_cut_text_line:
cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array,
img_path, s)
rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
txt_item = ','.join(map(str, rescaled_geo_list))
txt_items.append(txt_item + '\n')
elif not quiet:
print('quad invalid with vertex num less then 4.')
quad_im.save(img_path + '_predict.jpg')
if cfg.predict_write2txt and txt_items:
with open(img_path[:-4] + '.txt', 'w') as f_txt:
f_txt.writelines(txt_items)
def predict_txt(east_detect, img_path, txt_path, pixel_threshold, quiet=False):
"""
predict to txt
"""
img = Image.open(img_path)
d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
scale_ratio_w = d_wight / img.width
scale_ratio_h = d_height / img.height
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
img = np.asarray(img)
img = img / 127.5 - 1
img = img.transpose((2, 0, 1))
x = np.expand_dims(img, axis=0)
y = east_detect(x).asnumpy()
y = np.squeeze(y, axis=0)
if y.shape[0] == 7:
y = y.transpose((1, 2, 0)) # CHW->HWC
y[:, :, :3] = sigmoid(y[:, :, :3])
cond = np.greater_equal(y[:, :, 0], pixel_threshold)
activation_pixels = np.where(cond)
quad_scores, quad_after_nms = nms(y, activation_pixels)
txt_items = []
for score, geo in zip(quad_scores, quad_after_nms):
if np.amin(score) > 0:
rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
txt_item = ','.join(map(str, rescaled_geo_list))
txt_items.append(txt_item + '\n')
elif not quiet:
print('quad invalid with vertex num less then 4.')
if cfg.predict_write2txt and txt_items:
with open(txt_path, 'w') as f_txt:
f_txt.writelines(txt_items)
def parse_args():
"""
parse_args
"""
parser = argparse.ArgumentParser()
parser.add_argument('--path', '-p',
default='demo/004.jpg',
help='image path')
parser.add_argument('--threshold', '-t',
default=cfg.pixel_threshold,
help='pixel activation threshold')
return parser.parse_args()

View File

@ -0,0 +1,311 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################preprocess images########################
"""
import os
import random
from tqdm import tqdm
import numpy as np
from PIL import Image, ImageDraw
from src.label import shrink
from src.config import config as cfg
def batch_reorder_vertexes(xy_list_array):
"""
batch xy to batch vertices
"""
reorder_xy_list_array = np.zeros_like(xy_list_array)
for xy_list, i in zip(xy_list_array, range(len(xy_list_array))):
reorder_xy_list_array[i] = reorder_vertexes(xy_list)
return reorder_xy_list_array
def reorder_vertexes(xy_list):
"""
xy to vertices
"""
reorder_xy_list = np.zeros_like(xy_list)
# determine the first point with the smallest x,
# if two has same x, choose that with smallest y,
ordered = np.argsort(xy_list, axis=0)
xmin1_index = ordered[0, 0]
xmin2_index = ordered[1, 0]
if xy_list[xmin1_index, 0] == xy_list[xmin2_index, 0]:
if xy_list[xmin1_index, 1] <= xy_list[xmin2_index, 1]:
reorder_xy_list[0] = xy_list[xmin1_index]
first_v = xmin1_index
else:
reorder_xy_list[0] = xy_list[xmin2_index]
first_v = xmin2_index
else:
reorder_xy_list[0] = xy_list[xmin1_index]
first_v = xmin1_index
# connect the first point to others, the third point on the other side of
# the line with the middle slope
others = list(range(4))
others.remove(first_v)
k = np.zeros((len(others),))
for index, i in zip(others, range(len(others))):
k[i] = (xy_list[index, 1] - xy_list[first_v, 1]) \
/ (xy_list[index, 0] - xy_list[first_v, 0] + cfg.epsilon)
k_mid = np.argsort(k)[1]
third_v = others[k_mid]
reorder_xy_list[2] = xy_list[third_v]
# determine the second point which on the bigger side of the middle line
others.remove(third_v)
b_mid = xy_list[first_v, 1] - k[k_mid] * xy_list[first_v, 0]
second_v, fourth_v = 0, 0
for index, i in zip(others, range(len(others))):
# delta = y - (k * x + b)
delta_y = xy_list[index, 1] - (k[k_mid] * xy_list[index, 0] + b_mid)
if delta_y > 0:
second_v = index
else:
fourth_v = index
reorder_xy_list[1] = xy_list[second_v]
reorder_xy_list[3] = xy_list[fourth_v]
# compare slope of 13 and 24, determine the final order
k13 = k[k_mid]
k24 = (xy_list[second_v, 1] - xy_list[fourth_v, 1]) / (
xy_list[second_v, 0] - xy_list[fourth_v, 0] + cfg.epsilon)
if k13 < k24:
tmp_x, tmp_y = reorder_xy_list[3, 0], reorder_xy_list[3, 1]
for i in range(2, -1, -1):
reorder_xy_list[i + 1] = reorder_xy_list[i]
reorder_xy_list[0, 0], reorder_xy_list[0, 1] = tmp_x, tmp_y
return reorder_xy_list
def resize_image(im, max_img_size=cfg.max_train_img_size):
"""
resize image
"""
im_width = np.minimum(im.width, max_img_size)
if im_width == max_img_size < im.width:
im_height = int((im_width / im.width) * im.height)
else:
im_height = im.height
o_height = np.minimum(im_height, max_img_size)
if o_height == max_img_size < im_height:
o_width = int((o_height / im_height) * im_width)
else:
o_width = im_width
d_wight = o_width - (o_width % 32)
d_height = o_height - (o_height % 32)
return d_wight, d_height
def preprocess():
"""
preprocess data
"""
data_dir = cfg.data_dir
origin_image_dir = os.path.join(data_dir, cfg.origin_image_dir_name)
origin_txt_dir = os.path.join(data_dir, cfg.origin_txt_dir_name)
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
if not os.path.exists(train_image_dir):
os.mkdir(train_image_dir)
if not os.path.exists(train_label_dir):
os.mkdir(train_label_dir)
draw_gt_quad = cfg.draw_gt_quad
show_gt_image_dir = os.path.join(data_dir, cfg.show_gt_image_dir_name)
if not os.path.exists(show_gt_image_dir):
os.mkdir(show_gt_image_dir)
show_act_image_dir = os.path.join(cfg.data_dir, cfg.show_act_image_dir_name)
if not os.path.exists(show_act_image_dir):
os.mkdir(show_act_image_dir)
o_img_list = os.listdir(origin_image_dir)
print('found %d origin images.' % len(o_img_list))
train_val_set = []
for o_img_fname, _ in zip(o_img_list, tqdm(range(len(o_img_list)))):
with Image.open(os.path.join(origin_image_dir, o_img_fname)) as im:
if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \
os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)):
continue
# d_wight, d_height = resize_image(im)
d_wight, d_height = cfg.max_train_img_size, cfg.max_train_img_size
scale_ratio_w = d_wight / im.width
scale_ratio_h = d_height / im.height
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
show_gt_im = im.copy()
# draw on the img
draw = ImageDraw.Draw(show_gt_im)
with open(os.path.join(origin_txt_dir,
o_img_fname[:-4] + '.txt'), 'r', encoding='utf-8') as f:
anno_list = f.readlines()
xy_list_array = np.zeros((len(anno_list), 4, 2))
for anno, i in zip(anno_list, range(len(anno_list))):
anno_colums = anno.strip().split(',')
anno_array = np.array(anno_colums)
xy_list = np.reshape(anno_array[:8].astype(float), (4, 2))
xy_list[:, 0] = xy_list[:, 0] * scale_ratio_w
xy_list[:, 1] = xy_list[:, 1] * scale_ratio_h
xy_list = reorder_vertexes(xy_list)
xy_list_array[i] = xy_list
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
if draw_gt_quad:
draw.line([tuple(xy_list[0]), tuple(xy_list[1]),
tuple(xy_list[2]), tuple(xy_list[3]),
tuple(xy_list[0])
],
width=2, fill='green')
draw.line([tuple(shrink_xy_list[0]),
tuple(shrink_xy_list[1]),
tuple(shrink_xy_list[2]),
tuple(shrink_xy_list[3]),
tuple(shrink_xy_list[0])
],
width=2, fill='blue')
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
for q_th in range(2):
draw.line([tuple(xy_list[vs[long_edge][q_th][0]]),
tuple(shrink_1[vs[long_edge][q_th][1]]),
tuple(shrink_1[vs[long_edge][q_th][2]]),
tuple(xy_list[vs[long_edge][q_th][3]]),
tuple(xy_list[vs[long_edge][q_th][4]])],
width=3, fill='yellow')
if cfg.gen_origin_img:
im.save(os.path.join(train_image_dir, o_img_fname))
np.save(os.path.join(
train_label_dir,
o_img_fname[:-4] + '.npy'),
xy_list_array)
if draw_gt_quad:
show_gt_im.save(os.path.join(show_gt_image_dir, o_img_fname))
train_val_set.append('{},{},{}\n'.format(o_img_fname,
d_wight,
d_height))
train_img_list = os.listdir(train_image_dir)
print('found %d train images.' % len(train_img_list))
train_label_list = os.listdir(train_label_dir)
print('found %d train labels.' % len(train_label_list))
random.shuffle(train_val_set)
val_count = int(cfg.validation_split_ratio * len(train_val_set))
with open(os.path.join(data_dir, cfg.val_fname), 'a') as f_val:
f_val.writelines(train_val_set[:val_count])
with open(os.path.join(data_dir, cfg.train_fname), 'a') as f_train:
f_train.writelines(train_val_set[val_count:])
def preprocess_size(width=256):
"""
preprocess data at specific size
"""
if width not in cfg.train_img_size:
raise NotImplementedError
data_dir = cfg.data_dir
origin_image_dir = os.path.join(data_dir, cfg.origin_image_dir_name)
origin_txt_dir = os.path.join(data_dir, cfg.origin_txt_dir_name)
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name_var + str(width))
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
if not os.path.exists(train_image_dir):
os.mkdir(train_image_dir)
if not os.path.exists(train_label_dir):
os.mkdir(train_label_dir)
draw_gt_quad = cfg.draw_gt_quad
show_gt_image_dir = os.path.join(data_dir, cfg.show_gt_image_dir_name)
if not os.path.exists(show_gt_image_dir):
os.mkdir(show_gt_image_dir)
show_act_image_dir = os.path.join(cfg.data_dir, cfg.show_act_image_dir_name)
if not os.path.exists(show_act_image_dir):
os.mkdir(show_act_image_dir)
o_img_list = os.listdir(origin_image_dir)
print('found %d origin images.' % len(o_img_list))
train_val_set = []
for o_img_fname, _ in zip(o_img_list, tqdm(range(len(o_img_list)))):
with Image.open(os.path.join(origin_image_dir, o_img_fname)) as im:
if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \
os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)):
continue
# d_wight, d_height = resize_image(im)
d_wight, d_height = width, width
scale_ratio_w = d_wight / im.width
scale_ratio_h = d_height / im.height
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
show_gt_im = im.copy()
# draw on the img
draw = ImageDraw.Draw(show_gt_im)
with open(os.path.join(origin_txt_dir,
o_img_fname[:-4] + '.txt'), 'r', encoding='utf-8') as f:
anno_list = f.readlines()
xy_list_array = np.zeros((len(anno_list), 4, 2))
for anno, i in zip(anno_list, range(len(anno_list))):
anno_colums = anno.strip().split(',')
anno_array = np.array(anno_colums)
xy_list = np.reshape(anno_array[:8].astype(float), (4, 2))
xy_list[:, 0] = xy_list[:, 0] * scale_ratio_w
xy_list[:, 1] = xy_list[:, 1] * scale_ratio_h
xy_list = reorder_vertexes(xy_list)
xy_list_array[i] = xy_list
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
if draw_gt_quad:
draw.line([tuple(xy_list[0]), tuple(xy_list[1]),
tuple(xy_list[2]), tuple(xy_list[3]),
tuple(xy_list[0])
],
width=2, fill='green')
draw.line([tuple(shrink_xy_list[0]),
tuple(shrink_xy_list[1]),
tuple(shrink_xy_list[2]),
tuple(shrink_xy_list[3]),
tuple(shrink_xy_list[0])
],
width=2, fill='blue')
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
for q_th in range(2):
draw.line([tuple(xy_list[vs[long_edge][q_th][0]]),
tuple(shrink_1[vs[long_edge][q_th][1]]),
tuple(shrink_1[vs[long_edge][q_th][2]]),
tuple(xy_list[vs[long_edge][q_th][3]]),
tuple(xy_list[vs[long_edge][q_th][4]])],
width=3, fill='yellow')
if cfg.gen_origin_img:
im.save(os.path.join(train_image_dir, o_img_fname))
np.save(os.path.join(
train_label_dir,
o_img_fname[:-4] + '.npy'),
xy_list_array)
if draw_gt_quad:
show_gt_im.save(os.path.join(show_gt_image_dir, o_img_fname))
train_val_set.append('{},{},{}\n'.format(o_img_fname,
d_wight,
d_height))
train_img_list = os.listdir(train_image_dir)
print('found %d train images.' % len(train_img_list))
train_label_list = os.listdir(train_label_dir)
print('found %d train labels.' % len(train_label_list))
# random.shuffle(train_val_set)
val_count = int(cfg.validation_split_ratio * len(train_val_set))
with open(os.path.join(data_dir, cfg.val_fname_var + str(width) + '.txt'), 'a') as f_val:
f_val.writelines(train_val_set[:val_count])
with open(os.path.join(data_dir, cfg.train_fname_var + str(width) + '.txt'), 'a') as f_train:
f_train.writelines(train_val_set[val_count:])

View File

@ -0,0 +1,151 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################calculate precision ########################
"""
import numpy as np
from shapely.geometry import Polygon
from src.config import config as cfg
from src.nms import nms
# from src.preprocess import resize_image
class Averager():
"""Compute average for torch.Tensor, used for loss average."""
def __init__(self):
self.reset()
def add(self, v):
count = v.data.numel()
v = v.data.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
class eval_pre_rec_f1():
'''eval batch'''
def __init__(self):
self.pixel_threshold = float(cfg.pixel_threshold)
self.reset()
def reset(self):
self.img_num = 0
self.pre = 0
self.rec = 0
self.f1_score = 0
def val(self):
mpre = self.pre / self.img_num * 100
mrec = self.rec / self.img_num * 100
mf1_score = self.f1_score / self.img_num * 100
return mpre, mrec, mf1_score
def sigmoid(self, x):
"""`y = 1 / (1 + exp(-x))`"""
return 1 / (1 + np.exp(-x))
def get_iou(self, g, p):
"""`get_iou`"""
g = Polygon(g)
p = Polygon(p)
if not g.is_valid or not p.is_valid:
return 0
inter = Polygon(g).intersection(Polygon(p)).area
union = g.area + p.area - inter
if union == 0:
return 0
return inter / union
def eval_one(self, quad_scores, quad_after_nms, gt_xy, quiet=cfg.quiet):
"""`eval_one`"""
num_gts = len(gt_xy)
quad_scores_no_zero = []
quad_after_nms_no_zero = []
for score, geo in zip(quad_scores, quad_after_nms):
if np.amin(score) > 0:
quad_scores_no_zero.append(sum(score))
quad_after_nms_no_zero.append(geo)
elif not quiet:
print('quad invalid with vertex num less then 4.')
continue
num_quads = len(quad_after_nms_no_zero)
if num_quads == 0:
return 0, 0, 0
quad_flag = np.zeros(num_quads)
gt_flag = np.zeros(num_gts)
# print(num_quads, '-------', num_gts)
quad_scores_no_zero = np.array(quad_scores_no_zero)
scores_idx = np.argsort(quad_scores_no_zero)[::-1]
for i in range(num_quads):
idx = scores_idx[i]
# score = quad_scores_no_zero[idx]
geo = quad_after_nms_no_zero[idx]
for j in range(num_gts):
if gt_flag[j] == 0:
gt_geo = gt_xy[j]
iou = self.get_iou(geo, gt_geo)
if iou >= cfg.iou_threshold:
gt_flag[j] = 1
quad_flag[i] = 1
tp = np.sum(quad_flag)
fp = num_quads - tp
fn = num_gts - tp
pre = tp / (tp + fp)
rec = tp / (tp + fn)
if pre + rec == 0:
f1_score = 0
else:
f1_score = 2 * pre * rec / (pre + rec)
# print(pre, '---', rec, '---', f1_score)
return pre, rec, f1_score
def add(self, out, gt_xy_list):
"""`add`"""
self.img_num += len(gt_xy_list)
ys = out.asnumpy() # (N, 7, 64, 64)
print(ys.shape)
if ys.shape[1] == 7:
ys = ys.transpose((0, 2, 3, 1)) # NCHW->NHWC
for y, gt_xy in zip(ys, gt_xy_list):
y[:, :, :3] = self.sigmoid(y[:, :, :3])
cond = np.greater_equal(y[:, :, 0], self.pixel_threshold)
activation_pixels = np.where(cond)
quad_scores, quad_after_nms = nms(y, activation_pixels)
print(quad_scores)
lens = len(quad_after_nms)
if lens == 0 or (sum(sum(quad_scores)) == 0):
if not cfg.quiet:
print('NMS')
continue
else:
pre, rec, f1_score = self.eval_one(quad_scores, quad_after_nms, gt_xy)
print(pre, rec, f1_score)
self.pre += pre
self.rec += rec
self.f1_score += f1_score

View File

@ -0,0 +1,188 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train advanced_east on dataset########################
"""
import argparse
import datetime
import os
from mindspore import context, Model
from mindspore.communication.management import init, get_group_size
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.common import set_seed
from mindspore.nn.optim import AdamWeightDecay
from src.logger import get_logger
from src.dataset import load_adEAST_dataset
from src.model import get_AdvancedEast_net
from src.config import config as cfg
set_seed(1)
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore adveast training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
# network related
parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load')
# logging and checkpoint related
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=1, help='ckpt_interval')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
# distributed related
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args_opt = parser.parse_args()
args_opt.initial_epoch = cfg.initial_epoch
args_opt.epoch_num = cfg.epoch_num
args_opt.learning_rate = cfg.learning_rate
args_opt.decay = cfg.decay
args_opt.batch_size = cfg.batch_size
args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio)
args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio
args_opt.ckpt_save_max = cfg.ckpt_save_max
args_opt.data_dir = cfg.data_dir
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
args_opt.train_image_dir_name = cfg.train_image_dir_name
args_opt.train_label_dir_name = cfg.train_label_dir_name
args_opt.results_dir = cfg.results_dir
args_opt.last_model_name = cfg.last_model_name
args_opt.saved_model_file_path = cfg.saved_model_file_path
return args_opt
if __name__ == '__main__':
args = parse_args()
device_num = int(os.environ.get("DEVICE_NUM", 1))
context.set_context(mode=context.GRAPH_MODE)
workers = 32
if args.is_distributed:
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id, device_target=args.device_target)
init()
elif args.device_target == "GPU":
context.set_context(device_target=args.device_target)
init()
args.group_size = get_group_size()
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
context.set_context(device_id=args.device_id)
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# get network and init
loss_net, train_net = get_AdvancedEast_net()
loss_net.add_flags_recursive(fp32=True)
train_net.set_train(False)
# pre_trained
if args.pre_trained:
load_param_into_net(train_net, load_checkpoint(os.path.join(args.saved_model_file_path, args.last_model_name)))
# define callbacks
mindrecordfile256 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(256) + '.mindrecord')
train_dataset256, batch_num256 = load_adEAST_dataset(mindrecordfile256, batch_size=8,
device_num=device_num, rank_id=args.rank, is_training=True,
num_parallel_workers=workers)
mindrecordfile384 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(384) + '.mindrecord')
train_dataset384, batch_num384 = load_adEAST_dataset(mindrecordfile384, batch_size=4,
device_num=device_num, rank_id=args.rank, is_training=True,
num_parallel_workers=workers)
mindrecordfile448 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(448) + '.mindrecord')
train_dataset448, batch_num448 = load_adEAST_dataset(mindrecordfile448, batch_size=2,
device_num=device_num, rank_id=args.rank, is_training=True,
num_parallel_workers=workers)
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
, eps=1e-7, weight_decay=cfg.decay)
model = Model(train_net)
time_cb = TimeMonitor(data_size=batch_num256)
loss_cb = LossMonitor(per_print_times=batch_num256)
callbacks = [time_cb, loss_cb]
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num256,
keep_checkpoint_max=args.ckpt_save_max)
save_ckpt_path = args.saved_model_file_path
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='Epoch_A{}'.format(args.rank))
callbacks.append(ckpt_cb)
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset256, callbacks=callbacks, dataset_sink_mode=False)
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
, eps=1e-7, weight_decay=cfg.decay)
time_cb = TimeMonitor(data_size=batch_num384)
loss_cb = LossMonitor(per_print_times=batch_num384)
callbacks = [time_cb, loss_cb]
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num384,
keep_checkpoint_max=args.ckpt_save_max)
save_ckpt_path = args.saved_model_file_path
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='Epoch_B{}'.format(args.rank))
callbacks.append(ckpt_cb)
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
, eps=1e-7, weight_decay=cfg.decay)
model = Model(train_net)
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset384, callbacks=callbacks, dataset_sink_mode=False)
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
, eps=1e-7, weight_decay=cfg.decay)
time_cb = TimeMonitor(data_size=batch_num448)
loss_cb = LossMonitor(per_print_times=batch_num448)
callbacks = [time_cb, loss_cb]
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num448,
keep_checkpoint_max=args.ckpt_save_max)
save_ckpt_path = args.saved_model_file_path
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='Epoch_C{}'.format(args.rank))
callbacks.append(ckpt_cb)
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
, eps=1e-7, weight_decay=cfg.decay)
model = Model(train_net)
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset448, callbacks=callbacks, dataset_sink_mode=False)

View File

@ -0,0 +1,148 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train advanced_east on dataset########################
"""
import argparse
import datetime
import os
from mindspore import context, Model
from mindspore.common import set_seed
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.context import ParallelMode
from mindspore.nn.optim import AdamWeightDecay
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from src.logger import get_logger
from src.config import config as cfg
from src.dataset import load_adEAST_dataset
from src.model import get_AdvancedEast_net
set_seed(1)
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore adveast training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
# network related
parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load')
parser.add_argument('--data_path', default='/disk1/ade/icpr/advanced-east-val_256.mindrecord', type=str)
parser.add_argument('--pre_trained_ckpt', default='/disk1/adeast/scripts/1.ckpt', type=str)
# logging and checkpoint related
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=1, help='ckpt_interval')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
# distributed related
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args_opt = parser.parse_args()
args_opt.initial_epoch = cfg.initial_epoch
args_opt.epoch_num = cfg.epoch_num
args_opt.learning_rate = cfg.learning_rate
args_opt.decay = cfg.decay
args_opt.batch_size = cfg.batch_size
args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio)
args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio
args_opt.ckpt_save_max = cfg.ckpt_save_max
args_opt.data_dir = cfg.data_dir
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
args_opt.train_image_dir_name = cfg.train_image_dir_name
args_opt.train_label_dir_name = cfg.train_label_dir_name
args_opt.results_dir = cfg.results_dir
args_opt.last_model_name = cfg.last_model_name
args_opt.saved_model_file_path = cfg.saved_model_file_path
return args_opt
if __name__ == '__main__':
args = parse_args()
context.set_context(device_target=args.device_target, device_id=args.device_id, mode=context.GRAPH_MODE)
workers = 32
device_num = 1
if args.is_distributed:
init()
if args.device_target == "Ascend":
device_num = int(os.environ.get("RANK_SIZE"))
rank = int(os.environ.get("RANK_ID"))
args.rank = args.device_id
elif args.device_target == "GPU":
context.set_context(device_target=args.device_target)
device_num = get_group_size()
args.rank = get_rank()
args.group_size = device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, gradients_mean=True, parallel_mode=ParallelMode.DATA_PARALLEL)
else:
context.set_context(device_id=args.device_id)
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
# dataset
mindrecordfile = args.data_path
train_dataset, batch_num = load_adEAST_dataset(mindrecordfile, batch_size=args.batch_size,
device_num=device_num, rank_id=args.rank, is_training=True,
num_parallel_workers=workers)
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# get network and init
loss_net, train_net = get_AdvancedEast_net()
loss_net.add_flags_recursive(fp32=True)
train_net.set_train(True)
# pre_trained
if args.pre_trained:
load_param_into_net(train_net, load_checkpoint(args.pre_trained_ckpt))
# define callbacks
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
, eps=1e-7, weight_decay=cfg.decay)
model = Model(train_net)
time_cb = TimeMonitor(data_size=batch_num)
loss_cb = LossMonitor(per_print_times=batch_num)
callbacks = [time_cb, loss_cb]
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num,
keep_checkpoint_max=args.ckpt_save_max)
save_ckpt_path = args.saved_model_file_path
if args.rank_save_ckpt_flag:
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='Epoch_{}'.format(args.rank))
callbacks.append(ckpt_cb)
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset, callbacks=callbacks, dataset_sink_mode=True)