!9248 Add input check to nlp operators

From: @ezphlow
Reviewed-by: @nsyca,@mikef
Signed-off-by: @nsyca
This commit is contained in:
mindspore-ci-bot 2020-12-01 06:45:26 +08:00 committed by Gitee
commit c3e7a8255c
6 changed files with 23 additions and 2 deletions

View File

@ -131,6 +131,7 @@ Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input,
std::shared_ptr<Tensor> *output) { std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output); IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
std::vector<std::string> strs(input->Size()); std::vector<std::string> strs(input->Size());
int i = 0; int i = 0;
for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) {

View File

@ -29,6 +29,7 @@ namespace dataset {
Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output); IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
icu::ErrorCode error; icu::ErrorCode error;
const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error);
CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed.");

View File

@ -29,6 +29,8 @@ namespace dataset {
const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc;
Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output); IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
icu::ErrorCode error; icu::ErrorCode error;
const icu::Normalizer2 *normalize = nullptr; const icu::Normalizer2 *normalize = nullptr;
switch (normalize_form_) { switch (normalize_form_) {

View File

@ -42,6 +42,7 @@ Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std:
Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output); IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
UErrorCode icu_error = U_ZERO_ERROR; UErrorCode icu_error = U_ZERO_ERROR;
icu::RegexMatcher matcher(pattern_, 0, icu_error); icu::RegexMatcher matcher(pattern_, 0, icu_error);
CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern");

View File

@ -1940,7 +1940,7 @@ class BlockReleasePair:
Args: Args:
init_release_rows (int): Number of lines to allow through the pipeline. init_release_rows (int): Number of lines to allow through the pipeline.
callback (function): The callback function that will be called when release is called. callback (function): The callback function that will be called when release is called (default=None).
""" """
def __init__(self, init_release_rows, callback=None): def __init__(self, init_release_rows, callback=None):
@ -2015,7 +2015,7 @@ class SyncWaitDataset(Dataset):
input_dataset (Dataset): Input dataset to apply flow control. input_dataset (Dataset): Input dataset to apply flow control.
num_batch (int): Number of batches without blocking at the start of each epoch. num_batch (int): Number of batches without blocking at the start of each epoch.
condition_name (str): Condition name that is used to toggle sending next row. condition_name (str): Condition name that is used to toggle sending next row.
callback (function): Callback function that will be invoked when sync_update is called. callback (function): Callback function that will be invoked when sync_update is called (default=None).
Raises: Raises:
RuntimeError: If condition name already exists. RuntimeError: If condition name already exists.

View File

@ -270,6 +270,21 @@ def test_simple_sync_wait_empty_condition_name():
dataset.sync_update(condition_name="", data=data) dataset.sync_update(condition_name="", data=data)
def test_sync_exception_06():
"""
Test sync: with string batch size
"""
logger.info("test_sync_exception_03")
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# try to create dataset with batch_size < 0
with pytest.raises(TypeError) as e:
dataset.sync_wait(condition_name="every batch", num_batch="123", callback=aug.update)
assert "is not of type (<class 'int'>" in str(e.value)
if __name__ == "__main__": if __name__ == "__main__":
test_simple_sync_wait() test_simple_sync_wait()
test_simple_shuffle_sync() test_simple_shuffle_sync()
@ -279,6 +294,7 @@ if __name__ == "__main__":
test_sync_exception_03() test_sync_exception_03()
test_sync_exception_04() test_sync_exception_04()
test_sync_exception_05() test_sync_exception_05()
test_sync_exception_06()
test_sync_epoch() test_sync_epoch()
test_multiple_iterators() test_multiple_iterators()
test_simple_sync_wait_empty_condition_name() test_simple_sync_wait_empty_condition_name()