forked from mindspore-Ecosystem/mindspore
added check for invalid type for boolean args
This commit is contained in:
parent
709dfd7e81
commit
29aa589972
|
@ -606,8 +606,15 @@ def check_bucket_batch_by_length(method):
|
|||
nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
|
||||
check_param_type(nreq_param_list, param_dict, list)
|
||||
|
||||
nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
|
||||
check_param_type(nbool_param_list, param_dict, bool)
|
||||
|
||||
# check column_names: must be list of string.
|
||||
column_names = param_dict.get("column_names")
|
||||
|
||||
if not column_names:
|
||||
raise ValueError("column_names cannot be empty")
|
||||
|
||||
all_string = all(isinstance(item, str) for item in column_names)
|
||||
if not all_string:
|
||||
raise TypeError("column_names should be a list of str.")
|
||||
|
|
|
@ -53,6 +53,9 @@ def test_bucket_batch_invalid_input():
|
|||
negative_bucket_batch_sizes = [1, 2, 3, -4]
|
||||
zero_bucket_batch_sizes = [0, 1, 2, 3]
|
||||
|
||||
invalid_type_pad_to_bucket_boundary = ""
|
||||
invalid_type_drop_remainder = ""
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
|
||||
assert "column_names should be a list of str" in str(info.value)
|
||||
|
@ -93,6 +96,16 @@ def test_bucket_batch_invalid_input():
|
|||
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)
|
||||
assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
||||
None, None, invalid_type_pad_to_bucket_boundary)
|
||||
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
||||
None, None, False, invalid_type_drop_remainder)
|
||||
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value)
|
||||
|
||||
|
||||
def test_bucket_batch_multi_bucket_no_padding():
|
||||
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
||||
|
|
Loading…
Reference in New Issue