optimize resnet50 with imagenet2012 by prefetch_size and num_parallel_workers

This commit is contained in:
anzhengqi 2021-07-02 10:00:14 +08:00
parent 9ef07e3291
commit cc1a0fa59c
2 changed files with 2 additions and 1 deletions

View File

@ -123,6 +123,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
else: else:
device_num = 1 device_num = 1
ds.config.set_prefetch_size(64)
if device_num == 1: if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True) data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
else: else:

View File

@ -72,7 +72,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
type_cast_op = C2.TypeCast(mstype.int32) type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=4)
# apply batch operations # apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True) data_set = data_set.batch(batch_size, drop_remainder=True)