From 7b5640da4e6d4c8d27f34e1289de65f5b96d0737 Mon Sep 17 00:00:00 2001 From: ms_yan Date: Tue, 31 Mar 2020 18:46:59 +0800 Subject: [PATCH] Repair parameter check problem in TFRecordDataset --- mindspore/dataset/engine/validators.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 4dc1867808d..adfe54a02e0 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -398,6 +398,7 @@ def check_tfrecorddataset(method): nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_list = ['columns_list'] + nreq_param_bool = ['shard_equal_rows'] # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') @@ -410,6 +411,10 @@ def check_tfrecorddataset(method): check_param_type(nreq_param_list, param_dict, list) + check_param_type(nreq_param_bool, param_dict, bool) + + check_sampler_shuffle_shard_options(param_dict) + return method(*args, **kwargs) return new_method