diff --git a/tests/st/networks/models/resnet50/src/dataset.py b/tests/st/networks/models/resnet50/src/dataset.py index 799b1fed748..0d019c02790 100755 --- a/tests/st/networks/models/resnet50/src/dataset.py +++ b/tests/st/networks/models/resnet50/src/dataset.py @@ -38,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): device_num = int(os.getenv("RANK_SIZE")) rank_id = int(os.getenv("RANK_ID")) - if device_num == 1: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + if do_train: + if device_num == 1: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False, num_shards=device_num, shard_id=rank_id) image_size = 224