forked from mindspore-Ecosystem/mindspore
!22117 Fix resnet st probabilistic failure in daily version.
Merge pull request !22117 from linqingke/resnet
This commit is contained in:
commit
866d204658
|
@ -40,12 +40,12 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
rank_id = int(os.getenv("RANK_ID"))
|
||||
if do_train:
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=16, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False,
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=False,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
|
@ -71,7 +71,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=4)
|
||||
|
||||
# apply batch operations
|
||||
|
|
Loading…
Reference in New Issue