forked from mindspore-Ecosystem/mindspore
Add advanced east.
This commit is contained in:
parent
1b136e1713
commit
b1fe54df72
Binary file not shown.
After Width: | Height: | Size: 119 KiB |
|
@ -0,0 +1,194 @@
|
|||
# AdvancedEAST
|
||||
|
||||
## introduction
|
||||
|
||||
AdvancedEAST is inspired by EAST [EAST:An Efficient and Accurate Scene Text Detector](https://arxiv.org/abs/1704.03155v2).
|
||||
The architecture of AdvancedEAST is showed below![AdvancedEast network arch](AdvancedEast.network.png).
|
||||
This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie/AdvancedEAST)(preprocess, network architecture, predict) and [BaoWentz/AdvancedEAST-PyTorch](https://github.com/BaoWentz/AdvancedEAST-PyTorch)(performance).
|
||||
|
||||
## Environment
|
||||
|
||||
* euleros v2r7 x86_64
|
||||
* python 3.7.5
|
||||
|
||||
## Dependences
|
||||
|
||||
* mindspore==1.1
|
||||
* shapely==1.7.1
|
||||
* numpy==1.19.4
|
||||
* tqdm==4.36.1
|
||||
|
||||
## Project files
|
||||
|
||||
* configuration of file
|
||||
cfg.py, control parameters
|
||||
* pre-process data:
|
||||
preprocess.py, resize image
|
||||
* label data:
|
||||
label.py,produce label info
|
||||
* define network
|
||||
model.py and VGG.py
|
||||
* define loss function
|
||||
losses.py
|
||||
* execute training
|
||||
advanced_east.py and dataset.py
|
||||
* predict
|
||||
predict.py and nms.py
|
||||
* scoring
|
||||
score.py
|
||||
* logging
|
||||
logger.py
|
||||
|
||||
```shell
|
||||
.
|
||||
└──advanced_east
|
||||
├── README.md
|
||||
├── scripts
|
||||
├── run_distribute_train_gpu.sh # launch ascend distributed training(8 pcs)
|
||||
├── run_standalone_train_ascend.sh # launch ascend standalone training(1 pcs)
|
||||
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
|
||||
└── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
|
||||
├── src
|
||||
├── cfg.py # parameter configuration
|
||||
├── dataset.py # data preprocessing
|
||||
├── label.py # produce label info
|
||||
├── logger.py # generate learning rate for each step
|
||||
├── model.py # define network
|
||||
├── nms.py # non-maximum suppression
|
||||
├── predict.py # predict boxes
|
||||
├── preprocess.py # pre-process data
|
||||
└── score.py # scoring
|
||||
├── export.py # export model for inference
|
||||
├── prepare_data.py # exec data preprocessing
|
||||
├── eval.py # eval net
|
||||
├── train.py # train net
|
||||
└── train_mindrecord.py # train net on user specified mindrecord
|
||||
```
|
||||
|
||||
## dataset
|
||||
|
||||
ICPR MTWI 2018 challenge 2:Text detection of network image,[Link](https://tianchi.aliyun.com/competition/entrance/231651/introduction). It is not available to download dataset on the origin webpage,
|
||||
the dataset is now provided by the author of the original project,[Baiduyun link](https://pan.baidu.com/s/1NSyc-cHKV3IwDo6qojIrKA), password: ye9y. There are 10000 images and corresponding label
|
||||
information in total in the dataset, which is divided into 2 directories with 9000 and 1000 samples respectively. In the origin training setting, training set and validation set are partitioned at the ratio
|
||||
of 9:1. If you want to use your own dataset, please modify the configuration of dataset in /src/config.py. The organization of dataset file is listed as below:
|
||||
|
||||
> ```bash
|
||||
> .
|
||||
> └─data_dir
|
||||
> ├─images # dataset
|
||||
> └─txt # vertex of text boxes
|
||||
> ```
|
||||
Some parameters in config.py:
|
||||
|
||||
```python
|
||||
'validation_split_ratio': 0.1, # ratio of validation dataset
|
||||
'total_img': 10000, # total number of samples in dataset
|
||||
'data_dir': './icpr/', # dir of dataset
|
||||
'train_fname': 'train.txt', # the file which stores the images file name in training dataset
|
||||
'val_fname': 'val.txt', # the file which stores the images file name in validation dataset
|
||||
'mindsrecord_train_file': 'advanced-east.mindrecord', # mindsrecord of training dataset
|
||||
'mindsrecord_test_file': 'advanced-east-val.mindrecord', # mindsrecord of validation dataset
|
||||
'origin_image_dir_name': 'images_9000/', # dir which stores the original images.
|
||||
'train_image_dir_name': 'images_train/', # dir which stores the preprocessed images.
|
||||
'origin_txt_dir_name': 'txt_9000/', # dir which stores the original text verteices.
|
||||
'train_label_dir_name': 'labels_train/', # dir which stores the preprocessed text verteices.
|
||||
```
|
||||
|
||||
## Run the project
|
||||
|
||||
### Data preprocess
|
||||
|
||||
Resize all the images to fixed size, and convert the label information(the vertex of text box) into the format used in training and evaluation, then the Mindsrecord files are generated.
|
||||
|
||||
```bash
|
||||
python preparedata.py
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
Prepare the VGG16 pre-training model. Due to copyright restrictions, please go to https://github.com/machrisaa/tensorflow-vgg to download the VGG16 pre-training model and place it in the src folder.
|
||||
|
||||
single 1p)
|
||||
|
||||
```bash
|
||||
python train.py --device_target="Ascend" --is_distributed=0 --device_id=0 > output.train.log 2>&1 &
|
||||
```
|
||||
|
||||
single 1p)specific size
|
||||
|
||||
```bash
|
||||
python train_mindrecord.py --device_target="Ascend" --is_distributed=0 --device_id=2 --size=256 > output.train.log 2>&1 &
|
||||
```
|
||||
|
||||
multi Ascends
|
||||
|
||||
```bash
|
||||
# running on distributed environment(8p)
|
||||
bash scripts/run_distribute_train.sh
|
||||
```
|
||||
|
||||
The detailed training parameters are in /src/config.py。
|
||||
|
||||
multi GPUs
|
||||
|
||||
```bash
|
||||
# running on distributed environment(8p)
|
||||
bash scripts/run_distribute_train_gpu.sh
|
||||
```
|
||||
|
||||
The detailed training parameters are in /src/config.py。
|
||||
|
||||
config.py:
|
||||
|
||||
```bash
|
||||
'initial_epoch': 0, # epoch to init
|
||||
'learning_rate': 1e-4, # learning rate when initialization
|
||||
'decay': 5e-4, # weightdecay parameter
|
||||
'epsilon': 1e-4, # the value of epsilon in loss computation
|
||||
'batch_size': 8, # batch size
|
||||
'lambda_inside_score_loss': 4.0, # coef of inside_score_loss
|
||||
'lambda_side_vertex_code_loss': 1.0, # coef of vertex_code_loss
|
||||
"lambda_side_vertex_coord_loss": 1.0, # coef of vertex_coord_loss
|
||||
'max_train_img_size': 448, # max size of training images
|
||||
'max_predict_img_size': 448, # max size of the images to predict
|
||||
'ckpt_save_max': 10, # maximum of ckpt in dir
|
||||
'saved_model_file_path': './saved_model/', # dir of saved model
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Evaluate
|
||||
|
||||
The above python command will run in the background, you can view the results through the file output.eval.log. You will get the accuracy as following:
|
||||
|
||||
## performance
|
||||
|
||||
### Training performance
|
||||
|
||||
The performance listed below are acquired with the default configurations in /src/config.py
|
||||
| Parameters | single Ascend |
|
||||
| -------------------------- | ---------------------------------------------- |
|
||||
| Model Version | AdvancedEAST |
|
||||
| Resources | Ascend 910 |
|
||||
| MindSpore Version | 1.1 |
|
||||
| Dataset | MTWI-2018 |
|
||||
| Training Parameters | epoch=18, batch_size = 8, lr=1e-3 |
|
||||
| Optimizer | AdamWeightDecay |
|
||||
| Loss Function | QuadLoss |
|
||||
| Outputs | matrix with size of 3x64x64,3x96x96,3x112x112 |
|
||||
| Loss | 0.1 |
|
||||
| Total Time | 28mins, 60mins, 90mins |
|
||||
| Checkpoints | 173MB(.ckpt file) |
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
On the default
|
||||
| Parameters | single Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | AdvancedEAST |
|
||||
| Resources | Ascend 910 |
|
||||
| MindSpore Version | 1.1 |
|
||||
| Dataset | 1000 images |
|
||||
| batch_size | 8 |
|
||||
| Outputs | precision, recall, F score |
|
||||
| performance | 94.35, 55.45, 66.31 |
|
|
@ -0,0 +1,180 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################eval advanced_east on dataset########################
|
||||
"""
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||
from src.logger import get_logger
|
||||
from src.predict import predict
|
||||
from src.score import eval_pre_rec_f1
|
||||
|
||||
from src.config import config as cfg
|
||||
from src.dataset import load_adEAST_dataset
|
||||
from src.model import get_AdvancedEast_net, AdvancedEast
|
||||
from src.preprocess import resize_image
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('adveast evaling')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
|
||||
# logging and checkpoint related
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
|
||||
|
||||
parser.add_argument('--ckpt', type=str, default='Epoch_C0-6_4500.ckpt', help='ckpt to load')
|
||||
parser.add_argument('--method', type=str, default='score', choices=['loss', 'score', 'pred'], help='evaluation')
|
||||
parser.add_argument('--path', type=str, help='image path of prediction')
|
||||
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
args_opt.data_dir = cfg.data_dir
|
||||
args_opt.batch_size = 8
|
||||
args_opt.train_image_dir_name = cfg.data_dir + cfg.train_image_dir_name
|
||||
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
|
||||
args_opt.train_label_dir_name = cfg.data_dir + cfg.train_label_dir_name
|
||||
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
|
||||
args_opt.results_dir = cfg.results_dir
|
||||
args_opt.val_fname = cfg.val_fname
|
||||
args_opt.pixel_threshold = cfg.pixel_threshold
|
||||
args_opt.max_predict_img_size = cfg.max_predict_img_size
|
||||
args_opt.last_model_name = cfg.last_model_name
|
||||
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||
|
||||
return args_opt
|
||||
|
||||
|
||||
def eval_loss(arg):
|
||||
"""get network and init"""
|
||||
loss_net, train_net = get_AdvancedEast_net()
|
||||
print(os.path.join(arg.saved_model_file_path, arg.ckpt))
|
||||
load_param_into_net(train_net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||
train_net.set_train(False)
|
||||
loss = 0
|
||||
idx = 0
|
||||
for item in dataset.create_tuple_iterator():
|
||||
loss += loss_net(item[0], item[1])
|
||||
idx += 1
|
||||
print(loss / idx)
|
||||
|
||||
|
||||
def eval_score(arg):
|
||||
"""get network and init"""
|
||||
net = AdvancedEast()
|
||||
load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||
net.set_train(False)
|
||||
obj = eval_pre_rec_f1()
|
||||
with open(os.path.join(arg.data_dir, arg.val_fname), 'r') as f_val:
|
||||
f_list = f_val.readlines()
|
||||
|
||||
img_h, img_w = arg.max_predict_img_size, arg.max_predict_img_size
|
||||
x = np.zeros((arg.batch_size, 3, img_h, img_w), dtype=np.float32)
|
||||
batch_list = np.arange(0, len(f_list), arg.batch_size)
|
||||
for idx in batch_list:
|
||||
gt_list = []
|
||||
for i in range(idx, min(idx + arg.batch_size, len(f_list))):
|
||||
item = f_list[i]
|
||||
img_filename = str(item).strip().split(',')[0][:-4]
|
||||
img_path = os.path.join(arg.train_image_dir_name, img_filename) + '.jpg'
|
||||
img = Image.open(img_path)
|
||||
d_wight, d_height = resize_image(img, arg.max_predict_img_size)
|
||||
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||
img = np.asarray(img)
|
||||
img = img / 1.
|
||||
mean = np.array((123.68, 116.779, 103.939)).reshape([1, 1, 3])
|
||||
img = ((img - mean)).astype(np.float32)
|
||||
img = img.transpose((2, 0, 1))
|
||||
x[i - idx] = img
|
||||
# predict(east, img_path, threshold, cfg.pixel_threshold)
|
||||
gt_list.append(np.load(os.path.join(arg.train_label_dir_name, img_filename) + '.npy'))
|
||||
if idx + arg.batch_size >= len(f_list):
|
||||
x = x[:len(f_list) - idx]
|
||||
y = net(Tensor(x))
|
||||
obj.add(y, gt_list)
|
||||
|
||||
print(obj.val())
|
||||
|
||||
|
||||
def pred(arg):
|
||||
"""pred"""
|
||||
img_path = arg.path
|
||||
net = AdvancedEast()
|
||||
load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||
predict(net, img_path, arg.pixel_threshold)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
if args.is_distributed:
|
||||
if args.device_target == "Ascend":
|
||||
init()
|
||||
context.set_context(device_id=args.device_id)
|
||||
elif args.device_target == "GPU":
|
||||
init()
|
||||
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
device_num = args.group_size
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
print(os.path.join(args.data_dir, args.mindsrecord_test_file))
|
||||
dataset, batch_num = load_adEAST_dataset(os.path.join(args.data_dir,
|
||||
args.mindsrecord_test_file),
|
||||
batch_size=args.batch_size,
|
||||
device_num=device_num, rank_id=args.rank, is_training=False,
|
||||
num_parallel_workers=device_num)
|
||||
|
||||
args.logger.save_args(args)
|
||||
|
||||
# network
|
||||
args.logger.important_info('start create network')
|
||||
|
||||
method_dict = {'loss': eval_loss, 'score': eval_score, 'pred': pred}
|
||||
|
||||
method_dict[args.method](args)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air and onnx models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.model import AdvancedEast
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
parser = argparse.ArgumentParser(description='adveast export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, default="./saved_model/Fin0_9-10_2250.ckpt", help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="./out.om", help="output file name.")
|
||||
parser.add_argument('--width', type=int, default=256, help='input width')
|
||||
parser.add_argument('--height', type=int, default=256, help='input height')
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = AdvancedEast()
|
||||
|
||||
assert args.ckpt_file is not None, "checkpoint_path is None."
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train advanced_east on dataset########################
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from mindspore.common import set_seed
|
||||
from src.label import process_label, process_label_size
|
||||
|
||||
from src.config import config as cfg
|
||||
from src.dataset import transImage2Mind, transImage2Mind_size
|
||||
from src.preprocess import preprocess, preprocess_size
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('mindspore data prepare')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=3, help='device id of GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
# create data
|
||||
preprocess()
|
||||
process_label()
|
||||
preprocess_size(256)
|
||||
process_label_size(256)
|
||||
preprocess_size(384)
|
||||
process_label_size(384)
|
||||
preprocess_size(448)
|
||||
process_label_size(448)
|
||||
# preprocess_size(512)
|
||||
# process_label_size(512)
|
||||
# preprocess_size(640)
|
||||
# process_label_size(640)
|
||||
# preprocess_size(736)
|
||||
# process_label_size(736)
|
||||
# preprocess_size(768)
|
||||
# process_label_size(768)
|
||||
mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file)
|
||||
transImage2Mind(mindrecord_filename)
|
||||
mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_test_file)
|
||||
transImage2Mind(mindrecord_filename, True)
|
||||
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
|
||||
transImage2Mind_size(mindrecord_filename, 256)
|
||||
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
|
||||
transImage2Mind_size(mindrecord_filename, 256, True)
|
||||
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
|
||||
transImage2Mind_size(mindrecord_filename, 384)
|
||||
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
|
||||
transImage2Mind_size(mindrecord_filename, 384, True)
|
||||
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_train_file_var
|
||||
transImage2Mind_size(mindrecord_filename, 448)
|
||||
mindrecord_filename = cfg.data_dir + cfg.mindsrecord_val_file_var
|
||||
transImage2Mind_size(mindrecord_filename, 448, True)
|
|
@ -0,0 +1,37 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
PATH1=$1
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=2
|
||||
export RANK_SIZE=2
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train_mindrecord.py --device_target Ascend --is_distributed 1 --device_id $i --data_path /disk1/adenew/icpr/advanced-east_448.mindrecord > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_standalone_train_ascend.sh"
|
||||
echo "for example: bash run_standalone_train_ascend.sh"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
python train_mindrecord.py \
|
||||
--device_target="Ascend" > output.train.log 2>&1 &
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
configs
|
||||
"""
|
||||
from easydict import EasyDict
|
||||
|
||||
config = EasyDict({
|
||||
'initial_epoch': 2,
|
||||
'epoch_num': 6,
|
||||
'learning_rate': 1e-4,
|
||||
'decay': 3e-5,
|
||||
'epsilon': 1e-4,
|
||||
'batch_size': 2,
|
||||
'ckpt_interval': 2,
|
||||
'lambda_inside_score_loss': 4.0,
|
||||
'lambda_side_vertex_code_loss': 1.0,
|
||||
"lambda_side_vertex_coord_loss": 1.0,
|
||||
'max_train_img_size': 448,
|
||||
'max_predict_img_size': 448,
|
||||
'train_img_size': [256, 384, 448, 512, 640, 736, 768],
|
||||
'predict_img_size': [256, 384, 448, 512, 640, 736, 768],
|
||||
'ckpt_save_max': 500,
|
||||
'validation_split_ratio': 0.1,
|
||||
'total_img': 10000,
|
||||
'data_dir': './icpr/',
|
||||
'val_data_dir': './icpr/images/',
|
||||
'train_fname': 'train.txt',
|
||||
'train_fname_var': 'train_',
|
||||
'val_fname': 'val.txt',
|
||||
'val_fname_var': 'val_',
|
||||
'mindsrecord_train_file': 'advanced-east.mindrecord',
|
||||
'mindsrecord_test_file': 'advanced-east-val.mindrecord',
|
||||
'results_dir': './results/',
|
||||
'origin_image_dir_name': 'images/',
|
||||
'train_image_dir_name': 'images_train/',
|
||||
'origin_txt_dir_name': 'txt/',
|
||||
'train_label_dir_name': 'labels_train/',
|
||||
'train_image_dir_name_var': 'images_train_',
|
||||
'mindsrecord_train_file_var': 'advanced-east_',
|
||||
'train_label_dir_name_var': 'labels_train_',
|
||||
'mindsrecord_val_file_var': 'advanced-east-val_',
|
||||
'show_gt_image_dir_name': 'show_gt_images/',
|
||||
'show_act_image_dir_name': 'show_act_images/',
|
||||
'saved_model_file_path': './saved_model/',
|
||||
'last_model_name': '_.ckpt',
|
||||
'pixel_threshold': 0.8,
|
||||
'iou_threshold': 0.1,
|
||||
'feature_layers_range': range(5, 1, -1),
|
||||
'feature_layers_num': len(range(5, 1, -1)),
|
||||
'pixel_size': 2 ** range(5, 1, -1)[-1],
|
||||
'quiet': True,
|
||||
'side_vertex_pixel_threshold': 0.8,
|
||||
'trunc_threshold': 0.1,
|
||||
'predict_cut_text_line': False,
|
||||
'predict_write2txt': True,
|
||||
'shrink_ratio': 0.2,
|
||||
'shrink_side_ratio': 0.6,
|
||||
'gen_origin_img': True,
|
||||
'draw_gt_quad': False,
|
||||
'draw_act_quad': False,
|
||||
'vgg_npy': '/disk1/ade/vgg16.npy',
|
||||
})
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
dataset.
|
||||
"""
|
||||
import os
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
|
||||
import src.config as cfg
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
cfg = cfg.config
|
||||
|
||||
|
||||
def gen(batch_size=cfg.batch_size, is_val=False):
|
||||
"""generate label"""
|
||||
img_h, img_w = cfg.max_train_img_size, cfg.max_train_img_size
|
||||
x = np.zeros((batch_size, 3, img_h, img_w), dtype=np.float32)
|
||||
pixel_num_h = img_h // 4
|
||||
pixel_num_w = img_w // 4
|
||||
y = np.zeros((batch_size, 7, pixel_num_h, pixel_num_w), dtype=np.float32)
|
||||
if is_val:
|
||||
with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:
|
||||
f_list = f_val.readlines()
|
||||
else:
|
||||
with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
|
||||
f_list = f_train.readlines()
|
||||
while True:
|
||||
batch_list = np.arange(0, len(f_list), batch_size)
|
||||
for idx in batch_list:
|
||||
for idx2 in range(idx, idx + batch_size):
|
||||
# random gen an image name
|
||||
if idx2 < len(f_list):
|
||||
img = f_list[idx2]
|
||||
else:
|
||||
img = np.random.choice(f_list)
|
||||
img_filename = str(img).strip().split(',')[0]
|
||||
# load img and img anno
|
||||
img_path = os.path.join(cfg.data_dir,
|
||||
cfg.train_image_dir_name,
|
||||
img_filename)
|
||||
img = Image.open(img_path)
|
||||
img = np.asarray(img)
|
||||
img = img / 127.5 - 1
|
||||
x[idx2 - idx] = img.transpose((2, 0, 1))
|
||||
gt_file = os.path.join(cfg.data_dir,
|
||||
cfg.train_label_dir_name,
|
||||
img_filename[:-4] + '_gt.npy')
|
||||
y[idx2 - idx] = np.load(gt_file).transpose((2, 0, 1))
|
||||
yield x, y
|
||||
|
||||
|
||||
def transImage2Mind(mindrecord_filename, is_val=False):
|
||||
"""transfer the image to mindrecord"""
|
||||
if os.path.exists(mindrecord_filename):
|
||||
os.remove(mindrecord_filename)
|
||||
os.remove(mindrecord_filename + ".db")
|
||||
|
||||
writer = FileWriter(file_name=mindrecord_filename, shard_num=1)
|
||||
cv_schema = {"image": {"type": "bytes"}, "label": {"type": "float32", "shape": [7, 112, 112]}}
|
||||
writer.add_schema(cv_schema, "advancedEast dataset")
|
||||
|
||||
if is_val:
|
||||
with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:
|
||||
f_list = f_val.readlines()
|
||||
else:
|
||||
with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
|
||||
f_list = f_train.readlines()
|
||||
|
||||
data = []
|
||||
for item in f_list:
|
||||
img_filename = str(item).strip().split(',')[0]
|
||||
img_path = os.path.join(cfg.data_dir,
|
||||
cfg.train_image_dir_name,
|
||||
img_filename)
|
||||
print(img_path)
|
||||
with open(img_path, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
gt_file = os.path.join(cfg.data_dir,
|
||||
cfg.train_label_dir_name,
|
||||
img_filename[:-4] + '_gt.npy')
|
||||
labels = np.load(gt_file)
|
||||
labels = np.transpose(labels, (2, 0, 1))
|
||||
data.append({"image": img,
|
||||
"label": np.array(labels, np.float32)})
|
||||
if len(data) == 32:
|
||||
writer.write_raw_data(data)
|
||||
print('len(data):{}'.format(len(data)))
|
||||
data = []
|
||||
if data:
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
def transImage2Mind_size(mindrecord_filename, width=256, is_val=False):
|
||||
"""transfer the image to mindrecord at specified size"""
|
||||
mindrecord_filename = mindrecord_filename + str(width) + '.mindrecord'
|
||||
if os.path.exists(mindrecord_filename):
|
||||
os.remove(mindrecord_filename)
|
||||
os.remove(mindrecord_filename + ".db")
|
||||
|
||||
writer = FileWriter(file_name=mindrecord_filename, shard_num=1)
|
||||
cv_schema = {"image": {"type": "bytes"}, "label": {"type": "float32", "shape": [7, width // 4, width // 4]}}
|
||||
writer.add_schema(cv_schema, "advancedEast dataset")
|
||||
|
||||
if is_val:
|
||||
with open(os.path.join(cfg.data_dir, cfg.val_fname_var + str(width) + '.txt'), 'r') as f_val:
|
||||
f_list = f_val.readlines()
|
||||
else:
|
||||
with open(os.path.join(cfg.data_dir, cfg.train_fname_var + str(width) + '.txt'), 'r') as f_train:
|
||||
f_list = f_train.readlines()
|
||||
data = []
|
||||
for item in f_list:
|
||||
|
||||
img_filename = str(item).strip().split(',')[0]
|
||||
img_path = os.path.join(cfg.data_dir,
|
||||
cfg.train_image_dir_name_var + str(width),
|
||||
img_filename)
|
||||
with open(img_path, 'rb') as f:
|
||||
img = f.read()
|
||||
gt_file = os.path.join(cfg.data_dir,
|
||||
cfg.train_label_dir_name_var + str(width),
|
||||
img_filename[:-4] + '_gt.npy')
|
||||
labels = np.load(gt_file)
|
||||
labels = np.transpose(labels, (2, 0, 1))
|
||||
data.append({"image": img,
|
||||
"label": np.array(labels, np.float32)})
|
||||
|
||||
if len(data) == 32:
|
||||
writer.write_raw_data(data)
|
||||
print('len(data):{}'.format(len(data)))
|
||||
data = []
|
||||
|
||||
if data:
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
def load_adEAST_dataset(mindrecord_file, batch_size=64, device_num=1, rank_id=0,
|
||||
is_training=True, num_parallel_workers=8):
|
||||
"""load mindrecord"""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "label"], num_shards=device_num, shard_id=rank_id,
|
||||
num_parallel_workers=8, shuffle=is_training)
|
||||
hwc_to_chw = vision.HWC2CHW()
|
||||
cd = vision.Decode()
|
||||
ds = ds.map(operations=cd, input_columns=["image"])
|
||||
rc = vision.RandomColorAdjust(brightness=0.1, contrast=0.2, saturation=0.2)
|
||||
vn = vision.Normalize(mean=(123.68, 116.779, 103.939), std=(1., 1., 1.))
|
||||
ds = ds.map(operations=[rc, vn, hwc_to_chw], input_columns=["image"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
batch_num = ds.get_dataset_size()
|
||||
ds = ds.shuffle(batch_num)
|
||||
return ds, batch_num
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
labeling
|
||||
"""
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.config import config as cfg
|
||||
|
||||
|
||||
def point_inside_of_quad(px, py, quad_xy_list, p_min, p_max):
|
||||
"""test box in or not"""
|
||||
if (p_min[0] <= px <= p_max[0]) and (p_min[1] <= py <= p_max[1]):
|
||||
xy_list = np.zeros((4, 2))
|
||||
xy_list[:3, :] = quad_xy_list[1:4, :] - quad_xy_list[:3, :]
|
||||
xy_list[3] = quad_xy_list[0, :] - quad_xy_list[3, :]
|
||||
yx_list = np.zeros((4, 2))
|
||||
yx_list[:, :] = quad_xy_list[:, -1:-3:-1]
|
||||
a = xy_list * ([py, px] - yx_list)
|
||||
b = a[:, 0] - a[:, 1]
|
||||
if np.amin(b) >= 0 or np.amax(b) <= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def point_inside_of_nth_quad(px, py, xy_list, shrink_1, long_edge):
|
||||
"""test point in which box"""
|
||||
nth = -1
|
||||
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
|
||||
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
|
||||
for ith in range(2):
|
||||
quad_xy_list = np.concatenate((
|
||||
np.reshape(xy_list[vs[long_edge][ith][0]], (1, 2)),
|
||||
np.reshape(shrink_1[vs[long_edge][ith][1]], (1, 2)),
|
||||
np.reshape(shrink_1[vs[long_edge][ith][2]], (1, 2)),
|
||||
np.reshape(xy_list[vs[long_edge][ith][3]], (1, 2))), axis=0)
|
||||
p_min = np.amin(quad_xy_list, axis=0)
|
||||
p_max = np.amax(quad_xy_list, axis=0)
|
||||
if point_inside_of_quad(px, py, quad_xy_list, p_min, p_max):
|
||||
if nth == -1:
|
||||
nth = ith
|
||||
else:
|
||||
nth = -1
|
||||
break
|
||||
return nth
|
||||
|
||||
|
||||
def shrink(xy_list, ratio=cfg.shrink_ratio):
|
||||
"""shrink"""
|
||||
if ratio == 0.0:
|
||||
return xy_list, xy_list
|
||||
diff_1to3 = xy_list[:3, :] - xy_list[1:4, :]
|
||||
diff_4 = xy_list[3:4, :] - xy_list[0:1, :]
|
||||
diff = np.concatenate((diff_1to3, diff_4), axis=0)
|
||||
dis = np.sqrt(np.sum(np.square(diff), axis=-1))
|
||||
# determine which are long or short edges
|
||||
long_edge = int(np.argmax(np.sum(np.reshape(dis, (2, 2)), axis=0)))
|
||||
short_edge = 1 - long_edge
|
||||
# cal r length array
|
||||
r = [np.minimum(dis[i], dis[(i + 1) % 4]) for i in range(4)]
|
||||
# cal theta array
|
||||
diff_abs = np.abs(diff)
|
||||
diff_abs[:, 0] += cfg.epsilon
|
||||
theta = np.arctan(diff_abs[:, 1] / diff_abs[:, 0])
|
||||
# shrink two long edges
|
||||
temp_new_xy_list = np.copy(xy_list)
|
||||
shrink_edge(xy_list, temp_new_xy_list, long_edge, r, theta, ratio)
|
||||
shrink_edge(xy_list, temp_new_xy_list, long_edge + 2, r, theta, ratio)
|
||||
# shrink two short edges
|
||||
new_xy_list = np.copy(temp_new_xy_list)
|
||||
shrink_edge(temp_new_xy_list, new_xy_list, short_edge, r, theta, ratio)
|
||||
shrink_edge(temp_new_xy_list, new_xy_list, short_edge + 2, r, theta, ratio)
|
||||
return temp_new_xy_list, new_xy_list, long_edge # 缩短后的长边,缩短后的短边,长边下标
|
||||
|
||||
|
||||
def shrink_edge(xy_list, new_xy_list, edge, r, theta, ratio=cfg.shrink_ratio):
|
||||
"""shrink edge"""
|
||||
if ratio == 0.0:
|
||||
return
|
||||
start_point = edge # 边的起始点下标(0或1)
|
||||
end_point = (edge + 1) % 4
|
||||
long_start_sign_x = np.sign(
|
||||
xy_list[end_point, 0] - xy_list[start_point, 0])
|
||||
new_xy_list[start_point, 0] = \
|
||||
xy_list[start_point, 0] + \
|
||||
long_start_sign_x * ratio * r[start_point] * np.cos(theta[start_point])
|
||||
long_start_sign_y = np.sign(
|
||||
xy_list[end_point, 1] - xy_list[start_point, 1])
|
||||
new_xy_list[start_point, 1] = \
|
||||
xy_list[start_point, 1] + \
|
||||
long_start_sign_y * ratio * r[start_point] * np.sin(theta[start_point])
|
||||
# long edge one, end point
|
||||
long_end_sign_x = -1 * long_start_sign_x
|
||||
new_xy_list[end_point, 0] = \
|
||||
xy_list[end_point, 0] + \
|
||||
long_end_sign_x * ratio * r[end_point] * np.cos(theta[start_point])
|
||||
long_end_sign_y = -1 * long_start_sign_y
|
||||
new_xy_list[end_point, 1] = \
|
||||
xy_list[end_point, 1] + \
|
||||
long_end_sign_y * ratio * r[end_point] * np.sin(theta[start_point])
|
||||
|
||||
|
||||
def precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax,
|
||||
jmin, jmax, p_min, p_max, gt, long_edge, draw):
|
||||
"""precess list"""
|
||||
for i in range(imin, imax):
|
||||
for j in range(jmin, jmax):
|
||||
px = (j + 0.5) * cfg.pixel_size
|
||||
py = (i + 0.5) * cfg.pixel_size
|
||||
if point_inside_of_quad(px, py,
|
||||
shrink_xy_list, p_min, p_max):
|
||||
gt[i, j, 0] = 1
|
||||
line_width, line_color = 1, 'red'
|
||||
ith = point_inside_of_nth_quad(px, py,
|
||||
xy_list,
|
||||
shrink_1,
|
||||
long_edge)
|
||||
vs = [[[3, 0], [1, 2]], [[0, 1], [2, 3]]]
|
||||
if ith in range(2):
|
||||
gt[i, j, 1] = 1
|
||||
if ith == 0:
|
||||
line_width, line_color = 2, 'yellow'
|
||||
else:
|
||||
line_width, line_color = 2, 'green'
|
||||
gt[i, j, 2:3] = ith
|
||||
gt[i, j, 3:5] = \
|
||||
xy_list[vs[long_edge][ith][0]] - [px, py]
|
||||
gt[i, j, 5:] = \
|
||||
xy_list[vs[long_edge][ith][1]] - [px, py]
|
||||
draw.line([(px - 0.5 * cfg.pixel_size,
|
||||
py - 0.5 * cfg.pixel_size),
|
||||
(px + 0.5 * cfg.pixel_size,
|
||||
py - 0.5 * cfg.pixel_size),
|
||||
(px + 0.5 * cfg.pixel_size,
|
||||
py + 0.5 * cfg.pixel_size),
|
||||
(px - 0.5 * cfg.pixel_size,
|
||||
py + 0.5 * cfg.pixel_size),
|
||||
(px - 0.5 * cfg.pixel_size,
|
||||
py - 0.5 * cfg.pixel_size)],
|
||||
width=line_width, fill=line_color)
|
||||
return gt
|
||||
|
||||
|
||||
def process_label(data_dir=cfg.data_dir):
|
||||
"""process label"""
|
||||
with open(os.path.join(data_dir, cfg.val_fname), 'r') as f_val:
|
||||
f_list = f_val.readlines()
|
||||
with open(os.path.join(data_dir, cfg.train_fname), 'r') as f_train:
|
||||
f_list.extend(f_train.readlines())
|
||||
for line, _ in zip(f_list, tqdm(range(len(f_list)))):
|
||||
line_cols = str(line).strip().split(',')
|
||||
img_name, width, height = \
|
||||
line_cols[0].strip(), int(line_cols[1].strip()), \
|
||||
int(line_cols[2].strip())
|
||||
gt = np.zeros((height // cfg.pixel_size, width // cfg.pixel_size, 7))
|
||||
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
|
||||
xy_list_array = np.load(os.path.join(train_label_dir,
|
||||
img_name[:-4] + '.npy'))
|
||||
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
|
||||
with Image.open(os.path.join(train_image_dir, img_name)) as im:
|
||||
draw = ImageDraw.Draw(im)
|
||||
for xy_list in xy_list_array:
|
||||
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
|
||||
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
|
||||
p_min = np.amin(shrink_xy_list, axis=0)
|
||||
p_max = np.amax(shrink_xy_list, axis=0)
|
||||
# floor of the float
|
||||
ji_min = (p_min / cfg.pixel_size - 0.5).astype(int) - 1
|
||||
# +1 for ceil of the float and +1 for include the end
|
||||
ji_max = (p_max / cfg.pixel_size - 0.5).astype(int) + 3
|
||||
imin = np.maximum(0, ji_min[1])
|
||||
imax = np.minimum(height // cfg.pixel_size, ji_max[1])
|
||||
jmin = np.maximum(0, ji_min[0])
|
||||
jmax = np.minimum(width // cfg.pixel_size, ji_max[0])
|
||||
gt = precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax, jmin, jmax,
|
||||
p_min, p_max, gt, long_edge, draw)
|
||||
act_image_dir = os.path.join(cfg.data_dir,
|
||||
cfg.show_act_image_dir_name)
|
||||
if cfg.draw_act_quad:
|
||||
im.save(os.path.join(act_image_dir, img_name))
|
||||
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
|
||||
np.save(os.path.join(train_label_dir,
|
||||
img_name[:-4] + '_gt.npy'), gt)
|
||||
|
||||
|
||||
def process_label_size(width=256, data_dir=cfg.data_dir):
|
||||
"""process label at specific size"""
|
||||
with open(os.path.join(data_dir, cfg.val_fname_var + str(width) + '.txt'), 'r') as f_val:
|
||||
f_list = f_val.readlines()
|
||||
with open(os.path.join(data_dir, cfg.train_fname_var + str(width) + '.txt'), 'r') as f_train:
|
||||
f_list.extend(f_train.readlines())
|
||||
for line, _ in zip(f_list, tqdm(range(len(f_list)))):
|
||||
line_cols = str(line).strip().split(',')
|
||||
img_name, width, height = \
|
||||
line_cols[0].strip(), int(line_cols[1].strip()), \
|
||||
int(line_cols[2].strip())
|
||||
gt = np.zeros((height // cfg.pixel_size, width // cfg.pixel_size, 7))
|
||||
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
|
||||
xy_list_array = np.load(os.path.join(train_label_dir,
|
||||
img_name[:-4] + '.npy'))
|
||||
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
|
||||
with Image.open(os.path.join(train_image_dir, img_name)) as im:
|
||||
draw = ImageDraw.Draw(im)
|
||||
for xy_list in xy_list_array:
|
||||
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
|
||||
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
|
||||
p_min = np.amin(shrink_xy_list, axis=0)
|
||||
p_max = np.amax(shrink_xy_list, axis=0)
|
||||
# floor of the float
|
||||
ji_min = (p_min / cfg.pixel_size - 0.5).astype(int) - 1
|
||||
# +1 for ceil of the float and +1 for include the end
|
||||
ji_max = (p_max / cfg.pixel_size - 0.5).astype(int) + 3
|
||||
imin = np.maximum(0, ji_min[1])
|
||||
imax = np.minimum(height // cfg.pixel_size, ji_max[1])
|
||||
jmin = np.maximum(0, ji_min[0])
|
||||
jmax = np.minimum(width // cfg.pixel_size, ji_max[0])
|
||||
gt = precess_list(shrink_xy_list, xy_list, shrink_1, imin, imax, jmin, jmax, p_min, p_max, gt,
|
||||
long_edge,
|
||||
draw)
|
||||
act_image_dir = os.path.join(cfg.data_dir,
|
||||
cfg.show_act_image_dir_name)
|
||||
if cfg.draw_act_quad:
|
||||
im.save(os.path.join(act_image_dir, img_name))
|
||||
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
|
||||
np.save(os.path.join(train_label_dir,
|
||||
img_name[:-4] + '_gt.npy'), gt)
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
get logger.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
set up logging file.
|
||||
|
||||
Args:
|
||||
logger_name (string): logger name.
|
||||
log_dir (string): path of logger.
|
||||
|
||||
Returns:
|
||||
string, logger path
|
||||
"""
|
||||
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""set up log file"""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*' * 70 + '\n') * line_width
|
||||
important_msg += ('*' * line_width + '\n') * 2
|
||||
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
|
||||
important_msg += ('*' * line_width + '\n') * 2
|
||||
important_msg += ('*' * 70 + '\n') * line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
""" get logger"""
|
||||
logger = LOGGER("mindversion", rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
|
@ -0,0 +1,273 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
dataset processing.
|
||||
"""
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.config import config as cfg
|
||||
|
||||
|
||||
class AdvancedEast(nn.Cell):
|
||||
"""
|
||||
EAST network definition.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(AdvancedEast, self).__init__()
|
||||
vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item()
|
||||
|
||||
def get_var(name, idx):
|
||||
value = vgg_dict[name][idx]
|
||||
if idx == 0:
|
||||
value = np.transpose(value, [3, 2, 0, 1])
|
||||
var = Tensor(value)
|
||||
return var
|
||||
|
||||
def get_conv_var(name):
|
||||
filters = get_var(name, 0)
|
||||
biases = get_var(name, 1)
|
||||
return filters, biases
|
||||
|
||||
class VGG_Conv(nn.Cell):
|
||||
"""
|
||||
VGG16 network definition.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
super(VGG_Conv, self).__init__()
|
||||
filters, conv_biases = get_conv_var(name)
|
||||
out_channels, in_channels, filter_size, _ = filters.shape
|
||||
self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]),
|
||||
name='weight')
|
||||
self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias')
|
||||
self.relu = P.ReLU()
|
||||
self.gn = nn.GroupNorm(32, out_channels)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv2d(x, self.weight)
|
||||
output = self.bias_add(output, self.bias)
|
||||
output = self.gn(output)
|
||||
output = self.relu(output)
|
||||
return output
|
||||
|
||||
self.conv1_1 = VGG_Conv('conv1_1')
|
||||
self.conv1_2 = VGG_Conv('conv1_2')
|
||||
self.pool1 = nn.MaxPool2d(2, 2)
|
||||
self.conv2_1 = VGG_Conv('conv2_1')
|
||||
self.conv2_2 = VGG_Conv('conv2_2')
|
||||
self.pool2 = nn.MaxPool2d(2, 2)
|
||||
self.conv3_1 = VGG_Conv('conv3_1')
|
||||
self.conv3_2 = VGG_Conv('conv3_2')
|
||||
self.conv3_3 = VGG_Conv('conv3_3')
|
||||
self.pool3 = nn.MaxPool2d(2, 2)
|
||||
self.conv4_1 = VGG_Conv('conv4_1')
|
||||
self.conv4_2 = VGG_Conv('conv4_2')
|
||||
self.conv4_3 = VGG_Conv('conv4_3')
|
||||
self.pool4 = nn.MaxPool2d(2, 2)
|
||||
self.conv5_1 = VGG_Conv('conv5_1')
|
||||
self.conv5_2 = VGG_Conv('conv5_2')
|
||||
self.conv5_3 = VGG_Conv('conv5_3')
|
||||
self.pool5 = nn.MaxPool2d(2, 2)
|
||||
self.merging1 = self.merging(i=2)
|
||||
self.merging2 = self.merging(i=3)
|
||||
self.merging3 = self.merging(i=4)
|
||||
self.last_bn = nn.GroupNorm(16, 32)
|
||||
self.conv_last = nn.Conv2d(32, 32, kernel_size=3, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.inside_score_conv = nn.Conv2d(32, 1, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.side_v_angle_conv = nn.Conv2d(32, 2, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.side_v_coord_conv = nn.Conv2d(32, 4, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.op_concat = P.Concat(axis=1)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def merging(self, i=2):
|
||||
"""
|
||||
def merge layer
|
||||
"""
|
||||
in_size = {'2': 1024, '3': 384, '4': 192}
|
||||
layers = [
|
||||
nn.Conv2d(in_size[str(i)], 128 // 2 ** (i - 2), kernel_size=1, stride=1, has_bias=True,
|
||||
weight_init='XavierUniform'),
|
||||
nn.GroupNorm(16, 128 // 2 ** (i - 2)),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128 // 2 ** (i - 2), 128 // 2 ** (i - 2), kernel_size=3, stride=1, has_bias=True,
|
||||
weight_init='XavierUniform'),
|
||||
nn.GroupNorm(16, 128 // 2 ** (i - 2)),
|
||||
nn.ReLU()]
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
forward func
|
||||
"""
|
||||
f4 = self.conv1_1(x)
|
||||
f4 = self.conv1_2(f4)
|
||||
f4 = self.pool1(f4)
|
||||
f4 = self.conv2_1(f4)
|
||||
f4 = self.conv2_2(f4)
|
||||
f4 = self.pool2(f4)
|
||||
f3 = self.conv3_1(f4)
|
||||
f3 = self.conv3_2(f3)
|
||||
f3 = self.conv3_3(f3)
|
||||
f3 = self.pool3(f3)
|
||||
f2 = self.conv4_1(f3)
|
||||
f2 = self.conv4_2(f2)
|
||||
f2 = self.conv4_3(f2)
|
||||
f2 = self.pool4(f2)
|
||||
f1 = self.conv5_1(f2)
|
||||
f1 = self.conv5_2(f1)
|
||||
f1 = self.conv5_3(f1)
|
||||
f1 = self.pool5(f1)
|
||||
h1 = f1
|
||||
_, _, h_, w_ = P.Shape()(h1)
|
||||
H1 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h1)
|
||||
concat1 = self.op_concat((H1, f2))
|
||||
h2 = self.merging1(concat1)
|
||||
_, _, h_, w_ = P.Shape()(h2)
|
||||
H2 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h2)
|
||||
concat2 = self.op_concat((H2, f3))
|
||||
h3 = self.merging2(concat2)
|
||||
_, _, h_, w_ = P.Shape()(h3)
|
||||
H3 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h3)
|
||||
concat3 = self.op_concat((H3, f4))
|
||||
h4 = self.merging3(concat3)
|
||||
before_output = self.relu(self.last_bn(self.conv_last(h4)))
|
||||
inside_score = self.inside_score_conv(before_output)
|
||||
side_v_angle = self.side_v_angle_conv(before_output)
|
||||
side_v_coord = self.side_v_coord_conv(before_output)
|
||||
east_detect = self.op_concat((inside_score, side_v_coord, side_v_angle))
|
||||
return east_detect
|
||||
|
||||
|
||||
def dice_loss(gt_score, pred_score):
|
||||
"""dice_loss1"""
|
||||
inter = P.ReduceSum()(gt_score * pred_score)
|
||||
union = P.ReduceSum()(gt_score) + P.ReduceSum()(pred_score) + 1e-5
|
||||
return 1. - (2 * (inter / union))
|
||||
|
||||
|
||||
def dice_loss2(gt_score, pred_score, mask):
|
||||
"""dice_loss2"""
|
||||
inter = P.ReduceSum()(gt_score * pred_score * mask)
|
||||
union = P.ReduceSum()(gt_score * mask) + P.ReduceSum()(pred_score * mask) + 1e-5
|
||||
return 1. - (2 * (inter / union))
|
||||
|
||||
|
||||
def quad_loss(y_true, y_pred,
|
||||
lambda_inside_score_loss=0.2,
|
||||
lambda_side_vertex_code_loss=0.1,
|
||||
lambda_side_vertex_coord_loss=1.0,
|
||||
epsilon=1e-4):
|
||||
"""quad loss"""
|
||||
y_true = P.Transpose()(y_true, (0, 2, 3, 1))
|
||||
y_pred = P.Transpose()(y_pred, (0, 2, 3, 1))
|
||||
logits = y_pred[:, :, :, :1]
|
||||
labels = y_true[:, :, :, :1]
|
||||
predicts = P.Sigmoid()(logits)
|
||||
inside_score_loss = dice_loss(labels, predicts)
|
||||
inside_score_loss = inside_score_loss * lambda_inside_score_loss
|
||||
# loss for side_vertex_code
|
||||
vertex_logitsp = P.Sigmoid()(y_pred[:, :, :, 1:2])
|
||||
vertex_labelsp = y_true[:, :, :, 1:2]
|
||||
vertex_logitsn = P.Sigmoid()(y_pred[:, :, :, 2:3])
|
||||
vertex_labelsn = y_true[:, :, :, 2:3]
|
||||
labels2 = y_true[:, :, :, 1:2]
|
||||
side_vertex_code_lossp = dice_loss2(vertex_labelsp, vertex_logitsp, labels)
|
||||
side_vertex_code_lossn = dice_loss2(vertex_labelsn, vertex_logitsn, labels2)
|
||||
side_vertex_code_loss = (side_vertex_code_lossp + side_vertex_code_lossn) * lambda_side_vertex_code_loss
|
||||
# loss for side_vertex_coord delta
|
||||
g_hat = y_pred[:, :, :, 3:] # N*W*H*8
|
||||
g_true = y_true[:, :, :, 3:]
|
||||
vertex_weights = P.Cast()(P.Equal()(y_true[:, :, :, 1], 1), mindspore.float32)
|
||||
# vertex_weights=y_true[:, :, :,1]
|
||||
pixel_wise_smooth_l1norm = smooth_l1_loss(g_hat, g_true, vertex_weights)
|
||||
side_vertex_coord_loss = P.ReduceSum()(pixel_wise_smooth_l1norm) / (
|
||||
P.ReduceSum()(vertex_weights) + epsilon)
|
||||
side_vertex_coord_loss = side_vertex_coord_loss * lambda_side_vertex_coord_loss
|
||||
# print(inside_score_loss, side_vertex_code_loss, side_vertex_coord_loss)
|
||||
return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss
|
||||
|
||||
|
||||
def smooth_l1_loss(prediction_tensor, target_tensor, weights):
|
||||
"""smooth l1 loss"""
|
||||
n_q = P.Reshape()(quad_norm(target_tensor), weights.shape)
|
||||
diff = P.SmoothL1Loss()(prediction_tensor, target_tensor)
|
||||
pixel_wise_smooth_l1norm = P.ReduceSum()(diff, -1) / n_q * weights
|
||||
return pixel_wise_smooth_l1norm
|
||||
|
||||
|
||||
def quad_norm(g_true, epsilon=1e-4):
|
||||
""" quad norm"""
|
||||
shape = g_true.shape
|
||||
delta_xy_matrix = P.Reshape()(g_true, (shape[0] * shape[1] * shape[2], 2, 2))
|
||||
diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :]
|
||||
square = diff * diff
|
||||
distance = P.Sqrt()(P.ReduceSum()(square, -1))
|
||||
distance = distance * 4.0
|
||||
distance = distance + epsilon
|
||||
return P.Reshape()(distance, (shape[0], shape[1], shape[2]))
|
||||
|
||||
|
||||
class EastWithLossCell(nn.Cell):
|
||||
"""get loss cell"""
|
||||
|
||||
def __init__(self, network, config=None):
|
||||
super(EastWithLossCell, self).__init__()
|
||||
self.East_network = network
|
||||
self.cat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, image, label):
|
||||
y_pred = self.East_network(image)
|
||||
loss = quad_loss(label, y_pred)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainStepWrap(nn.Cell):
|
||||
"""train net warper"""
|
||||
def __init__(self, network, steps, config=None):
|
||||
super(TrainStepWrap, self).__init__()
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = AdamWeightDecay(self.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = 3.0
|
||||
|
||||
def construct(self, image, label):
|
||||
weights = self.weights
|
||||
loss = self.network(image, label)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(image, label, sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
def get_AdvancedEast_net(configure=None, steps=1, mode=True):
|
||||
"""
|
||||
Get network of wide&deep model.
|
||||
"""
|
||||
AdvancedEast_net = AdvancedEast()
|
||||
loss_net = EastWithLossCell(AdvancedEast_net, configure)
|
||||
train_net = TrainStepWrap(loss_net, steps, configure)
|
||||
return loss_net, train_net
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
NMS module
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from src.config import config as cfg
|
||||
|
||||
|
||||
def should_merge(region, i, j):
|
||||
"""
|
||||
test merge
|
||||
"""
|
||||
neighbor = {(i, j - 1)}
|
||||
return not region.isdisjoint(neighbor)
|
||||
|
||||
|
||||
def region_neighbor(region_set):
|
||||
"""
|
||||
cal the neighbor of the region
|
||||
"""
|
||||
region_pixels = np.array(list(region_set))
|
||||
j_min = np.amin(region_pixels, axis=0)[1] - 1
|
||||
j_max = np.amax(region_pixels, axis=0)[1] + 1
|
||||
i_m = np.amin(region_pixels, axis=0)[0] + 1
|
||||
region_pixels[:, 0] += 1
|
||||
neighbor = {(region_pixels[n, 0], region_pixels[n, 1]) for n in
|
||||
range(len(region_pixels))}
|
||||
neighbor.add((i_m, j_min))
|
||||
neighbor.add((i_m, j_max))
|
||||
return neighbor
|
||||
|
||||
|
||||
def region_group(region_list):
|
||||
"""
|
||||
group regions
|
||||
"""
|
||||
S = [i for i in range(len(region_list))]
|
||||
D = []
|
||||
while S:
|
||||
m = S.pop(0)
|
||||
if not S:
|
||||
# S has only one element, put it to D
|
||||
D.append([m])
|
||||
else:
|
||||
D.append(rec_region_merge(region_list, m, S))
|
||||
return D
|
||||
|
||||
|
||||
def rec_region_merge(region_list, m, S):
|
||||
"""
|
||||
merge regions
|
||||
"""
|
||||
rows = [m]
|
||||
tmp = []
|
||||
for n in S:
|
||||
if not region_neighbor(region_list[m]).isdisjoint(region_list[n]) or \
|
||||
not region_neighbor(region_list[n]).isdisjoint(region_list[m]):
|
||||
tmp.append(n)
|
||||
for d in tmp:
|
||||
S.remove(d)
|
||||
for e in tmp:
|
||||
rows.extend(rec_region_merge(region_list, e, S))
|
||||
return rows
|
||||
|
||||
|
||||
def nms(predict, activation_pixels, threshold=cfg.side_vertex_pixel_threshold):
|
||||
"""
|
||||
perform nms on results
|
||||
"""
|
||||
region_list = []
|
||||
for i, j in zip(activation_pixels[0], activation_pixels[1]):
|
||||
merge = False
|
||||
for k, value in enumerate(region_list):
|
||||
if should_merge(value, i, j):
|
||||
region_list[k].add((i, j))
|
||||
merge = True
|
||||
if not merge:
|
||||
region_list.append({(i, j)})
|
||||
D = region_group(region_list)
|
||||
quad_list = np.zeros((len(D), 4, 2))
|
||||
score_list = np.zeros((len(D), 4))
|
||||
for group, g_th in zip(D, range(len(D))):
|
||||
total_score = np.zeros((4, 2))
|
||||
for row in group:
|
||||
for ij in region_list[row]:
|
||||
score = predict[ij[0], ij[1], 1]
|
||||
if score >= threshold:
|
||||
ith_score = predict[ij[0], ij[1], 2:3]
|
||||
if not (cfg.trunc_threshold <= ith_score < 1 -
|
||||
cfg.trunc_threshold):
|
||||
ith = int(np.around(ith_score))
|
||||
total_score[ith * 2:(ith + 1) * 2] += score
|
||||
px = (ij[1] + 0.5) * cfg.pixel_size
|
||||
py = (ij[0] + 0.5) * cfg.pixel_size
|
||||
p_v = [px, py] + np.reshape(predict[ij[0], ij[1], 3:7],
|
||||
(2, 2))
|
||||
quad_list[g_th, ith * 2:(ith + 1) * 2] += score * p_v
|
||||
score_list[g_th] = total_score[:, 0]
|
||||
quad_list[g_th] /= (total_score + cfg.epsilon)
|
||||
return score_list, quad_list
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################predict the qual line of images ########################
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from mindspore import Tensor
|
||||
|
||||
from src.label import point_inside_of_quad
|
||||
from src.config import config as cfg
|
||||
from src.nms import nms
|
||||
from src.preprocess import resize_image
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
"""`y = 1 / (1 + exp(-x))`"""
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
def cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array, img_path, s):
|
||||
"""
|
||||
cut text line
|
||||
"""
|
||||
geo /= [scale_ratio_w, scale_ratio_h]
|
||||
p_min = np.amin(geo, axis=0)
|
||||
p_max = np.amax(geo, axis=0)
|
||||
min_xy = p_min.astype(int)
|
||||
max_xy = p_max.astype(int) + 2
|
||||
sub_im_arr = im_array[min_xy[1]:max_xy[1], min_xy[0]:max_xy[0], :].copy()
|
||||
for m in range(min_xy[1], max_xy[1]):
|
||||
for n in range(min_xy[0], max_xy[0]):
|
||||
if not point_inside_of_quad(n, m, geo, p_min, p_max):
|
||||
sub_im_arr[m - min_xy[1], n - min_xy[0], :] = 255
|
||||
sub_im = Image.fromarray(sub_im_arr)
|
||||
sub_im.save(img_path + '_subim%d.jpg' % s)
|
||||
|
||||
|
||||
def predict(east_detect, img_path, pixel_threshold, quiet=True):
|
||||
"""
|
||||
predict to txt and image
|
||||
"""
|
||||
img = Image.open(img_path)
|
||||
d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
|
||||
d_wight = max(d_wight, d_height)
|
||||
d_height = max(d_wight, d_height)
|
||||
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||
img = np.asarray(img)
|
||||
img = img / 127.5 - 1
|
||||
img = img.transpose((2, 0, 1))
|
||||
x = Tensor(np.expand_dims(img, axis=0), "float32")
|
||||
y = east_detect(x).asnumpy()
|
||||
y = np.squeeze(y, axis=0)
|
||||
|
||||
if y.shape[0] == 7:
|
||||
y = y.transpose((1, 2, 0)) # CHW->HWC
|
||||
|
||||
y[:, :, :3] = sigmoid(y[:, :, :3])
|
||||
cond = np.greater_equal(y[:, :, 0], pixel_threshold)
|
||||
activation_pixels = np.where(cond)
|
||||
quad_scores, quad_after_nms = nms(y, activation_pixels)
|
||||
with Image.open(img_path) as im:
|
||||
im_array = np.asarray(im.convert('RGB'))
|
||||
d_wight, d_height = resize_image(im, cfg.max_predict_img_size)
|
||||
scale_ratio_w = d_wight / im.width
|
||||
scale_ratio_h = d_height / im.height
|
||||
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||
quad_im = im.copy()
|
||||
draw = ImageDraw.Draw(im)
|
||||
for i, j in zip(activation_pixels[0], activation_pixels[1]):
|
||||
px = (j + 0.5) * cfg.pixel_size
|
||||
py = (i + 0.5) * cfg.pixel_size
|
||||
line_width, line_color = 1, 'red'
|
||||
if y[i, j, 1] >= cfg.side_vertex_pixel_threshold:
|
||||
if y[i, j, 2] < cfg.trunc_threshold:
|
||||
line_width, line_color = 2, 'yellow'
|
||||
elif y[i, j, 2] >= 1 - cfg.trunc_threshold:
|
||||
line_width, line_color = 2, 'green'
|
||||
draw.line([(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
|
||||
(px + 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
|
||||
(px + 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
|
||||
(px - 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
|
||||
(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size)],
|
||||
width=line_width, fill=line_color)
|
||||
im.save(img_path + '_act.jpg')
|
||||
quad_draw = ImageDraw.Draw(quad_im)
|
||||
txt_items = []
|
||||
for score, geo, s in zip(quad_scores, quad_after_nms,
|
||||
range(len(quad_scores))):
|
||||
print(np.amin(score))
|
||||
if np.amin(score) > 0:
|
||||
quad_draw.line([tuple(geo[0]),
|
||||
tuple(geo[1]),
|
||||
tuple(geo[2]),
|
||||
tuple(geo[3]),
|
||||
tuple(geo[0])], width=2, fill='red')
|
||||
if cfg.predict_cut_text_line:
|
||||
cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array,
|
||||
img_path, s)
|
||||
rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
|
||||
rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
|
||||
txt_item = ','.join(map(str, rescaled_geo_list))
|
||||
txt_items.append(txt_item + '\n')
|
||||
elif not quiet:
|
||||
print('quad invalid with vertex num less then 4.')
|
||||
quad_im.save(img_path + '_predict.jpg')
|
||||
if cfg.predict_write2txt and txt_items:
|
||||
with open(img_path[:-4] + '.txt', 'w') as f_txt:
|
||||
f_txt.writelines(txt_items)
|
||||
|
||||
|
||||
def predict_txt(east_detect, img_path, txt_path, pixel_threshold, quiet=False):
|
||||
"""
|
||||
predict to txt
|
||||
"""
|
||||
img = Image.open(img_path)
|
||||
d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
|
||||
scale_ratio_w = d_wight / img.width
|
||||
scale_ratio_h = d_height / img.height
|
||||
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||
img = np.asarray(img)
|
||||
img = img / 127.5 - 1
|
||||
img = img.transpose((2, 0, 1))
|
||||
x = np.expand_dims(img, axis=0)
|
||||
y = east_detect(x).asnumpy()
|
||||
|
||||
y = np.squeeze(y, axis=0)
|
||||
|
||||
if y.shape[0] == 7:
|
||||
y = y.transpose((1, 2, 0)) # CHW->HWC
|
||||
|
||||
y[:, :, :3] = sigmoid(y[:, :, :3])
|
||||
cond = np.greater_equal(y[:, :, 0], pixel_threshold)
|
||||
activation_pixels = np.where(cond)
|
||||
quad_scores, quad_after_nms = nms(y, activation_pixels)
|
||||
|
||||
txt_items = []
|
||||
for score, geo in zip(quad_scores, quad_after_nms):
|
||||
if np.amin(score) > 0:
|
||||
rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
|
||||
rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
|
||||
txt_item = ','.join(map(str, rescaled_geo_list))
|
||||
txt_items.append(txt_item + '\n')
|
||||
elif not quiet:
|
||||
print('quad invalid with vertex num less then 4.')
|
||||
if cfg.predict_write2txt and txt_items:
|
||||
with open(txt_path, 'w') as f_txt:
|
||||
f_txt.writelines(txt_items)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse_args
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--path', '-p',
|
||||
default='demo/004.jpg',
|
||||
help='image path')
|
||||
parser.add_argument('--threshold', '-t',
|
||||
default=cfg.pixel_threshold,
|
||||
help='pixel activation threshold')
|
||||
return parser.parse_args()
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################preprocess images########################
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from src.label import shrink
|
||||
from src.config import config as cfg
|
||||
|
||||
|
||||
def batch_reorder_vertexes(xy_list_array):
|
||||
"""
|
||||
batch xy to batch vertices
|
||||
"""
|
||||
reorder_xy_list_array = np.zeros_like(xy_list_array)
|
||||
for xy_list, i in zip(xy_list_array, range(len(xy_list_array))):
|
||||
reorder_xy_list_array[i] = reorder_vertexes(xy_list)
|
||||
return reorder_xy_list_array
|
||||
|
||||
|
||||
def reorder_vertexes(xy_list):
|
||||
"""
|
||||
xy to vertices
|
||||
"""
|
||||
|
||||
reorder_xy_list = np.zeros_like(xy_list)
|
||||
# determine the first point with the smallest x,
|
||||
# if two has same x, choose that with smallest y,
|
||||
ordered = np.argsort(xy_list, axis=0)
|
||||
xmin1_index = ordered[0, 0]
|
||||
xmin2_index = ordered[1, 0]
|
||||
if xy_list[xmin1_index, 0] == xy_list[xmin2_index, 0]:
|
||||
if xy_list[xmin1_index, 1] <= xy_list[xmin2_index, 1]:
|
||||
reorder_xy_list[0] = xy_list[xmin1_index]
|
||||
first_v = xmin1_index
|
||||
else:
|
||||
reorder_xy_list[0] = xy_list[xmin2_index]
|
||||
first_v = xmin2_index
|
||||
else:
|
||||
reorder_xy_list[0] = xy_list[xmin1_index]
|
||||
first_v = xmin1_index
|
||||
# connect the first point to others, the third point on the other side of
|
||||
# the line with the middle slope
|
||||
others = list(range(4))
|
||||
others.remove(first_v)
|
||||
k = np.zeros((len(others),))
|
||||
for index, i in zip(others, range(len(others))):
|
||||
k[i] = (xy_list[index, 1] - xy_list[first_v, 1]) \
|
||||
/ (xy_list[index, 0] - xy_list[first_v, 0] + cfg.epsilon)
|
||||
k_mid = np.argsort(k)[1]
|
||||
third_v = others[k_mid]
|
||||
reorder_xy_list[2] = xy_list[third_v]
|
||||
# determine the second point which on the bigger side of the middle line
|
||||
others.remove(third_v)
|
||||
b_mid = xy_list[first_v, 1] - k[k_mid] * xy_list[first_v, 0]
|
||||
second_v, fourth_v = 0, 0
|
||||
for index, i in zip(others, range(len(others))):
|
||||
# delta = y - (k * x + b)
|
||||
delta_y = xy_list[index, 1] - (k[k_mid] * xy_list[index, 0] + b_mid)
|
||||
if delta_y > 0:
|
||||
second_v = index
|
||||
else:
|
||||
fourth_v = index
|
||||
reorder_xy_list[1] = xy_list[second_v]
|
||||
reorder_xy_list[3] = xy_list[fourth_v]
|
||||
# compare slope of 13 and 24, determine the final order
|
||||
k13 = k[k_mid]
|
||||
k24 = (xy_list[second_v, 1] - xy_list[fourth_v, 1]) / (
|
||||
xy_list[second_v, 0] - xy_list[fourth_v, 0] + cfg.epsilon)
|
||||
if k13 < k24:
|
||||
tmp_x, tmp_y = reorder_xy_list[3, 0], reorder_xy_list[3, 1]
|
||||
for i in range(2, -1, -1):
|
||||
reorder_xy_list[i + 1] = reorder_xy_list[i]
|
||||
reorder_xy_list[0, 0], reorder_xy_list[0, 1] = tmp_x, tmp_y
|
||||
return reorder_xy_list
|
||||
|
||||
|
||||
def resize_image(im, max_img_size=cfg.max_train_img_size):
|
||||
"""
|
||||
resize image
|
||||
"""
|
||||
im_width = np.minimum(im.width, max_img_size)
|
||||
if im_width == max_img_size < im.width:
|
||||
im_height = int((im_width / im.width) * im.height)
|
||||
else:
|
||||
im_height = im.height
|
||||
o_height = np.minimum(im_height, max_img_size)
|
||||
if o_height == max_img_size < im_height:
|
||||
o_width = int((o_height / im_height) * im_width)
|
||||
else:
|
||||
o_width = im_width
|
||||
d_wight = o_width - (o_width % 32)
|
||||
d_height = o_height - (o_height % 32)
|
||||
return d_wight, d_height
|
||||
|
||||
|
||||
def preprocess():
|
||||
"""
|
||||
preprocess data
|
||||
"""
|
||||
data_dir = cfg.data_dir
|
||||
origin_image_dir = os.path.join(data_dir, cfg.origin_image_dir_name)
|
||||
origin_txt_dir = os.path.join(data_dir, cfg.origin_txt_dir_name)
|
||||
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name)
|
||||
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name)
|
||||
if not os.path.exists(train_image_dir):
|
||||
os.mkdir(train_image_dir)
|
||||
if not os.path.exists(train_label_dir):
|
||||
os.mkdir(train_label_dir)
|
||||
draw_gt_quad = cfg.draw_gt_quad
|
||||
show_gt_image_dir = os.path.join(data_dir, cfg.show_gt_image_dir_name)
|
||||
if not os.path.exists(show_gt_image_dir):
|
||||
os.mkdir(show_gt_image_dir)
|
||||
show_act_image_dir = os.path.join(cfg.data_dir, cfg.show_act_image_dir_name)
|
||||
if not os.path.exists(show_act_image_dir):
|
||||
os.mkdir(show_act_image_dir)
|
||||
|
||||
o_img_list = os.listdir(origin_image_dir)
|
||||
print('found %d origin images.' % len(o_img_list))
|
||||
train_val_set = []
|
||||
for o_img_fname, _ in zip(o_img_list, tqdm(range(len(o_img_list)))):
|
||||
with Image.open(os.path.join(origin_image_dir, o_img_fname)) as im:
|
||||
if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \
|
||||
os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)):
|
||||
continue
|
||||
# d_wight, d_height = resize_image(im)
|
||||
d_wight, d_height = cfg.max_train_img_size, cfg.max_train_img_size
|
||||
scale_ratio_w = d_wight / im.width
|
||||
scale_ratio_h = d_height / im.height
|
||||
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||
show_gt_im = im.copy()
|
||||
# draw on the img
|
||||
draw = ImageDraw.Draw(show_gt_im)
|
||||
with open(os.path.join(origin_txt_dir,
|
||||
o_img_fname[:-4] + '.txt'), 'r', encoding='utf-8') as f:
|
||||
anno_list = f.readlines()
|
||||
xy_list_array = np.zeros((len(anno_list), 4, 2))
|
||||
for anno, i in zip(anno_list, range(len(anno_list))):
|
||||
anno_colums = anno.strip().split(',')
|
||||
anno_array = np.array(anno_colums)
|
||||
xy_list = np.reshape(anno_array[:8].astype(float), (4, 2))
|
||||
xy_list[:, 0] = xy_list[:, 0] * scale_ratio_w
|
||||
xy_list[:, 1] = xy_list[:, 1] * scale_ratio_h
|
||||
xy_list = reorder_vertexes(xy_list)
|
||||
xy_list_array[i] = xy_list
|
||||
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
|
||||
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
|
||||
if draw_gt_quad:
|
||||
draw.line([tuple(xy_list[0]), tuple(xy_list[1]),
|
||||
tuple(xy_list[2]), tuple(xy_list[3]),
|
||||
tuple(xy_list[0])
|
||||
],
|
||||
width=2, fill='green')
|
||||
draw.line([tuple(shrink_xy_list[0]),
|
||||
tuple(shrink_xy_list[1]),
|
||||
tuple(shrink_xy_list[2]),
|
||||
tuple(shrink_xy_list[3]),
|
||||
tuple(shrink_xy_list[0])
|
||||
],
|
||||
width=2, fill='blue')
|
||||
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
|
||||
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
|
||||
for q_th in range(2):
|
||||
draw.line([tuple(xy_list[vs[long_edge][q_th][0]]),
|
||||
tuple(shrink_1[vs[long_edge][q_th][1]]),
|
||||
tuple(shrink_1[vs[long_edge][q_th][2]]),
|
||||
tuple(xy_list[vs[long_edge][q_th][3]]),
|
||||
tuple(xy_list[vs[long_edge][q_th][4]])],
|
||||
width=3, fill='yellow')
|
||||
if cfg.gen_origin_img:
|
||||
im.save(os.path.join(train_image_dir, o_img_fname))
|
||||
np.save(os.path.join(
|
||||
train_label_dir,
|
||||
o_img_fname[:-4] + '.npy'),
|
||||
xy_list_array)
|
||||
if draw_gt_quad:
|
||||
show_gt_im.save(os.path.join(show_gt_image_dir, o_img_fname))
|
||||
train_val_set.append('{},{},{}\n'.format(o_img_fname,
|
||||
d_wight,
|
||||
d_height))
|
||||
|
||||
train_img_list = os.listdir(train_image_dir)
|
||||
print('found %d train images.' % len(train_img_list))
|
||||
train_label_list = os.listdir(train_label_dir)
|
||||
print('found %d train labels.' % len(train_label_list))
|
||||
|
||||
random.shuffle(train_val_set)
|
||||
val_count = int(cfg.validation_split_ratio * len(train_val_set))
|
||||
with open(os.path.join(data_dir, cfg.val_fname), 'a') as f_val:
|
||||
f_val.writelines(train_val_set[:val_count])
|
||||
with open(os.path.join(data_dir, cfg.train_fname), 'a') as f_train:
|
||||
f_train.writelines(train_val_set[val_count:])
|
||||
|
||||
|
||||
def preprocess_size(width=256):
|
||||
"""
|
||||
preprocess data at specific size
|
||||
"""
|
||||
if width not in cfg.train_img_size:
|
||||
raise NotImplementedError
|
||||
data_dir = cfg.data_dir
|
||||
origin_image_dir = os.path.join(data_dir, cfg.origin_image_dir_name)
|
||||
origin_txt_dir = os.path.join(data_dir, cfg.origin_txt_dir_name)
|
||||
|
||||
train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name_var + str(width))
|
||||
train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name_var + str(width))
|
||||
|
||||
if not os.path.exists(train_image_dir):
|
||||
os.mkdir(train_image_dir)
|
||||
if not os.path.exists(train_label_dir):
|
||||
os.mkdir(train_label_dir)
|
||||
draw_gt_quad = cfg.draw_gt_quad
|
||||
show_gt_image_dir = os.path.join(data_dir, cfg.show_gt_image_dir_name)
|
||||
if not os.path.exists(show_gt_image_dir):
|
||||
os.mkdir(show_gt_image_dir)
|
||||
show_act_image_dir = os.path.join(cfg.data_dir, cfg.show_act_image_dir_name)
|
||||
if not os.path.exists(show_act_image_dir):
|
||||
os.mkdir(show_act_image_dir)
|
||||
|
||||
o_img_list = os.listdir(origin_image_dir)
|
||||
print('found %d origin images.' % len(o_img_list))
|
||||
train_val_set = []
|
||||
for o_img_fname, _ in zip(o_img_list, tqdm(range(len(o_img_list)))):
|
||||
with Image.open(os.path.join(origin_image_dir, o_img_fname)) as im:
|
||||
if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \
|
||||
os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)):
|
||||
continue
|
||||
# d_wight, d_height = resize_image(im)
|
||||
d_wight, d_height = width, width
|
||||
scale_ratio_w = d_wight / im.width
|
||||
scale_ratio_h = d_height / im.height
|
||||
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||
show_gt_im = im.copy()
|
||||
# draw on the img
|
||||
draw = ImageDraw.Draw(show_gt_im)
|
||||
with open(os.path.join(origin_txt_dir,
|
||||
o_img_fname[:-4] + '.txt'), 'r', encoding='utf-8') as f:
|
||||
anno_list = f.readlines()
|
||||
xy_list_array = np.zeros((len(anno_list), 4, 2))
|
||||
for anno, i in zip(anno_list, range(len(anno_list))):
|
||||
anno_colums = anno.strip().split(',')
|
||||
anno_array = np.array(anno_colums)
|
||||
xy_list = np.reshape(anno_array[:8].astype(float), (4, 2))
|
||||
xy_list[:, 0] = xy_list[:, 0] * scale_ratio_w
|
||||
xy_list[:, 1] = xy_list[:, 1] * scale_ratio_h
|
||||
xy_list = reorder_vertexes(xy_list)
|
||||
xy_list_array[i] = xy_list
|
||||
_, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio)
|
||||
shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio)
|
||||
if draw_gt_quad:
|
||||
draw.line([tuple(xy_list[0]), tuple(xy_list[1]),
|
||||
tuple(xy_list[2]), tuple(xy_list[3]),
|
||||
tuple(xy_list[0])
|
||||
],
|
||||
width=2, fill='green')
|
||||
draw.line([tuple(shrink_xy_list[0]),
|
||||
tuple(shrink_xy_list[1]),
|
||||
tuple(shrink_xy_list[2]),
|
||||
tuple(shrink_xy_list[3]),
|
||||
tuple(shrink_xy_list[0])
|
||||
],
|
||||
width=2, fill='blue')
|
||||
vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]],
|
||||
[[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]]
|
||||
for q_th in range(2):
|
||||
draw.line([tuple(xy_list[vs[long_edge][q_th][0]]),
|
||||
tuple(shrink_1[vs[long_edge][q_th][1]]),
|
||||
tuple(shrink_1[vs[long_edge][q_th][2]]),
|
||||
tuple(xy_list[vs[long_edge][q_th][3]]),
|
||||
tuple(xy_list[vs[long_edge][q_th][4]])],
|
||||
width=3, fill='yellow')
|
||||
if cfg.gen_origin_img:
|
||||
im.save(os.path.join(train_image_dir, o_img_fname))
|
||||
np.save(os.path.join(
|
||||
train_label_dir,
|
||||
o_img_fname[:-4] + '.npy'),
|
||||
xy_list_array)
|
||||
if draw_gt_quad:
|
||||
show_gt_im.save(os.path.join(show_gt_image_dir, o_img_fname))
|
||||
train_val_set.append('{},{},{}\n'.format(o_img_fname,
|
||||
d_wight,
|
||||
d_height))
|
||||
|
||||
train_img_list = os.listdir(train_image_dir)
|
||||
print('found %d train images.' % len(train_img_list))
|
||||
train_label_list = os.listdir(train_label_dir)
|
||||
print('found %d train labels.' % len(train_label_list))
|
||||
|
||||
# random.shuffle(train_val_set)
|
||||
val_count = int(cfg.validation_split_ratio * len(train_val_set))
|
||||
with open(os.path.join(data_dir, cfg.val_fname_var + str(width) + '.txt'), 'a') as f_val:
|
||||
f_val.writelines(train_val_set[:val_count])
|
||||
with open(os.path.join(data_dir, cfg.train_fname_var + str(width) + '.txt'), 'a') as f_train:
|
||||
f_train.writelines(train_val_set[val_count:])
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################calculate precision ########################
|
||||
"""
|
||||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from src.config import config as cfg
|
||||
from src.nms import nms
|
||||
|
||||
|
||||
# from src.preprocess import resize_image
|
||||
|
||||
class Averager():
|
||||
"""Compute average for torch.Tensor, used for loss average."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def add(self, v):
|
||||
count = v.data.numel()
|
||||
v = v.data.sum()
|
||||
self.n_count += count
|
||||
self.sum += v
|
||||
|
||||
def reset(self):
|
||||
self.n_count = 0
|
||||
self.sum = 0
|
||||
|
||||
def val(self):
|
||||
res = 0
|
||||
if self.n_count != 0:
|
||||
res = self.sum / float(self.n_count)
|
||||
return res
|
||||
|
||||
|
||||
class eval_pre_rec_f1():
|
||||
'''eval batch'''
|
||||
|
||||
def __init__(self):
|
||||
self.pixel_threshold = float(cfg.pixel_threshold)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.img_num = 0
|
||||
self.pre = 0
|
||||
self.rec = 0
|
||||
self.f1_score = 0
|
||||
|
||||
def val(self):
|
||||
mpre = self.pre / self.img_num * 100
|
||||
mrec = self.rec / self.img_num * 100
|
||||
mf1_score = self.f1_score / self.img_num * 100
|
||||
return mpre, mrec, mf1_score
|
||||
|
||||
def sigmoid(self, x):
|
||||
"""`y = 1 / (1 + exp(-x))`"""
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
def get_iou(self, g, p):
|
||||
"""`get_iou`"""
|
||||
g = Polygon(g)
|
||||
p = Polygon(p)
|
||||
if not g.is_valid or not p.is_valid:
|
||||
return 0
|
||||
inter = Polygon(g).intersection(Polygon(p)).area
|
||||
union = g.area + p.area - inter
|
||||
if union == 0:
|
||||
return 0
|
||||
return inter / union
|
||||
|
||||
def eval_one(self, quad_scores, quad_after_nms, gt_xy, quiet=cfg.quiet):
|
||||
"""`eval_one`"""
|
||||
num_gts = len(gt_xy)
|
||||
quad_scores_no_zero = []
|
||||
quad_after_nms_no_zero = []
|
||||
for score, geo in zip(quad_scores, quad_after_nms):
|
||||
if np.amin(score) > 0:
|
||||
quad_scores_no_zero.append(sum(score))
|
||||
quad_after_nms_no_zero.append(geo)
|
||||
elif not quiet:
|
||||
print('quad invalid with vertex num less then 4.')
|
||||
continue
|
||||
num_quads = len(quad_after_nms_no_zero)
|
||||
if num_quads == 0:
|
||||
return 0, 0, 0
|
||||
quad_flag = np.zeros(num_quads)
|
||||
gt_flag = np.zeros(num_gts)
|
||||
# print(num_quads, '-------', num_gts)
|
||||
quad_scores_no_zero = np.array(quad_scores_no_zero)
|
||||
scores_idx = np.argsort(quad_scores_no_zero)[::-1]
|
||||
for i in range(num_quads):
|
||||
idx = scores_idx[i]
|
||||
# score = quad_scores_no_zero[idx]
|
||||
geo = quad_after_nms_no_zero[idx]
|
||||
for j in range(num_gts):
|
||||
if gt_flag[j] == 0:
|
||||
gt_geo = gt_xy[j]
|
||||
iou = self.get_iou(geo, gt_geo)
|
||||
if iou >= cfg.iou_threshold:
|
||||
gt_flag[j] = 1
|
||||
quad_flag[i] = 1
|
||||
tp = np.sum(quad_flag)
|
||||
fp = num_quads - tp
|
||||
fn = num_gts - tp
|
||||
pre = tp / (tp + fp)
|
||||
rec = tp / (tp + fn)
|
||||
if pre + rec == 0:
|
||||
f1_score = 0
|
||||
else:
|
||||
f1_score = 2 * pre * rec / (pre + rec)
|
||||
# print(pre, '---', rec, '---', f1_score)
|
||||
return pre, rec, f1_score
|
||||
|
||||
def add(self, out, gt_xy_list):
|
||||
"""`add`"""
|
||||
self.img_num += len(gt_xy_list)
|
||||
ys = out.asnumpy() # (N, 7, 64, 64)
|
||||
print(ys.shape)
|
||||
if ys.shape[1] == 7:
|
||||
ys = ys.transpose((0, 2, 3, 1)) # NCHW->NHWC
|
||||
for y, gt_xy in zip(ys, gt_xy_list):
|
||||
y[:, :, :3] = self.sigmoid(y[:, :, :3])
|
||||
cond = np.greater_equal(y[:, :, 0], self.pixel_threshold)
|
||||
activation_pixels = np.where(cond)
|
||||
quad_scores, quad_after_nms = nms(y, activation_pixels)
|
||||
print(quad_scores)
|
||||
lens = len(quad_after_nms)
|
||||
if lens == 0 or (sum(sum(quad_scores)) == 0):
|
||||
if not cfg.quiet:
|
||||
print('NMS')
|
||||
continue
|
||||
else:
|
||||
pre, rec, f1_score = self.eval_one(quad_scores, quad_after_nms, gt_xy)
|
||||
print(pre, rec, f1_score)
|
||||
self.pre += pre
|
||||
self.rec += rec
|
||||
self.f1_score += f1_score
|
|
@ -0,0 +1,188 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train advanced_east on dataset########################
|
||||
"""
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from mindspore import context, Model
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from src.logger import get_logger
|
||||
from src.dataset import load_adEAST_dataset
|
||||
from src.model import get_AdvancedEast_net
|
||||
from src.config import config as cfg
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('mindspore adveast training')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
|
||||
|
||||
# network related
|
||||
parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load')
|
||||
|
||||
# logging and checkpoint related
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=1, help='ckpt_interval')
|
||||
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
args_opt.initial_epoch = cfg.initial_epoch
|
||||
args_opt.epoch_num = cfg.epoch_num
|
||||
args_opt.learning_rate = cfg.learning_rate
|
||||
args_opt.decay = cfg.decay
|
||||
args_opt.batch_size = cfg.batch_size
|
||||
args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio)
|
||||
args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio
|
||||
args_opt.ckpt_save_max = cfg.ckpt_save_max
|
||||
args_opt.data_dir = cfg.data_dir
|
||||
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
|
||||
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
|
||||
args_opt.train_image_dir_name = cfg.train_image_dir_name
|
||||
args_opt.train_label_dir_name = cfg.train_label_dir_name
|
||||
args_opt.results_dir = cfg.results_dir
|
||||
args_opt.last_model_name = cfg.last_model_name
|
||||
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
workers = 32
|
||||
if args.is_distributed:
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id, device_target=args.device_target)
|
||||
init()
|
||||
elif args.device_target == "GPU":
|
||||
context.set_context(device_target=args.device_target)
|
||||
init()
|
||||
|
||||
args.group_size = get_group_size()
|
||||
device_num = args.group_size
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
|
||||
|
||||
args.logger.save_args(args)
|
||||
# network
|
||||
args.logger.important_info('start create network')
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
# get network and init
|
||||
loss_net, train_net = get_AdvancedEast_net()
|
||||
loss_net.add_flags_recursive(fp32=True)
|
||||
train_net.set_train(False)
|
||||
# pre_trained
|
||||
if args.pre_trained:
|
||||
load_param_into_net(train_net, load_checkpoint(os.path.join(args.saved_model_file_path, args.last_model_name)))
|
||||
# define callbacks
|
||||
|
||||
mindrecordfile256 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(256) + '.mindrecord')
|
||||
|
||||
train_dataset256, batch_num256 = load_adEAST_dataset(mindrecordfile256, batch_size=8,
|
||||
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||
num_parallel_workers=workers)
|
||||
|
||||
mindrecordfile384 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(384) + '.mindrecord')
|
||||
train_dataset384, batch_num384 = load_adEAST_dataset(mindrecordfile384, batch_size=4,
|
||||
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||
num_parallel_workers=workers)
|
||||
mindrecordfile448 = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file_var + str(448) + '.mindrecord')
|
||||
train_dataset448, batch_num448 = load_adEAST_dataset(mindrecordfile448, batch_size=2,
|
||||
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||
num_parallel_workers=workers)
|
||||
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
model = Model(train_net)
|
||||
time_cb = TimeMonitor(data_size=batch_num256)
|
||||
loss_cb = LossMonitor(per_print_times=batch_num256)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num256,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = args.saved_model_file_path
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_A{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset256, callbacks=callbacks, dataset_sink_mode=False)
|
||||
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
time_cb = TimeMonitor(data_size=batch_num384)
|
||||
loss_cb = LossMonitor(per_print_times=batch_num384)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num384,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = args.saved_model_file_path
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_B{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
model = Model(train_net)
|
||||
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset384, callbacks=callbacks, dataset_sink_mode=False)
|
||||
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
time_cb = TimeMonitor(data_size=batch_num448)
|
||||
loss_cb = LossMonitor(per_print_times=batch_num448)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num448,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = args.saved_model_file_path
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_C{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
model = Model(train_net)
|
||||
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset448, callbacks=callbacks, dataset_sink_mode=False)
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train advanced_east on dataset########################
|
||||
"""
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from mindspore import context, Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||
from src.logger import get_logger
|
||||
|
||||
from src.config import config as cfg
|
||||
from src.dataset import load_adEAST_dataset
|
||||
from src.model import get_AdvancedEast_net
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('mindspore adveast training')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
|
||||
# network related
|
||||
parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load')
|
||||
parser.add_argument('--data_path', default='/disk1/ade/icpr/advanced-east-val_256.mindrecord', type=str)
|
||||
parser.add_argument('--pre_trained_ckpt', default='/disk1/adeast/scripts/1.ckpt', type=str)
|
||||
|
||||
# logging and checkpoint related
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=1, help='ckpt_interval')
|
||||
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
args_opt.initial_epoch = cfg.initial_epoch
|
||||
args_opt.epoch_num = cfg.epoch_num
|
||||
args_opt.learning_rate = cfg.learning_rate
|
||||
args_opt.decay = cfg.decay
|
||||
args_opt.batch_size = cfg.batch_size
|
||||
args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio)
|
||||
args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio
|
||||
args_opt.ckpt_save_max = cfg.ckpt_save_max
|
||||
args_opt.data_dir = cfg.data_dir
|
||||
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
|
||||
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
|
||||
args_opt.train_image_dir_name = cfg.train_image_dir_name
|
||||
args_opt.train_label_dir_name = cfg.train_label_dir_name
|
||||
args_opt.results_dir = cfg.results_dir
|
||||
args_opt.last_model_name = cfg.last_model_name
|
||||
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
context.set_context(device_target=args.device_target, device_id=args.device_id, mode=context.GRAPH_MODE)
|
||||
workers = 32
|
||||
device_num = 1
|
||||
if args.is_distributed:
|
||||
init()
|
||||
if args.device_target == "Ascend":
|
||||
device_num = int(os.environ.get("RANK_SIZE"))
|
||||
rank = int(os.environ.get("RANK_ID"))
|
||||
args.rank = args.device_id
|
||||
elif args.device_target == "GPU":
|
||||
context.set_context(device_target=args.device_target)
|
||||
device_num = get_group_size()
|
||||
args.rank = get_rank()
|
||||
args.group_size = device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, gradients_mean=True, parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
|
||||
# dataset
|
||||
|
||||
mindrecordfile = args.data_path
|
||||
|
||||
train_dataset, batch_num = load_adEAST_dataset(mindrecordfile, batch_size=args.batch_size,
|
||||
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||
num_parallel_workers=workers)
|
||||
|
||||
args.logger.save_args(args)
|
||||
# network
|
||||
args.logger.important_info('start create network')
|
||||
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
# get network and init
|
||||
loss_net, train_net = get_AdvancedEast_net()
|
||||
loss_net.add_flags_recursive(fp32=True)
|
||||
train_net.set_train(True)
|
||||
# pre_trained
|
||||
if args.pre_trained:
|
||||
load_param_into_net(train_net, load_checkpoint(args.pre_trained_ckpt))
|
||||
# define callbacks
|
||||
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
model = Model(train_net)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
loss_cb = LossMonitor(per_print_times=batch_num)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = args.saved_model_file_path
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
Loading…
Reference in New Issue