forked from mindspore-Ecosystem/mindspore
BasicTokenizer not case fold on preserverd words
This commit is contained in:
parent
d6d93f16b1
commit
cae77c0c22
|
@ -15,11 +15,16 @@
|
|||
*/
|
||||
#include "dataset/text/kernels/basic_tokenizer_op.h"
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "unicode/errorcode.h"
|
||||
#include "unicode/normalizer2.h"
|
||||
#include "unicode/utypes.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
const bool BasicTokenizerOp::kDefLowerCase = false;
|
||||
|
@ -40,8 +45,8 @@ const char BasicTokenizerOp::kCommonPattern[] =
|
|||
"|[\\x{2B820}-\\x{2CEAF}]"
|
||||
"|[\\x{F900}-\\x{FAFF}]"
|
||||
"|[\\x{2F800}-\\x{2FA1F}]";
|
||||
const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|";
|
||||
|
||||
const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|\\[unused\\d+\\]|";
|
||||
const std::unordered_set<std::string> BasicTokenizerOp::kUnusedWords{"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"};
|
||||
BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, NormalizeForm normalization_form,
|
||||
bool preserve_unused_token)
|
||||
: lower_case_(lower_case),
|
||||
|
@ -67,6 +72,69 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal
|
|||
regex_tokenizer_ = std::make_unique<RegexTokenizerOp>(delim_pattern, keep_delim_pattern);
|
||||
}
|
||||
|
||||
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text,
|
||||
const std::unordered_set<std::string> &unused_words,
|
||||
std::string *outupt) {
|
||||
icu::ErrorCode error;
|
||||
const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed.");
|
||||
outupt->clear();
|
||||
|
||||
// 1. get start and end offsets of not case fold strs
|
||||
std::queue<std::pair<int, int>> offsets; // offsets of not used words
|
||||
int start = -1;
|
||||
int len = 0;
|
||||
for (int i = 0; i < text.length(); i++) {
|
||||
if (text[i] == '[') {
|
||||
start = i;
|
||||
++len;
|
||||
} else if (text[i] == ']' && start >= 0) {
|
||||
++len;
|
||||
std::string word(text.substr(start, len));
|
||||
if (unused_words.find(word) != unused_words.end()) {
|
||||
offsets.push(std::make_pair(start, start + len - 1));
|
||||
}
|
||||
start = -1;
|
||||
len = 0;
|
||||
} else if (start >= 0) {
|
||||
++len;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Do not apply case fold on `unused_words`
|
||||
start = 0;
|
||||
for (int i = 0; i < text.length();) {
|
||||
std::string_view process_text;
|
||||
std::string preserve_token;
|
||||
if (offsets.empty()) {
|
||||
i = text.length();
|
||||
process_text = text.substr(start, i - start);
|
||||
} else {
|
||||
preserve_token = text.substr(offsets.front().first, offsets.front().second - offsets.front().first + 1);
|
||||
process_text = text.substr(start, offsets.front().first - start);
|
||||
i = offsets.front().second + 1;
|
||||
offsets.pop();
|
||||
}
|
||||
std::string temp;
|
||||
icu::StringByteSink<std::string> sink(&temp);
|
||||
nfkc_case_fold->normalizeUTF8(0, icu::StringPiece(process_text.data(), process_text.size()), sink, nullptr, error);
|
||||
*outupt += temp + preserve_token;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input,
|
||||
std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
std::vector<std::string> strs(input->Size());
|
||||
int i = 0;
|
||||
for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) {
|
||||
RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++]));
|
||||
}
|
||||
*output = std::make_shared<Tensor>(std::move(strs), input->shape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BasicTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
|
||||
|
@ -75,8 +143,13 @@ Status BasicTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
|
|||
std::shared_ptr<Tensor> cur_input;
|
||||
std::shared_ptr<Tensor> processed_tensor;
|
||||
if (lower_case_) {
|
||||
// to lower case
|
||||
RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor));
|
||||
if (!preserve_unused_token_) {
|
||||
// to lower case
|
||||
RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor));
|
||||
} else {
|
||||
// to lower case except words in kUnusedWords
|
||||
RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(input, &processed_tensor));
|
||||
}
|
||||
cur_input = processed_tensor;
|
||||
// strip accent characters
|
||||
RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor));
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -45,9 +46,15 @@ class BasicTokenizerOp : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
protected:
|
||||
Status CaseFoldWithoutUnusedWords(const std::string_view &text, const std::unordered_set<std::string> &unused_words,
|
||||
std::string *outupt);
|
||||
Status CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
private:
|
||||
static const char kCommonPattern[];
|
||||
static const char kUnusedPattern[];
|
||||
static const std::unordered_set<std::string> kUnusedWords;
|
||||
bool lower_case_;
|
||||
bool keep_whitespace_;
|
||||
NormalizeForm normalization_form_;
|
||||
|
|
|
@ -10,5 +10,6 @@ unused [SEP]
|
|||
unused [UNK]
|
||||
unused [PAD]
|
||||
unused [MASK]
|
||||
12+/-28=40/-16
|
||||
Hello World!
|
||||
[unused1]
|
||||
[unused10]
|
||||
12+/-28=40/-16
|
|
@ -27,7 +27,7 @@ vocab_bert = [
|
|||
"繁", "體", "字", "嘿", "哈", "大", "笑", "嘻",
|
||||
"i", "am", "mak", "make", "small", "mistake", "##s", "during", "work", "##ing", "hour",
|
||||
"😀", "😃", "😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I",
|
||||
"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"
|
||||
"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]", "[unused1]", "[unused10]"
|
||||
]
|
||||
pad = '<pad>'
|
||||
test_paras = [
|
||||
|
@ -69,22 +69,40 @@ test_paras = [
|
|||
# test preserved tokens
|
||||
dict(
|
||||
first=8,
|
||||
last=12,
|
||||
last=14,
|
||||
expect_str=[
|
||||
['[UNK]', '[CLS]'],
|
||||
['[UNK]', '[SEP]'],
|
||||
['[UNK]', '[UNK]'],
|
||||
['[UNK]', '[PAD]'],
|
||||
['[UNK]', '[MASK]'],
|
||||
['[unused1]'],
|
||||
['[unused10]']
|
||||
],
|
||||
lower_case=False,
|
||||
vocab_list=vocab_bert,
|
||||
preserve_unused_token=True,
|
||||
),
|
||||
dict(
|
||||
first=8,
|
||||
last=14,
|
||||
expect_str=[
|
||||
['[UNK]', '[CLS]'],
|
||||
['[UNK]', '[SEP]'],
|
||||
['[UNK]', '[UNK]'],
|
||||
['[UNK]', '[PAD]'],
|
||||
['[UNK]', '[MASK]'],
|
||||
['[unused1]'],
|
||||
['[unused10]']
|
||||
],
|
||||
lower_case=True,
|
||||
vocab_list=vocab_bert,
|
||||
preserve_unused_token=True,
|
||||
),
|
||||
# test special symbol
|
||||
dict(
|
||||
first=13,
|
||||
last=13,
|
||||
first=15,
|
||||
last=15,
|
||||
expect_str=[['12', '+', '/', '-', '28', '=', '40', '/', '-', '16']],
|
||||
preserve_unused_token=True,
|
||||
vocab_list=vocab_bert
|
||||
|
|
Loading…
Reference in New Issue