optimize network resnet's dataset module by prefetch_size and num_parallel_workers
This commit is contained in:
parent
4703c3085f
commit
81a02891d6
|
@ -49,10 +49,11 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
device_num = get_group_size()
|
||||
else:
|
||||
device_num = 1
|
||||
ds.config.set_prefetch_size(64)
|
||||
if device_num == 1:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
# define map operations
|
||||
|
|
Loading…
Reference in New Issue