forked from mindspore-Ecosystem/mindspore
!13481 fix squeezenet 8p performance degradation by adjusting parameters.
From: @anzhengqi Reviewed-by: @heleiwang,@liucunwei Signed-off-by: @liucunwei
This commit is contained in:
commit
2cef6a1143
|
@ -125,11 +125,9 @@ def create_dataset_imagenet(dataset_path,
|
||||||
|
|
||||||
if device_num == 1:
|
if device_num == 1:
|
||||||
data_set = ds.ImageFolderDataset(dataset_path,
|
data_set = ds.ImageFolderDataset(dataset_path,
|
||||||
num_parallel_workers=8,
|
|
||||||
shuffle=True)
|
shuffle=True)
|
||||||
else:
|
else:
|
||||||
data_set = ds.ImageFolderDataset(dataset_path,
|
data_set = ds.ImageFolderDataset(dataset_path,
|
||||||
num_parallel_workers=8,
|
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_shards=device_num,
|
num_shards=device_num,
|
||||||
shard_id=rank_id)
|
shard_id=rank_id)
|
||||||
|
@ -162,11 +160,10 @@ def create_dataset_imagenet(dataset_path,
|
||||||
type_cast_op = C2.TypeCast(mstype.int32)
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
|
||||||
data_set = data_set.map(operations=type_cast_op,
|
data_set = data_set.map(operations=type_cast_op,
|
||||||
input_columns="label",
|
input_columns="label")
|
||||||
num_parallel_workers=8)
|
|
||||||
data_set = data_set.map(operations=trans,
|
data_set = data_set.map(operations=trans,
|
||||||
input_columns="image",
|
input_columns="image",
|
||||||
num_parallel_workers=8)
|
num_parallel_workers=10)
|
||||||
|
|
||||||
# apply batch operations
|
# apply batch operations
|
||||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
|
Loading…
Reference in New Issue