diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc index f530edf779d..69ee1b388c2 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc @@ -131,6 +131,7 @@ Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); std::vector strs(input->Size()); int i = 0; for (auto iter = input->begin(); iter != input->end(); iter++) { diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc index b38df2f0f6a..a2458a04cd9 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc @@ -29,6 +29,7 @@ namespace dataset { Status CaseFoldOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); icu::ErrorCode error; const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc index b669ca9a8a4..3d3cbf1d5b9 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc @@ -29,6 +29,8 @@ namespace dataset { const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; Status NormalizeUTF8Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); + icu::ErrorCode error; const icu::Normalizer2 *normalize = nullptr; switch (normalize_form_) { diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc index b36afba8fc5..485413cd528 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc @@ -42,6 +42,7 @@ Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std: Status RegexReplaceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); UErrorCode icu_error = U_ZERO_ERROR; icu::RegexMatcher matcher(pattern_, 0, icu_error); CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 901de8f4619..c73302d3691 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1940,7 +1940,7 @@ class BlockReleasePair: Args: init_release_rows (int): Number of lines to allow through the pipeline. - callback (function): The callback function that will be called when release is called. + callback (function): The callback function that will be called when release is called (default=None). """ def __init__(self, init_release_rows, callback=None): @@ -2015,7 +2015,7 @@ class SyncWaitDataset(Dataset): input_dataset (Dataset): Input dataset to apply flow control. num_batch (int): Number of batches without blocking at the start of each epoch. condition_name (str): Condition name that is used to toggle sending next row. - callback (function): Callback function that will be invoked when sync_update is called. + callback (function): Callback function that will be invoked when sync_update is called (default=None). Raises: RuntimeError: If condition name already exists. diff --git a/tests/ut/python/dataset/test_sync_wait.py b/tests/ut/python/dataset/test_sync_wait.py index a601ce896a5..7e7fa1b5d62 100644 --- a/tests/ut/python/dataset/test_sync_wait.py +++ b/tests/ut/python/dataset/test_sync_wait.py @@ -270,6 +270,21 @@ def test_simple_sync_wait_empty_condition_name(): dataset.sync_update(condition_name="", data=data) +def test_sync_exception_06(): + """ + Test sync: with string batch size + """ + logger.info("test_sync_exception_03") + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # try to create dataset with batch_size < 0 + with pytest.raises(TypeError) as e: + dataset.sync_wait(condition_name="every batch", num_batch="123", callback=aug.update) + assert "is not of type (" in str(e.value) + + if __name__ == "__main__": test_simple_sync_wait() test_simple_shuffle_sync() @@ -279,6 +294,7 @@ if __name__ == "__main__": test_sync_exception_03() test_sync_exception_04() test_sync_exception_05() + test_sync_exception_06() test_sync_epoch() test_multiple_iterators() test_simple_sync_wait_empty_condition_name()