diff --git a/model_zoo/official/cv/squeezenet/src/dataset.py b/model_zoo/official/cv/squeezenet/src/dataset.py index fc4e8167e8c..75d5d9ed9c4 100755 --- a/model_zoo/official/cv/squeezenet/src/dataset.py +++ b/model_zoo/official/cv/squeezenet/src/dataset.py @@ -125,11 +125,9 @@ def create_dataset_imagenet(dataset_path, if device_num == 1: data_set = ds.ImageFolderDataset(dataset_path, - num_parallel_workers=8, shuffle=True) else: data_set = ds.ImageFolderDataset(dataset_path, - num_parallel_workers=8, shuffle=True, num_shards=device_num, shard_id=rank_id) @@ -162,11 +160,10 @@ def create_dataset_imagenet(dataset_path, type_cast_op = C2.TypeCast(mstype.int32) data_set = data_set.map(operations=type_cast_op, - input_columns="label", - num_parallel_workers=8) + input_columns="label") data_set = data_set.map(operations=trans, input_columns="image", - num_parallel_workers=8) + num_parallel_workers=10) # apply batch operations data_set = data_set.batch(batch_size, drop_remainder=True)