From 29aa589972070f3cd0b43e60ff01e25d6ad1190d Mon Sep 17 00:00:00 2001 From: peilin-wang Date: Thu, 2 Jul 2020 16:49:09 -0400 Subject: [PATCH] added check for invalid type for boolean args --- mindspore/dataset/engine/validators.py | 7 +++++++ .../python/dataset/test_bucket_batch_by_length.py | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 2a0bef3b422..744a9b94be7 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -606,8 +606,15 @@ def check_bucket_batch_by_length(method): nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] check_param_type(nreq_param_list, param_dict, list) + nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder'] + check_param_type(nbool_param_list, param_dict, bool) + # check column_names: must be list of string. column_names = param_dict.get("column_names") + + if not column_names: + raise ValueError("column_names cannot be empty") + all_string = all(isinstance(item, str) for item in column_names) if not all_string: raise TypeError("column_names should be a list of str.") diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index 4436f98e534..febcc6483f7 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -53,6 +53,9 @@ def test_bucket_batch_invalid_input(): negative_bucket_batch_sizes = [1, 2, 3, -4] zero_bucket_batch_sizes = [0, 1, 2, 3] + invalid_type_pad_to_bucket_boundary = "" + invalid_type_drop_remainder = "" + with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) assert "column_names should be a list of str" in str(info.value) @@ -93,6 +96,16 @@ def test_bucket_batch_invalid_input(): _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value) + with pytest.raises(TypeError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, + None, None, invalid_type_pad_to_bucket_boundary) + assert "Wrong input type for pad_to_bucket_boundary, should be " in str(info.value) + + with pytest.raises(TypeError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, + None, None, False, invalid_type_drop_remainder) + assert "Wrong input type for drop_remainder, should be " in str(info.value) + def test_bucket_batch_multi_bucket_no_padding(): dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])