!12658 fix ctpn performance tingle
From: @qujianwei Reviewed-by: Signed-off-by:
This commit is contained in:
commit
0f3c5b1d0f
|
@ -301,12 +301,12 @@ def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord"
|
|||
writer.commit()
|
||||
|
||||
def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0,
|
||||
is_training=True, num_parallel_workers=4):
|
||||
"""Creatr deeptext dataset with MindDataset."""
|
||||
is_training=True, num_parallel_workers=12):
|
||||
"""Creatr ctpn dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\
|
||||
num_parallel_workers=8, shuffle=is_training)
|
||||
num_parallel_workers=num_parallel_workers, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
|
||||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=num_parallel_workers)
|
||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
||||
hwc_to_chw = C.HWC2CHW()
|
||||
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
|
||||
|
@ -318,17 +318,21 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=
|
|||
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
python_multiprocessing=True)
|
||||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
|
||||
num_parallel_workers=12)
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
python_multiprocessing=True)
|
||||
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
|
||||
num_parallel_workers=12)
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
python_multiprocessing=True)
|
||||
else:
|
||||
ds = ds.map(operations=compose_map_func,
|
||||
input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
column_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
python_multiprocessing=True)
|
||||
|
||||
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
|
||||
num_parallel_workers=24)
|
||||
|
|
Loading…
Reference in New Issue