forked from mindspore-Ecosystem/mindspore
!8732 [MD][PERFORMANCE] Fix Faster RCNN x86 performance issue
From: @xiefangqi Reviewed-by: Signed-off-by:
This commit is contained in:
commit
0a9899e3a1
|
@ -41,8 +41,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=a
|
||||||
|
|
||||||
def FasterRcnn_eval(dataset_path, ckpt_path, ann_file):
|
def FasterRcnn_eval(dataset_path, ckpt_path, ann_file):
|
||||||
"""FasterRcnn evaluation."""
|
"""FasterRcnn evaluation."""
|
||||||
ds = create_fasterrcnn_dataset(dataset_path, batch_size=config.test_batch_size,
|
ds = create_fasterrcnn_dataset(dataset_path, batch_size=config.test_batch_size, is_training=False)
|
||||||
repeat_num=1, is_training=False)
|
|
||||||
net = Faster_Rcnn_Resnet50(config)
|
net = Faster_Rcnn_Resnet50(config)
|
||||||
param_dict = load_checkpoint(ckpt_path)
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
|
|
|
@ -23,11 +23,9 @@ from numpy import random
|
||||||
import mmcv
|
import mmcv
|
||||||
import mindspore.dataset as de
|
import mindspore.dataset as de
|
||||||
import mindspore.dataset.vision.c_transforms as C
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
import mindspore.dataset.transforms.c_transforms as CC
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
from mindspore.mindrecord import FileWriter
|
from mindspore.mindrecord import FileWriter
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
import cv2
|
||||||
|
|
||||||
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
||||||
"""Calculate the ious between each bbox of bboxes1 and bboxes2.
|
"""Calculate the ious between each bbox of bboxes1 and bboxes2.
|
||||||
|
@ -74,10 +72,8 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
||||||
ious = ious.T
|
ious = ious.T
|
||||||
return ious
|
return ious
|
||||||
|
|
||||||
|
|
||||||
class PhotoMetricDistortion:
|
class PhotoMetricDistortion:
|
||||||
"""Photo Metric Distortion"""
|
"""Photo Metric Distortion"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
brightness_delta=32,
|
brightness_delta=32,
|
||||||
contrast_range=(0.5, 1.5),
|
contrast_range=(0.5, 1.5),
|
||||||
|
@ -136,10 +132,8 @@ class PhotoMetricDistortion:
|
||||||
|
|
||||||
return img, boxes, labels
|
return img, boxes, labels
|
||||||
|
|
||||||
|
|
||||||
class Expand:
|
class Expand:
|
||||||
"""expand image"""
|
"""expand image"""
|
||||||
|
|
||||||
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
|
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
|
||||||
if to_rgb:
|
if to_rgb:
|
||||||
self.mean = mean[::-1]
|
self.mean = mean[::-1]
|
||||||
|
@ -162,7 +156,6 @@ class Expand:
|
||||||
boxes += np.tile((left, top), 2)
|
boxes += np.tile((left, top), 2)
|
||||||
return img, boxes, labels
|
return img, boxes, labels
|
||||||
|
|
||||||
|
|
||||||
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""rescale operation for image"""
|
"""rescale operation for image"""
|
||||||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
|
||||||
|
@ -178,7 +171,6 @@ def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""resize operation for image"""
|
"""resize operation for image"""
|
||||||
img_data = img
|
img_data = img
|
||||||
|
@ -196,7 +188,6 @@ def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""resize operation for image of eval"""
|
"""resize operation for image of eval"""
|
||||||
img_data = img
|
img_data = img
|
||||||
|
@ -214,21 +205,18 @@ def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""impad operation for image"""
|
"""impad operation for image"""
|
||||||
img_data = mmcv.impad(img, (config.img_height, config.img_width))
|
img_data = mmcv.impad(img, (config.img_height, config.img_width))
|
||||||
img_data = img_data.astype(np.float32)
|
img_data = img_data.astype(np.float32)
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""imnormalize operation for image"""
|
"""imnormalize operation for image"""
|
||||||
img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
|
img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
|
||||||
img_data = img_data.astype(np.float32)
|
img_data = img_data.astype(np.float32)
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""flip operation for image"""
|
"""flip operation for image"""
|
||||||
img_data = img
|
img_data = img
|
||||||
|
@ -241,24 +229,6 @@ def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
|
|
||||||
return (img_data, img_shape, flipped, gt_label, gt_num)
|
return (img_data, img_shape, flipped, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num):
|
|
||||||
"""flipped generation"""
|
|
||||||
img_data = img
|
|
||||||
flipped = gt_bboxes.copy()
|
|
||||||
_, w, _ = img_data.shape
|
|
||||||
|
|
||||||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
|
|
||||||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
|
|
||||||
|
|
||||||
return (img_data, img_shape, flipped, gt_label, gt_num)
|
|
||||||
|
|
||||||
|
|
||||||
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num):
|
|
||||||
img_data = img[:, :, ::-1]
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
|
||||||
|
|
||||||
|
|
||||||
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""transpose operation for image"""
|
"""transpose operation for image"""
|
||||||
img_data = img.transpose(2, 0, 1).copy()
|
img_data = img.transpose(2, 0, 1).copy()
|
||||||
|
@ -270,7 +240,6 @@ def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""photo crop operation for image"""
|
"""photo crop operation for image"""
|
||||||
random_photo = PhotoMetricDistortion()
|
random_photo = PhotoMetricDistortion()
|
||||||
|
@ -278,7 +247,6 @@ def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
|
|
||||||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
"""expand operation for image"""
|
"""expand operation for image"""
|
||||||
expand = Expand()
|
expand = Expand()
|
||||||
|
@ -286,10 +254,8 @@ def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
|
||||||
|
|
||||||
return (img, img_shape, gt_bboxes, gt_label, gt_num)
|
return (img, img_shape, gt_bboxes, gt_label, gt_num)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_fn(image, box, is_training):
|
def preprocess_fn(image, box, is_training):
|
||||||
"""Preprocess function for dataset."""
|
"""Preprocess function for dataset."""
|
||||||
|
|
||||||
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
|
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
|
||||||
image_shape = image_shape[:2]
|
image_shape = image_shape[:2]
|
||||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
|
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
|
||||||
|
@ -298,10 +264,9 @@ def preprocess_fn(image, box, is_training):
|
||||||
input_data = rescale_column(*input_data)
|
input_data = rescale_column(*input_data)
|
||||||
else:
|
else:
|
||||||
input_data = resize_column_test(*input_data)
|
input_data = resize_column_test(*input_data)
|
||||||
|
input_data = imnormalize_column(*input_data)
|
||||||
|
|
||||||
input_data = image_bgr_rgb(*input_data)
|
output_data = transpose_column(*input_data)
|
||||||
|
|
||||||
output_data = input_data
|
|
||||||
return output_data
|
return output_data
|
||||||
|
|
||||||
def _data_aug(image, box, is_training):
|
def _data_aug(image, box, is_training):
|
||||||
|
@ -324,25 +289,25 @@ def preprocess_fn(image, box, is_training):
|
||||||
if not is_training:
|
if not is_training:
|
||||||
return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)
|
return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)
|
||||||
|
|
||||||
|
flip = (np.random.rand() < config.flip_ratio)
|
||||||
|
expand = (np.random.rand() < config.expand_ratio)
|
||||||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
|
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
|
||||||
|
|
||||||
expand = (np.random.rand() < config.expand_ratio)
|
|
||||||
if expand:
|
if expand:
|
||||||
input_data = expand_column(*input_data)
|
input_data = expand_column(*input_data)
|
||||||
|
|
||||||
if config.keep_ratio:
|
if config.keep_ratio:
|
||||||
input_data = rescale_column(*input_data)
|
input_data = rescale_column(*input_data)
|
||||||
else:
|
else:
|
||||||
input_data = resize_column(*input_data)
|
input_data = resize_column(*input_data)
|
||||||
|
input_data = imnormalize_column(*input_data)
|
||||||
|
if flip:
|
||||||
|
input_data = flip_column(*input_data)
|
||||||
|
|
||||||
input_data = image_bgr_rgb(*input_data)
|
output_data = transpose_column(*input_data)
|
||||||
|
|
||||||
output_data = input_data
|
|
||||||
return output_data
|
return output_data
|
||||||
|
|
||||||
return _data_aug(image, box, is_training)
|
return _data_aug(image, box, is_training)
|
||||||
|
|
||||||
|
|
||||||
def create_coco_label(is_training):
|
def create_coco_label(is_training):
|
||||||
"""Get image path and annotation from COCO."""
|
"""Get image path and annotation from COCO."""
|
||||||
from pycocotools.coco import COCO
|
from pycocotools.coco import COCO
|
||||||
|
@ -393,7 +358,6 @@ def create_coco_label(is_training):
|
||||||
|
|
||||||
return image_files, image_anno_dict
|
return image_files, image_anno_dict
|
||||||
|
|
||||||
|
|
||||||
def anno_parser(annos_str):
|
def anno_parser(annos_str):
|
||||||
"""Parse annotation from string to list."""
|
"""Parse annotation from string to list."""
|
||||||
annos = []
|
annos = []
|
||||||
|
@ -402,7 +366,6 @@ def anno_parser(annos_str):
|
||||||
annos.append(anno)
|
annos.append(anno)
|
||||||
return annos
|
return annos
|
||||||
|
|
||||||
|
|
||||||
def filter_valid_data(image_dir, anno_path):
|
def filter_valid_data(image_dir, anno_path):
|
||||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||||
image_files = []
|
image_files = []
|
||||||
|
@ -424,7 +387,6 @@ def filter_valid_data(image_dir, anno_path):
|
||||||
image_files.append(image_path)
|
image_files.append(image_path)
|
||||||
return image_files, image_anno_dict
|
return image_files, image_anno_dict
|
||||||
|
|
||||||
|
|
||||||
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8):
|
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8):
|
||||||
"""Create MindRecord file."""
|
"""Create MindRecord file."""
|
||||||
mindrecord_dir = config.mindrecord_dir
|
mindrecord_dir = config.mindrecord_dir
|
||||||
|
@ -449,59 +411,29 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast
|
||||||
writer.write_raw_data([row])
|
writer.write_raw_data([row])
|
||||||
writer.commit()
|
writer.commit()
|
||||||
|
|
||||||
|
def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id=0, is_training=True,
|
||||||
def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
|
num_parallel_workers=8):
|
||||||
is_training=True, num_parallel_workers=4):
|
|
||||||
"""Creatr FasterRcnn dataset with MindDataset."""
|
"""Creatr FasterRcnn dataset with MindDataset."""
|
||||||
|
cv2.setNumThreads(0)
|
||||||
|
de.config.set_prefetch_size(8)
|
||||||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
|
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
|
||||||
num_parallel_workers=1, shuffle=is_training)
|
num_parallel_workers=4, shuffle=is_training)
|
||||||
decode = C.Decode()
|
decode = C.Decode()
|
||||||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
|
ds = ds.map(input_columns=["image"], operations=decode)
|
||||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
||||||
|
|
||||||
hwc_to_chw = C.HWC2CHW()
|
|
||||||
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
|
|
||||||
horizontally_op = C.RandomHorizontalFlip(1)
|
|
||||||
type_cast0 = CC.TypeCast(mstype.float32)
|
|
||||||
type_cast1 = CC.TypeCast(mstype.float16)
|
|
||||||
type_cast2 = CC.TypeCast(mstype.int32)
|
|
||||||
type_cast3 = CC.TypeCast(mstype.bool_)
|
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
|
ds = ds.map(input_columns=["image", "annotation"],
|
||||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||||
|
operations=compose_map_func, python_multiprocessing=False,
|
||||||
num_parallel_workers=num_parallel_workers)
|
num_parallel_workers=num_parallel_workers)
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
flip = (np.random.rand() < config.flip_ratio)
|
|
||||||
if flip:
|
|
||||||
ds = ds.map(operations=[normalize_op, type_cast0, horizontally_op], input_columns=["image"],
|
|
||||||
num_parallel_workers=12)
|
|
||||||
ds = ds.map(operations=flipped_generation,
|
|
||||||
input_columns=["image", "image_shape", "box", "label", "valid_num"],
|
|
||||||
num_parallel_workers=num_parallel_workers)
|
|
||||||
else:
|
|
||||||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
|
|
||||||
num_parallel_workers=12)
|
|
||||||
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
|
|
||||||
num_parallel_workers=12)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ds = ds.map(operations=compose_map_func,
|
ds = ds.map(input_columns=["image", "annotation"],
|
||||||
input_columns=["image", "annotation"],
|
|
||||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||||
|
operations=compose_map_func,
|
||||||
num_parallel_workers=num_parallel_workers)
|
num_parallel_workers=num_parallel_workers)
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
|
|
||||||
num_parallel_workers=24)
|
|
||||||
|
|
||||||
# transpose_column from python to c
|
|
||||||
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
|
|
||||||
ds = ds.map(operations=[type_cast1], input_columns=["box"])
|
|
||||||
ds = ds.map(operations=[type_cast2], input_columns=["label"])
|
|
||||||
ds = ds.map(operations=[type_cast3], input_columns=["valid_num"])
|
|
||||||
ds = ds.batch(batch_size, drop_remainder=True)
|
|
||||||
ds = ds.repeat(repeat_num)
|
|
||||||
|
|
||||||
return ds
|
return ds
|
||||||
|
|
|
@ -101,8 +101,8 @@ if __name__ == '__main__':
|
||||||
loss_scale = float(config.loss_scale)
|
loss_scale = float(config.loss_scale)
|
||||||
|
|
||||||
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
|
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
|
||||||
dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=1,
|
dataset = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.batch_size,
|
||||||
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
|
device_num=device_num, rank_id=rank)
|
||||||
|
|
||||||
dataset_size = dataset.get_dataset_size()
|
dataset_size = dataset.get_dataset_size()
|
||||||
print("Create dataset done!")
|
print("Create dataset done!")
|
||||||
|
|
Loading…
Reference in New Issue