From cc1a0fa59ccb315904bdb3ad0ffcad2b44bda994 Mon Sep 17 00:00:00 2001 From: anzhengqi Date: Fri, 2 Jul 2021 10:00:14 +0800 Subject: [PATCH] optimize resnet50 with imagenet2012 by prefetch_size and num_parallel_workers --- model_zoo/official/cv/resnet/src/dataset.py | 1 + tests/st/networks/models/resnet50/src/dataset.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index 34ab2869a6b..626b8da805a 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -123,6 +123,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= else: device_num = 1 + ds.config.set_prefetch_size(64) if device_num == 1: data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True) else: diff --git a/tests/st/networks/models/resnet50/src/dataset.py b/tests/st/networks/models/resnet50/src/dataset.py index 0d019c02790..ada36751590 100755 --- a/tests/st/networks/models/resnet50/src/dataset.py +++ b/tests/st/networks/models/resnet50/src/dataset.py @@ -72,7 +72,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): type_cast_op = C2.TypeCast(mstype.int32) data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) - data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=4) # apply batch operations data_set = data_set.batch(batch_size, drop_remainder=True)