forked from mindspore-Ecosystem/mindspore
!9248 Add input check to nlp operators
From: @ezphlow Reviewed-by: @nsyca,@mikef Signed-off-by: @nsyca
This commit is contained in:
commit
c3e7a8255c
|
@ -131,6 +131,7 @@ Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text
|
||||||
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input,
|
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input,
|
||||||
std::shared_ptr<Tensor> *output) {
|
std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
|
||||||
std::vector<std::string> strs(input->Size());
|
std::vector<std::string> strs(input->Size());
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) {
|
for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) {
|
||||||
|
|
|
@ -29,6 +29,7 @@ namespace dataset {
|
||||||
|
|
||||||
Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
|
||||||
icu::ErrorCode error;
|
icu::ErrorCode error;
|
||||||
const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error);
|
const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed.");
|
CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed.");
|
||||||
|
|
|
@ -29,6 +29,8 @@ namespace dataset {
|
||||||
const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc;
|
const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc;
|
||||||
Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
|
||||||
|
|
||||||
icu::ErrorCode error;
|
icu::ErrorCode error;
|
||||||
const icu::Normalizer2 *normalize = nullptr;
|
const icu::Normalizer2 *normalize = nullptr;
|
||||||
switch (normalize_form_) {
|
switch (normalize_form_) {
|
||||||
|
|
|
@ -42,6 +42,7 @@ Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std:
|
||||||
|
|
||||||
Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string");
|
||||||
UErrorCode icu_error = U_ZERO_ERROR;
|
UErrorCode icu_error = U_ZERO_ERROR;
|
||||||
icu::RegexMatcher matcher(pattern_, 0, icu_error);
|
icu::RegexMatcher matcher(pattern_, 0, icu_error);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern");
|
CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern");
|
||||||
|
|
|
@ -1940,7 +1940,7 @@ class BlockReleasePair:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
init_release_rows (int): Number of lines to allow through the pipeline.
|
init_release_rows (int): Number of lines to allow through the pipeline.
|
||||||
callback (function): The callback function that will be called when release is called.
|
callback (function): The callback function that will be called when release is called (default=None).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, init_release_rows, callback=None):
|
def __init__(self, init_release_rows, callback=None):
|
||||||
|
@ -2015,7 +2015,7 @@ class SyncWaitDataset(Dataset):
|
||||||
input_dataset (Dataset): Input dataset to apply flow control.
|
input_dataset (Dataset): Input dataset to apply flow control.
|
||||||
num_batch (int): Number of batches without blocking at the start of each epoch.
|
num_batch (int): Number of batches without blocking at the start of each epoch.
|
||||||
condition_name (str): Condition name that is used to toggle sending next row.
|
condition_name (str): Condition name that is used to toggle sending next row.
|
||||||
callback (function): Callback function that will be invoked when sync_update is called.
|
callback (function): Callback function that will be invoked when sync_update is called (default=None).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If condition name already exists.
|
RuntimeError: If condition name already exists.
|
||||||
|
|
|
@ -270,6 +270,21 @@ def test_simple_sync_wait_empty_condition_name():
|
||||||
dataset.sync_update(condition_name="", data=data)
|
dataset.sync_update(condition_name="", data=data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_exception_06():
|
||||||
|
"""
|
||||||
|
Test sync: with string batch size
|
||||||
|
"""
|
||||||
|
logger.info("test_sync_exception_03")
|
||||||
|
|
||||||
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||||
|
|
||||||
|
aug = Augment(0)
|
||||||
|
# try to create dataset with batch_size < 0
|
||||||
|
with pytest.raises(TypeError) as e:
|
||||||
|
dataset.sync_wait(condition_name="every batch", num_batch="123", callback=aug.update)
|
||||||
|
assert "is not of type (<class 'int'>" in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_simple_sync_wait()
|
test_simple_sync_wait()
|
||||||
test_simple_shuffle_sync()
|
test_simple_shuffle_sync()
|
||||||
|
@ -279,6 +294,7 @@ if __name__ == "__main__":
|
||||||
test_sync_exception_03()
|
test_sync_exception_03()
|
||||||
test_sync_exception_04()
|
test_sync_exception_04()
|
||||||
test_sync_exception_05()
|
test_sync_exception_05()
|
||||||
|
test_sync_exception_06()
|
||||||
test_sync_epoch()
|
test_sync_epoch()
|
||||||
test_multiple_iterators()
|
test_multiple_iterators()
|
||||||
test_simple_sync_wait_empty_condition_name()
|
test_simple_sync_wait_empty_condition_name()
|
||||||
|
|
Loading…
Reference in New Issue