forked from mindspore-Ecosystem/mindspore
!13159 delete image_meta in ctpn eval process
From: @qujianwei Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
63e033b8f1
|
@ -65,7 +65,7 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
# run net
|
# run net
|
||||||
output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
|
output = net(img_data, gt_bboxes, gt_labels, gt_num)
|
||||||
gt_bboxes = gt_bboxes.asnumpy()
|
gt_bboxes = gt_bboxes.asnumpy()
|
||||||
gt_labels = gt_labels.asnumpy()
|
gt_labels = gt_labels.asnumpy()
|
||||||
gt_num = gt_num.asnumpy().astype(bool)
|
gt_num = gt_num.asnumpy().astype(bool)
|
||||||
|
|
|
@ -119,19 +119,14 @@ class CTPN(nn.Cell):
|
||||||
config.activate_num_classes,
|
config.activate_num_classes,
|
||||||
config.use_sigmoid_cls)
|
config.use_sigmoid_cls)
|
||||||
self.proposal_generator_test.set_train_local(config, False)
|
self.proposal_generator_test.set_train_local(config, False)
|
||||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
|
def construct(self, img_data, gt_bboxes, gt_labels, gt_valids, img_metas=None):
|
||||||
# (1,3,600,900)
|
|
||||||
x = self.vgg16_feature_extractor(img_data)
|
x = self.vgg16_feature_extractor(img_data)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.cast(x, mstype.float16)
|
x = self.cast(x, mstype.float16)
|
||||||
# (1, 512, 38, 57)
|
|
||||||
x = self.transpose(x, (0, 2, 1, 3))
|
x = self.transpose(x, (0, 2, 1, 3))
|
||||||
x = self.reshape(x, (-1, self.input_size, self.num_step))
|
x = self.reshape(x, (-1, self.input_size, self.num_step))
|
||||||
x = self.transpose(x, (2, 0, 1))
|
x = self.transpose(x, (2, 0, 1))
|
||||||
# (57, 38, 512)
|
|
||||||
x = self.rnn(x)
|
x = self.rnn(x)
|
||||||
# (57, 38, 256)
|
|
||||||
#x = self.cast(x, mstype.float32)
|
|
||||||
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x,
|
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x,
|
||||||
img_metas,
|
img_metas,
|
||||||
self.anchor_list,
|
self.anchor_list,
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
"""CTPN dataset"""
|
"""CTPN dataset"""
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
import os
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import random
|
from numpy import random
|
||||||
import mmcv
|
import mmcv
|
||||||
|
@ -23,7 +22,6 @@ import mindspore.dataset as de
|
||||||
import mindspore.dataset.vision.c_transforms as C
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
import mindspore.dataset.transforms.c_transforms as CC
|
import mindspore.dataset.transforms.c_transforms as CC
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.mindrecord import FileWriter
|
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
class PhotoMetricDistortion:
|
class PhotoMetricDistortion:
|
||||||
|
@ -98,7 +96,7 @@ class Expand:
|
||||||
boxes += np.tile((left, top), 2)
|
boxes += np.tile((left, top), 2)
|
||||||
return img, boxes, labels
|
return img, boxes, labels
|
||||||
|
|
||||||
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def rescale_column(img, gt_bboxes, gt_label, gt_num, img_shape):
|
||||||
"""rescale operation for image"""
|
"""rescale operation for image"""
|
||||||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
||||||
if img_data.shape[0] > config.img_height:
|
if img_data.shape[0] > config.img_height:
|
||||||
|
@ -112,10 +110,10 @@ def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
|
||||||
|
|
||||||
|
|
||||||
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def resize_column(img, gt_bboxes, gt_label, gt_num, img_shape):
|
||||||
"""resize operation for image"""
|
"""resize operation for image"""
|
||||||
img_data = img
|
img_data = img
|
||||||
img_data, w_scale, h_scale = mmcv.imresize(
|
img_data, w_scale, h_scale = mmcv.imresize(
|
||||||
|
@ -129,10 +127,10 @@ def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
|
||||||
|
|
||||||
|
|
||||||
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def resize_column_test(img, gt_bboxes, gt_label, gt_num, img_shape):
|
||||||
"""resize operation for image of eval"""
|
"""resize operation for image of eval"""
|
||||||
img_data = img
|
img_data = img
|
||||||
img_data, w_scale, h_scale = mmcv.imresize(
|
img_data, w_scale, h_scale = mmcv.imresize(
|
||||||
|
@ -149,34 +147,34 @@ def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
|
||||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
|
||||||
|
|
||||||
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def flipped_generation(img, gt_bboxes, gt_label, gt_num, img_shape):
|
||||||
"""flipped generation"""
|
"""flipped generation"""
|
||||||
img_data = img
|
img_data = img
|
||||||
flipped = gt_bboxes.copy()
|
flipped = gt_bboxes.copy()
|
||||||
_, w, _ = img_data.shape
|
_, w, _ = img_data.shape
|
||||||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
|
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
|
||||||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
|
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
|
||||||
return (img_data, img_shape, flipped, gt_label, gt_num)
|
return (img_data, flipped, gt_label, gt_num, img_shape)
|
||||||
|
|
||||||
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def image_bgr_rgb(img, gt_bboxes, gt_label, gt_num, img_shape):
|
||||||
img_data = img[:, :, ::-1]
|
img_data = img[:, :, ::-1]
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
|
||||||
|
|
||||||
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def photo_crop_column(img, gt_bboxes, gt_label, gt_num, img_shape):
|
||||||
"""photo crop operation for image"""
|
"""photo crop operation for image"""
|
||||||
random_photo = PhotoMetricDistortion()
|
random_photo = PhotoMetricDistortion()
|
||||||
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
|
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
|
||||||
|
|
||||||
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def expand_column(img, gt_bboxes, gt_label, gt_num, img_shape):
|
||||||
"""expand operation for image"""
|
"""expand operation for image"""
|
||||||
expand = Expand()
|
expand = Expand()
|
||||||
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
|
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
|
||||||
|
|
||||||
return (img, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img, gt_bboxes, gt_label, gt_num, img_shape)
|
||||||
|
|
||||||
def split_gtbox_label(gt_bbox_total):
|
def split_gtbox_label(gt_bbox_total):
|
||||||
"""split ground truth box label"""
|
"""split ground truth box label"""
|
||||||
|
@ -193,7 +191,7 @@ def split_gtbox_label(gt_bbox_total):
|
||||||
gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1])
|
gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1])
|
||||||
return np.array(gtbox_list)
|
return np.array(gtbox_list)
|
||||||
|
|
||||||
def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid):
|
def pad_label(img, gt_bboxes, gt_label, gt_valid, img_shape):
|
||||||
"""pad ground truth label"""
|
"""pad ground truth label"""
|
||||||
pad_max_number = 256
|
pad_max_number = 256
|
||||||
gt_label = gt_bboxes[:, 4]
|
gt_label = gt_bboxes[:, 4]
|
||||||
|
@ -208,13 +206,13 @@ def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid):
|
||||||
gt_box = gt_bboxes[0:pad_max_number]
|
gt_box = gt_bboxes[0:pad_max_number]
|
||||||
gt_label = gt_label[0:pad_max_number]
|
gt_label = gt_label[0:pad_max_number]
|
||||||
gt_valid = gt_valid[0:pad_max_number]
|
gt_valid = gt_valid[0:pad_max_number]
|
||||||
return (img, img_shape, gt_box[:, :4], gt_label, gt_valid)
|
return (img, gt_box[:, :4], gt_label, gt_valid, img_shape)
|
||||||
|
|
||||||
def preprocess_fn(image, box, is_training):
|
def preprocess_fn(image, box, is_training):
|
||||||
"""Preprocess function for dataset."""
|
"""Preprocess function for dataset."""
|
||||||
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid):
|
def _infer_data(image_bgr, gt_box_new, gt_label_new, gt_valid, image_shape):
|
||||||
image_shape = image_shape[:2]
|
image_shape = image_shape[:2]
|
||||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid
|
input_data = image_bgr, gt_box_new, gt_label_new, gt_valid, image_shape
|
||||||
if config.keep_ratio:
|
if config.keep_ratio:
|
||||||
input_data = rescale_column(*input_data)
|
input_data = rescale_column(*input_data)
|
||||||
else:
|
else:
|
||||||
|
@ -234,9 +232,9 @@ def preprocess_fn(image, box, is_training):
|
||||||
gt_box = box[:, :4]
|
gt_box = box[:, :4]
|
||||||
gt_label = box[:, 4]
|
gt_label = box[:, 4]
|
||||||
gt_valid = box[:, 4]
|
gt_valid = box[:, 4]
|
||||||
input_data = image_bgr, image_shape, gt_box, gt_label, gt_valid
|
input_data = image_bgr, gt_box, gt_label, gt_valid, image_shape
|
||||||
if not is_training:
|
if not is_training:
|
||||||
return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_valid)
|
return _infer_data(image_bgr, gt_box, gt_label, gt_valid, image_shape)
|
||||||
expand = (np.random.rand() < config.expand_ratio)
|
expand = (np.random.rand() < config.expand_ratio)
|
||||||
if expand:
|
if expand:
|
||||||
input_data = expand_column(*input_data)
|
input_data = expand_column(*input_data)
|
||||||
|
@ -260,46 +258,6 @@ def anno_parser(annos_str):
|
||||||
annos.append(anno)
|
annos.append(anno)
|
||||||
return annos
|
return annos
|
||||||
|
|
||||||
def filter_valid_data(image_dir, anno_path):
|
|
||||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
|
||||||
image_files = []
|
|
||||||
image_anno_dict = {}
|
|
||||||
if not os.path.isdir(image_dir):
|
|
||||||
raise RuntimeError("Path given is not valid.")
|
|
||||||
if not os.path.isfile(anno_path):
|
|
||||||
raise RuntimeError("Annotation file is not valid.")
|
|
||||||
|
|
||||||
with open(anno_path, "rb") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
for line in lines:
|
|
||||||
line_str = line.decode("utf-8").strip()
|
|
||||||
line_split = str(line_str).split(' ')
|
|
||||||
file_name = line_split[0]
|
|
||||||
image_path = os.path.join(image_dir, file_name)
|
|
||||||
if os.path.isfile(image_path):
|
|
||||||
image_anno_dict[image_path] = anno_parser(line_split[1:])
|
|
||||||
image_files.append(image_path)
|
|
||||||
return image_files, image_anno_dict
|
|
||||||
|
|
||||||
def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord", file_num=8):
|
|
||||||
"""Create MindRecord file."""
|
|
||||||
mindrecord_dir = config.mindrecord_dir
|
|
||||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
|
||||||
writer = FileWriter(mindrecord_path, file_num)
|
|
||||||
image_files, image_anno_dict = create_icdar_test_label()
|
|
||||||
ctpn_json = {
|
|
||||||
"image": {"type": "bytes"},
|
|
||||||
"annotation": {"type": "int32", "shape": [-1, 6]},
|
|
||||||
}
|
|
||||||
writer.add_schema(ctpn_json, "ctpn_json")
|
|
||||||
for image_name in image_files:
|
|
||||||
with open(image_name, 'rb') as f:
|
|
||||||
img = f.read()
|
|
||||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
|
||||||
row = {"image": img, "annotation": annos}
|
|
||||||
writer.write_raw_data([row])
|
|
||||||
writer.commit()
|
|
||||||
|
|
||||||
def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0,
|
def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0,
|
||||||
is_training=True, num_parallel_workers=12):
|
is_training=True, num_parallel_workers=12):
|
||||||
"""Creatr ctpn dataset with MindDataset."""
|
"""Creatr ctpn dataset with MindDataset."""
|
||||||
|
@ -316,8 +274,8 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=
|
||||||
type_cast3 = CC.TypeCast(mstype.bool_)
|
type_cast3 = CC.TypeCast(mstype.bool_)
|
||||||
if is_training:
|
if is_training:
|
||||||
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
|
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
|
||||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
output_columns=["image", "box", "label", "valid_num", "image_shape"],
|
||||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
column_order=["image", "box", "label", "valid_num", "image_shape"],
|
||||||
num_parallel_workers=num_parallel_workers,
|
num_parallel_workers=num_parallel_workers,
|
||||||
python_multiprocessing=True)
|
python_multiprocessing=True)
|
||||||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
|
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
|
||||||
|
@ -329,8 +287,8 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=
|
||||||
else:
|
else:
|
||||||
ds = ds.map(operations=compose_map_func,
|
ds = ds.map(operations=compose_map_func,
|
||||||
input_columns=["image", "annotation"],
|
input_columns=["image", "annotation"],
|
||||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
output_columns=["image", "box", "label", "valid_num", "image_shape"],
|
||||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
column_order=["image", "box", "label", "valid_num", "image_shape"],
|
||||||
num_parallel_workers=num_parallel_workers,
|
num_parallel_workers=num_parallel_workers,
|
||||||
python_multiprocessing=True)
|
python_multiprocessing=True)
|
||||||
|
|
||||||
|
|
|
@ -99,8 +99,8 @@ class WithLossCell(nn.Cell):
|
||||||
self._backbone = backbone
|
self._backbone = backbone
|
||||||
self._loss_fn = loss_fn
|
self._loss_fn = loss_fn
|
||||||
|
|
||||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
def construct(self, x, gt_bbox, gt_label, gt_num, img_shape=None):
|
||||||
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, gt_bbox, gt_label, gt_num, img_shape)
|
||||||
return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss)
|
return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -144,10 +144,10 @@ class TrainOneStepCell(nn.Cell):
|
||||||
if reduce_flag:
|
if reduce_flag:
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
def construct(self, x, gt_bbox, gt_label, gt_num, img_shape=None):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, gt_bbox, gt_label, gt_num, img_shape)
|
||||||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
|
grads = self.grad(self.network, weights)(x, gt_bbox, gt_label, gt_num, img_shape, self.sens)
|
||||||
if self.reduce_flag:
|
if self.reduce_flag:
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss
|
return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss
|
||||||
|
|
Loading…
Reference in New Issue