added check for invalid type for boolean args

This commit is contained in:
peilin-wang 2020-07-02 16:49:09 -04:00
parent 709dfd7e81
commit 29aa589972
2 changed files with 20 additions and 0 deletions

View File

@ -606,8 +606,15 @@ def check_bucket_batch_by_length(method):
nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
check_param_type(nreq_param_list, param_dict, list)
nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
check_param_type(nbool_param_list, param_dict, bool)
# check column_names: must be list of string.
column_names = param_dict.get("column_names")
if not column_names:
raise ValueError("column_names cannot be empty")
all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")

View File

@ -53,6 +53,9 @@ def test_bucket_batch_invalid_input():
negative_bucket_batch_sizes = [1, 2, 3, -4]
zero_bucket_batch_sizes = [0, 1, 2, 3]
invalid_type_pad_to_bucket_boundary = ""
invalid_type_drop_remainder = ""
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
assert "column_names should be a list of str" in str(info.value)
@ -93,6 +96,16 @@ def test_bucket_batch_invalid_input():
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)
assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value)
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, invalid_type_pad_to_bucket_boundary)
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value)
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, False, invalid_type_drop_remainder)
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value)
def test_bucket_batch_multi_bucket_no_padding():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])