forked from mindspore-Ecosystem/mindspore
!2580 BasicTokenizer do not case fold on preserved words
Merge pull request !2580 from qianlong21st/fix_basic_tokenizer
This commit is contained in:
commit
363489d00f
|
@ -15,11 +15,16 @@
|
||||||
*/
|
*/
|
||||||
#include "dataset/text/kernels/basic_tokenizer_op.h"
|
#include "dataset/text/kernels/basic_tokenizer_op.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "unicode/errorcode.h"
|
||||||
|
#include "unicode/normalizer2.h"
|
||||||
|
#include "unicode/utypes.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
const bool BasicTokenizerOp::kDefLowerCase = false;
|
const bool BasicTokenizerOp::kDefLowerCase = false;
|
||||||
|
@ -40,8 +45,8 @@ const char BasicTokenizerOp::kCommonPattern[] =
|
||||||
"|[\\x{2B820}-\\x{2CEAF}]"
|
"|[\\x{2B820}-\\x{2CEAF}]"
|
||||||
"|[\\x{F900}-\\x{FAFF}]"
|
"|[\\x{F900}-\\x{FAFF}]"
|
||||||
"|[\\x{2F800}-\\x{2FA1F}]";
|
"|[\\x{2F800}-\\x{2FA1F}]";
|
||||||
const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|";
|
const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|\\[unused\\d+\\]|";
|
||||||
|
const std::unordered_set<std::string> BasicTokenizerOp::kUnusedWords{"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"};
|
||||||
BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, NormalizeForm normalization_form,
|
BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, NormalizeForm normalization_form,
|
||||||
bool preserve_unused_token)
|
bool preserve_unused_token)
|
||||||
: lower_case_(lower_case),
|
: lower_case_(lower_case),
|
||||||
|
@ -67,6 +72,69 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal
|
||||||
regex_tokenizer_ = std::make_unique<RegexTokenizerOp>(delim_pattern, keep_delim_pattern);
|
regex_tokenizer_ = std::make_unique<RegexTokenizerOp>(delim_pattern, keep_delim_pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text,
|
||||||
|
const std::unordered_set<std::string> &unused_words,
|
||||||
|
std::string *outupt) {
|
||||||
|
icu::ErrorCode error;
|
||||||
|
const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed.");
|
||||||
|
outupt->clear();
|
||||||
|
|
||||||
|
// 1. get start and end offsets of not case fold strs
|
||||||
|
std::queue<std::pair<int, int>> offsets; // offsets of not used words
|
||||||
|
int start = -1;
|
||||||
|
int len = 0;
|
||||||
|
for (int i = 0; i < text.length(); i++) {
|
||||||
|
if (text[i] == '[') {
|
||||||
|
start = i;
|
||||||
|
++len;
|
||||||
|
} else if (text[i] == ']' && start >= 0) {
|
||||||
|
++len;
|
||||||
|
std::string word(text.substr(start, len));
|
||||||
|
if (unused_words.find(word) != unused_words.end()) {
|
||||||
|
offsets.push(std::make_pair(start, start + len - 1));
|
||||||
|
}
|
||||||
|
start = -1;
|
||||||
|
len = 0;
|
||||||
|
} else if (start >= 0) {
|
||||||
|
++len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Do not apply case fold on `unused_words`
|
||||||
|
start = 0;
|
||||||
|
for (int i = 0; i < text.length();) {
|
||||||
|
std::string_view process_text;
|
||||||
|
std::string preserve_token;
|
||||||
|
if (offsets.empty()) {
|
||||||
|
i = text.length();
|
||||||
|
process_text = text.substr(start, i - start);
|
||||||
|
} else {
|
||||||
|
preserve_token = text.substr(offsets.front().first, offsets.front().second - offsets.front().first + 1);
|
||||||
|
process_text = text.substr(start, offsets.front().first - start);
|
||||||
|
i = offsets.front().second + 1;
|
||||||
|
offsets.pop();
|
||||||
|
}
|
||||||
|
std::string temp;
|
||||||
|
icu::StringByteSink<std::string> sink(&temp);
|
||||||
|
nfkc_case_fold->normalizeUTF8(0, icu::StringPiece(process_text.data(), process_text.size()), sink, nullptr, error);
|
||||||
|
*outupt += temp + preserve_token;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input,
|
||||||
|
std::shared_ptr<Tensor> *output) {
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
std::vector<std::string> strs(input->Size());
|
||||||
|
int i = 0;
|
||||||
|
for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) {
|
||||||
|
RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++]));
|
||||||
|
}
|
||||||
|
*output = std::make_shared<Tensor>(std::move(strs), input->shape());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status BasicTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
Status BasicTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
|
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
|
||||||
|
@ -75,8 +143,13 @@ Status BasicTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
|
||||||
std::shared_ptr<Tensor> cur_input;
|
std::shared_ptr<Tensor> cur_input;
|
||||||
std::shared_ptr<Tensor> processed_tensor;
|
std::shared_ptr<Tensor> processed_tensor;
|
||||||
if (lower_case_) {
|
if (lower_case_) {
|
||||||
// to lower case
|
if (!preserve_unused_token_) {
|
||||||
RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor));
|
// to lower case
|
||||||
|
RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor));
|
||||||
|
} else {
|
||||||
|
// to lower case except words in kUnusedWords
|
||||||
|
RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(input, &processed_tensor));
|
||||||
|
}
|
||||||
cur_input = processed_tensor;
|
cur_input = processed_tensor;
|
||||||
// strip accent characters
|
// strip accent characters
|
||||||
RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor));
|
RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor));
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_
|
#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "dataset/core/tensor.h"
|
#include "dataset/core/tensor.h"
|
||||||
#include "dataset/kernels/tensor_op.h"
|
#include "dataset/kernels/tensor_op.h"
|
||||||
|
@ -45,9 +46,15 @@ class BasicTokenizerOp : public TensorOp {
|
||||||
|
|
||||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status CaseFoldWithoutUnusedWords(const std::string_view &text, const std::unordered_set<std::string> &unused_words,
|
||||||
|
std::string *outupt);
|
||||||
|
Status CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static const char kCommonPattern[];
|
static const char kCommonPattern[];
|
||||||
static const char kUnusedPattern[];
|
static const char kUnusedPattern[];
|
||||||
|
static const std::unordered_set<std::string> kUnusedWords;
|
||||||
bool lower_case_;
|
bool lower_case_;
|
||||||
bool keep_whitespace_;
|
bool keep_whitespace_;
|
||||||
NormalizeForm normalization_form_;
|
NormalizeForm normalization_form_;
|
||||||
|
|
|
@ -10,5 +10,6 @@ unused [SEP]
|
||||||
unused [UNK]
|
unused [UNK]
|
||||||
unused [PAD]
|
unused [PAD]
|
||||||
unused [MASK]
|
unused [MASK]
|
||||||
|
[unused1]
|
||||||
|
[unused10]
|
||||||
12+/-28=40/-16
|
12+/-28=40/-16
|
||||||
Hello World!
|
|
|
@ -27,7 +27,7 @@ vocab_bert = [
|
||||||
"繁", "體", "字", "嘿", "哈", "大", "笑", "嘻",
|
"繁", "體", "字", "嘿", "哈", "大", "笑", "嘻",
|
||||||
"i", "am", "mak", "make", "small", "mistake", "##s", "during", "work", "##ing", "hour",
|
"i", "am", "mak", "make", "small", "mistake", "##s", "during", "work", "##ing", "hour",
|
||||||
"😀", "😃", "😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I",
|
"😀", "😃", "😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I",
|
||||||
"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"
|
"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]", "[unused1]", "[unused10]"
|
||||||
]
|
]
|
||||||
pad = '<pad>'
|
pad = '<pad>'
|
||||||
test_paras = [
|
test_paras = [
|
||||||
|
@ -69,22 +69,40 @@ test_paras = [
|
||||||
# test preserved tokens
|
# test preserved tokens
|
||||||
dict(
|
dict(
|
||||||
first=8,
|
first=8,
|
||||||
last=12,
|
last=14,
|
||||||
expect_str=[
|
expect_str=[
|
||||||
['[UNK]', '[CLS]'],
|
['[UNK]', '[CLS]'],
|
||||||
['[UNK]', '[SEP]'],
|
['[UNK]', '[SEP]'],
|
||||||
['[UNK]', '[UNK]'],
|
['[UNK]', '[UNK]'],
|
||||||
['[UNK]', '[PAD]'],
|
['[UNK]', '[PAD]'],
|
||||||
['[UNK]', '[MASK]'],
|
['[UNK]', '[MASK]'],
|
||||||
|
['[unused1]'],
|
||||||
|
['[unused10]']
|
||||||
],
|
],
|
||||||
lower_case=False,
|
lower_case=False,
|
||||||
vocab_list=vocab_bert,
|
vocab_list=vocab_bert,
|
||||||
preserve_unused_token=True,
|
preserve_unused_token=True,
|
||||||
),
|
),
|
||||||
|
dict(
|
||||||
|
first=8,
|
||||||
|
last=14,
|
||||||
|
expect_str=[
|
||||||
|
['[UNK]', '[CLS]'],
|
||||||
|
['[UNK]', '[SEP]'],
|
||||||
|
['[UNK]', '[UNK]'],
|
||||||
|
['[UNK]', '[PAD]'],
|
||||||
|
['[UNK]', '[MASK]'],
|
||||||
|
['[unused1]'],
|
||||||
|
['[unused10]']
|
||||||
|
],
|
||||||
|
lower_case=True,
|
||||||
|
vocab_list=vocab_bert,
|
||||||
|
preserve_unused_token=True,
|
||||||
|
),
|
||||||
# test special symbol
|
# test special symbol
|
||||||
dict(
|
dict(
|
||||||
first=13,
|
first=15,
|
||||||
last=13,
|
last=15,
|
||||||
expect_str=[['12', '+', '/', '-', '28', '=', '40', '/', '-', '16']],
|
expect_str=[['12', '+', '/', '-', '28', '=', '40', '/', '-', '16']],
|
||||||
preserve_unused_token=True,
|
preserve_unused_token=True,
|
||||||
vocab_list=vocab_bert
|
vocab_list=vocab_bert
|
||||||
|
|
Loading…
Reference in New Issue