forked from mindspore-Ecosystem/mindspore
!59 dataset: Repair parameter check problem in TFRecordDataset
Merge pull request !59 from ms_yan/repair_shard_tf
This commit is contained in:
commit
1fc5a69d6f
|
@ -398,6 +398,7 @@ def check_tfrecorddataset(method):
|
|||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_list = ['columns_list']
|
||||
nreq_param_bool = ['shard_equal_rows']
|
||||
|
||||
# check dataset_files; required argument
|
||||
dataset_files = param_dict.get('dataset_files')
|
||||
|
@ -410,6 +411,10 @@ def check_tfrecorddataset(method):
|
|||
|
||||
check_param_type(nreq_param_list, param_dict, list)
|
||||
|
||||
check_param_type(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
Loading…
Reference in New Issue