forked from mindspore-Ecosystem/mindspore
update tests/st/networks/models/resnet50/src/dataset.py.
This commit is contained in:
parent
2675d13804
commit
a5c16ba5c4
|
@ -38,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
if do_train:
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
|
|
Loading…
Reference in New Issue