From cae77c0c22fa6979b0032a505edb4d55e1383d22 Mon Sep 17 00:00:00 2001 From: qianlong Date: Wed, 24 Jun 2020 16:51:38 +0800 Subject: [PATCH] BasicTokenizer not case fold on preserverd words --- .../text/kernels/basic_tokenizer_op.cc | 81 ++++++++++++++++++- .../dataset/text/kernels/basic_tokenizer_op.h | 7 ++ .../testTokenizerData/bert_tokenizer.txt | 5 +- .../ut/python/dataset/test_bert_tokenizer.py | 26 +++++- 4 files changed, 109 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc index 1128990b44e..3512a4b2d71 100644 --- a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc +++ b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc @@ -15,11 +15,16 @@ */ #include "dataset/text/kernels/basic_tokenizer_op.h" #include +#include #include #include #include #include +#include "unicode/errorcode.h" +#include "unicode/normalizer2.h" +#include "unicode/utypes.h" + namespace mindspore { namespace dataset { const bool BasicTokenizerOp::kDefLowerCase = false; @@ -40,8 +45,8 @@ const char BasicTokenizerOp::kCommonPattern[] = "|[\\x{2B820}-\\x{2CEAF}]" "|[\\x{F900}-\\x{FAFF}]" "|[\\x{2F800}-\\x{2FA1F}]"; -const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|"; - +const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|\\[unused\\d+\\]|"; +const std::unordered_set BasicTokenizerOp::kUnusedWords{"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"}; BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, NormalizeForm normalization_form, bool preserve_unused_token) : lower_case_(lower_case), @@ -67,6 +72,69 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal regex_tokenizer_ = std::make_unique(delim_pattern, keep_delim_pattern); } +Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text, + const std::unordered_set &unused_words, + std::string *outupt) { + icu::ErrorCode error; + const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); + outupt->clear(); + + // 1. get start and end offsets of not case fold strs + std::queue> offsets; // offsets of not used words + int start = -1; + int len = 0; + for (int i = 0; i < text.length(); i++) { + if (text[i] == '[') { + start = i; + ++len; + } else if (text[i] == ']' && start >= 0) { + ++len; + std::string word(text.substr(start, len)); + if (unused_words.find(word) != unused_words.end()) { + offsets.push(std::make_pair(start, start + len - 1)); + } + start = -1; + len = 0; + } else if (start >= 0) { + ++len; + } + } + + // 2. Do not apply case fold on `unused_words` + start = 0; + for (int i = 0; i < text.length();) { + std::string_view process_text; + std::string preserve_token; + if (offsets.empty()) { + i = text.length(); + process_text = text.substr(start, i - start); + } else { + preserve_token = text.substr(offsets.front().first, offsets.front().second - offsets.front().first + 1); + process_text = text.substr(start, offsets.front().first - start); + i = offsets.front().second + 1; + offsets.pop(); + } + std::string temp; + icu::StringByteSink sink(&temp); + nfkc_case_fold->normalizeUTF8(0, icu::StringPiece(process_text.data(), process_text.size()), sink, nullptr, error); + *outupt += temp + preserve_token; + } + return Status::OK(); +} + +Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr &input, + std::shared_ptr *output) { + IO_CHECK(input, output); + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++])); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} + Status BasicTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { @@ -75,8 +143,13 @@ Status BasicTokenizerOp::Compute(const std::shared_ptr &input, std::shar std::shared_ptr cur_input; std::shared_ptr processed_tensor; if (lower_case_) { - // to lower case - RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor)); + if (!preserve_unused_token_) { + // to lower case + RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor)); + } else { + // to lower case except words in kUnusedWords + RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(input, &processed_tensor)); + } cur_input = processed_tensor; // strip accent characters RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor)); diff --git a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h index a37e841573e..01827a0ba4c 100644 --- a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h +++ b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h @@ -17,6 +17,7 @@ #define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ #include #include +#include #include "dataset/core/tensor.h" #include "dataset/kernels/tensor_op.h" @@ -45,9 +46,15 @@ class BasicTokenizerOp : public TensorOp { Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + protected: + Status CaseFoldWithoutUnusedWords(const std::string_view &text, const std::unordered_set &unused_words, + std::string *outupt); + Status CaseFoldWithoutUnusedWords(const std::shared_ptr &input, std::shared_ptr *output); + private: static const char kCommonPattern[]; static const char kUnusedPattern[]; + static const std::unordered_set kUnusedWords; bool lower_case_; bool keep_whitespace_; NormalizeForm normalization_form_; diff --git a/tests/ut/data/dataset/testTokenizerData/bert_tokenizer.txt b/tests/ut/data/dataset/testTokenizerData/bert_tokenizer.txt index 657b7599765..27fc4e5db19 100644 --- a/tests/ut/data/dataset/testTokenizerData/bert_tokenizer.txt +++ b/tests/ut/data/dataset/testTokenizerData/bert_tokenizer.txt @@ -10,5 +10,6 @@ unused [SEP] unused [UNK] unused [PAD] unused [MASK] -12+/-28=40/-16 -Hello World! \ No newline at end of file +[unused1] +[unused10] +12+/-28=40/-16 \ No newline at end of file diff --git a/tests/ut/python/dataset/test_bert_tokenizer.py b/tests/ut/python/dataset/test_bert_tokenizer.py index ad7a663e933..ba487343a03 100644 --- a/tests/ut/python/dataset/test_bert_tokenizer.py +++ b/tests/ut/python/dataset/test_bert_tokenizer.py @@ -27,7 +27,7 @@ vocab_bert = [ "繁", "體", "字", "嘿", "哈", "大", "笑", "嘻", "i", "am", "mak", "make", "small", "mistake", "##s", "during", "work", "##ing", "hour", "😀", "😃", "😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I", - "[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]" + "[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]", "[unused1]", "[unused10]" ] pad = '' test_paras = [ @@ -69,22 +69,40 @@ test_paras = [ # test preserved tokens dict( first=8, - last=12, + last=14, expect_str=[ ['[UNK]', '[CLS]'], ['[UNK]', '[SEP]'], ['[UNK]', '[UNK]'], ['[UNK]', '[PAD]'], ['[UNK]', '[MASK]'], + ['[unused1]'], + ['[unused10]'] ], lower_case=False, vocab_list=vocab_bert, preserve_unused_token=True, ), + dict( + first=8, + last=14, + expect_str=[ + ['[UNK]', '[CLS]'], + ['[UNK]', '[SEP]'], + ['[UNK]', '[UNK]'], + ['[UNK]', '[PAD]'], + ['[UNK]', '[MASK]'], + ['[unused1]'], + ['[unused10]'] + ], + lower_case=True, + vocab_list=vocab_bert, + preserve_unused_token=True, + ), # test special symbol dict( - first=13, - last=13, + first=15, + last=15, expect_str=[['12', '+', '/', '-', '28', '=', '40', '/', '-', '16']], preserve_unused_token=True, vocab_list=vocab_bert