forked from mindspore-Ecosystem/mindspore
update tests/st/networks/models/resnet50/src/dataset.py.
This commit is contained in:
parent
2675d13804
commit
a5c16ba5c4
|
@ -38,11 +38,15 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
if do_train:
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
|
|
Loading…
Reference in New Issue