!13481 fix squeezenet 8p performance degradation by adjusting parameters.

From: @anzhengqi
Reviewed-by: @heleiwang,@liucunwei
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2021-03-18 09:11:55 +08:00 committed by Gitee
commit 2cef6a1143
1 changed files with 2 additions and 5 deletions

View File

@ -125,11 +125,9 @@ def create_dataset_imagenet(dataset_path,
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path,
num_parallel_workers=8,
shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path,
num_parallel_workers=8,
shuffle=True,
num_shards=device_num,
shard_id=rank_id)
@ -162,11 +160,10 @@ def create_dataset_imagenet(dataset_path,
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=type_cast_op,
input_columns="label",
num_parallel_workers=8)
input_columns="label")
data_set = data_set.map(operations=trans,
input_columns="image",
num_parallel_workers=8)
num_parallel_workers=10)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)