diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index 34ab2869a6b..626b8da805a 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -123,6 +123,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= else: device_num = 1 + ds.config.set_prefetch_size(64) if device_num == 1: data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True) else: diff --git a/tests/st/networks/models/resnet50/src/dataset.py b/tests/st/networks/models/resnet50/src/dataset.py index 0d019c02790..ada36751590 100755 --- a/tests/st/networks/models/resnet50/src/dataset.py +++ b/tests/st/networks/models/resnet50/src/dataset.py @@ -72,7 +72,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): type_cast_op = C2.TypeCast(mstype.int32) data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) - data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=4) # apply batch operations data_set = data_set.batch(batch_size, drop_remainder=True)