forked from mindspore-Ecosystem/mindspore
!14736 add eval process while training in ctpn and crnn
From: @qujianwei Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34
This commit is contained in:
commit
64a0e5d08b
|
@ -131,6 +131,7 @@ crnn
|
|||
│ ├── crnn.py # crnn network definition
|
||||
│ ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
|
||||
│ ├── dataset.py # Data preprocessing for training and evaluation
|
||||
│ ├── eval_callback.py
|
||||
│ ├── ic03_dataset.py # Data preprocessing for IC03
|
||||
│ ├── ic13_dataset.py # Data preprocessing for IC13
|
||||
│ ├── iiit5k_dataset.py # Data preprocessing for IIIT5K
|
||||
|
@ -225,6 +226,10 @@ Check the `eval/log.txt` and you will get outputs as following:
|
|||
result: {'CRNNAccuracy': (0.806)}
|
||||
```
|
||||
|
||||
### Evaluation while training
|
||||
|
||||
You can add `run_eval` to start shell and set it True.You need also add `eval_dataset` to select which dataset to eval, and add eval_dataset_path to start shell if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True.
|
||||
|
||||
## [Inference Process](#contents)
|
||||
|
||||
### [Export MindIR](#contents)
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation callback when training"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from mindspore import save_checkpoint
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluation callback when training.
|
||||
|
||||
Args:
|
||||
eval_function (function): evaluation function.
|
||||
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||
interval (int): run evaluation interval, default is 1.
|
||||
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||
"""
|
||||
|
||||
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.eval_function = eval_function
|
||||
self.eval_start_epoch = eval_start_epoch
|
||||
if interval < 1:
|
||||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.save_best_ckpt = save_best_ckpt
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
if not os.path.isdir(ckpt_directory):
|
||||
os.makedirs(ckpt_directory)
|
||||
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||
self.metrics_name = metrics_name
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_res = res
|
||||
self.best_epoch = cur_epoch
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.bast_ckpt_path):
|
||||
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
|
||||
|
||||
def end(self, run_context):
|
||||
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
|
||||
self.best_res,
|
||||
self.best_epoch), flush=True)
|
||||
|
|
@ -15,6 +15,7 @@
|
|||
"""crnn training"""
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
|
@ -28,7 +29,8 @@ from src.loss import CTCLoss
|
|||
from src.dataset import create_dataset
|
||||
from src.crnn import crnn
|
||||
from src.crnn_for_train import TrainOneStepCellWithGradClip
|
||||
|
||||
from src.metric import CRNNAccuracy
|
||||
from src.eval_callback import EvalCallBack
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="crnn training")
|
||||
|
@ -38,6 +40,16 @@ parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend']
|
|||
help='Running platform, only support Ascend now. Default is Ascend.')
|
||||
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
|
||||
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
parser.add_argument('--eval_dataset', type=str, default='svt', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
parser.add_argument('--eval_dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
||||
help="Run evaluation when training, default is False.")
|
||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
||||
help="Save best checkpoint when run_eval is True, default is True.")
|
||||
parser.add_argument("--eval_start_epoch", type=int, default=5,
|
||||
help="Evaluation start epoch when run_eval is True, default is 5.")
|
||||
parser.add_argument("--eval_interval", type=int, default=5,
|
||||
help="Evaluation interval when run_eval is True, default is 5.")
|
||||
parser.set_defaults(run_distribute=False)
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
@ -50,6 +62,12 @@ if args_opt.platform == 'Ascend':
|
|||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
def apply_eval(eval_param):
|
||||
evaluation_model = eval_param["model"]
|
||||
eval_ds = eval_param["dataset"]
|
||||
metrics_name = eval_param["metrics_name"]
|
||||
res = evaluation_model.eval(eval_ds)
|
||||
return res[metrics_name]
|
||||
|
||||
if __name__ == '__main__':
|
||||
lr_scale = 1
|
||||
|
@ -86,16 +104,31 @@ if __name__ == '__main__':
|
|||
net = crnn(config)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
|
||||
|
||||
net = WithLossCell(net, loss)
|
||||
net = TrainOneStepCellWithGradClip(net, opt).set_train()
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
net_with_grads = TrainOneStepCellWithGradClip(net_with_loss, opt).set_train()
|
||||
# define model
|
||||
model = Model(net)
|
||||
model = Model(net_with_grads)
|
||||
# define callbacks
|
||||
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
if args_opt.run_eval:
|
||||
if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)):
|
||||
raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path))
|
||||
eval_dataset = create_dataset(name=args_opt.eval_dataset,
|
||||
dataset_path=args_opt.eval_dataset_path,
|
||||
batch_size=config.batch_size,
|
||||
is_training=False,
|
||||
config=config)
|
||||
eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
|
||||
eval_param_dict = {"model": eval_model, "dataset": eval_dataset, "metrics_name": "CRNNAccuracy"}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
|
||||
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
|
||||
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_acc.ckpt",
|
||||
metrics_name="acc")
|
||||
callbacks += [eval_cb]
|
||||
if config.save_checkpoint and rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(config.epoch_size, dataset, callbacks=callbacks)
|
||||
|
|
|
@ -96,6 +96,8 @@ Here we used 6 datasets for training, and 1 datasets for Evaluation.
|
|||
│ ├── create_dataset.py # create mindrecord dataset
|
||||
│ ├── ctpn.py # ctpn network definition
|
||||
│ ├── dataset.py # data proprocessing
|
||||
│ ├── eval_callback.py # evaluation callback while training
|
||||
│ ├── eval_utils.py # evaluation function
|
||||
│ ├── lr_schedule.py # learning rate scheduler
|
||||
│ ├── network_define.py # network definition
|
||||
│ └── text_connector
|
||||
|
@ -235,6 +237,10 @@ Then you can run the scripts/eval_res.sh to calculate the evalulation result.
|
|||
bash eval_res.sh
|
||||
```
|
||||
|
||||
### Evaluation while training
|
||||
|
||||
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True.
|
||||
|
||||
### Result
|
||||
|
||||
Evaluation result will be stored in the example path, you can find result like the followings in `log`.
|
||||
|
|
|
@ -14,17 +14,14 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Evaluation for CTPN"""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.ctpn import CTPN
|
||||
from src.config import config
|
||||
from src.dataset import create_ctpn_dataset
|
||||
from src.text_connector.detector import detect
|
||||
from src.eval_utils import eval_for_ctpn
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CTPN evaluation")
|
||||
|
@ -39,80 +36,13 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
|
|||
"""ctpn infer."""
|
||||
print("ckpt path is {}".format(ckpt_path))
|
||||
ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False)
|
||||
config.batch_size = config.test_batch_size
|
||||
total = ds.get_dataset_size()
|
||||
print("*************total dataset size is {}".format(total))
|
||||
net = CTPN(config, is_training=False)
|
||||
print("eval dataset size is {}".format(total))
|
||||
net = CTPN(config, batch_size=config.test_batch_size, is_training=False)
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
eval_iter = 0
|
||||
|
||||
print("\n========================================\n")
|
||||
print("Processing, please wait a moment.")
|
||||
img_basenames = []
|
||||
output_dir = os.path.join(os.getcwd(), "submit")
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
for file in os.listdir(img_dir):
|
||||
img_basenames.append(os.path.basename(file))
|
||||
for data in ds.create_dict_iterator():
|
||||
img_data = data['image']
|
||||
img_metas = data['image_shape']
|
||||
gt_bboxes = data['box']
|
||||
gt_labels = data['label']
|
||||
gt_num = data['valid_num']
|
||||
|
||||
start = time.time()
|
||||
# run net
|
||||
output = net(img_data, gt_bboxes, gt_labels, gt_num)
|
||||
gt_bboxes = gt_bboxes.asnumpy()
|
||||
gt_labels = gt_labels.asnumpy()
|
||||
gt_num = gt_num.asnumpy().astype(bool)
|
||||
end = time.time()
|
||||
proposal = output[0]
|
||||
proposal_mask = output[1]
|
||||
print("start to draw pic")
|
||||
for j in range(config.test_batch_size):
|
||||
img = img_basenames[config.test_batch_size * eval_iter + j]
|
||||
all_box_tmp = proposal[j].asnumpy()
|
||||
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
|
||||
using_boxes_mask = all_box_tmp * all_mask_tmp
|
||||
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
|
||||
scores = using_boxes_mask[:, 4].astype(np.float32)
|
||||
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
|
||||
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
|
||||
from PIL import Image, ImageDraw
|
||||
im = Image.open(img_dir + '/' + img)
|
||||
draw = ImageDraw.Draw(im)
|
||||
image_h = img_metas.asnumpy()[j][2]
|
||||
image_w = img_metas.asnumpy()[j][3]
|
||||
gt_boxs = gt_bboxes[j][gt_num[j], :]
|
||||
for gt_box in gt_boxs:
|
||||
gt_x1 = gt_box[0] / image_w
|
||||
gt_y1 = gt_box[1] / image_h
|
||||
gt_x2 = gt_box[2] / image_w
|
||||
gt_y2 = gt_box[3] / image_h
|
||||
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
|
||||
fill='green', width=2)
|
||||
file_name = "res_" + img.replace("jpg", "txt")
|
||||
output_file = os.path.join(output_dir, file_name)
|
||||
f = open(output_file, 'w')
|
||||
for bbox in bboxes:
|
||||
x1 = bbox[0] / image_w
|
||||
y1 = bbox[1] / image_h
|
||||
x2 = bbox[2] / image_w
|
||||
y2 = bbox[3] / image_h
|
||||
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
|
||||
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
|
||||
f.write(str_tmp)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
im.save(img)
|
||||
percent = round(eval_iter / total * 100, 2)
|
||||
eval_iter = eval_iter + 1
|
||||
print("Iter {} cost time {}".format(eval_iter, end - start))
|
||||
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
|
||||
eval_for_ctpn(net, ds, img_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path)
|
||||
|
|
|
@ -36,7 +36,7 @@ if args.device_target == "Ascend":
|
|||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = CTPN_Infer(config=config)
|
||||
net = CTPN_Infer(config=config, batch_size=config.test_batch_size)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
|
||||
|
|
|
@ -56,6 +56,8 @@ do
|
|||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ./*.py ./train_parallel$i
|
||||
cp ./*.zip ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
|
|
|
@ -29,8 +29,7 @@ finetune_config = EasyDict({
|
|||
"total_epoch": 50,
|
||||
})
|
||||
|
||||
# use for low case number
|
||||
config = EasyDict({
|
||||
config_default = EasyDict({
|
||||
"img_width": 960,
|
||||
"img_height": 576,
|
||||
"keep_ratio": False,
|
||||
|
@ -39,7 +38,6 @@ config = EasyDict({
|
|||
"expand_ratio": 1.0,
|
||||
|
||||
# anchor
|
||||
"feature_shapes": (36, 60),
|
||||
"num_anchors": 14,
|
||||
"anchor_base": 16,
|
||||
"anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406],
|
||||
|
@ -56,7 +54,6 @@ config = EasyDict({
|
|||
"neg_iou_thr": 0.5,
|
||||
"pos_iou_thr": 0.7,
|
||||
"min_pos_iou": 0.001,
|
||||
"num_bboxes": 30240,
|
||||
"num_gts": 256,
|
||||
"num_expected_neg": 512,
|
||||
"num_expected_pos": 256,
|
||||
|
@ -75,12 +72,11 @@ config = EasyDict({
|
|||
|
||||
# rnn structure
|
||||
"input_size": 512,
|
||||
"num_step": 60,
|
||||
"rnn_batch_size": 36,
|
||||
"hidden_size": 128,
|
||||
|
||||
# training
|
||||
"warmup_mode": "linear",
|
||||
# batch_size only support 1
|
||||
"batch_size": 1,
|
||||
"momentum": 0.9,
|
||||
"save_checkpoint": True,
|
||||
|
@ -131,3 +127,12 @@ config = EasyDict({
|
|||
"pretraining_dataset_file": "",
|
||||
"finetune_dataset_file": ""
|
||||
})
|
||||
|
||||
config_add = {
|
||||
"feature_shapes": (config_default["img_height"] // 16, config_default["img_width"] // 16),
|
||||
"num_bboxes": (config_default["img_height"] // 16) * \
|
||||
(config_default["img_width"] // 16) *config_default["num_anchors"],
|
||||
"num_step": config_default["img_width"] // 16,
|
||||
"rnn_batch_size": config_default["img_height"] // 16
|
||||
}
|
||||
config = EasyDict({**config_default, **config_add})
|
||||
|
|
|
@ -145,7 +145,7 @@ def create_train_dataset(dataset_type):
|
|||
# test: icdar2013 test
|
||||
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
|
||||
config.icdar13_test_path[1], "")
|
||||
image_files = icdar_test_image_files
|
||||
image_files = sorted(icdar_test_image_files)
|
||||
image_anno_dict = icdar_test_anno_dict
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
|
||||
prefix="ctpn_test.mindrecord", file_num=1)
|
||||
|
|
|
@ -29,16 +29,13 @@ class BiLSTM(nn.Cell):
|
|||
Define a BiLSTM network which contains two LSTM layers
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
config(EasyDict): config for ctpn network
|
||||
batch_size(int): batch size of input data, only support 1
|
||||
"""
|
||||
def __init__(self, config, is_training=True):
|
||||
def __init__(self, config, batch_size):
|
||||
super(BiLSTM, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.batch_size = config.batch_size * config.rnn_batch_size
|
||||
print("batch size is {} ".format(self.batch_size))
|
||||
self.batch_size = batch_size
|
||||
self.batch_size = self.batch_size * config.rnn_batch_size
|
||||
self.input_size = config.input_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_step = config.num_step
|
||||
|
@ -84,25 +81,24 @@ class CTPN(nn.Cell):
|
|||
Define CTPN network
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
config(EasyDict): config for ctpn network
|
||||
batch_size(int): batch size of input data, only support 1
|
||||
is_training(bool): whether training, default is True
|
||||
"""
|
||||
def __init__(self, config, is_training=True):
|
||||
def __init__(self, config, batch_size, is_training=True):
|
||||
super(CTPN, self).__init__()
|
||||
self.config = config
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
self.num_step = config.num_step
|
||||
self.input_size = config.input_size
|
||||
self.batch_size = config.batch_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vgg16_feature_extractor = VGG16FeatureExtraction()
|
||||
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
|
||||
self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16)
|
||||
self.rnn = BiLSTM(self.config, batch_size=self.batch_size).to_float(mstype.float16)
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.is_training = is_training
|
||||
|
||||
# rpn block
|
||||
self.rpn_with_loss = RPN(config,
|
||||
|
@ -115,7 +111,7 @@ class CTPN(nn.Cell):
|
|||
self.featmap_size = config.feature_shapes
|
||||
self.anchor_list = self.get_anchors(self.featmap_size)
|
||||
self.proposal_generator_test = Proposal(config,
|
||||
config.test_batch_size,
|
||||
self.batch_size,
|
||||
config.activate_num_classes,
|
||||
config.use_sigmoid_cls)
|
||||
self.proposal_generator_test.set_train_local(config, False)
|
||||
|
@ -143,9 +139,9 @@ class CTPN(nn.Cell):
|
|||
return Tensor(anchors, mstype.float16)
|
||||
|
||||
class CTPN_Infer(nn.Cell):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, batch_size):
|
||||
super(CTPN_Infer, self).__init__()
|
||||
self.network = CTPN(config, is_training=False)
|
||||
self.network = CTPN(config, batch_size=batch_size, is_training=False)
|
||||
self.network.set_train(False)
|
||||
|
||||
def construct(self, img_data):
|
||||
|
|
|
@ -289,11 +289,11 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=
|
|||
input_columns=["image", "annotation"],
|
||||
output_columns=["image", "box", "label", "valid_num", "image_shape"],
|
||||
column_order=["image", "box", "label", "valid_num", "image_shape"],
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
num_parallel_workers=8,
|
||||
python_multiprocessing=True)
|
||||
|
||||
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
|
||||
num_parallel_workers=24)
|
||||
num_parallel_workers=8)
|
||||
# transpose_column from python to c
|
||||
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
|
||||
ds = ds.map(operations=[type_cast1], input_columns=["box"])
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation callback when training"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from mindspore import save_checkpoint
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluation callback when training.
|
||||
|
||||
Args:
|
||||
eval_function (function): evaluation function.
|
||||
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||
interval (int): run evaluation interval, default is 1.
|
||||
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||
"""
|
||||
|
||||
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.eval_function = eval_function
|
||||
self.eval_start_epoch = eval_start_epoch
|
||||
if interval < 1:
|
||||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.save_best_ckpt = save_best_ckpt
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
if not os.path.isdir(ckpt_directory):
|
||||
os.makedirs(ckpt_directory)
|
||||
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||
self.metrics_name = metrics_name
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_res = res
|
||||
self.best_epoch = cur_epoch
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.bast_ckpt_path):
|
||||
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
|
||||
|
||||
def end(self, run_context):
|
||||
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
|
||||
self.best_res,
|
||||
self.best_epoch), flush=True)
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation utils for CTPN"""
|
||||
import os
|
||||
import subprocess
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from src.text_connector.detector import detect
|
||||
|
||||
def exec_shell_cmd(cmd):
|
||||
sub = subprocess.Popen(args="{}".format(cmd), shell=True, stdin=subprocess.PIPE, \
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
|
||||
stdout_data, _ = sub.communicate()
|
||||
if sub.returncode != 0:
|
||||
raise ValueError("{} is not a executable command, please check.".format(cmd))
|
||||
return stdout_data.strip()
|
||||
|
||||
def get_eval_result():
|
||||
create_eval_bbox = 'rm -rf submit*.zip;cd ./submit/;zip -r ../submit.zip *.txt;cd ../;bash eval_res.sh'
|
||||
os.system(create_eval_bbox)
|
||||
get_eval_output = "grep hmean log | awk '{print $NF}' | awk -F} '{print $1}' |tail -n 1"
|
||||
hmean = exec_shell_cmd(get_eval_output)
|
||||
return float(hmean)
|
||||
|
||||
def eval_for_ctpn(network, dataset, eval_image_path):
|
||||
network.set_train(False)
|
||||
eval_iter = 0
|
||||
img_basenames = []
|
||||
output_dir = os.path.join(os.getcwd(), "submit")
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
for file in os.listdir(eval_image_path):
|
||||
img_basenames.append(os.path.basename(file))
|
||||
img_basenames = sorted(img_basenames)
|
||||
for data in dataset.create_dict_iterator():
|
||||
img_data = data['image']
|
||||
img_metas = data['image_shape']
|
||||
gt_bboxes = data['box']
|
||||
gt_labels = data['label']
|
||||
gt_num = data['valid_num']
|
||||
# run net
|
||||
output = network(img_data, gt_bboxes, gt_labels, gt_num)
|
||||
gt_bboxes = gt_bboxes.asnumpy()
|
||||
gt_labels = gt_labels.asnumpy()
|
||||
gt_num = gt_num.asnumpy().astype(bool)
|
||||
proposal = output[0]
|
||||
proposal_mask = output[1]
|
||||
for j in range(config.test_batch_size):
|
||||
img = img_basenames[config.test_batch_size * eval_iter + j]
|
||||
all_box_tmp = proposal[j].asnumpy()
|
||||
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
|
||||
using_boxes_mask = all_box_tmp * all_mask_tmp
|
||||
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
|
||||
scores = using_boxes_mask[:, 4].astype(np.float32)
|
||||
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
|
||||
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
|
||||
from PIL import Image, ImageDraw
|
||||
im = Image.open(eval_image_path + '/' + img)
|
||||
draw = ImageDraw.Draw(im)
|
||||
image_h = img_metas.asnumpy()[j][2]
|
||||
image_w = img_metas.asnumpy()[j][3]
|
||||
gt_boxs = gt_bboxes[j][gt_num[j], :]
|
||||
for gt_box in gt_boxs:
|
||||
gt_x1 = gt_box[0] / image_w
|
||||
gt_y1 = gt_box[1] / image_h
|
||||
gt_x2 = gt_box[2] / image_w
|
||||
gt_y2 = gt_box[3] / image_h
|
||||
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
|
||||
fill='green', width=2)
|
||||
file_name = "res_" + img.replace("jpg", "txt")
|
||||
output_file = os.path.join(output_dir, file_name)
|
||||
f = open(output_file, 'w')
|
||||
for bbox in bboxes:
|
||||
x1 = bbox[0] / image_w
|
||||
y1 = bbox[1] / image_h
|
||||
x2 = bbox[2] / image_w
|
||||
y2 = bbox[3] / image_h
|
||||
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
|
||||
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
|
||||
f.write(str_tmp)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
im.save(img)
|
||||
eval_iter = eval_iter + 1
|
|
@ -32,6 +32,8 @@ from src.config import config, pretrain_config, finetune_config
|
|||
from src.dataset import create_ctpn_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell
|
||||
from src.eval_utils import eval_for_ctpn, get_eval_result
|
||||
from src.eval_callback import EvalCallBack
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
@ -43,10 +45,30 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums,
|
|||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
parser.add_argument("--task_type", type=str, default="Pretraining",\
|
||||
choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining")
|
||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False, \
|
||||
help="Run evaluation when training, default is False.")
|
||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, \
|
||||
help="Save best checkpoint when run_eval is True, default is True.")
|
||||
parser.add_argument("--eval_image_path", type=str, default="", \
|
||||
help="eval image path, when run_eval is True, eval_image_path should be set.")
|
||||
parser.add_argument("--eval_dataset_path", type=str, default="", \
|
||||
help="eval dataset path, when run_eval is True, eval_dataset_path should be set.")
|
||||
parser.add_argument("--eval_start_epoch", type=int, default=10, \
|
||||
help="Evaluation start epoch when run_eval is True, default is 10.")
|
||||
parser.add_argument("--eval_interval", type=int, default=10, \
|
||||
help="Evaluation interval when run_eval is True, default is 10.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True)
|
||||
|
||||
def apply_eval(eval_param):
|
||||
network = eval_param["eval_network"]
|
||||
eval_ds = eval_param["eval_dataset"]
|
||||
eval_image_path = eval_param["eval_image_path"]
|
||||
eval_for_ctpn(network, eval_ds, eval_image_path)
|
||||
hmean = get_eval_result()
|
||||
return hmean
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.run_distribute:
|
||||
rank = args_opt.rank_id
|
||||
|
@ -78,7 +100,7 @@ if __name__ == '__main__':
|
|||
dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\
|
||||
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
net = CTPN(config=config, is_training=True)
|
||||
net = CTPN(config=config, batch_size=config.batch_size)
|
||||
net = net.set_train()
|
||||
|
||||
load_path = args_opt.pre_trained
|
||||
|
@ -100,20 +122,34 @@ if __name__ == '__main__':
|
|||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
if args_opt.run_distribute:
|
||||
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
|
||||
mean=True, degree=device_num)
|
||||
net_with_grads = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, \
|
||||
mean=True, degree=device_num)
|
||||
else:
|
||||
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
|
||||
net_with_grads = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
|
||||
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossCallBack(rank_id=rank)
|
||||
cb = [time_cb, loss_cb]
|
||||
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
|
||||
if config.save_checkpoint:
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
|
||||
ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig)
|
||||
cb += [ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
if args_opt.run_eval:
|
||||
if args_opt.eval_dataset_path is None or (not os.path.isfile(args_opt.eval_dataset_path)):
|
||||
raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path))
|
||||
if args_opt.eval_image_path is None or (not os.path.isdir(args_opt.eval_image_path)):
|
||||
raise ValueError("{} is not a existing path.".format(args_opt.eval_image_path))
|
||||
eval_dataset = create_ctpn_dataset(args_opt.eval_dataset_path, \
|
||||
batch_size=config.batch_size, repeat_num=1, is_training=False)
|
||||
eval_net = net
|
||||
eval_param_dict = {"eval_network": eval_net, "eval_dataset": eval_dataset, \
|
||||
"eval_image_path": args_opt.eval_image_path}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
|
||||
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
|
||||
ckpt_directory=save_checkpoint_path, besk_ckpt_name="best_acc.ckpt",
|
||||
metrics_name="hmean")
|
||||
cb += [eval_cb]
|
||||
model = Model(net_with_grads)
|
||||
model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
|
Loading…
Reference in New Issue