!22117 Fix resnet st probabilistic failure in daily version.

Merge pull request !22117 from linqingke/resnet
This commit is contained in:
i-robot 2021-08-23 03:49:04 +00:00 committed by Gitee
commit 866d204658
1 changed files with 4 additions and 4 deletions

View File

@ -40,12 +40,12 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
rank_id = int(os.getenv("RANK_ID"))
if do_train:
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=16, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True,
num_shards=device_num, shard_id=rank_id)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=False,
num_shards=device_num, shard_id=rank_id)
image_size = 224
@ -71,7 +71,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=4)
# apply batch operations