!12658 fix ctpn performance tingle

From: @qujianwei
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-01 10:42:30 +08:00 committed by Gitee
commit 0f3c5b1d0f
1 changed files with 12 additions and 8 deletions

View File

@ -301,12 +301,12 @@ def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord"
writer.commit()
def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=4):
"""Creatr deeptext dataset with MindDataset."""
is_training=True, num_parallel_workers=12):
"""Creatr ctpn dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\
num_parallel_workers=8, shuffle=is_training)
num_parallel_workers=num_parallel_workers, shuffle=is_training)
decode = C.Decode()
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=num_parallel_workers)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
@ -318,17 +318,21 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)
num_parallel_workers=num_parallel_workers,
python_multiprocessing=True)
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
num_parallel_workers=12)
num_parallel_workers=num_parallel_workers,
python_multiprocessing=True)
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=12)
num_parallel_workers=num_parallel_workers,
python_multiprocessing=True)
else:
ds = ds.map(operations=compose_map_func,
input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)
num_parallel_workers=num_parallel_workers,
python_multiprocessing=True)
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=24)