forked from mindspore-Ecosystem/mindspore
!14008 Add AdvancedEast model.
From: @windaway Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5820159fcc
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/space_to_depth_base.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/space_to_depth_base.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/space_to_batch_fp32.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/space_to_batch_fp32.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/gather_parameter.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/gather_parameter.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gatherNd_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gatherNd_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gather_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/gather_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/leaky_relu_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/leaky_relu_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/quant_dtype_cast_int8.c
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/quant_dtype_cast_int8.c
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/sigmoid_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/sigmoid_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/squeeze_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/squeeze_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/unsqueeze_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/unsqueeze_int8.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/space_to_depth_parameter.h
Normal file → Executable file
0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/space_to_depth_parameter.h
Normal file → Executable file
Binary file not shown.
After Width: | Height: | Size: 119 KiB |
|
@ -0,0 +1,194 @@
|
||||||
|
# AdvancedEAST
|
||||||
|
|
||||||
|
## introduction
|
||||||
|
|
||||||
|
AdvancedEAST is inspired by EAST [EAST:An Efficient and Accurate Scene Text Detector](https://arxiv.org/abs/1704.03155v2).
|
||||||
|
The architecture of AdvancedEAST is showed below![AdvancedEast network arch](AdvancedEast.network.png).
|
||||||
|
This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie/AdvancedEAST)(preprocess, network architecture, predict) and [BaoWentz/AdvancedEAST-PyTorch](https://github.com/BaoWentz/AdvancedEAST-PyTorch)(performance).
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
|
||||||
|
* euleros v2r7 x86_64
|
||||||
|
* python 3.7.5
|
||||||
|
|
||||||
|
## Dependences
|
||||||
|
|
||||||
|
* mindspore==1.1
|
||||||
|
* shapely==1.7.1
|
||||||
|
* numpy==1.19.4
|
||||||
|
* tqdm==4.36.1
|
||||||
|
|
||||||
|
## Project files
|
||||||
|
|
||||||
|
* configuration of file
|
||||||
|
cfg.py, control parameters
|
||||||
|
* pre-process data:
|
||||||
|
preprocess.py, resize image
|
||||||
|
* label data:
|
||||||
|
label.py,produce label info
|
||||||
|
* define network
|
||||||
|
model.py and VGG.py
|
||||||
|
* define loss function
|
||||||
|
losses.py
|
||||||
|
* execute training
|
||||||
|
advanced_east.py and dataset.py
|
||||||
|
* predict
|
||||||
|
predict.py and nms.py
|
||||||
|
* scoring
|
||||||
|
score.py
|
||||||
|
* logging
|
||||||
|
logger.py
|
||||||
|
|
||||||
|
```shell
|
||||||
|
.
|
||||||
|
└──advanced_east
|
||||||
|
├── README.md
|
||||||
|
├── scripts
|
||||||
|
├── run_distribute_train_gpu.sh # launch ascend distributed training(8 pcs)
|
||||||
|
├── run_standalone_train_ascend.sh # launch ascend standalone training(1 pcs)
|
||||||
|
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
|
||||||
|
└── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
|
||||||
|
├── src
|
||||||
|
├── cfg.py # parameter configuration
|
||||||
|
├── dataset.py # data preprocessing
|
||||||
|
├── label.py # produce label info
|
||||||
|
├── logger.py # generate learning rate for each step
|
||||||
|
├── model.py # define network
|
||||||
|
├── nms.py # non-maximum suppression
|
||||||
|
├── predict.py # predict boxes
|
||||||
|
├── preprocess.py # pre-process data
|
||||||
|
└── score.py # scoring
|
||||||
|
├── export.py # export model for inference
|
||||||
|
├── prepare_data.py # exec data preprocessing
|
||||||
|
├── eval.py # eval net
|
||||||
|
├── train.py # train net
|
||||||
|
└── train_mindrecord.py # train net on user specified mindrecord
|
||||||
|
```
|
||||||
|
|
||||||
|
## dataset
|
||||||
|
|
||||||
|
ICPR MTWI 2018 challenge 2:Text detection of network image,[Link](https://tianchi.aliyun.com/competition/entrance/231651/introduction). It is not available to download dataset on the origin webpage,
|
||||||
|
the dataset is now provided by the author of the original project,[Baiduyun link](https://pan.baidu.com/s/1NSyc-cHKV3IwDo6qojIrKA), password: ye9y. There are 10000 images and corresponding label
|
||||||
|
information in total in the dataset, which is divided into 2 directories with 9000 and 1000 samples respectively. In the origin training setting, training set and validation set are partitioned at the ratio
|
||||||
|
of 9:1. If you want to use your own dataset, please modify the configuration of dataset in /src/config.py. The organization of dataset file is listed as below:
|
||||||
|
|
||||||
|
> ```bash
|
||||||
|
> .
|
||||||
|
> └─data_dir
|
||||||
|
> ├─images # dataset
|
||||||
|
> └─txt # vertex of text boxes
|
||||||
|
> ```
|
||||||
|
Some parameters in config.py:
|
||||||
|
|
||||||
|
```python
|
||||||
|
'validation_split_ratio': 0.1, # ratio of validation dataset
|
||||||
|
'total_img': 10000, # total number of samples in dataset
|
||||||
|
'data_dir': './icpr/', # dir of dataset
|
||||||
|
'train_fname': 'train.txt', # the file which stores the images file name in training dataset
|
||||||
|
'val_fname': 'val.txt', # the file which stores the images file name in validation dataset
|
||||||
|
'mindsrecord_train_file': 'advanced-east.mindrecord', # mindsrecord of training dataset
|
||||||
|
'mindsrecord_test_file': 'advanced-east-val.mindrecord', # mindsrecord of validation dataset
|
||||||
|
'origin_image_dir_name': 'images_9000/', # dir which stores the original images.
|
||||||
|
'train_image_dir_name': 'images_train/', # dir which stores the preprocessed images.
|
||||||
|
'origin_txt_dir_name': 'txt_9000/', # dir which stores the original text verteices.
|
||||||
|
'train_label_dir_name': 'labels_train/', # dir which stores the preprocessed text verteices.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run the project
|
||||||
|
|
||||||
|
### Data preprocess
|
||||||
|
|
||||||
|
Resize all the images to fixed size, and convert the label information(the vertex of text box) into the format used in training and evaluation, then the Mindsrecord files are generated.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python preparedata.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
Prepare the VGG16 pre-training model. Due to copyright restrictions, please go to https://github.com/machrisaa/tensorflow-vgg to download the VGG16 pre-training model and place it in the src folder.
|
||||||
|
|
||||||
|
single 1p)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py --device_target="Ascend" --is_distributed=0 --device_id=0 > output.train.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
single 1p)specific size
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train_mindrecord.py --device_target="Ascend" --is_distributed=0 --device_id=2 --size=256 > output.train.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
multi Ascends
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# running on distributed environment(8p)
|
||||||
|
bash scripts/run_distribute_train.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The detailed training parameters are in /src/config.py。
|
||||||
|
|
||||||
|
multi GPUs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# running on distributed environment(8p)
|
||||||
|
bash scripts/run_distribute_train_gpu.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The detailed training parameters are in /src/config.py。
|
||||||
|
|
||||||
|
config.py:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
'initial_epoch': 0, # epoch to init
|
||||||
|
'learning_rate': 1e-4, # learning rate when initialization
|
||||||
|
'decay': 5e-4, # weightdecay parameter
|
||||||
|
'epsilon': 1e-4, # the value of epsilon in loss computation
|
||||||
|
'batch_size': 8, # batch size
|
||||||
|
'lambda_inside_score_loss': 4.0, # coef of inside_score_loss
|
||||||
|
'lambda_side_vertex_code_loss': 1.0, # coef of vertex_code_loss
|
||||||
|
"lambda_side_vertex_coord_loss": 1.0, # coef of vertex_coord_loss
|
||||||
|
'max_train_img_size': 448, # max size of training images
|
||||||
|
'max_predict_img_size': 448, # max size of the images to predict
|
||||||
|
'ckpt_save_max': 10, # maximum of ckpt in dir
|
||||||
|
'saved_model_file_path': './saved_model/', # dir of saved model
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
### Evaluate
|
||||||
|
|
||||||
|
The above python command will run in the background, you can view the results through the file output.eval.log. You will get the accuracy as following:
|
||||||
|
|
||||||
|
## performance
|
||||||
|
|
||||||
|
### Training performance
|
||||||
|
|
||||||
|
The performance listed below are acquired with the default configurations in /src/config.py
|
||||||
|
| Parameters | single Ascend |
|
||||||
|
| -------------------------- | ---------------------------------------------- |
|
||||||
|
| Model Version | AdvancedEAST |
|
||||||
|
| Resources | Ascend 910 |
|
||||||
|
| MindSpore Version | 1.1 |
|
||||||
|
| Dataset | MTWI-2018 |
|
||||||
|
| Training Parameters | epoch=18, batch_size = 8, lr=1e-3 |
|
||||||
|
| Optimizer | AdamWeightDecay |
|
||||||
|
| Loss Function | QuadLoss |
|
||||||
|
| Outputs | matrix with size of 3x64x64,3x96x96,3x112x112 |
|
||||||
|
| Loss | 0.1 |
|
||||||
|
| Total Time | 28mins, 60mins, 90mins |
|
||||||
|
| Checkpoints | 173MB(.ckpt file) |
|
||||||
|
|
||||||
|
### Evaluation Performance
|
||||||
|
|
||||||
|
On the default
|
||||||
|
| Parameters | single Ascend |
|
||||||
|
| ------------------- | --------------------------- |
|
||||||
|
| Model Version | AdvancedEAST |
|
||||||
|
| Resources | Ascend 910 |
|
||||||
|
| MindSpore Version | 1.1 |
|
||||||
|
| Dataset | 1000 images |
|
||||||
|
| batch_size | 8 |
|
||||||
|
| Outputs | precision, recall, F score |
|
||||||
|
| performance | 94.35, 55.45, 66.31 |
|
|
@ -0,0 +1,180 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
#################eval advanced_east on dataset########################
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||||
|
from src.logger import get_logger
|
||||||
|
from src.predict import predict
|
||||||
|
from src.score import eval_pre_rec_f1
|
||||||
|
|
||||||
|
from src.config import config as cfg
|
||||||
|
from src.dataset import load_adEAST_dataset
|
||||||
|
from src.model import get_AdvancedEast_net, AdvancedEast
|
||||||
|
from src.preprocess import resize_image
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(cloud_args=None):
|
||||||
|
"""parameters"""
|
||||||
|
parser = argparse.ArgumentParser('adveast evaling')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||||
|
|
||||||
|
# logging and checkpoint related
|
||||||
|
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
|
||||||
|
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||||
|
parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
|
||||||
|
|
||||||
|
parser.add_argument('--ckpt', type=str, default='Epoch_C0-6_4500.ckpt', help='ckpt to load')
|
||||||
|
parser.add_argument('--method', type=str, default='score', choices=['loss', 'score', 'pred'], help='evaluation')
|
||||||
|
parser.add_argument('--path', type=str, help='image path of prediction')
|
||||||
|
|
||||||
|
# distributed related
|
||||||
|
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||||
|
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||||
|
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||||
|
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
args_opt.data_dir = cfg.data_dir
|
||||||
|
args_opt.batch_size = 8
|
||||||
|
args_opt.train_image_dir_name = cfg.data_dir + cfg.train_image_dir_name
|
||||||
|
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
|
||||||
|
args_opt.train_label_dir_name = cfg.data_dir + cfg.train_label_dir_name
|
||||||
|
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
|
||||||
|
args_opt.results_dir = cfg.results_dir
|
||||||
|
args_opt.val_fname = cfg.val_fname
|
||||||
|
args_opt.pixel_threshold = cfg.pixel_threshold
|
||||||
|
args_opt.max_predict_img_size = cfg.max_predict_img_size
|
||||||
|
args_opt.last_model_name = cfg.last_model_name
|
||||||
|
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||||
|
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
def eval_loss(arg):
|
||||||
|
"""get network and init"""
|
||||||
|
loss_net, train_net = get_AdvancedEast_net()
|
||||||
|
print(os.path.join(arg.saved_model_file_path, arg.ckpt))
|
||||||
|
load_param_into_net(train_net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||||
|
train_net.set_train(False)
|
||||||
|
loss = 0
|
||||||
|
idx = 0
|
||||||
|
for item in dataset.create_tuple_iterator():
|
||||||
|
loss += loss_net(item[0], item[1])
|
||||||
|
idx += 1
|
||||||
|
print(loss / idx)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_score(arg):
|
||||||
|
"""get network and init"""
|
||||||
|
net = AdvancedEast()
|
||||||
|
load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||||
|
net.set_train(False)
|
||||||
|
obj = eval_pre_rec_f1()
|
||||||
|
with open(os.path.join(arg.data_dir, arg.val_fname), 'r') as f_val:
|
||||||
|
f_list = f_val.readlines()
|
||||||
|
|
||||||
|
img_h, img_w = arg.max_predict_img_size, arg.max_predict_img_size
|
||||||
|
x = np.zeros((arg.batch_size, 3, img_h, img_w), dtype=np.float32)
|
||||||
|
batch_list = np.arange(0, len(f_list), arg.batch_size)
|
||||||
|
for idx in batch_list:
|
||||||
|
gt_list = []
|
||||||
|
for i in range(idx, min(idx + arg.batch_size, len(f_list))):
|
||||||
|
item = f_list[i]
|
||||||
|
img_filename = str(item).strip().split(',')[0][:-4]
|
||||||
|
img_path = os.path.join(arg.train_image_dir_name, img_filename) + '.jpg'
|
||||||
|
img = Image.open(img_path)
|
||||||
|
d_wight, d_height = resize_image(img, arg.max_predict_img_size)
|
||||||
|
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||||
|
img = np.asarray(img)
|
||||||
|
img = img / 1.
|
||||||
|
mean = np.array((123.68, 116.779, 103.939)).reshape([1, 1, 3])
|
||||||
|
img = ((img - mean)).astype(np.float32)
|
||||||
|
img = img.transpose((2, 0, 1))
|
||||||
|
x[i - idx] = img
|
||||||
|
# predict(east, img_path, threshold, cfg.pixel_threshold)
|
||||||
|
gt_list.append(np.load(os.path.join(arg.train_label_dir_name, img_filename) + '.npy'))
|
||||||
|
if idx + arg.batch_size >= len(f_list):
|
||||||
|
x = x[:len(f_list) - idx]
|
||||||
|
y = net(Tensor(x))
|
||||||
|
obj.add(y, gt_list)
|
||||||
|
|
||||||
|
print(obj.val())
|
||||||
|
|
||||||
|
|
||||||
|
def pred(arg):
|
||||||
|
"""pred"""
|
||||||
|
img_path = arg.path
|
||||||
|
net = AdvancedEast()
|
||||||
|
load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||||
|
predict(net, img_path, arg.pixel_threshold)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||||
|
if args.is_distributed:
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
init()
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
elif args.device_target == "GPU":
|
||||||
|
init()
|
||||||
|
|
||||||
|
args.rank = get_rank()
|
||||||
|
args.group_size = get_group_size()
|
||||||
|
device_num = args.group_size
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
else:
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
|
||||||
|
# logger
|
||||||
|
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||||
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
|
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||||
|
print(os.path.join(args.data_dir, args.mindsrecord_test_file))
|
||||||
|
dataset, batch_num = load_adEAST_dataset(os.path.join(args.data_dir,
|
||||||
|
args.mindsrecord_test_file),
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
device_num=device_num, rank_id=args.rank, is_training=False,
|
||||||
|
num_parallel_workers=device_num)
|
||||||
|
|
||||||
|
args.logger.save_args(args)
|
||||||
|
|
||||||
|
# network
|
||||||
|
args.logger.important_info('start create network')
|
||||||
|
|
||||||
|
method_dict = {'loss': eval_loss, 'score': eval_score, 'pred': pred}
|
||||||
|
|
||||||
|
method_dict[args.method](args)
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
##############export checkpoint file into air and onnx models#################
|
||||||
|
python export.py
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from src.model import AdvancedEast
|
||||||
|
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='adveast export')
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||||
|
parser.add_argument("--ckpt_file", type=str, default="./saved_model/Fin0_9-10_2250.ckpt", help="Checkpoint file path.")
|
||||||
|
parser.add_argument("--file_name", type=str, default="./out.om", help="output file name.")
|
||||||
|
parser.add_argument('--width', type=int, default=256, help='input width')
|
||||||
|
parser.add_argument('--height', type=int, default=256, help='input height')
|
||||||
|
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||||
|
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
|
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
net = AdvancedEast()
|
||||||
|
|
||||||
|
assert args.ckpt_file is not None, "checkpoint_path is None."
|
||||||
|
|
||||||
|
param_dict = load_checkpoint(args.ckpt_file)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
|
||||||
|
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
#################train advanced_east on dataset########################
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from src.label import process_label, process_label_size
|
||||||
|
|
||||||
|
from src.config import config as cfg
|
||||||
|
from src.dataset import transImage2Mind, transImage2Mind_size
|
||||||
|
from src.preprocess import preprocess, preprocess_size
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(cloud_args=None):
|
||||||
|
"""parameters"""
|
||||||
|
parser = argparse.ArgumentParser('mindspore data prepare')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
parser.add_argument('--device_id', type=int, default=3, help='device id of GPU or Ascend. (Default: None)')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
# create data
|
||||||
|
preprocess()
|
||||||
|
process_label()
|
||||||
|
preprocess_size(256)
|
||||||
|
process_label_size(256)
|
||||||
|
preprocess_size(384)
|
||||||
|
process_label_size(384)
|
||||||
|
preprocess_size(448)
|
||||||
|
process_label_size(448)
|
||||||
|
# preprocess_size(512)
|
||||||
|
# process_label_size(512)
|
||||||
|
# preprocess_size(640)
|
||||||
|
# process_label_size(640)
|
||||||
|
# preprocess_size(736)
|
||||||
|
# process_label_size(736)
|
||||||
|
# preprocess_size(768)
|
||||||
|
# process_label_size(768)
|
||||||
|
mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file)
|
||||||
|
transImage2Mind(mindrecord_filename)
|
||||||
|
mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_test_file)
|
||||||
|
transImage2Mind(mindrecord_filename, True)
|
||||||
|
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
|
||||||
|
transImage2Mind_size(mindrecord_filename, 256)
|
||||||
|
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
|
||||||
|
transImage2Mind_size(mindrecord_filename, 256, True)
|
||||||
|
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
|
||||||
|
transImage2Mind_size(mindrecord_filename, 384)
|
||||||
|
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
|
||||||
|
transImage2Mind_size(mindrecord_filename, 384, True)
|
||||||
|
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
|
||||||
|
transImage2Mind_size(mindrecord_filename, 448)
|
||||||
|
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
|
||||||
|
transImage2Mind_size(mindrecord_filename, 448, True)
|
|
@ -0,0 +1,37 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
PATH1=$1
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=2
|
||||||
|
export RANK_SIZE=2
|
||||||
|
export RANK_TABLE_FILE=$PATH1
|
||||||
|
|
||||||
|
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
rm -rf ./train_parallel$i
|
||||||
|
mkdir ./train_parallel$i
|
||||||
|
cp ../*.py ./train_parallel$i
|
||||||
|
cp *.sh ./train_parallel$i
|
||||||
|
cp -r ../src ./train_parallel$i
|
||||||
|
cd ./train_parallel$i || exit
|
||||||
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
|
env >env.log
|
||||||
|
python train_mindrecord.py --device_target Ascend --is_distributed 1 --device_id $i --data_path /disk1/adenew/icpr/advanced-east_448.mindrecord > log.txt 2>&1 &
|
||||||
|
cd ..
|
||||||
|
done
|
|
@ -0,0 +1,24 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run_standalone_train_ascend.sh"
|
||||||
|
echo "for example: bash run_standalone_train_ascend.sh"
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
python train_mindrecord.py \
|
||||||
|
--device_target="Ascend" > output.train.log 2>&1 &
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
configs
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict
|
||||||
|
|
||||||
|
config = EasyDict({
|
||||||
|
'initial_epoch': 2,
|
||||||
|
'epoch_num': 6,
|
||||||
|
'learning_rate': 1e-4,
|
||||||
|
'decay': 3e-5,
|
||||||
|
'epsilon': 1e-4,
|
||||||
|
'batch_size': 2,
|
||||||
|
'ckpt_interval': 2,
|
||||||
|
'lambda_inside_score_loss': 4.0,
|
||||||
|
'lambda_side_vertex_code_loss': 1.0,
|
||||||
|
"lambda_side_vertex_coord_loss": 1.0,
|
||||||
|
'max_train_img_size': 448,
|
||||||
|
'max_predict_img_size': 448,
|
||||||
|
'train_img_size': [256, 384, 448, 512, 640, 736, 768],
|
||||||
|
'predict_img_size': [256, 384, 448, 512, 640, 736, 768],
|
||||||
|
'ckpt_save_max': 500,
|
||||||
|
'validation_split_ratio': 0.1,
|
||||||
|
'total_img': 10000,
|
||||||
|
'data_dir': './icpr/',
|
||||||
|
'val_data_dir': './icpr/images/',
|
||||||
|
'train_fname': 'train.txt',
|
||||||
|
'train_fname_var': 'train_',
|
||||||
|
'val_fname': 'val.txt',
|
||||||
|
'val_fname_var': 'val_',
|
||||||
|
'mindsrecord_train_file': 'advanced-east.mindrecord',
|
||||||
|
'mindsrecord_test_file': 'advanced-east-val.mindrecord',
|
||||||
|
'results_dir': './results/',
|
||||||
|
'origin_image_dir_name': 'images/',
|
||||||
|
'train_image_dir_name': 'images_train/',
|
||||||
|
'origin_txt_dir_name': 'txt/',
|
||||||
|
'train_label_dir_name': 'labels_train/',
|
||||||
|
'train_image_dir_name_var': 'images_train_',
|
||||||
|
'mindsrecord_train_file_var': 'advanced-east_',
|
||||||
|
'train_label_dir_name_var': 'labels_train_',
|
||||||
|
'mindsrecord_val_file_var': 'advanced-east-val_',
|
||||||
|
'show_gt_image_dir_name': 'show_gt_images/',
|
||||||
|
'show_act_image_dir_name': 'show_act_images/',
|
||||||
|
'saved_model_file_path': './saved_model/',
|
||||||
|
'last_model_name': '_.ckpt',
|
||||||
|
'pixel_threshold': 0.8,
|
||||||
|
'iou_threshold': 0.1,
|
||||||
|
'feature_layers_range': range(5, 1, -1),
|
||||||
|
'feature_layers_num': len(range(5, 1, -1)),
|
||||||
|
'pixel_size': 2 ** range(5, 1, -1)[-1],
|
||||||
|
'quiet': True,
|
||||||
|
'side_vertex_pixel_threshold': 0.8,
|
||||||
|
'trunc_threshold': 0.1,
|
||||||
|
'predict_cut_text_line': False,
|
||||||
|
'predict_write2txt': True,
|
||||||
|
'shrink_ratio': 0.2,
|
||||||
|
'shrink_side_ratio': 0.6,
|
||||||
|
'gen_origin_img': True,
|
||||||
|
'draw_gt_quad': False,
|
||||||
|
'draw_act_quad': False,
|
||||||
|
'vgg_npy': '/disk1/ade/vgg16.npy',
|
||||||
|
})
|
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
dataset.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import mindspore.dataset as de
|
||||||
|
import mindspore.dataset.vision.c_transforms as vision
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
|
||||||
|
import src.config as cfg
|
||||||
|
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
cfg = cfg.config
|
||||||
|
|
||||||
|
|
||||||
|
def gen(batch_size=cfg.batch_size, is_val=False):
|
||||||
|
"""generate label"""
|
||||||
|
img_h, img_w = cfg.max_train_img_size, cfg.max_train_img_size
|
||||||
|
x = np.zeros((batch_size, 3, img_h, img_w), dtype=np.float32)
|
||||||
|
pixel_num_h = img_h // 4
|
||||||
|
pixel_num_w = img_w // 4
|
||||||
|
y = np.zeros((batch_size, 7, pixel_num_h, pixel_num_w), dtype=np.float32)
|
||||||
|
if is_val:
|
||||||
|
with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:
|
||||||
|
f_list = f_val.readlines()
|
||||||
|
else:
|
||||||
|
with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
|
||||||
|
f_list = f_train.readlines()
|
||||||
|
while True:
|
||||||
|
batch_list = np.arange(0, len(f_list), batch_size)
|
||||||
|
for idx in batch_list:
|
||||||
|
for idx2 in range(idx, idx + batch_size):
|
||||||
|
# random gen an image name
|
||||||
|
if idx2 < len(f_list):
|
||||||
|
img = f_list[idx2]
|
||||||
|
else:
|
||||||
|
img = np.random.choice(f_list)
|
||||||
|
img_filename = str(img).strip().split(',')[0]
|
||||||
|
# load img and img anno
|
||||||
|
img_path = os.path.join(cfg.data_dir,
|
||||||
|
cfg.train_image_dir_name,
|
||||||
|
img_filename)
|
||||||
|
img = Image.open(img_path)
|
||||||
|
img = np.asarray(img)
|
||||||
|
img = img / 127.5 - 1
|
||||||
|
x[idx2 - idx] = img.transpose((2, 0, 1))
|
||||||
|
gt_file = os.path.join(cfg.data_dir,
|
||||||
|
cfg.train_label_dir_name,
|
||||||
|
img_filename[:-4] + '_gt.npy')
|
||||||
|
y[idx2 - idx] = np.load(gt_file).transpose((2, 0, 1))
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
|
||||||
|
def transImage2Mind(mindrecord_filename, is_val=False):
|
||||||
|
"""transfer the image to mindrecord"""
|
||||||
|
if os.path.exists(mindrecord_filename):
|
||||||
|
os.remove(mindrecord_filename)
|
||||||
|
os.remove(mindrecord_filename + ".db")
|
||||||
|
|
||||||
|
writer = FileWriter(file_name=mindrecord_filename, shard_num=1)
|
||||||
|
cv_schema = {"image": {"type": "bytes"}, "label": {"type": "float32", "shape": [7, 112, 112]}}
|
||||||
|
writer.add_schema(cv_schema, "advancedEast dataset")
|
||||||
|
|
||||||
|
if is_val:
|
||||||
|
with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:
|
||||||
|
f_list = f_val.readlines()
|
||||||
|
else:
|
||||||
|
with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
|
||||||
|
f_list = f_train.readlines()
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for item in f_list:
|
||||||
|
img_filename = str(item).strip().split(',')[0]
|
||||||
|
img_path = os.path.join(cfg.data_dir,
|
||||||
|
cfg.train_image_dir_name,
|
||||||
|
img_filename)
|
||||||
|
print(img_path)
|
||||||
|
with open(img_path, 'rb') as f:
|
||||||
|
img = f.read()
|
||||||
|
|
||||||
|
gt_file = os.path.join(cfg.data_dir,
|
||||||
|
cfg.train_label_dir_name,
|
||||||
|
img_filename[:-4] + '_gt.npy')
|
||||||
|
labels = np.load(gt_file)
|
||||||
|
labels = np.transpose(labels, (2, 0, 1))
|
||||||
|
data.append({"image": img,
|
||||||
|
"label": np.array(labels, np.float32)})
|
||||||
|
if len(data) == 32:
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
print('len(data):{}'.format(len(data)))
|
||||||
|
data = []
|
||||||
|
if data:
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def transImage2Mind_size(mindrecord_filename, width=256, is_val=False):
|
||||||
|
"""transfer the image to mindrecord at specified size"""
|
||||||
|
mindrecord_filename = mindrecord_filename + str(width) + '.mindrecord'
|
||||||
|
if os.path.exists(mindrecord_filename):
|
||||||
|
os.remove(mindrecord_filename)
|
||||||
|
os.remove(mindrecord_filename + ".db")
|
||||||
|
|
||||||
|
writer = FileWriter(file_name=mindrecord_filename, shard_num=1)
|
||||||
|
cv_schema = {"image": {"type": "bytes"}, "label": {"type": "float32", "shape": [7, width // 4, width // 4]}}
|
||||||
|
writer.add_schema(cv_schema, "advancedEast dataset")
|
||||||
|
|
||||||
|
if is_val:
|
||||||
|
with open(os.path.join(cfg.data_dir, cfg.val_fname_var + str(width) + '.txt'), 'r') as f_val:
|
||||||
|
f_list = f_val.readlines()
|
||||||
|
else:
|
||||||
|
with open(os.path.join(cfg.data_dir, cfg.train_fname_var + str(width) + '.txt'), 'r') as f_train:
|
||||||
|
f_list = f_train.readlines()
|
||||||
|
data = []
|
||||||
|
for item in f_list:
|
||||||
|
|
||||||
|
img_filename = str(item).strip().split(',')[0]
|
||||||
|
img_path = os.path.join(cfg.data_dir,
|
||||||
|
cfg.train_image_dir_name_var + str(width),
|
||||||
|
img_filename)
|
||||||
|
with open(img_path, 'rb') as f:
|
||||||
|
img = f.read()
|
||||||
|
gt_file = os.path.join(cfg.data_dir,
|
||||||
|
cfg.train_label_dir_name_var + str(width),
|
||||||
|
img_filename[:-4] + '_gt.npy')
|
||||||
|
labels = np.load(gt_file)
|
||||||
|
labels = np.transpose(labels, (2, 0, 1))
|
||||||
|
data.append({"image": img,
|
||||||
|
"label": np.array(labels, np.float32)})
|
||||||
|
|
||||||
|
if len(data) == 32:
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
print('len(data):{}'.format(len(data)))
|
||||||
|
data = []
|
||||||
|
|
||||||
|
if data:
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def load_adEAST_dataset(mindrecord_file, batch_size=64, device_num=1, rank_id=0,
|
||||||
|
is_training=True, num_parallel_workers=8):
|
||||||
|
"""load mindrecord"""
|
||||||
|
ds = de.MindDataset(mindrecord_file, columns_list=["image", "label"], num_shards=device_num, shard_id=rank_id,
|
||||||
|
num_parallel_workers=8, shuffle=is_training)
|
||||||
|
hwc_to_chw = vision.HWC2CHW()
|
||||||
|
cd = vision.Decode()
|
||||||
|
ds = ds.map(operations=cd, input_columns=["image"])
|
||||||
|
rc = vision.RandomColorAdjust(brightness=0.1, contrast=0.2, saturation=0.2)
|
||||||
|
vn = vision.Normalize(mean=(123.68, 116.779, 103.939), std=(1., 1., 1.))
|
||||||
|
ds = ds.map(operations=[rc, vn, hwc_to_chw], input_columns=["image"], num_parallel_workers=num_parallel_workers)
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
batch_num = ds.get_dataset_size()
|
||||||
|
ds = ds.shuffle(batch_num)
|
||||||
|
return ds, batch_num
|
|
@ -0,0 +1,242 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
labeling
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from src.config import config as cfg
|
||||||
|
|
||||||
|
|
||||||
|
def point_inside_of_quad(px, py, quad_xy_list, p_min, p_max):
|
||||||
|
"""test box in or not"""
|
||||||
|
if (p_min[0] <= px <= p_max[0]) and (p_min[1] <= py <= p_max[1]):
|
||||||
|
xy_list = np.zeros((4, 2))
|
||||||
|
xy_list[:3, :] = quad_xy_list[1:4, :] - quad_xy_list[:3, :]
|
||||||
|
xy_list[3] = quad_xy_list[0, :] - quad_xy_list[3, :]
|
||||||
|
yx_list = np.zeros((4, 2))
|
||||||
|
yx_list[:, :] = quad_xy_list[:, -1:-3:-1]
|
||||||
|
a = xy_list * ([py, px] - yx_list)
|
||||||
|
b = a[:, 0] - a[:, 1]
|
||||||
|
if np.amin(b) >= 0 or np.amax(b) <= 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def point_inside_of_nth_quad(px, py, xy_list, shrink_1, long_edge):
|
||||||
|
"""test point in which box"""
|
||||||
|
nth = -1
|
||||||
|
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
|
||||||
|
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
|
||||||
|
for ith in range(2):
|
||||||
|
quad_xy_list = np.concatenate((
|
||||||
|
np.reshape(xy_list[vs[long_edge][ith][0]], (1, 2)),
|
||||||
|
np.reshape(shrink_1[vs[long_edge][ith][1]], (1, 2)),
|
||||||
|
np.reshape(shrink_1[vs[long_edge][ith][2]], (1, 2)),
|
||||||
|
np.reshape(xy_list[vs[long_edge][ith][3]], (1, 2))), axis=0)
|
||||||
|
p_min = np.amin(quad_xy_list, axis=0)
|
||||||
|
p_max = np.amax(quad_xy_list, axis=0)
|
||||||
|
if point_inside_of_quad(px, py, quad_xy_list, p_min, p_max):
|
||||||
|
if nth == -1:
|
||||||
|
nth = ith
|
||||||
|
else:
|
||||||
|
nth = -1
|
||||||
|
break
|
||||||
|
return nth
|
||||||
|
|
||||||
|
|
||||||
|
def shrink(xy_list, ratio=cfg.shrink_ratio):
|
||||||
|
"""shrink"""
|
||||||
|
if ratio == 0.0:
|
||||||
|
return xy_list, xy_list
|
||||||
|
diff_1to3 = xy_list[:3, :] - xy_list[1:4, :]
|
||||||
|
diff_4 = xy_list[3:4, :] - xy_list[0:1, :]
|
||||||
|
diff = np.concatenate((diff_1to3, diff_4), axis=0)
|
||||||
|
dis = np.sqrt(np.sum(np.square(diff), axis=-1))
|
||||||
|
# determine which are long or short edges
|
||||||
|
long_edge = int(np.argmax(np.sum(np.reshape(dis, (2, 2)), axis=0)))
|
||||||
|
short_edge = 1 - long_edge
|
||||||
|
# cal r length array
|
||||||
|
r = [np.minimum(dis[i], dis[(i + 1) % 4]) for i in range(4)]
|
||||||
|
# cal theta array
|
||||||
|
diff_abs = np.abs(diff)
|
||||||
|
diff_abs[:, 0] += cfg.epsilon
|
||||||
|
theta = np.arctan(diff_abs[:, 1] / diff_abs[:, 0])
|
||||||
|
# shrink two long edges
|
||||||
|
temp_new_xy_list = np.copy(xy_list)
|
||||||
|
shrink_edge(xy_list, temp_new_xy_list, long_edge, r, theta, ratio)
|
||||||
|
shrink_edge(xy_list, temp_new_xy_list, long_edge + 2, r, theta, ratio)
|
||||||
|
# shrink two short edges
|
||||||
|
new_xy_list = np.copy(temp_new_xy_list)
|
||||||
|
shrink_edge(temp_new_xy_list, new_xy_list, short_edge, r, theta, ratio)
|
||||||
|
shrink_edge(temp_new_xy_list, new_xy_list, short_edge + 2, r, theta, ratio)
|
||||||
|
return temp_new_xy_list, new_xy_list, long_edge # 缩短后的长边,缩短后的短边,长边下标
|
||||||
|
|
||||||
|
|
||||||
|
def shrink_edge(xy_list, new_xy_list, edge, r, theta, ratio=cfg.shrink_ratio):
|
||||||
|
"""shrink edge"""
|
||||||
|
if ratio == 0.0:
|
||||||
|
return
|
||||||
|
start_point = edge # 边的起始点下标(0或1)
|
||||||
|
end_point = (edge + 1) % 4
|
||||||
|
long_start_sign_x = np.sign(
|
||||||
|
xy_list[end_point, 0] - xy_list[start_point, 0])
|
||||||
|
new_xy_list[start_point, 0] = \
|
||||||
|
xy_list[start_point, 0] + \
|
||||||
|
long_start_sign_x * ratio * r[start_point] * np.cos(theta[start_point])
|
||||||
|
long_start_sign_y = np.sign(
|
||||||
|
xy_list[end_point, 1] - xy_list[start_point, 1])
|
||||||
|
new_xy_list[start_point, 1] = \
|
||||||
|
xy_list[start_point, 1] + \
|
||||||
|
long_start_sign_y * ratio * r[start_point] * np.sin(theta[start_point])
|
||||||
|
# long edge one, end point
|
||||||
|
long_end_sign_x = -1 * long_start_sign_x
|
||||||
|
new_xy_list[end_point, 0] = \
|
||||||
|
xy_list[end_point, 0] + \
|
||||||
|
long_end_sign_x * ratio * r[end_point] * np.cos(theta[start_point])
|
||||||
|
long_end_sign_y = -1 * long_start_sign_y
|
||||||
|
new_xy_list[end_point, 1] = \
|
||||||
|
xy_list[end_point, 1] + \
|
||||||
|
long_end_sign_y * ratio * r[end_point] * np.sin(theta[start_point])
|
||||||
|
|
||||||
|
|
||||||
|
def precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax,
|
||||||
|
jmin, jmax, p_min, p_max, gt, long_edge, draw):
|
||||||
|
"""precess list"""
|
||||||
|
for i in range(imin, imax):
|
||||||
|
for j in range(jmin, jmax):
|
||||||
|
px = (j + 0.5) * cfg.pixel_size
|
||||||
|
py = (i + 0.5) * cfg.pixel_size
|
||||||
|
if point_inside_of_quad(px, py,
|
||||||
|
shrink_xy_list, p_min, p_max):
|
||||||
|
gt[i, j, 0] = 1
|
||||||
|
line_width, line_color = 1, 'red'
|
||||||
|
ith = point_inside_of_nth_quad(px, py,
|
||||||
|
xy_list,
|
||||||
|
shrink_1,
|
||||||
|
long_edge)
|
||||||
|
vs = [[[3, 0], [1, 2]], [[0, 1], [2, 3]]]
|
||||||
|
if ith in range(2):
|
||||||
|
gt[i, j, 1] = 1
|
||||||
|
if ith == 0:
|
||||||
|
line_width, line_color = 2, 'yellow'
|
||||||
|
else:
|
||||||
|
line_width, line_color = 2, 'green'
|
||||||
|
gt[i, j, 2:3] = ith
|
||||||
|
gt[i, j, 3:5] = \
|
||||||
|
xy_list[vs[long_edge][ith][0]] - [px, py]
|
||||||
|
gt[i, j, 5:] = \
|
||||||
|
xy_list[vs[long_edge][ith][1]] - [px, py]
|
||||||
|
draw.line([(px - 0.5 * cfg.pixel_size,
|
||||||
|
py - 0.5 * cfg.pixel_size),
|
||||||
|
(px + 0.5 * cfg.pixel_size,
|
||||||
|
py - 0.5 * cfg.pixel_size),
|
||||||
|
(px + 0.5 * cfg.pixel_size,
|
||||||
|
py + 0.5 * cfg.pixel_size),
|
||||||
|
(px - 0.5 * cfg.pixel_size,
|
||||||
|
py + 0.5 * cfg.pixel_size),
|
||||||
|
(px - 0.5 * cfg.pixel_size,
|
||||||
|
py - 0.5 * cfg.pixel_size)],
|
||||||
|
width=line_width, fill=line_color)
|
||||||
|
return gt
|
||||||
|
|
||||||
|
|
||||||
|
def process_label(data_dir=cfg.data_dir):
|
||||||
|
"""process label"""
|
||||||
|
with open(os.path.join(data_dir, cfg.val_fname), 'r') as f_val:
|
||||||
|
f_list = f_val.readlines()
|
||||||
|
with open(os.path.join(data_dir, cfg.train_fname), 'r') as f_train:
|
||||||
|
f_list.extend(f_train.readlines())
|
||||||
|
for line, _ in zip(f_list, tqdm(range(len(f_list)))):
|
||||||
|
line_cols = str(line).strip().split(',')
|
||||||
|
img_name, width, height = \
|
||||||
|
line_cols[0].strip(), int(line_cols[1].strip()), \
|
||||||
|
int(line_cols[2].strip())
|
||||||
|
gt = np.zeros((height // cfg.pixel_size, width // cfg.pixel_size, 7))
|
||||||
|
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
|
||||||
|
xy_list_array = np.load(os.path.join(train_label_dir,
|
||||||
|
img_name[:-4] + '.npy'))
|
||||||
|
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
|
||||||
|
with Image.open(os.path.join(train_image_dir, img_name)) as im:
|
||||||
|
draw = ImageDraw.Draw(im)
|
||||||
|
for xy_list in xy_list_array:
|
||||||
|
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
|
||||||
|
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
|
||||||
|
p_min = np.amin(shrink_xy_list, axis=0)
|
||||||
|
p_max = np.amax(shrink_xy_list, axis=0)
|
||||||
|
# floor of the float
|
||||||
|
ji_min = (p_min / cfg.pixel_size - 0.5).astype(int) - 1
|
||||||
|
# +1 for ceil of the float and +1 for include the end
|
||||||
|
ji_max = (p_max / cfg.pixel_size - 0.5).astype(int) + 3
|
||||||
|
imin = np.maximum(0, ji_min[1])
|
||||||
|
imax = np.minimum(height // cfg.pixel_size, ji_max[1])
|
||||||
|
jmin = np.maximum(0, ji_min[0])
|
||||||
|
jmax = np.minimum(width // cfg.pixel_size, ji_max[0])
|
||||||
|
gt = precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax, jmin, jmax,
|
||||||
|
p_min, p_max, gt, long_edge, draw)
|
||||||
|
act_image_dir = os.path.join(cfg.data_dir,
|
||||||
|
cfg.show_act_image_dir_name)
|
||||||
|
if cfg.draw_act_quad:
|
||||||
|
im.save(os.path.join(act_image_dir, img_name))
|
||||||
|
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
|
||||||
|
np.save(os.path.join(train_label_dir,
|
||||||
|
img_name[:-4] + '_gt.npy'), gt)
|
||||||
|
|
||||||
|
|
||||||
|
def process_label_size(width=256, data_dir=cfg.data_dir):
|
||||||
|
"""process label at specific size"""
|
||||||
|
with open(os.path.join(data_dir, cfg.val_fname_var + str(width) + '.txt'), 'r') as f_val:
|
||||||
|
f_list = f_val.readlines()
|
||||||
|
with open(os.path.join(data_dir, cfg.train_fname_var + str(width) + '.txt'), 'r') as f_train:
|
||||||
|
f_list.extend(f_train.readlines())
|
||||||
|
for line, _ in zip(f_list, tqdm(range(len(f_list)))):
|
||||||
|
line_cols = str(line).strip().split(',')
|
||||||
|
img_name, width, height = \
|
||||||
|
line_cols[0].strip(), int(line_cols[1].strip()), \
|
||||||
|
int(line_cols[2].strip())
|
||||||
|
gt = np.zeros((height // cfg.pixel_size, width // cfg.pixel_size, 7))
|
||||||
|
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
|
||||||
|
xy_list_array = np.load(os.path.join(train_label_dir,
|
||||||
|
img_name[:-4] + '.npy'))
|
||||||
|
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
|
||||||
|
with Image.open(os.path.join(train_image_dir, img_name)) as im:
|
||||||
|
draw = ImageDraw.Draw(im)
|
||||||
|
for xy_list in xy_list_array:
|
||||||
|
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
|
||||||
|
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
|
||||||
|
p_min = np.amin(shrink_xy_list, axis=0)
|
||||||
|
p_max = np.amax(shrink_xy_list, axis=0)
|
||||||
|
# floor of the float
|
||||||
|
ji_min = (p_min / cfg.pixel_size - 0.5).astype(int) - 1
|
||||||
|
# +1 for ceil of the float and +1 for include the end
|
||||||
|
ji_max = (p_max / cfg.pixel_size - 0.5).astype(int) + 3
|
||||||
|
imin = np.maximum(0, ji_min[1])
|
||||||
|
imax = np.minimum(height // cfg.pixel_size, ji_max[1])
|
||||||
|
jmin = np.maximum(0, ji_min[0])
|
||||||
|
jmax = np.minimum(width // cfg.pixel_size, ji_max[0])
|
||||||
|
gt = precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax, jmin, jmax, p_min, p_max, gt,
|
||||||
|
long_edge,
|
||||||
|
draw)
|
||||||
|
act_image_dir = os.path.join(cfg.data_dir,
|
||||||
|
cfg.show_act_image_dir_name)
|
||||||
|
if cfg.draw_act_quad:
|
||||||
|
im.save(os.path.join(act_image_dir, img_name))
|
||||||
|
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
|
||||||
|
np.save(os.path.join(train_label_dir,
|
||||||
|
img_name[:-4] + '_gt.npy'), gt)
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
get logger.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class LOGGER(logging.Logger):
|
||||||
|
"""
|
||||||
|
set up logging file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger_name (string): logger name.
|
||||||
|
log_dir (string): path of logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
string, logger path
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger_name, rank=0):
|
||||||
|
super(LOGGER, self).__init__(logger_name)
|
||||||
|
if rank % 8 == 0:
|
||||||
|
console = logging.StreamHandler(sys.stdout)
|
||||||
|
console.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||||
|
console.setFormatter(formatter)
|
||||||
|
self.addHandler(console)
|
||||||
|
|
||||||
|
def setup_logging_file(self, log_dir, rank=0):
|
||||||
|
"""set up log file"""
|
||||||
|
self.rank = rank
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||||
|
self.log_fn = os.path.join(log_dir, log_name)
|
||||||
|
fh = logging.FileHandler(self.log_fn)
|
||||||
|
fh.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
self.addHandler(fh)
|
||||||
|
|
||||||
|
def info(self, msg, *args, **kwargs):
|
||||||
|
if self.isEnabledFor(logging.INFO):
|
||||||
|
self._log(logging.INFO, msg, args, **kwargs)
|
||||||
|
|
||||||
|
def save_args(self, args):
|
||||||
|
self.info('Args:')
|
||||||
|
args_dict = vars(args)
|
||||||
|
for key in args_dict.keys():
|
||||||
|
self.info('--> %s: %s', key, args_dict[key])
|
||||||
|
self.info('')
|
||||||
|
|
||||||
|
def important_info(self, msg, *args, **kwargs):
|
||||||
|
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||||
|
line_width = 2
|
||||||
|
important_msg = '\n'
|
||||||
|
important_msg += ('*' * 70 + '\n') * line_width
|
||||||
|
important_msg += ('*' * line_width + '\n') * 2
|
||||||
|
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
|
||||||
|
important_msg += ('*' * line_width + '\n') * 2
|
||||||
|
important_msg += ('*' * 70 + '\n') * line_width
|
||||||
|
self.info(important_msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(path, rank):
|
||||||
|
""" get logger"""
|
||||||
|
logger = LOGGER("mindversion", rank)
|
||||||
|
logger.setup_logging_file(path, rank)
|
||||||
|
return logger
|
|
@ -0,0 +1,273 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
dataset processing.
|
||||||
|
"""
|
||||||
|
import mindspore
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, ParameterTuple, Parameter
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.nn.optim import AdamWeightDecay
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.config import config as cfg
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedEast(nn.Cell):
|
||||||
|
"""
|
||||||
|
EAST network definition.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(AdvancedEast, self).__init__()
|
||||||
|
vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item()
|
||||||
|
|
||||||
|
def get_var(name, idx):
|
||||||
|
value = vgg_dict[name][idx]
|
||||||
|
if idx == 0:
|
||||||
|
value = np.transpose(value, [3, 2, 0, 1])
|
||||||
|
var = Tensor(value)
|
||||||
|
return var
|
||||||
|
|
||||||
|
def get_conv_var(name):
|
||||||
|
filters = get_var(name, 0)
|
||||||
|
biases = get_var(name, 1)
|
||||||
|
return filters, biases
|
||||||
|
|
||||||
|
class VGG_Conv(nn.Cell):
|
||||||
|
"""
|
||||||
|
VGG16 network definition.
|
||||||
|
"""
|
||||||
|
def __init__(self, name):
|
||||||
|
super(VGG_Conv, self).__init__()
|
||||||
|
filters, conv_biases = get_conv_var(name)
|
||||||
|
out_channels, in_channels, filter_size, _ = filters.shape
|
||||||
|
self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1)
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]),
|
||||||
|
name='weight')
|
||||||
|
self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias')
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
self.gn = nn.GroupNorm(32, out_channels)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
output = self.conv2d(x, self.weight)
|
||||||
|
output = self.bias_add(output, self.bias)
|
||||||
|
output = self.gn(output)
|
||||||
|
output = self.relu(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
self.conv1_1 = VGG_Conv('conv1_1')
|
||||||
|
self.conv1_2 = VGG_Conv('conv1_2')
|
||||||
|
self.pool1 = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv2_1 = VGG_Conv('conv2_1')
|
||||||
|
self.conv2_2 = VGG_Conv('conv2_2')
|
||||||
|
self.pool2 = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv3_1 = VGG_Conv('conv3_1')
|
||||||
|
self.conv3_2 = VGG_Conv('conv3_2')
|
||||||
|
self.conv3_3 = VGG_Conv('conv3_3')
|
||||||
|
self.pool3 = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv4_1 = VGG_Conv('conv4_1')
|
||||||
|
self.conv4_2 = VGG_Conv('conv4_2')
|
||||||
|
self.conv4_3 = VGG_Conv('conv4_3')
|
||||||
|
self.pool4 = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv5_1 = VGG_Conv('conv5_1')
|
||||||
|
self.conv5_2 = VGG_Conv('conv5_2')
|
||||||
|
self.conv5_3 = VGG_Conv('conv5_3')
|
||||||
|
self.pool5 = nn.MaxPool2d(2, 2)
|
||||||
|
self.merging1 = self.merging(i=2)
|
||||||
|
self.merging2 = self.merging(i=3)
|
||||||
|
self.merging3 = self.merging(i=4)
|
||||||
|
self.last_bn = nn.GroupNorm(16, 32)
|
||||||
|
self.conv_last = nn.Conv2d(32, 32, kernel_size=3, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||||
|
self.inside_score_conv = nn.Conv2d(32, 1, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||||
|
self.side_v_angle_conv = nn.Conv2d(32, 2, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||||
|
self.side_v_coord_conv = nn.Conv2d(32, 4, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||||
|
self.op_concat = P.Concat(axis=1)
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
|
||||||
|
def merging(self, i=2):
|
||||||
|
"""
|
||||||
|
def merge layer
|
||||||
|
"""
|
||||||
|
in_size = {'2': 1024, '3': 384, '4': 192}
|
||||||
|
layers = [
|
||||||
|
nn.Conv2d(in_size[str(i)], 128 // 2 ** (i - 2), kernel_size=1, stride=1, has_bias=True,
|
||||||
|
weight_init='XavierUniform'),
|
||||||
|
nn.GroupNorm(16, 128 // 2 ** (i - 2)),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(128 // 2 ** (i - 2), 128 // 2 ** (i - 2), kernel_size=3, stride=1, has_bias=True,
|
||||||
|
weight_init='XavierUniform'),
|
||||||
|
nn.GroupNorm(16, 128 // 2 ** (i - 2)),
|
||||||
|
nn.ReLU()]
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""
|
||||||
|
forward func
|
||||||
|
"""
|
||||||
|
f4 = self.conv1_1(x)
|
||||||
|
f4 = self.conv1_2(f4)
|
||||||
|
f4 = self.pool1(f4)
|
||||||
|
f4 = self.conv2_1(f4)
|
||||||
|
f4 = self.conv2_2(f4)
|
||||||
|
f4 = self.pool2(f4)
|
||||||
|
f3 = self.conv3_1(f4)
|
||||||
|
f3 = self.conv3_2(f3)
|
||||||
|
f3 = self.conv3_3(f3)
|
||||||
|
f3 = self.pool3(f3)
|
||||||
|
f2 = self.conv4_1(f3)
|
||||||
|
f2 = self.conv4_2(f2)
|
||||||
|
f2 = self.conv4_3(f2)
|
||||||
|
f2 = self.pool4(f2)
|
||||||
|
f1 = self.conv5_1(f2)
|
||||||
|
f1 = self.conv5_2(f1)
|
||||||
|
f1 = self.conv5_3(f1)
|
||||||
|
f1 = self.pool5(f1)
|
||||||
|
h1 = f1
|
||||||
|
_, _, h_, w_ = P.Shape()(h1)
|
||||||
|
H1 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h1)
|
||||||
|
concat1 = self.op_concat((H1, f2))
|
||||||
|
h2 = self.merging1(concat1)
|
||||||
|
_, _, h_, w_ = P.Shape()(h2)
|
||||||
|
H2 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h2)
|
||||||
|
concat2 = self.op_concat((H2, f3))
|
||||||
|
h3 = self.merging2(concat2)
|
||||||
|
_, _, h_, w_ = P.Shape()(h3)
|
||||||
|
H3 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h3)
|
||||||
|
concat3 = self.op_concat((H3, f4))
|
||||||
|
h4 = self.merging3(concat3)
|
||||||
|
before_output = self.relu(self.last_bn(self.conv_last(h4)))
|
||||||
|
inside_score = self.inside_score_conv(before_output)
|
||||||
|
side_v_angle = self.side_v_angle_conv(before_output)
|
||||||
|
side_v_coord = self.side_v_coord_conv(before_output)
|
||||||
|
east_detect = self.op_concat((inside_score, side_v_coord, side_v_angle))
|
||||||
|
return east_detect
|
||||||
|
|
||||||
|
|
||||||
|
def dice_loss(gt_score, pred_score):
|
||||||
|
"""dice_loss1"""
|
||||||
|
inter = P.ReduceSum()(gt_score * pred_score)
|
||||||
|
union = P.ReduceSum()(gt_score) + P.ReduceSum()(pred_score) + 1e-5
|
||||||
|
return 1. - (2 * (inter / union))
|
||||||
|
|
||||||
|
|
||||||
|
def dice_loss2(gt_score, pred_score, mask):
|
||||||
|
"""dice_loss2"""
|
||||||
|
inter = P.ReduceSum()(gt_score * pred_score * mask)
|
||||||
|
union = P.ReduceSum()(gt_score * mask) + P.ReduceSum()(pred_score * mask) + 1e-5
|
||||||
|
return 1. - (2 * (inter / union))
|
||||||
|
|
||||||
|
|
||||||
|
def quad_loss(y_true, y_pred,
|
||||||
|
lambda_inside_score_loss=0.2,
|
||||||
|
lambda_side_vertex_code_loss=0.1,
|
||||||
|
lambda_side_vertex_coord_loss=1.0,
|
||||||
|
epsilon=1e-4):
|
||||||
|
"""quad loss"""
|
||||||
|
y_true = P.Transpose()(y_true, (0, 2, 3, 1))
|
||||||
|
y_pred = P.Transpose()(y_pred, (0, 2, 3, 1))
|
||||||
|
logits = y_pred[:, :, :, :1]
|
||||||
|
labels = y_true[:, :, :, :1]
|
||||||
|
predicts = P.Sigmoid()(logits)
|
||||||
|
inside_score_loss = dice_loss(labels, predicts)
|
||||||
|
inside_score_loss = inside_score_loss * lambda_inside_score_loss
|
||||||
|
# loss for side_vertex_code
|
||||||
|
vertex_logitsp = P.Sigmoid()(y_pred[:, :, :, 1:2])
|
||||||
|
vertex_labelsp = y_true[:, :, :, 1:2]
|
||||||
|
vertex_logitsn = P.Sigmoid()(y_pred[:, :, :, 2:3])
|
||||||
|
vertex_labelsn = y_true[:, :, :, 2:3]
|
||||||
|
labels2 = y_true[:, :, :, 1:2]
|
||||||
|
side_vertex_code_lossp = dice_loss2(vertex_labelsp, vertex_logitsp, labels)
|
||||||
|
side_vertex_code_lossn = dice_loss2(vertex_labelsn, vertex_logitsn, labels2)
|
||||||
|
side_vertex_code_loss = (side_vertex_code_lossp + side_vertex_code_lossn) * lambda_side_vertex_code_loss
|
||||||
|
# loss for side_vertex_coord delta
|
||||||
|
g_hat = y_pred[:, :, :, 3:] # N*W*H*8
|
||||||
|
g_true = y_true[:, :, :, 3:]
|
||||||
|
vertex_weights = P.Cast()(P.Equal()(y_true[:, :, :, 1], 1), mindspore.float32)
|
||||||
|
# vertex_weights=y_true[:, :, :,1]
|
||||||
|
pixel_wise_smooth_l1norm = smooth_l1_loss(g_hat, g_true, vertex_weights)
|
||||||
|
side_vertex_coord_loss = P.ReduceSum()(pixel_wise_smooth_l1norm) / (
|
||||||
|
P.ReduceSum()(vertex_weights) + epsilon)
|
||||||
|
side_vertex_coord_loss = side_vertex_coord_loss * lambda_side_vertex_coord_loss
|
||||||
|
# print(inside_score_loss, side_vertex_code_loss, side_vertex_coord_loss)
|
||||||
|
return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss
|
||||||
|
|
||||||
|
|
||||||
|
def smooth_l1_loss(prediction_tensor, target_tensor, weights):
|
||||||
|
"""smooth l1 loss"""
|
||||||
|
n_q = P.Reshape()(quad_norm(target_tensor), weights.shape)
|
||||||
|
diff = P.SmoothL1Loss()(prediction_tensor, target_tensor)
|
||||||
|
pixel_wise_smooth_l1norm = P.ReduceSum()(diff, -1) / n_q * weights
|
||||||
|
return pixel_wise_smooth_l1norm
|
||||||
|
|
||||||
|
|
||||||
|
def quad_norm(g_true, epsilon=1e-4):
|
||||||
|
""" quad norm"""
|
||||||
|
shape = g_true.shape
|
||||||
|
delta_xy_matrix = P.Reshape()(g_true, (shape[0] * shape[1] * shape[2], 2, 2))
|
||||||
|
diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :]
|
||||||
|
square = diff * diff
|
||||||
|
distance = P.Sqrt()(P.ReduceSum()(square, -1))
|
||||||
|
distance = distance * 4.0
|
||||||
|
distance = distance + epsilon
|
||||||
|
return P.Reshape()(distance, (shape[0], shape[1], shape[2]))
|
||||||
|
|
||||||
|
|
||||||
|
class EastWithLossCell(nn.Cell):
|
||||||
|
"""get loss cell"""
|
||||||
|
|
||||||
|
def __init__(self, network, config=None):
|
||||||
|
super(EastWithLossCell, self).__init__()
|
||||||
|
self.East_network = network
|
||||||
|
self.cat = P.Concat(axis=1)
|
||||||
|
|
||||||
|
def construct(self, image, label):
|
||||||
|
y_pred = self.East_network(image)
|
||||||
|
loss = quad_loss(label, y_pred)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TrainStepWrap(nn.Cell):
|
||||||
|
"""train net warper"""
|
||||||
|
def __init__(self, network, steps, config=None):
|
||||||
|
super(TrainStepWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.network.set_train()
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = AdamWeightDecay(self.weights, learning_rate=cfg.learning_rate
|
||||||
|
, eps=1e-7, weight_decay=cfg.decay)
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.sens = 3.0
|
||||||
|
|
||||||
|
def construct(self, image, label):
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(image, label)
|
||||||
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
|
grads = self.grad(self.network, weights)(image, label, sens)
|
||||||
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
|
def get_AdvancedEast_net(configure=None, steps=1, mode=True):
|
||||||
|
"""
|
||||||
|
Get network of wide&deep model.
|
||||||
|
"""
|
||||||
|
AdvancedEast_net = AdvancedEast()
|
||||||
|
loss_net = EastWithLossCell(AdvancedEast_net, configure)
|
||||||
|
train_net = TrainStepWrap(loss_net, steps, configure)
|
||||||
|
return loss_net, train_net
|
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
NMS module
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.config import config as cfg
|
||||||
|
|
||||||
|
|
||||||
|
def should_merge(region, i, j):
|
||||||
|
"""
|
||||||
|
test merge
|
||||||
|
"""
|
||||||
|
neighbor = {(i, j - 1)}
|
||||||
|
return not region.isdisjoint(neighbor)
|
||||||
|
|
||||||
|
|
||||||
|
def region_neighbor(region_set):
|
||||||
|
"""
|
||||||
|
cal the neighbor of the region
|
||||||
|
"""
|
||||||
|
region_pixels = np.array(list(region_set))
|
||||||
|
j_min = np.amin(region_pixels, axis=0)[1] - 1
|
||||||
|
j_max = np.amax(region_pixels, axis=0)[1] + 1
|
||||||
|
i_m = np.amin(region_pixels, axis=0)[0] + 1
|
||||||
|
region_pixels[:, 0] += 1
|
||||||
|
neighbor = {(region_pixels[n, 0], region_pixels[n, 1]) for n in
|
||||||
|
range(len(region_pixels))}
|
||||||
|
neighbor.add((i_m, j_min))
|
||||||
|
neighbor.add((i_m, j_max))
|
||||||
|
return neighbor
|
||||||
|
|
||||||
|
|
||||||
|
def region_group(region_list):
|
||||||
|
"""
|
||||||
|
group regions
|
||||||
|
"""
|
||||||
|
S = [i for i in range(len(region_list))]
|
||||||
|
D = []
|
||||||
|
while S:
|
||||||
|
m = S.pop(0)
|
||||||
|
if not S:
|
||||||
|
# S has only one element, put it to D
|
||||||
|
D.append([m])
|
||||||
|
else:
|
||||||
|
D.append(rec_region_merge(region_list, m, S))
|
||||||
|
return D
|
||||||
|
|
||||||
|
|
||||||
|
def rec_region_merge(region_list, m, S):
|
||||||
|
"""
|
||||||
|
merge regions
|
||||||
|
"""
|
||||||
|
rows = [m]
|
||||||
|
tmp = []
|
||||||
|
for n in S:
|
||||||
|
if not region_neighbor(region_list[m]).isdisjoint(region_list[n]) or \
|
||||||
|
not region_neighbor(region_list[n]).isdisjoint(region_list[m]):
|
||||||
|
tmp.append(n)
|
||||||
|
for d in tmp:
|
||||||
|
S.remove(d)
|
||||||
|
for e in tmp:
|
||||||
|
rows.extend(rec_region_merge(region_list, e, S))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def nms(predict, activation_pixels, threshold=cfg.side_vertex_pixel_threshold):
|
||||||
|
"""
|
||||||
|
perform nms on results
|
||||||
|
"""
|
||||||
|
region_list = []
|
||||||
|
for i, j in zip(activation_pixels[0], activation_pixels[1]):
|
||||||
|
merge = False
|
||||||
|
for k, value in enumerate(region_list):
|
||||||
|
if should_merge(value, i, j):
|
||||||
|
region_list[k].add((i, j))
|
||||||
|
merge = True
|
||||||
|
if not merge:
|
||||||
|
region_list.append({(i, j)})
|
||||||
|
D = region_group(region_list)
|
||||||
|
quad_list = np.zeros((len(D), 4, 2))
|
||||||
|
score_list = np.zeros((len(D), 4))
|
||||||
|
for group, g_th in zip(D, range(len(D))):
|
||||||
|
total_score = np.zeros((4, 2))
|
||||||
|
for row in group:
|
||||||
|
for ij in region_list[row]:
|
||||||
|
score = predict[ij[0], ij[1], 1]
|
||||||
|
if score >= threshold:
|
||||||
|
ith_score = predict[ij[0], ij[1], 2:3]
|
||||||
|
if not (cfg.trunc_threshold <= ith_score < 1 -
|
||||||
|
cfg.trunc_threshold):
|
||||||
|
ith = int(np.around(ith_score))
|
||||||
|
total_score[ith * 2:(ith + 1) * 2] += score
|
||||||
|
px = (ij[1] + 0.5) * cfg.pixel_size
|
||||||
|
py = (ij[0] + 0.5) * cfg.pixel_size
|
||||||
|
p_v = [px, py] + np.reshape(predict[ij[0], ij[1], 3:7],
|
||||||
|
(2, 2))
|
||||||
|
quad_list[g_th, ith * 2:(ith + 1) * 2] += score * p_v
|
||||||
|
score_list[g_th] = total_score[:, 0]
|
||||||
|
quad_list[g_th] /= (total_score + cfg.epsilon)
|
||||||
|
return score_list, quad_list
|
|
@ -0,0 +1,176 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
#################predict the qual line of images ########################
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
from src.label import point_inside_of_quad
|
||||||
|
from src.config import config as cfg
|
||||||
|
from src.nms import nms
|
||||||
|
from src.preprocess import resize_image
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid(x):
|
||||||
|
"""`y = 1 / (1 + exp(-x))`"""
|
||||||
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
|
|
||||||
|
def cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array, img_path, s):
|
||||||
|
"""
|
||||||
|
cut text line
|
||||||
|
"""
|
||||||
|
geo /= [scale_ratio_w, scale_ratio_h]
|
||||||
|
p_min = np.amin(geo, axis=0)
|
||||||
|
p_max = np.amax(geo, axis=0)
|
||||||
|
min_xy = p_min.astype(int)
|
||||||
|
max_xy = p_max.astype(int) + 2
|
||||||
|
sub_im_arr = im_array[min_xy[1]:max_xy[1], min_xy[0]:max_xy[0], :].copy()
|
||||||
|
for m in range(min_xy[1], max_xy[1]):
|
||||||
|
for n in range(min_xy[0], max_xy[0]):
|
||||||
|
if not point_inside_of_quad(n, m, geo, p_min, p_max):
|
||||||
|
sub_im_arr[m - min_xy[1], n - min_xy[0], :] = 255
|
||||||
|
sub_im = Image.fromarray(sub_im_arr)
|
||||||
|
sub_im.save(img_path + '_subim%d.jpg' % s)
|
||||||
|
|
||||||
|
|
||||||
|
def predict(east_detect, img_path, pixel_threshold, quiet=True):
|
||||||
|
"""
|
||||||
|
predict to txt and image
|
||||||
|
"""
|
||||||
|
img = Image.open(img_path)
|
||||||
|
d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
|
||||||
|
d_wight = max(d_wight, d_height)
|
||||||
|
d_height = max(d_wight, d_height)
|
||||||
|
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||||
|
img = np.asarray(img)
|
||||||
|
img = img / 127.5 - 1
|
||||||
|
img = img.transpose((2, 0, 1))
|
||||||
|
x = Tensor(np.expand_dims(img, axis=0), "float32")
|
||||||
|
y = east_detect(x).asnumpy()
|
||||||
|
y = np.squeeze(y, axis=0)
|
||||||
|
|
||||||
|
if y.shape[0] == 7:
|
||||||
|
y = y.transpose((1, 2, 0)) # CHW->HWC
|
||||||
|
|
||||||
|
y[:, :, :3] = sigmoid(y[:, :, :3])
|
||||||
|
cond = np.greater_equal(y[:, :, 0], pixel_threshold)
|
||||||
|
activation_pixels = np.where(cond)
|
||||||
|
quad_scores, quad_after_nms = nms(y, activation_pixels)
|
||||||
|
with Image.open(img_path) as im:
|
||||||
|
im_array = np.asarray(im.convert('RGB'))
|
||||||
|
d_wight, d_height = resize_image(im, cfg.max_predict_img_size)
|
||||||
|
scale_ratio_w = d_wight / im.width
|
||||||
|
scale_ratio_h = d_height / im.height
|
||||||
|
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||||
|
quad_im = im.copy()
|
||||||
|
draw = ImageDraw.Draw(im)
|
||||||
|
for i, j in zip(activation_pixels[0], activation_pixels[1]):
|
||||||
|
px = (j + 0.5) * cfg.pixel_size
|
||||||
|
py = (i + 0.5) * cfg.pixel_size
|
||||||
|
line_width, line_color = 1, 'red'
|
||||||
|
if y[i, j, 1] >= cfg.side_vertex_pixel_threshold:
|
||||||
|
if y[i, j, 2] < cfg.trunc_threshold:
|
||||||
|
line_width, line_color = 2, 'yellow'
|
||||||
|
elif y[i, j, 2] >= 1 - cfg.trunc_threshold:
|
||||||
|
line_width, line_color = 2, 'green'
|
||||||
|
draw.line([(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
|
||||||
|
(px + 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
|
||||||
|
(px + 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
|
||||||
|
(px - 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
|
||||||
|
(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size)],
|
||||||
|
width=line_width, fill=line_color)
|
||||||
|
im.save(img_path + '_act.jpg')
|
||||||
|
quad_draw = ImageDraw.Draw(quad_im)
|
||||||
|
txt_items = []
|
||||||
|
for score, geo, s in zip(quad_scores, quad_after_nms,
|
||||||
|
range(len(quad_scores))):
|
||||||
|
print(np.amin(score))
|
||||||
|
if np.amin(score) > 0:
|
||||||
|
quad_draw.line([tuple(geo[0]),
|
||||||
|
tuple(geo[1]),
|
||||||
|
tuple(geo[2]),
|
||||||
|
tuple(geo[3]),
|
||||||
|
tuple(geo[0])], width=2, fill='red')
|
||||||
|
if cfg.predict_cut_text_line:
|
||||||
|
cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array,
|
||||||
|
img_path, s)
|
||||||
|
rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
|
||||||
|
rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
|
||||||
|
txt_item = ','.join(map(str, rescaled_geo_list))
|
||||||
|
txt_items.append(txt_item + '\n')
|
||||||
|
elif not quiet:
|
||||||
|
print('quad invalid with vertex num less then 4.')
|
||||||
|
quad_im.save(img_path + '_predict.jpg')
|
||||||
|
if cfg.predict_write2txt and txt_items:
|
||||||
|
with open(img_path[:-4] + '.txt', 'w') as f_txt:
|
||||||
|
f_txt.writelines(txt_items)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_txt(east_detect, img_path, txt_path, pixel_threshold, quiet=False):
|
||||||
|
"""
|
||||||
|
predict to txt
|
||||||
|
"""
|
||||||
|
img = Image.open(img_path)
|
||||||
|
d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
|
||||||
|
scale_ratio_w = d_wight / img.width
|
||||||
|
scale_ratio_h = d_height / img.height
|
||||||
|
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||||
|
img = np.asarray(img)
|
||||||
|
img = img / 127.5 - 1
|
||||||
|
img = img.transpose((2, 0, 1))
|
||||||
|
x = np.expand_dims(img, axis=0)
|
||||||
|
y = east_detect(x).asnumpy()
|
||||||
|
|
||||||
|
y = np.squeeze(y, axis=0)
|
||||||
|
|
||||||
|
if y.shape[0] == 7:
|
||||||
|
y = y.transpose((1, 2, 0)) # CHW->HWC
|
||||||
|
|
||||||
|
y[:, :, :3] = sigmoid(y[:, :, :3])
|
||||||
|
cond = np.greater_equal(y[:, :, 0], pixel_threshold)
|
||||||
|
activation_pixels = np.where(cond)
|
||||||
|
quad_scores, quad_after_nms = nms(y, activation_pixels)
|
||||||
|
|
||||||
|
txt_items = []
|
||||||
|
for score, geo in zip(quad_scores, quad_after_nms):
|
||||||
|
if np.amin(score) > 0:
|
||||||
|
rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
|
||||||
|
rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
|
||||||
|
txt_item = ','.join(map(str, rescaled_geo_list))
|
||||||
|
txt_items.append(txt_item + '\n')
|
||||||
|
elif not quiet:
|
||||||
|
print('quad invalid with vertex num less then 4.')
|
||||||
|
if cfg.predict_write2txt and txt_items:
|
||||||
|
with open(txt_path, 'w') as f_txt:
|
||||||
|
f_txt.writelines(txt_items)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""
|
||||||
|
parse_args
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--path', '-p',
|
||||||
|
default='demo/004.jpg',
|
||||||
|
help='image path')
|
||||||
|
parser.add_argument('--threshold', '-t',
|
||||||
|
default=cfg.pixel_threshold,
|
||||||
|
help='pixel activation threshold')
|
||||||
|
return parser.parse_args()
|
|
@ -0,0 +1,311 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
#################preprocess images########################
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
from src.label import shrink
|
||||||
|
from src.config import config as cfg
|
||||||
|
|
||||||
|
|
||||||
|
def batch_reorder_vertexes(xy_list_array):
|
||||||
|
"""
|
||||||
|
batch xy to batch vertices
|
||||||
|
"""
|
||||||
|
reorder_xy_list_array = np.zeros_like(xy_list_array)
|
||||||
|
for xy_list, i in zip(xy_list_array, range(len(xy_list_array))):
|
||||||
|
reorder_xy_list_array[i] = reorder_vertexes(xy_list)
|
||||||
|
return reorder_xy_list_array
|
||||||
|
|
||||||
|
|
||||||
|
def reorder_vertexes(xy_list):
|
||||||
|
"""
|
||||||
|
xy to vertices
|
||||||
|
"""
|
||||||
|
|
||||||
|
reorder_xy_list = np.zeros_like(xy_list)
|
||||||
|
# determine the first point with the smallest x,
|
||||||
|
# if two has same x, choose that with smallest y,
|
||||||
|
ordered = np.argsort(xy_list, axis=0)
|
||||||
|
xmin1_index = ordered[0, 0]
|
||||||
|
xmin2_index = ordered[1, 0]
|
||||||
|
if xy_list[xmin1_index, 0] == xy_list[xmin2_index, 0]:
|
||||||
|
if xy_list[xmin1_index, 1] <= xy_list[xmin2_index, 1]:
|
||||||
|
reorder_xy_list[0] = xy_list[xmin1_index]
|
||||||
|
first_v = xmin1_index
|
||||||
|
else:
|
||||||
|
reorder_xy_list[0] = xy_list[xmin2_index]
|
||||||
|
first_v = xmin2_index
|
||||||
|
else:
|
||||||
|
reorder_xy_list[0] = xy_list[xmin1_index]
|
||||||
|
first_v = xmin1_index
|
||||||
|
# connect the first point to others, the third point on the other side of
|
||||||
|
# the line with the middle slope
|
||||||
|
others = list(range(4))
|
||||||
|
others.remove(first_v)
|
||||||
|
k = np.zeros((len(others),))
|
||||||
|
for index, i in zip(others, range(len(others))):
|
||||||
|
k[i] = (xy_list[index, 1] - xy_list[first_v, 1]) \
|
||||||
|
/ (xy_list[index, 0] - xy_list[first_v, 0] + cfg.epsilon)
|
||||||
|
k_mid = np.argsort(k)[1]
|
||||||
|
third_v = others[k_mid]
|
||||||
|
reorder_xy_list[2] = xy_list[third_v]
|
||||||
|
# determine the second point which on the bigger side of the middle line
|
||||||
|
others.remove(third_v)
|
||||||
|
b_mid = xy_list[first_v, 1] - k[k_mid] * xy_list[first_v, 0]
|
||||||
|
second_v, fourth_v = 0, 0
|
||||||
|
for index, i in zip(others, range(len(others))):
|
||||||
|
# delta = y - (k * x + b)
|
||||||
|
delta_y = xy_list[index, 1] - (k[k_mid] * xy_list[index, 0] + b_mid)
|
||||||
|
if delta_y > 0:
|
||||||
|
second_v = index
|
||||||
|
else:
|
||||||
|
fourth_v = index
|
||||||
|
reorder_xy_list[1] = xy_list[second_v]
|
||||||
|
reorder_xy_list[3] = xy_list[fourth_v]
|
||||||
|
# compare slope of 13 and 24, determine the final order
|
||||||
|
k13 = k[k_mid]
|
||||||
|
k24 = (xy_list[second_v, 1] - xy_list[fourth_v, 1]) / (
|
||||||
|
xy_list[second_v, 0] - xy_list[fourth_v, 0] + cfg.epsilon)
|
||||||
|
if k13 < k24:
|
||||||
|
tmp_x, tmp_y = reorder_xy_list[3, 0], reorder_xy_list[3, 1]
|
||||||
|
for i in range(2, -1, -1):
|
||||||
|
reorder_xy_list[i + 1] = reorder_xy_list[i]
|
||||||
|
reorder_xy_list[0, 0], reorder_xy_list[0, 1] = tmp_x, tmp_y
|
||||||
|
return reorder_xy_list
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(im, max_img_size=cfg.max_train_img_size):
|
||||||
|
"""
|
||||||
|
resize image
|
||||||
|
"""
|
||||||
|
im_width = np.minimum(im.width, max_img_size)
|
||||||
|
if im_width == max_img_size < im.width:
|
||||||
|
im_height = int((im_width / im.width) * im.height)
|
||||||
|
else:
|
||||||
|
im_height = im.height
|
||||||
|
o_height = np.minimum(im_height, max_img_size)
|
||||||
|
if o_height == max_img_size < im_height:
|
||||||
|
o_width = int((o_height / im_height) * im_width)
|
||||||
|
else:
|
||||||
|
o_width = im_width
|
||||||
|
d_wight = o_width - (o_width % 32)
|
||||||
|
d_height = o_height - (o_height % 32)
|
||||||
|
return d_wight, d_height
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess():
|
||||||
|
"""
|
||||||
|
preprocess data
|
||||||
|
"""
|
||||||
|
data_dir = cfg.data_dir
|
||||||
|
origin_image_dir = os.path.join(data_dir, cfg.origin_image_dir_name)
|
||||||
|
origin_txt_dir = os.path.join(data_dir, cfg.origin_txt_dir_name)
|
||||||
|
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
|
||||||
|
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
|
||||||
|
if not os.path.exists(train_image_dir):
|
||||||
|
os.mkdir(train_image_dir)
|
||||||
|
if not os.path.exists(train_label_dir):
|
||||||
|
os.mkdir(train_label_dir)
|
||||||
|
draw_gt_quad = cfg.draw_gt_quad
|
||||||
|
show_gt_image_dir = os.path.join(data_dir, cfg.show_gt_image_dir_name)
|
||||||
|
if not os.path.exists(show_gt_image_dir):
|
||||||
|
os.mkdir(show_gt_image_dir)
|
||||||
|
show_act_image_dir = os.path.join(cfg.data_dir, cfg.show_act_image_dir_name)
|
||||||
|
if not os.path.exists(show_act_image_dir):
|
||||||
|
os.mkdir(show_act_image_dir)
|
||||||
|
|
||||||
|
o_img_list = os.listdir(origin_image_dir)
|
||||||
|
print('found %d origin images.' % len(o_img_list))
|
||||||
|
train_val_set = []
|
||||||
|
for o_img_fname, _ in zip(o_img_list, tqdm(range(len(o_img_list)))):
|
||||||
|
with Image.open(os.path.join(origin_image_dir, o_img_fname)) as im:
|
||||||
|
if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \
|
||||||
|
os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)):
|
||||||
|
continue
|
||||||
|
# d_wight, d_height = resize_image(im)
|
||||||
|
d_wight, d_height = cfg.max_train_img_size, cfg.max_train_img_size
|
||||||
|
scale_ratio_w = d_wight / im.width
|
||||||
|
scale_ratio_h = d_height / im.height
|
||||||
|
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||||
|
show_gt_im = im.copy()
|
||||||
|
# draw on the img
|
||||||
|
draw = ImageDraw.Draw(show_gt_im)
|
||||||
|
with open(os.path.join(origin_txt_dir,
|
||||||
|
o_img_fname[:-4] + '.txt'), 'r', encoding='utf-8') as f:
|
||||||
|
anno_list = f.readlines()
|
||||||
|
xy_list_array = np.zeros((len(anno_list), 4, 2))
|
||||||
|
for anno, i in zip(anno_list, range(len(anno_list))):
|
||||||
|
anno_colums = anno.strip().split(',')
|
||||||
|
anno_array = np.array(anno_colums)
|
||||||
|
xy_list = np.reshape(anno_array[:8].astype(float), (4, 2))
|
||||||
|
xy_list[:, 0] = xy_list[:, 0] * scale_ratio_w
|
||||||
|
xy_list[:, 1] = xy_list[:, 1] * scale_ratio_h
|
||||||
|
xy_list = reorder_vertexes(xy_list)
|
||||||
|
xy_list_array[i] = xy_list
|
||||||
|
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
|
||||||
|
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
|
||||||
|
if draw_gt_quad:
|
||||||
|
draw.line([tuple(xy_list[0]), tuple(xy_list[1]),
|
||||||
|
tuple(xy_list[2]), tuple(xy_list[3]),
|
||||||
|
tuple(xy_list[0])
|
||||||
|
],
|
||||||
|
width=2, fill='green')
|
||||||
|
draw.line([tuple(shrink_xy_list[0]),
|
||||||
|
tuple(shrink_xy_list[1]),
|
||||||
|
tuple(shrink_xy_list[2]),
|
||||||
|
tuple(shrink_xy_list[3]),
|
||||||
|
tuple(shrink_xy_list[0])
|
||||||
|
],
|
||||||
|
width=2, fill='blue')
|
||||||
|
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
|
||||||
|
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
|
||||||
|
for q_th in range(2):
|
||||||
|
draw.line([tuple(xy_list[vs[long_edge][q_th][0]]),
|
||||||
|
tuple(shrink_1[vs[long_edge][q_th][1]]),
|
||||||
|
tuple(shrink_1[vs[long_edge][q_th][2]]),
|
||||||
|
tuple(xy_list[vs[long_edge][q_th][3]]),
|
||||||
|
tuple(xy_list[vs[long_edge][q_th][4]])],
|
||||||
|
width=3, fill='yellow')
|
||||||
|
if cfg.gen_origin_img:
|
||||||
|
im.save(os.path.join(train_image_dir, o_img_fname))
|
||||||
|
np.save(os.path.join(
|
||||||
|
train_label_dir,
|
||||||
|
o_img_fname[:-4] + '.npy'),
|
||||||
|
xy_list_array)
|
||||||
|
if draw_gt_quad:
|
||||||
|
show_gt_im.save(os.path.join(show_gt_image_dir, o_img_fname))
|
||||||
|
train_val_set.append('{},{},{}\n'.format(o_img_fname,
|
||||||
|
d_wight,
|
||||||
|
d_height))
|
||||||
|
|
||||||
|
train_img_list = os.listdir(train_image_dir)
|
||||||
|
print('found %d train images.' % len(train_img_list))
|
||||||
|
train_label_list = os.listdir(train_label_dir)
|
||||||
|
print('found %d train labels.' % len(train_label_list))
|
||||||
|
|
||||||
|
random.shuffle(train_val_set)
|
||||||
|
val_count = int(cfg.validation_split_ratio * len(train_val_set))
|
||||||
|
with open(os.path.join(data_dir, cfg.val_fname), 'a') as f_val:
|
||||||
|
f_val.writelines(train_val_set[:val_count])
|
||||||
|
with open(os.path.join(data_dir, cfg.train_fname), 'a') as f_train:
|
||||||
|
f_train.writelines(train_val_set[val_count:])
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_size(width=256):
|
||||||
|
"""
|
||||||
|
preprocess data at specific size
|
||||||
|
"""
|
||||||
|
if width not in cfg.train_img_size:
|
||||||
|
raise NotImplementedError
|
||||||
|
data_dir = cfg.data_dir
|
||||||
|
origin_image_dir = os.path.join(data_dir, cfg.origin_image_dir_name)
|
||||||
|
origin_txt_dir = os.path.join(data_dir, cfg.origin_txt_dir_name)
|
||||||
|
|
||||||
|
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name_var + str(width))
|
||||||
|
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
|
||||||
|
|
||||||
|
if not os.path.exists(train_image_dir):
|
||||||
|
os.mkdir(train_image_dir)
|
||||||
|
if not os.path.exists(train_label_dir):
|
||||||
|
os.mkdir(train_label_dir)
|
||||||
|
draw_gt_quad = cfg.draw_gt_quad
|
||||||
|
show_gt_image_dir = os.path.join(data_dir, cfg.show_gt_image_dir_name)
|
||||||
|
if not os.path.exists(show_gt_image_dir):
|
||||||
|
os.mkdir(show_gt_image_dir)
|
||||||
|
show_act_image_dir = os.path.join(cfg.data_dir, cfg.show_act_image_dir_name)
|
||||||
|
if not os.path.exists(show_act_image_dir):
|
||||||
|
os.mkdir(show_act_image_dir)
|
||||||
|
|
||||||
|
o_img_list = os.listdir(origin_image_dir)
|
||||||
|
print('found %d origin images.' % len(o_img_list))
|
||||||
|
train_val_set = []
|
||||||
|
for o_img_fname, _ in zip(o_img_list, tqdm(range(len(o_img_list)))):
|
||||||
|
with Image.open(os.path.join(origin_image_dir, o_img_fname)) as im:
|
||||||
|
if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \
|
||||||
|
os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)):
|
||||||
|
continue
|
||||||
|
# d_wight, d_height = resize_image(im)
|
||||||
|
d_wight, d_height = width, width
|
||||||
|
scale_ratio_w = d_wight / im.width
|
||||||
|
scale_ratio_h = d_height / im.height
|
||||||
|
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||||
|
show_gt_im = im.copy()
|
||||||
|
# draw on the img
|
||||||
|
draw = ImageDraw.Draw(show_gt_im)
|
||||||
|
with open(os.path.join(origin_txt_dir,
|
||||||
|
o_img_fname[:-4] + '.txt'), 'r', encoding='utf-8') as f:
|
||||||
|
anno_list = f.readlines()
|
||||||
|
xy_list_array = np.zeros((len(anno_list), 4, 2))
|
||||||
|
for anno, i in zip(anno_list, range(len(anno_list))):
|
||||||
|
anno_colums = anno.strip().split(',')
|
||||||
|
anno_array = np.array(anno_colums)
|
||||||
|
xy_list = np.reshape(anno_array[:8].astype(float), (4, 2))
|
||||||
|
xy_list[:, 0] = xy_list[:, 0] * scale_ratio_w
|
||||||
|
xy_list[:, 1] = xy_list[:, 1] * scale_ratio_h
|
||||||
|
xy_list = reorder_vertexes(xy_list)
|
||||||
|
xy_list_array[i] = xy_list
|
||||||
|
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
|
||||||
|
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
|
||||||
|
if draw_gt_quad:
|
||||||
|
draw.line([tuple(xy_list[0]), tuple(xy_list[1]),
|
||||||
|
tuple(xy_list[2]), tuple(xy_list[3]),
|
||||||
|
tuple(xy_list[0])
|
||||||
|
],
|
||||||
|
width=2, fill='green')
|
||||||
|
draw.line([tuple(shrink_xy_list[0]),
|
||||||
|
tuple(shrink_xy_list[1]),
|
||||||
|
tuple(shrink_xy_list[2]),
|
||||||
|
tuple(shrink_xy_list[3]),
|
||||||
|
tuple(shrink_xy_list[0])
|
||||||
|
],
|
||||||
|
width=2, fill='blue')
|
||||||
|
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
|
||||||
|
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
|
||||||
|
for q_th in range(2):
|
||||||
|
draw.line([tuple(xy_list[vs[long_edge][q_th][0]]),
|
||||||
|
tuple(shrink_1[vs[long_edge][q_th][1]]),
|
||||||
|
tuple(shrink_1[vs[long_edge][q_th][2]]),
|
||||||
|
tuple(xy_list[vs[long_edge][q_th][3]]),
|
||||||
|
tuple(xy_list[vs[long_edge][q_th][4]])],
|
||||||
|
width=3, fill='yellow')
|
||||||
|
if cfg.gen_origin_img:
|
||||||
|
im.save(os.path.join(train_image_dir, o_img_fname))
|
||||||
|
np.save(os.path.join(
|
||||||
|
train_label_dir,
|
||||||
|
o_img_fname[:-4] + '.npy'),
|
||||||
|
xy_list_array)
|
||||||
|
if draw_gt_quad:
|
||||||
|
show_gt_im.save(os.path.join(show_gt_image_dir, o_img_fname))
|
||||||
|
train_val_set.append('{},{},{}\n'.format(o_img_fname,
|
||||||
|
d_wight,
|
||||||
|
d_height))
|
||||||
|
|
||||||
|
train_img_list = os.listdir(train_image_dir)
|
||||||
|
print('found %d train images.' % len(train_img_list))
|
||||||
|
train_label_list = os.listdir(train_label_dir)
|
||||||
|
print('found %d train labels.' % len(train_label_list))
|
||||||
|
|
||||||
|
# random.shuffle(train_val_set)
|
||||||
|
val_count = int(cfg.validation_split_ratio * len(train_val_set))
|
||||||
|
with open(os.path.join(data_dir, cfg.val_fname_var + str(width) + '.txt'), 'a') as f_val:
|
||||||
|
f_val.writelines(train_val_set[:val_count])
|
||||||
|
with open(os.path.join(data_dir, cfg.train_fname_var + str(width) + '.txt'), 'a') as f_train:
|
||||||
|
f_train.writelines(train_val_set[val_count:])
|
|
@ -0,0 +1,151 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
#################calculate precision ########################
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from shapely.geometry import Polygon
|
||||||
|
|
||||||
|
from src.config import config as cfg
|
||||||
|
from src.nms import nms
|
||||||
|
|
||||||
|
|
||||||
|
# from src.preprocess import resize_image
|
||||||
|
|
||||||
|
class Averager():
|
||||||
|
"""Compute average for torch.Tensor, used for loss average."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def add(self, v):
|
||||||
|
count = v.data.numel()
|
||||||
|
v = v.data.sum()
|
||||||
|
self.n_count += count
|
||||||
|
self.sum += v
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.n_count = 0
|
||||||
|
self.sum = 0
|
||||||
|
|
||||||
|
def val(self):
|
||||||
|
res = 0
|
||||||
|
if self.n_count != 0:
|
||||||
|
res = self.sum / float(self.n_count)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class eval_pre_rec_f1():
|
||||||
|
'''eval batch'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.pixel_threshold = float(cfg.pixel_threshold)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.img_num = 0
|
||||||
|
self.pre = 0
|
||||||
|
self.rec = 0
|
||||||
|
self.f1_score = 0
|
||||||
|
|
||||||
|
def val(self):
|
||||||
|
mpre = self.pre / self.img_num * 100
|
||||||
|
mrec = self.rec / self.img_num * 100
|
||||||
|
mf1_score = self.f1_score / self.img_num * 100
|
||||||
|
return mpre, mrec, mf1_score
|
||||||
|
|
||||||
|
def sigmoid(self, x):
|
||||||
|
"""`y = 1 / (1 + exp(-x))`"""
|
||||||
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
|
def get_iou(self, g, p):
|
||||||
|
"""`get_iou`"""
|
||||||
|
g = Polygon(g)
|
||||||
|
p = Polygon(p)
|
||||||
|
if not g.is_valid or not p.is_valid:
|
||||||
|
return 0
|
||||||
|
inter = Polygon(g).intersection(Polygon(p)).area
|
||||||
|
union = g.area + p.area - inter
|
||||||
|
if union == 0:
|
||||||
|
return 0
|
||||||
|
return inter / union
|
||||||
|
|
||||||
|
def eval_one(self, quad_scores, quad_after_nms, gt_xy, quiet=cfg.quiet):
|
||||||
|
"""`eval_one`"""
|
||||||
|
num_gts = len(gt_xy)
|
||||||
|
quad_scores_no_zero = []
|
||||||
|
quad_after_nms_no_zero = []
|
||||||
|
for score, geo in zip(quad_scores, quad_after_nms):
|
||||||
|
if np.amin(score) > 0:
|
||||||
|
quad_scores_no_zero.append(sum(score))
|
||||||
|
quad_after_nms_no_zero.append(geo)
|
||||||
|
elif not quiet:
|
||||||
|
print('quad invalid with vertex num less then 4.')
|
||||||
|
continue
|
||||||
|
num_quads = len(quad_after_nms_no_zero)
|
||||||
|
if num_quads == 0:
|
||||||
|
return 0, 0, 0
|
||||||
|
quad_flag = np.zeros(num_quads)
|
||||||
|
gt_flag = np.zeros(num_gts)
|
||||||
|
# print(num_quads, '-------', num_gts)
|
||||||
|
quad_scores_no_zero = np.array(quad_scores_no_zero)
|
||||||
|
scores_idx = np.argsort(quad_scores_no_zero)[::-1]
|
||||||
|
for i in range(num_quads):
|
||||||
|
idx = scores_idx[i]
|
||||||
|
# score = quad_scores_no_zero[idx]
|
||||||
|
geo = quad_after_nms_no_zero[idx]
|
||||||
|
for j in range(num_gts):
|
||||||
|
if gt_flag[j] == 0:
|
||||||
|
gt_geo = gt_xy[j]
|
||||||
|
iou = self.get_iou(geo, gt_geo)
|
||||||
|
if iou >= cfg.iou_threshold:
|
||||||
|
gt_flag[j] = 1
|
||||||
|
quad_flag[i] = 1
|
||||||
|
tp = np.sum(quad_flag)
|
||||||
|
fp = num_quads - tp
|
||||||
|
fn = num_gts - tp
|
||||||
|
pre = tp / (tp + fp)
|
||||||
|
rec = tp / (tp + fn)
|
||||||
|
if pre + rec == 0:
|
||||||
|
f1_score = 0
|
||||||
|
else:
|
||||||
|
f1_score = 2 * pre * rec / (pre + rec)
|
||||||
|
# print(pre, '---', rec, '---', f1_score)
|
||||||
|
return pre, rec, f1_score
|
||||||
|
|
||||||
|
def add(self, out, gt_xy_list):
|
||||||
|
"""`add`"""
|
||||||
|
self.img_num += len(gt_xy_list)
|
||||||
|
ys = out.asnumpy() # (N, 7, 64, 64)
|
||||||
|
print(ys.shape)
|
||||||
|
if ys.shape[1] == 7:
|
||||||
|
ys = ys.transpose((0, 2, 3, 1)) # NCHW->NHWC
|
||||||
|
for y, gt_xy in zip(ys, gt_xy_list):
|
||||||
|
y[:, :, :3] = self.sigmoid(y[:, :, :3])
|
||||||
|
cond = np.greater_equal(y[:, :, 0], self.pixel_threshold)
|
||||||
|
activation_pixels = np.where(cond)
|
||||||
|
quad_scores, quad_after_nms = nms(y, activation_pixels)
|
||||||
|
print(quad_scores)
|
||||||
|
lens = len(quad_after_nms)
|
||||||
|
if lens == 0 or (sum(sum(quad_scores)) == 0):
|
||||||
|
if not cfg.quiet:
|
||||||
|
print('NMS')
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
pre, rec, f1_score = self.eval_one(quad_scores, quad_after_nms, gt_xy)
|
||||||
|
print(pre, rec, f1_score)
|
||||||
|
self.pre += pre
|
||||||
|
self.rec += rec
|
||||||
|
self.f1_score += f1_score
|
|
@ -0,0 +1,188 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
#################train advanced_east on dataset########################
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
from mindspore import context, Model
|
||||||
|
from mindspore.communication.management import init, get_group_size
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.nn.optim import AdamWeightDecay
|
||||||
|
from src.logger import get_logger
|
||||||
|
from src.dataset import load_adEAST_dataset
|
||||||
|
from src.model import get_AdvancedEast_net
|
||||||
|
from src.config import config as cfg
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(cloud_args=None):
|
||||||
|
"""parameters"""
|
||||||
|
parser = argparse.ArgumentParser('mindspore adveast training')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
|
||||||
|
|
||||||
|
# network related
|
||||||
|
parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load')
|
||||||
|
|
||||||
|
# logging and checkpoint related
|
||||||
|
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||||
|
parser.add_argument('--ckpt_interval', type=int, default=1, help='ckpt_interval')
|
||||||
|
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||||
|
|
||||||
|
# distributed related
|
||||||
|
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||||
|
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||||
|
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
args_opt.initial_epoch = cfg.initial_epoch
|
||||||
|
args_opt.epoch_num = cfg.epoch_num
|
||||||
|
args_opt.learning_rate = cfg.learning_rate
|
||||||
|
args_opt.decay = cfg.decay
|
||||||
|
args_opt.batch_size = cfg.batch_size
|
||||||
|
args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio)
|
||||||
|
args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio
|
||||||
|
args_opt.ckpt_save_max = cfg.ckpt_save_max
|
||||||
|
args_opt.data_dir = cfg.data_dir
|
||||||
|
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
|
||||||
|
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
|
||||||
|
args_opt.train_image_dir_name = cfg.train_image_dir_name
|
||||||
|
args_opt.train_label_dir_name = cfg.train_label_dir_name
|
||||||
|
args_opt.results_dir = cfg.results_dir
|
||||||
|
args_opt.last_model_name = cfg.last_model_name
|
||||||
|
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
workers = 32
|
||||||
|
if args.is_distributed:
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id, device_target=args.device_target)
|
||||||
|
init()
|
||||||
|
elif args.device_target == "GPU":
|
||||||
|
context.set_context(device_target=args.device_target)
|
||||||
|
init()
|
||||||
|
|
||||||
|
args.group_size = get_group_size()
|
||||||
|
device_num = args.group_size
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
else:
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
# logger
|
||||||
|
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||||
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
|
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||||
|
|
||||||
|
|
||||||
|
args.logger.save_args(args)
|
||||||
|
# network
|
||||||
|
args.logger.important_info('start create network')
|
||||||
|
|
||||||
|
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||||
|
args.rank_save_ckpt_flag = 0
|
||||||
|
if args.is_save_on_master:
|
||||||
|
if args.rank == 0:
|
||||||
|
args.rank_save_ckpt_flag = 1
|
||||||
|
else:
|
||||||
|
args.rank_save_ckpt_flag = 1
|
||||||
|
|
||||||
|
# get network and init
|
||||||
|
loss_net, train_net = get_AdvancedEast_net()
|
||||||
|
loss_net.add_flags_recursive(fp32=True)
|
||||||
|
train_net.set_train(False)
|
||||||
|
# pre_trained
|
||||||
|
if args.pre_trained:
|
||||||
|
load_param_into_net(train_net, load_checkpoint(os.path.join(args.saved_model_file_path, args.last_model_name)))
|
||||||
|
# define callbacks
|
||||||
|
|
||||||
|
mindrecordfile256 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(256) + '.mindrecord')
|
||||||
|
|
||||||
|
train_dataset256, batch_num256 = load_adEAST_dataset(mindrecordfile256, batch_size=8,
|
||||||
|
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||||
|
num_parallel_workers=workers)
|
||||||
|
|
||||||
|
mindrecordfile384 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(384) + '.mindrecord')
|
||||||
|
train_dataset384, batch_num384 = load_adEAST_dataset(mindrecordfile384, batch_size=4,
|
||||||
|
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||||
|
num_parallel_workers=workers)
|
||||||
|
mindrecordfile448 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(448) + '.mindrecord')
|
||||||
|
train_dataset448, batch_num448 = load_adEAST_dataset(mindrecordfile448, batch_size=2,
|
||||||
|
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||||
|
num_parallel_workers=workers)
|
||||||
|
|
||||||
|
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||||
|
, eps=1e-7, weight_decay=cfg.decay)
|
||||||
|
model = Model(train_net)
|
||||||
|
time_cb = TimeMonitor(data_size=batch_num256)
|
||||||
|
loss_cb = LossMonitor(per_print_times=batch_num256)
|
||||||
|
callbacks = [time_cb, loss_cb]
|
||||||
|
if args.rank_save_ckpt_flag:
|
||||||
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num256,
|
||||||
|
keep_checkpoint_max=args.ckpt_save_max)
|
||||||
|
save_ckpt_path = args.saved_model_file_path
|
||||||
|
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||||
|
directory=save_ckpt_path,
|
||||||
|
prefix='Epoch_A{}'.format(args.rank))
|
||||||
|
callbacks.append(ckpt_cb)
|
||||||
|
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset256, callbacks=callbacks, dataset_sink_mode=False)
|
||||||
|
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||||
|
, eps=1e-7, weight_decay=cfg.decay)
|
||||||
|
time_cb = TimeMonitor(data_size=batch_num384)
|
||||||
|
loss_cb = LossMonitor(per_print_times=batch_num384)
|
||||||
|
callbacks = [time_cb, loss_cb]
|
||||||
|
if args.rank_save_ckpt_flag:
|
||||||
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num384,
|
||||||
|
keep_checkpoint_max=args.ckpt_save_max)
|
||||||
|
save_ckpt_path = args.saved_model_file_path
|
||||||
|
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||||
|
directory=save_ckpt_path,
|
||||||
|
prefix='Epoch_B{}'.format(args.rank))
|
||||||
|
callbacks.append(ckpt_cb)
|
||||||
|
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||||
|
, eps=1e-7, weight_decay=cfg.decay)
|
||||||
|
model = Model(train_net)
|
||||||
|
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset384, callbacks=callbacks, dataset_sink_mode=False)
|
||||||
|
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||||
|
, eps=1e-7, weight_decay=cfg.decay)
|
||||||
|
time_cb = TimeMonitor(data_size=batch_num448)
|
||||||
|
loss_cb = LossMonitor(per_print_times=batch_num448)
|
||||||
|
callbacks = [time_cb, loss_cb]
|
||||||
|
if args.rank_save_ckpt_flag:
|
||||||
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num448,
|
||||||
|
keep_checkpoint_max=args.ckpt_save_max)
|
||||||
|
save_ckpt_path = args.saved_model_file_path
|
||||||
|
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||||
|
directory=save_ckpt_path,
|
||||||
|
prefix='Epoch_C{}'.format(args.rank))
|
||||||
|
callbacks.append(ckpt_cb)
|
||||||
|
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||||
|
, eps=1e-7, weight_decay=cfg.decay)
|
||||||
|
model = Model(train_net)
|
||||||
|
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset448, callbacks=callbacks, dataset_sink_mode=False)
|
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
#################train advanced_east on dataset########################
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
from mindspore import context, Model
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.nn.optim import AdamWeightDecay
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||||
|
from src.logger import get_logger
|
||||||
|
|
||||||
|
from src.config import config as cfg
|
||||||
|
from src.dataset import load_adEAST_dataset
|
||||||
|
from src.model import get_AdvancedEast_net
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(cloud_args=None):
|
||||||
|
"""parameters"""
|
||||||
|
parser = argparse.ArgumentParser('mindspore adveast training')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||||
|
|
||||||
|
# network related
|
||||||
|
parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load')
|
||||||
|
parser.add_argument('--data_path', default='/disk1/ade/icpr/advanced-east-val_256.mindrecord', type=str)
|
||||||
|
parser.add_argument('--pre_trained_ckpt', default='/disk1/adeast/scripts/1.ckpt', type=str)
|
||||||
|
|
||||||
|
# logging and checkpoint related
|
||||||
|
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||||
|
parser.add_argument('--ckpt_interval', type=int, default=1, help='ckpt_interval')
|
||||||
|
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||||
|
|
||||||
|
# distributed related
|
||||||
|
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||||
|
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||||
|
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
args_opt.initial_epoch = cfg.initial_epoch
|
||||||
|
args_opt.epoch_num = cfg.epoch_num
|
||||||
|
args_opt.learning_rate = cfg.learning_rate
|
||||||
|
args_opt.decay = cfg.decay
|
||||||
|
args_opt.batch_size = cfg.batch_size
|
||||||
|
args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio)
|
||||||
|
args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio
|
||||||
|
args_opt.ckpt_save_max = cfg.ckpt_save_max
|
||||||
|
args_opt.data_dir = cfg.data_dir
|
||||||
|
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
|
||||||
|
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
|
||||||
|
args_opt.train_image_dir_name = cfg.train_image_dir_name
|
||||||
|
args_opt.train_label_dir_name = cfg.train_label_dir_name
|
||||||
|
args_opt.results_dir = cfg.results_dir
|
||||||
|
args_opt.last_model_name = cfg.last_model_name
|
||||||
|
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
context.set_context(device_target=args.device_target, device_id=args.device_id, mode=context.GRAPH_MODE)
|
||||||
|
workers = 32
|
||||||
|
device_num = 1
|
||||||
|
if args.is_distributed:
|
||||||
|
init()
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
device_num = int(os.environ.get("RANK_SIZE"))
|
||||||
|
rank = int(os.environ.get("RANK_ID"))
|
||||||
|
args.rank = args.device_id
|
||||||
|
elif args.device_target == "GPU":
|
||||||
|
context.set_context(device_target=args.device_target)
|
||||||
|
device_num = get_group_size()
|
||||||
|
args.rank = get_rank()
|
||||||
|
args.group_size = device_num
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=8, gradients_mean=True, parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||||
|
else:
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
# logger
|
||||||
|
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||||
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
|
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
|
||||||
|
mindrecordfile = args.data_path
|
||||||
|
|
||||||
|
train_dataset, batch_num = load_adEAST_dataset(mindrecordfile, batch_size=args.batch_size,
|
||||||
|
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||||
|
num_parallel_workers=workers)
|
||||||
|
|
||||||
|
args.logger.save_args(args)
|
||||||
|
# network
|
||||||
|
args.logger.important_info('start create network')
|
||||||
|
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||||
|
args.rank_save_ckpt_flag = 0
|
||||||
|
if args.is_save_on_master:
|
||||||
|
if args.rank == 0:
|
||||||
|
args.rank_save_ckpt_flag = 1
|
||||||
|
else:
|
||||||
|
args.rank_save_ckpt_flag = 1
|
||||||
|
|
||||||
|
# get network and init
|
||||||
|
loss_net, train_net = get_AdvancedEast_net()
|
||||||
|
loss_net.add_flags_recursive(fp32=True)
|
||||||
|
train_net.set_train(True)
|
||||||
|
# pre_trained
|
||||||
|
if args.pre_trained:
|
||||||
|
load_param_into_net(train_net, load_checkpoint(args.pre_trained_ckpt))
|
||||||
|
# define callbacks
|
||||||
|
|
||||||
|
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||||
|
, eps=1e-7, weight_decay=cfg.decay)
|
||||||
|
model = Model(train_net)
|
||||||
|
time_cb = TimeMonitor(data_size=batch_num)
|
||||||
|
loss_cb = LossMonitor(per_print_times=batch_num)
|
||||||
|
callbacks = [time_cb, loss_cb]
|
||||||
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num,
|
||||||
|
keep_checkpoint_max=args.ckpt_save_max)
|
||||||
|
save_ckpt_path = args.saved_model_file_path
|
||||||
|
if args.rank_save_ckpt_flag:
|
||||||
|
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||||
|
directory=save_ckpt_path,
|
||||||
|
prefix='Epoch_{}'.format(args.rank))
|
||||||
|
callbacks.append(ckpt_cb)
|
||||||
|
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
Loading…
Reference in New Issue