forked from mindspore-Ecosystem/mindspore
Added missing input check for nlp operator
This commit is contained in:
parent
212eccbbae
commit
cf802c839c
|
@ -131,6 +131,7 @@ Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text
|
|||
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input,
|
||||
std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
|
||||
std::vector<std::string> strs(input->Size());
|
||||
int i = 0;
|
||||
for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) {
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace dataset {
|
|||
|
||||
Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
|
||||
icu::ErrorCode error;
|
||||
const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed.");
|
||||
|
|
|
@ -29,6 +29,8 @@ namespace dataset {
|
|||
const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc;
|
||||
Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
|
||||
|
||||
icu::ErrorCode error;
|
||||
const icu::Normalizer2 *normalize = nullptr;
|
||||
switch (normalize_form_) {
|
||||
|
|
|
@ -42,6 +42,7 @@ Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std:
|
|||
|
||||
Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
|
||||
UErrorCode icu_error = U_ZERO_ERROR;
|
||||
icu::RegexMatcher matcher(pattern_, 0, icu_error);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern");
|
||||
|
|
|
@ -1940,7 +1940,7 @@ class BlockReleasePair:
|
|||
|
||||
Args:
|
||||
init_release_rows (int): Number of lines to allow through the pipeline.
|
||||
callback (function): The callback function that will be called when release is called.
|
||||
callback (function): The callback function that will be called when release is called (default=None).
|
||||
"""
|
||||
|
||||
def __init__(self, init_release_rows, callback=None):
|
||||
|
@ -2015,7 +2015,7 @@ class SyncWaitDataset(Dataset):
|
|||
input_dataset (Dataset): Input dataset to apply flow control.
|
||||
num_batch (int): Number of batches without blocking at the start of each epoch.
|
||||
condition_name (str): Condition name that is used to toggle sending next row.
|
||||
callback (function): Callback function that will be invoked when sync_update is called.
|
||||
callback (function): Callback function that will be invoked when sync_update is called (default=None).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If condition name already exists.
|
||||
|
|
|
@ -270,6 +270,21 @@ def test_simple_sync_wait_empty_condition_name():
|
|||
dataset.sync_update(condition_name="", data=data)
|
||||
|
||||
|
||||
def test_sync_exception_06():
|
||||
"""
|
||||
Test sync: with string batch size
|
||||
"""
|
||||
logger.info("test_sync_exception_03")
|
||||
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
aug = Augment(0)
|
||||
# try to create dataset with batch_size < 0
|
||||
with pytest.raises(TypeError) as e:
|
||||
dataset.sync_wait(condition_name="every batch", num_batch="123", callback=aug.update)
|
||||
assert "is not of type (<class 'int'>" in str(e.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simple_sync_wait()
|
||||
test_simple_shuffle_sync()
|
||||
|
@ -279,6 +294,7 @@ if __name__ == "__main__":
|
|||
test_sync_exception_03()
|
||||
test_sync_exception_04()
|
||||
test_sync_exception_05()
|
||||
test_sync_exception_06()
|
||||
test_sync_epoch()
|
||||
test_multiple_iterators()
|
||||
test_simple_sync_wait_empty_condition_name()
|
||||
|
|
Loading…
Reference in New Issue