update tests/st/networks/models/resnet50/src/dataset.py.

This commit is contained in:
wangmin0104 2020-12-27 20:06:44 +08:00 committed by wangmin
parent 2675d13804
commit a5c16ba5c4
1 changed files with 7 additions and 3 deletions

View File

@ -38,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
device_num = int(os.getenv("RANK_SIZE")) device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID")) rank_id = int(os.getenv("RANK_ID"))
if device_num == 1: if do_train:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
else: else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False,
num_shards=device_num, shard_id=rank_id) num_shards=device_num, shard_id=rank_id)
image_size = 224 image_size = 224