forked from mindspore-Ecosystem/mindspore
optimize resnet50 with imagenet2012 by prefetch_size and num_parallel_workers
This commit is contained in:
parent
9ef07e3291
commit
cc1a0fa59c
|
@ -123,6 +123,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
||||||
else:
|
else:
|
||||||
device_num = 1
|
device_num = 1
|
||||||
|
|
||||||
|
ds.config.set_prefetch_size(64)
|
||||||
if device_num == 1:
|
if device_num == 1:
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -72,7 +72,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||||
type_cast_op = C2.TypeCast(mstype.int32)
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
|
||||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=4)
|
||||||
|
|
||||||
# apply batch operations
|
# apply batch operations
|
||||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
|
Loading…
Reference in New Issue