forked from mindspore-Ecosystem/mindspore
update tests/st/networks/models/resnet50/src/dataset.py.
This commit is contained in:
parent
2675d13804
commit
a5c16ba5c4
|
@ -38,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||||
|
|
||||||
device_num = int(os.getenv("RANK_SIZE"))
|
device_num = int(os.getenv("RANK_SIZE"))
|
||||||
rank_id = int(os.getenv("RANK_ID"))
|
rank_id = int(os.getenv("RANK_ID"))
|
||||||
if device_num == 1:
|
if do_train:
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
if device_num == 1:
|
||||||
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||||
|
else:
|
||||||
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||||
|
num_shards=device_num, shard_id=rank_id)
|
||||||
else:
|
else:
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False,
|
||||||
num_shards=device_num, shard_id=rank_id)
|
num_shards=device_num, shard_id=rank_id)
|
||||||
|
|
||||||
image_size = 224
|
image_size = 224
|
||||||
|
|
Loading…
Reference in New Issue