forked from mindspore-Ecosystem/mindspore
repair issue for same column name and para check
This commit is contained in:
parent
ee199007ed
commit
83bc1fb343
|
@ -95,7 +95,7 @@ def check_uint32(value, arg_name=""):
|
|||
|
||||
def check_pos_int32(value, arg_name=""):
|
||||
type_check(value, (int,), arg_name)
|
||||
check_value(value, [POS_INT_MIN, INT32_MAX])
|
||||
check_value(value, [POS_INT_MIN, INT32_MAX], arg_name)
|
||||
|
||||
|
||||
def check_uint64(value, arg_name=""):
|
||||
|
@ -143,6 +143,8 @@ def check_columns(columns, name):
|
|||
|
||||
col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
|
||||
type_check_list(columns, (str,), col_names)
|
||||
if len(set(columns)) != len(columns):
|
||||
raise ValueError("Every column name should not be same with others in column_names.")
|
||||
|
||||
|
||||
def parse_user_args(method, *args, **kwargs):
|
||||
|
|
|
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||
check_split, check_bucket_batch_by_length, check_cluedataset
|
||||
check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
||||
try:
|
||||
|
@ -939,6 +939,7 @@ class Dataset:
|
|||
raise TypeError("apply_func must return a dataset.")
|
||||
return dataset
|
||||
|
||||
@check_positive_int32
|
||||
def device_que(self, prefetch_size=None):
|
||||
"""
|
||||
Return a transferredDataset that transfer data through device.
|
||||
|
@ -956,6 +957,7 @@ class Dataset:
|
|||
"""
|
||||
return self.to_device()
|
||||
|
||||
@check_positive_int32
|
||||
def to_device(self, num_batch=None):
|
||||
"""
|
||||
Transfer data through CPU, GPU or Ascend devices.
|
||||
|
@ -973,7 +975,7 @@ class Dataset:
|
|||
Raises:
|
||||
TypeError: If device_type is empty.
|
||||
ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
|
||||
ValueError: If num_batch is None or 0 or larger than int_max.
|
||||
ValueError: If num_batch is negative or larger than int_max.
|
||||
RuntimeError: If dataset is unknown.
|
||||
RuntimeError: If distribution file path is given but failed to read.
|
||||
"""
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore._c_expression import typing
|
|||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
||||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
||||
check_columns, check_positive
|
||||
check_columns, check_positive, check_pos_int32
|
||||
|
||||
from . import datasets
|
||||
from . import samplers
|
||||
|
@ -593,6 +593,25 @@ def check_take(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_positive_int32(method):
|
||||
"""check whether the input argument is positive and int, only works for functions with one input."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[count], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
para_name = None
|
||||
for key in list(param_dict.keys()):
|
||||
if key not in ['self', 'cls']:
|
||||
para_name = key
|
||||
# Need to get default value of param
|
||||
if count is not None:
|
||||
check_pos_int32(count, para_name)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_zip(method):
|
||||
"""check the input arguments of zip."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue