forked from mindspore-Ecosystem/mindspore
fixes for PR-2908: avoid empty strings for column names
This commit is contained in:
parent
b23fc4e492
commit
2c7fd248f8
|
@ -123,25 +123,39 @@ def check_valid_detype(type_):
|
|||
|
||||
|
||||
def check_columns(columns, name):
|
||||
"""
|
||||
Validate strings in column_names.
|
||||
|
||||
Args:
|
||||
columns (list): list of column_names.
|
||||
name (str): name of columns.
|
||||
|
||||
Returns:
|
||||
Exception: when the value is not correct, otherwise nothing.
|
||||
"""
|
||||
type_check(columns, (list, str), name)
|
||||
if isinstance(columns, list):
|
||||
if not columns:
|
||||
raise ValueError("Column names should not be empty")
|
||||
col_names = ["col_{0}".format(i) for i in range(len(columns))]
|
||||
raise ValueError("{0} should not be empty".format(name))
|
||||
for i, column_name in enumerate(columns):
|
||||
if not column_name:
|
||||
raise ValueError("{0}[{1}] should not be empty".format(name, i))
|
||||
|
||||
col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
|
||||
type_check_list(columns, (str,), col_names)
|
||||
|
||||
|
||||
def parse_user_args(method, *args, **kwargs):
|
||||
"""
|
||||
Parse user arguments in a function
|
||||
Parse user arguments in a function.
|
||||
|
||||
Args:
|
||||
method (method): a callable function
|
||||
*args: user passed args
|
||||
**kwargs: user passed kwargs
|
||||
method (method): a callable function.
|
||||
*args: user passed args.
|
||||
**kwargs: user passed kwargs.
|
||||
|
||||
Returns:
|
||||
user_filled_args (list): values of what the user passed in for the arguments,
|
||||
user_filled_args (list): values of what the user passed in for the arguments.
|
||||
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
|
||||
"""
|
||||
sig = inspect.signature(method)
|
||||
|
@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs):
|
|||
|
||||
def type_check_list(args, types, arg_names):
|
||||
"""
|
||||
Check the type of each parameter in the list
|
||||
Check the type of each parameter in the list.
|
||||
|
||||
Args:
|
||||
args (list, tuple): a list or tuple of any variable
|
||||
types (tuple): tuple of all valid types for arg
|
||||
arg_names (list, tuple of str): the names of args
|
||||
args (list, tuple): a list or tuple of any variable.
|
||||
types (tuple): tuple of all valid types for arg.
|
||||
arg_names (list, tuple of str): the names of args.
|
||||
|
||||
Returns:
|
||||
Exception: when the type is not correct, otherwise nothing
|
||||
Exception: when the type is not correct, otherwise nothing.
|
||||
"""
|
||||
type_check(args, (list, tuple,), arg_names)
|
||||
if len(args) != len(arg_names):
|
||||
|
@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names):
|
|||
|
||||
def type_check(arg, types, arg_name):
|
||||
"""
|
||||
Check the type of the parameter
|
||||
Check the type of the parameter.
|
||||
|
||||
Args:
|
||||
arg : any variable
|
||||
types (tuple): tuple of all valid types for arg
|
||||
arg_name (str): the name of arg
|
||||
arg : any variable.
|
||||
types (tuple): tuple of all valid types for arg.
|
||||
arg_name (str): the name of arg.
|
||||
|
||||
Returns:
|
||||
Exception: when the type is not correct, otherwise nothing
|
||||
Exception: when the type is not correct, otherwise nothing.
|
||||
"""
|
||||
# handle special case of booleans being a subclass of ints
|
||||
print_value = '\"\"' if repr(arg) == repr('') else arg
|
||||
|
@ -201,13 +215,13 @@ def type_check(arg, types, arg_name):
|
|||
|
||||
def check_filename(path):
|
||||
"""
|
||||
check the filename in the path
|
||||
check the filename in the path.
|
||||
|
||||
Args:
|
||||
path (str): the path
|
||||
path (str): the path.
|
||||
|
||||
Returns:
|
||||
Exception: when error
|
||||
Exception: when error.
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
raise TypeError("path: {} is not string".format(path))
|
||||
|
@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
"""
|
||||
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
|
||||
Args:
|
||||
param_dict (dict): param_dict
|
||||
param_dict (dict): param_dict.
|
||||
|
||||
Returns:
|
||||
Exception: ValueError or RuntimeError if error
|
||||
Exception: ValueError or RuntimeError if error.
|
||||
"""
|
||||
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
|
||||
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||
|
@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
|
||||
def check_padding_options(param_dict):
|
||||
"""
|
||||
Check for valid padded_sample and num_padded of padded samples
|
||||
Check for valid padded_sample and num_padded of padded samples.
|
||||
|
||||
Args:
|
||||
param_dict (dict): param_dict
|
||||
param_dict (dict): param_dict.
|
||||
|
||||
Returns:
|
||||
Exception: ValueError or RuntimeError if error
|
||||
Exception: ValueError or RuntimeError if error.
|
||||
"""
|
||||
|
||||
columns_list = param_dict.get('columns_list')
|
||||
|
@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name):
|
|||
Check if the input parameter is list or numpy.ndarray.
|
||||
|
||||
Args:
|
||||
param (list, nd.ndarray): param
|
||||
param_name (str): param_name
|
||||
param (list, nd.ndarray): param.
|
||||
param_name (str): param_name.
|
||||
|
||||
Returns:
|
||||
Exception: TypeError if error
|
||||
Exception: TypeError if error.
|
||||
"""
|
||||
|
||||
type_check(param, (list, np.ndarray), param_name)
|
||||
|
|
|
@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method):
|
|||
type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
|
||||
|
||||
# check column_names: must be list of string.
|
||||
if not column_names:
|
||||
raise ValueError("column_names cannot be empty")
|
||||
|
||||
all_string = all(isinstance(item, str) for item in column_names)
|
||||
if not all_string:
|
||||
raise TypeError("column_names should be a list of str.")
|
||||
check_columns(column_names, "column_names")
|
||||
|
||||
if element_length_function is None and len(column_names) != 1:
|
||||
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input():
|
|||
|
||||
with pytest.raises(TypeError) as info:
|
||||
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
|
||||
assert "column_names should be a list of str" in str(info.value)
|
||||
assert "Argument column_names[0] with value 1 is not of type (<class 'str'>,)." in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.dataset as de
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler():
|
|||
|
||||
|
||||
def test_numpy_slices_sequential_sampler():
|
||||
|
||||
logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")
|
||||
|
||||
np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
|
||||
|
@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler():
|
|||
assert np.equal(data[0], np_data[i % 8]).all()
|
||||
|
||||
|
||||
def test_numpy_slices_invalid_column_names_type():
|
||||
logger.info("Test incorrect column_names input")
|
||||
np_data = [1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False)
|
||||
assert "Argument column_names[0] with value 1 is not of type (<class 'str'>,)." in str(err.value)
|
||||
|
||||
|
||||
def test_numpy_slices_invalid_column_names_string():
|
||||
logger.info("Test incorrect column_names input")
|
||||
np_data = [1, 2, 3]
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False)
|
||||
assert "column_names[0] should not be empty" in str(err.value)
|
||||
|
||||
|
||||
def test_numpy_slices_invalid_empty_column_names():
|
||||
logger.info("Test incorrect column_names input")
|
||||
np_data = [1, 2, 3]
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
de.NumpySlicesDataset(np_data, column_names=[], shuffle=False)
|
||||
assert "column_names should not be empty" in str(err.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_numpy_slices_list_1()
|
||||
test_numpy_slices_list_2()
|
||||
|
@ -197,3 +224,6 @@ if __name__ == "__main__":
|
|||
test_numpy_slices_num_samplers()
|
||||
test_numpy_slices_distributed_sampler()
|
||||
test_numpy_slices_sequential_sampler()
|
||||
test_numpy_slices_invalid_column_names_type()
|
||||
test_numpy_slices_invalid_column_names_string()
|
||||
test_numpy_slices_invalid_empty_column_names()
|
||||
|
|
Loading…
Reference in New Issue