From ccfd572aa1fd7a917df0bc6826f1b62f6f2b03bc Mon Sep 17 00:00:00 2001 From: qujianwei Date: Fri, 26 Feb 2021 15:14:53 +0800 Subject: [PATCH] fix ctpn performance tingle --- model_zoo/official/cv/ctpn/src/dataset.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/model_zoo/official/cv/ctpn/src/dataset.py b/model_zoo/official/cv/ctpn/src/dataset.py index cdc4cc582f3..03acea4b620 100644 --- a/model_zoo/official/cv/ctpn/src/dataset.py +++ b/model_zoo/official/cv/ctpn/src/dataset.py @@ -301,12 +301,12 @@ def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord" writer.commit() def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0, - is_training=True, num_parallel_workers=4): - """Creatr deeptext dataset with MindDataset.""" + is_training=True, num_parallel_workers=12): + """Creatr ctpn dataset with MindDataset.""" ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\ - num_parallel_workers=8, shuffle=is_training) + num_parallel_workers=num_parallel_workers, shuffle=is_training) decode = C.Decode() - ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1) + ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=num_parallel_workers) compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) hwc_to_chw = C.HWC2CHW() normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) @@ -318,17 +318,21 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num= ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], output_columns=["image", "image_shape", "box", "label", "valid_num"], column_order=["image", "image_shape", "box", "label", "valid_num"], - num_parallel_workers=num_parallel_workers) + num_parallel_workers=num_parallel_workers, + python_multiprocessing=True) ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"], - num_parallel_workers=12) + num_parallel_workers=num_parallel_workers, + python_multiprocessing=True) ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"], - num_parallel_workers=12) + num_parallel_workers=num_parallel_workers, + python_multiprocessing=True) else: ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], output_columns=["image", "image_shape", "box", "label", "valid_num"], column_order=["image", "image_shape", "box", "label", "valid_num"], - num_parallel_workers=num_parallel_workers) + num_parallel_workers=num_parallel_workers, + python_multiprocessing=True) ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"], num_parallel_workers=24)