optimize network resnet's dataset module by prefetch_size and num_parallel_workers

This commit is contained in:
anzhengqi 2021-07-01 11:07:10 +08:00
parent 4703c3085f
commit 81a02891d6
1 changed files with 3 additions and 2 deletions

View File

@ -49,10 +49,11 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
device_num = get_group_size()
else:
device_num = 1
ds.config.set_prefetch_size(64)
if device_num == 1:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True)
else:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True,
num_shards=device_num, shard_id=rank_id)
# define map operations