Add UnicodeCharTokenizer for nlp

This commit is contained in:
qianlong 2020-05-05 16:51:05 +08:00
parent 93e7c97a96
commit 451c20a6f5
14 changed files with 280 additions and 12 deletions

View File

@ -61,7 +61,7 @@ set(submodules
$<TARGET_OBJECTS:kernels>
$<TARGET_OBJECTS:kernels-image>
$<TARGET_OBJECTS:kernels-data>
$<TARGET_OBJECTS:kernels-nlp>
$<TARGET_OBJECTS:kernels-text>
$<TARGET_OBJECTS:APItoPython>
$<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler>

View File

@ -39,6 +39,7 @@
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/text/jieba_tokenizer_op.h"
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
@ -407,12 +408,16 @@ void bindTensorOps4(py::module *m) {
py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
}
void bindTensorOps6(py::module *m) {
void bindTensorOps5(py::module *m) {
(void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "")
.def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"),
py::arg("mode") = JiebaMode::kMix)
.def("add_word",
[](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); });
(void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>(
*m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
.def(py::init<>());
}
void bindSamplerOps(py::module *m) {
@ -534,7 +539,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindTensorOps2(&m);
bindTensorOps3(&m);
bindTensorOps4(&m);
bindTensorOps6(&m);
bindTensorOps5(&m);
bindSamplerOps(&m);
bindDatasetOps(&m);
bindInfoObjects(&m);

View File

@ -1,5 +1,6 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(kernels-nlp OBJECT
add_library(kernels-text OBJECT
jieba_tokenizer_op.cc
unicode_char_tokenizer_op.cc
)

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_NLP_JIEBA_OP_H_
#define DATASET_ENGINE_NLP_JIEBA_OP_H_
#ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_
#define DATASET_ENGINE_TEXT_JIEBA_OP_H_
#include <string>
#include <memory>
@ -61,4 +61,4 @@ class JiebaTokenizerOp : public TensorOp {
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_NLP_JIEBA_OP_H_
#endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
#include <memory>
#include <string>
#include <string_view>
#include <vector>
#include "cppjieba/Unicode.hpp"
using cppjieba::DecodeRunesInString;
using cppjieba::RuneStrArray;
namespace mindspore {
namespace dataset {
Status UnicodeCharTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor");
}
std::string_view str;
RETURN_IF_NOT_OK(input->GetItemAt(&str, {}));
RuneStrArray runes;
if (!DecodeRunesInString(str.data(), str.size(), runes)) {
RETURN_STATUS_UNEXPECTED("Decode utf8 string failed.");
}
std::vector<std::string> splits(runes.size());
for (size_t i = 0; i < runes.size(); i++) {
splits[i] = str.substr(runes[i].offset, runes[i].len);
}
if (splits.empty()) {
splits.emplace_back("");
}
*output = std::make_shared<Tensor>(splits, TensorShape({(dsize_t)splits.size()}));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_
#define DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_
#include <memory>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
class UnicodeCharTokenizerOp : public TensorOp {
public:
UnicodeCharTokenizerOp() {}
~UnicodeCharTokenizerOp() override = default;
void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_

View File

@ -284,10 +284,10 @@ class Dataset:
Examples:
>>> import mindspore.dataset as ds
>>> import mindspore.dataset.transforms.nlp.utils as nlp
>>> import mindspore.dataset.transforms.text.utils as text
>>> # declare a function which returns a Dataset object
>>> def flat_map_func(x):
>>> data_dir = nlp.as_text(x[0])
>>> data_dir = text.as_text(x[0])
>>> d = ds.ImageFolderDatasetV2(data_dir)
>>> return d
>>> # data is a Dataset object

View File

@ -18,3 +18,4 @@ image augmentation module which is developed with c++ opencv. Py_transforms
provide more kinds of image augmentations which is developed with python PIL.
"""
from .utils import as_text, JiebaMode
from . import c_transforms

View File

@ -123,3 +123,9 @@ class JiebaTokenizer(cde.JiebaTokenizerOp):
if not os.path.exists(model_path):
raise ValueError(
" jieba mode file {} is not exist".format(model_path))
class UnicodeCharTokenizer(cde.UnicodeCharTokenizerOp):
"""
Tokenize a scalar tensor of UTF-8 string to Unicode characters.
"""

View File

@ -33,9 +33,7 @@ def as_text(array, encoding='utf8'):
if not isinstance(array, np.ndarray):
raise ValueError('input should be a numpy array')
def decode(x):
return x.decode(encoding)
decode = np.vectorize(decode)
decode = np.vectorize(lambda x: x.decode(encoding))
return decode(array)

View File

@ -69,6 +69,7 @@ SET(DE_UT_SRCS
filter_op_test.cc
concat_op_test.cc
jieba_tokenizer_op_test.cc
tokenizer_op_test.cc
)
add_executable(de_ut_tests ${DE_UT_SRCS})

View File

@ -0,0 +1,107 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include <string_view>
#include "common/common.h"
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
class MindDataTestTokenizerOp : public UT::Common {
public:
void CheckEqual(const std::shared_ptr<Tensor> &o,
const std::vector<dsize_t> &index,
const std::string &expect) {
std::string_view str;
Status s = o->GetItemAt(&str, index);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(str, expect);
}
};
TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) {
MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp.";
std::unique_ptr<UnicodeCharTokenizerOp> op(new UnicodeCharTokenizerOp());
std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Hello World!");
std::shared_ptr<Tensor> output;
Status s = op->Compute(input, &output);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output->Size(), 12);
EXPECT_EQ(output->Rank(), 1);
MS_LOG(INFO) << "Out tensor1: " << output->ToString();
CheckEqual(output, {0}, "H");
CheckEqual(output, {1}, "e");
CheckEqual(output, {2}, "l");
CheckEqual(output, {3}, "l");
CheckEqual(output, {4}, "o");
CheckEqual(output, {5}, " ");
CheckEqual(output, {6}, "W");
CheckEqual(output, {7}, "o");
CheckEqual(output, {8}, "r");
CheckEqual(output, {9}, "l");
CheckEqual(output, {10}, "d");
CheckEqual(output, {11}, "!");
input = std::make_shared<Tensor>("中国 你好!");
s = op->Compute(input, &output);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output->Size(), 6);
EXPECT_EQ(output->Rank(), 1);
MS_LOG(INFO) << "Out tensor2: " << output->ToString();
CheckEqual(output, {0}, "");
CheckEqual(output, {1}, "");
CheckEqual(output, {2}, " ");
CheckEqual(output, {3}, "");
CheckEqual(output, {4}, "");
CheckEqual(output, {5}, "!");
input = std::make_shared<Tensor>("");
s = op->Compute(input, &output);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output->Size(), 1);
EXPECT_EQ(output->Rank(), 1);
MS_LOG(INFO) << "Out tensor3: " << output->ToString();
CheckEqual(output, {0}, "");
input = std::make_shared<Tensor>("H");
s = op->Compute(input, &output);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output->Size(), 1);
EXPECT_EQ(output->Rank(), 1);
MS_LOG(INFO) << "Out tensor4: " << output->ToString();
CheckEqual(output, {0}, "H");
input = std::make_shared<Tensor>(" ");
s = op->Compute(input, &output);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output->Size(), 2);
EXPECT_EQ(output->Rank(), 1);
MS_LOG(INFO) << "Out tensor5: " << output->ToString();
CheckEqual(output, {0}, " ");
CheckEqual(output, {1}, " ");
input = std::make_shared<Tensor>("");
s = op->Compute(input, &output);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output->Size(), 1);
EXPECT_EQ(output->Rank(), 1);
MS_LOG(INFO) << "Out tensor6: " << output->ToString();
CheckEqual(output, {0}, "");
}

View File

@ -0,0 +1,4 @@
Welcome to Beijing!
北京欢迎您!
我喜欢English!

View File

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing UnicodeCharTokenizer op in DE
"""
import mindspore.dataset as ds
from mindspore import log as logger
import mindspore.dataset.transforms.text.c_transforms as nlp
import mindspore.dataset.transforms.text.utils as nlp_util
DATA_FILE = "../data/dataset/testTokenizerData/1.txt"
def split_by_unicode_char(input_strs):
"""
Split utf-8 strings to unicode characters
"""
out = []
for s in input_strs:
out.append([c for c in s])
return out
def test_unicode_char_tokenizer():
"""
Test UnicodeCharTokenizer
"""
input_strs = ("Welcome to Beijing!", "北京欢迎您!", "我喜欢English!", " ")
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
tokenizer = nlp.UnicodeCharTokenizer()
dataset = dataset.map(operations=tokenizer)
tokens = []
for i in dataset.create_dict_iterator():
text = nlp_util.as_text(i['text']).tolist()
tokens.append(text)
logger.info("The out tokens is : {}".format(tokens))
assert split_by_unicode_char(input_strs) == tokens
if __name__ == '__main__':
test_unicode_char_tokenizer()