From 451c20a6f5b9f97171bda44cd95e462ca5bb1d39 Mon Sep 17 00:00:00 2001 From: qianlong Date: Tue, 5 May 2020 16:51:05 +0800 Subject: [PATCH] Add UnicodeCharTokenizer for nlp --- mindspore/ccsrc/dataset/CMakeLists.txt | 2 +- .../ccsrc/dataset/api/python_bindings.cc | 9 +- .../ccsrc/dataset/kernels/text/CMakeLists.txt | 3 +- .../dataset/kernels/text/jieba_tokenizer_op.h | 6 +- .../kernels/text/unicode_char_tokenizer_op.cc | 52 +++++++++ .../kernels/text/unicode_char_tokenizer_op.h | 40 +++++++ mindspore/dataset/engine/datasets.py | 4 +- mindspore/dataset/transforms/text/__init__.py | 1 + .../dataset/transforms/text/c_transforms.py | 6 + mindspore/dataset/transforms/text/utils.py | 4 +- tests/ut/cpp/dataset/CMakeLists.txt | 1 + tests/ut/cpp/dataset/tokenizer_op_test.cc | 107 ++++++++++++++++++ tests/ut/data/dataset/testTokenizerData/1.txt | 4 + tests/ut/python/dataset/test_tokenizer.py | 53 +++++++++ 14 files changed, 280 insertions(+), 12 deletions(-) create mode 100644 mindspore/ccsrc/dataset/kernels/text/unicode_char_tokenizer_op.cc create mode 100644 mindspore/ccsrc/dataset/kernels/text/unicode_char_tokenizer_op.h create mode 100644 tests/ut/cpp/dataset/tokenizer_op_test.cc create mode 100644 tests/ut/data/dataset/testTokenizerData/1.txt create mode 100644 tests/ut/python/dataset/test_tokenizer.py diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt index 8b8ade52e38..abea7a7c47f 100644 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/CMakeLists.txt @@ -61,7 +61,7 @@ set(submodules $ $ $ - $ + $ $ $ $ diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index c5ba2629010..951aaaeccf5 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -39,6 +39,7 @@ #include "dataset/kernels/image/uniform_aug_op.h" #include "dataset/kernels/data/type_cast_op.h" #include "dataset/kernels/text/jieba_tokenizer_op.h" +#include "dataset/kernels/text/unicode_char_tokenizer_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h" #include "dataset/engine/datasetops/source/io_block.h" @@ -407,12 +408,16 @@ void bindTensorOps4(py::module *m) { py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); } -void bindTensorOps6(py::module *m) { +void bindTensorOps5(py::module *m) { (void)py::class_>(*m, "JiebaTokenizerOp", "") .def(py::init(), py::arg("hmm_path"), py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix) .def("add_word", [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); + + (void)py::class_>( + *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") + .def(py::init<>()); } void bindSamplerOps(py::module *m) { @@ -534,7 +539,7 @@ PYBIND11_MODULE(_c_dataengine, m) { bindTensorOps2(&m); bindTensorOps3(&m); bindTensorOps4(&m); - bindTensorOps6(&m); + bindTensorOps5(&m); bindSamplerOps(&m); bindDatasetOps(&m); bindInfoObjects(&m); diff --git a/mindspore/ccsrc/dataset/kernels/text/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/text/CMakeLists.txt index 40d9fbca9c6..6d2e72fa521 100644 --- a/mindspore/ccsrc/dataset/kernels/text/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/text/CMakeLists.txt @@ -1,5 +1,6 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(kernels-nlp OBJECT +add_library(kernels-text OBJECT jieba_tokenizer_op.cc + unicode_char_tokenizer_op.cc ) \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/kernels/text/jieba_tokenizer_op.h b/mindspore/ccsrc/dataset/kernels/text/jieba_tokenizer_op.h index beddc321a75..41736e4fdb8 100644 --- a/mindspore/ccsrc/dataset/kernels/text/jieba_tokenizer_op.h +++ b/mindspore/ccsrc/dataset/kernels/text/jieba_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_NLP_JIEBA_OP_H_ -#define DATASET_ENGINE_NLP_JIEBA_OP_H_ +#ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_ +#define DATASET_ENGINE_TEXT_JIEBA_OP_H_ #include #include @@ -61,4 +61,4 @@ class JiebaTokenizerOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_NLP_JIEBA_OP_H_ +#endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/text/unicode_char_tokenizer_op.cc b/mindspore/ccsrc/dataset/kernels/text/unicode_char_tokenizer_op.cc new file mode 100644 index 00000000000..19bcb52203a --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/text/unicode_char_tokenizer_op.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/kernels/text/unicode_char_tokenizer_op.h" +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; + +namespace mindspore { +namespace dataset { + +Status UnicodeCharTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view str; + RETURN_IF_NOT_OK(input->GetItemAt(&str, {})); + + RuneStrArray runes; + if (!DecodeRunesInString(str.data(), str.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + std::vector splits(runes.size()); + for (size_t i = 0; i < runes.size(); i++) { + splits[i] = str.substr(runes[i].offset, runes[i].len); + } + if (splits.empty()) { + splits.emplace_back(""); + } + *output = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/text/unicode_char_tokenizer_op.h b/mindspore/ccsrc/dataset/kernels/text/unicode_char_tokenizer_op.h new file mode 100644 index 00000000000..53c42d599eb --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/text/unicode_char_tokenizer_op.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ +#define DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class UnicodeCharTokenizerOp : public TensorOp { + public: + UnicodeCharTokenizerOp() {} + + ~UnicodeCharTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; +}; + +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_ diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 20a40d5fb0d..df6d191b840 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -284,10 +284,10 @@ class Dataset: Examples: >>> import mindspore.dataset as ds - >>> import mindspore.dataset.transforms.nlp.utils as nlp + >>> import mindspore.dataset.transforms.text.utils as text >>> # declare a function which returns a Dataset object >>> def flat_map_func(x): - >>> data_dir = nlp.as_text(x[0]) + >>> data_dir = text.as_text(x[0]) >>> d = ds.ImageFolderDatasetV2(data_dir) >>> return d >>> # data is a Dataset object diff --git a/mindspore/dataset/transforms/text/__init__.py b/mindspore/dataset/transforms/text/__init__.py index 9698b0ab709..1353983deec 100644 --- a/mindspore/dataset/transforms/text/__init__.py +++ b/mindspore/dataset/transforms/text/__init__.py @@ -18,3 +18,4 @@ image augmentation module which is developed with c++ opencv. Py_transforms provide more kinds of image augmentations which is developed with python PIL. """ from .utils import as_text, JiebaMode +from . import c_transforms diff --git a/mindspore/dataset/transforms/text/c_transforms.py b/mindspore/dataset/transforms/text/c_transforms.py index f17def79bbc..534cdafa0ec 100644 --- a/mindspore/dataset/transforms/text/c_transforms.py +++ b/mindspore/dataset/transforms/text/c_transforms.py @@ -123,3 +123,9 @@ class JiebaTokenizer(cde.JiebaTokenizerOp): if not os.path.exists(model_path): raise ValueError( " jieba mode file {} is not exist".format(model_path)) + + +class UnicodeCharTokenizer(cde.UnicodeCharTokenizerOp): + """ + Tokenize a scalar tensor of UTF-8 string to Unicode characters. + """ diff --git a/mindspore/dataset/transforms/text/utils.py b/mindspore/dataset/transforms/text/utils.py index 8c817cb00b6..7e5dcf48970 100644 --- a/mindspore/dataset/transforms/text/utils.py +++ b/mindspore/dataset/transforms/text/utils.py @@ -33,9 +33,7 @@ def as_text(array, encoding='utf8'): if not isinstance(array, np.ndarray): raise ValueError('input should be a numpy array') - def decode(x): - return x.decode(encoding) - decode = np.vectorize(decode) + decode = np.vectorize(lambda x: x.decode(encoding)) return decode(array) diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 53fd0aaf064..f80cc74a8b9 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -69,6 +69,7 @@ SET(DE_UT_SRCS filter_op_test.cc concat_op_test.cc jieba_tokenizer_op_test.cc + tokenizer_op_test.cc ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/tokenizer_op_test.cc b/tests/ut/cpp/dataset/tokenizer_op_test.cc new file mode 100644 index 00000000000..5beb725ec3a --- /dev/null +++ b/tests/ut/cpp/dataset/tokenizer_op_test.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "common/common.h" +#include "dataset/kernels/text/unicode_char_tokenizer_op.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; + +class MindDataTestTokenizerOp : public UT::Common { + public: + void CheckEqual(const std::shared_ptr &o, + const std::vector &index, + const std::string &expect) { + std::string_view str; + Status s = o->GetItemAt(&str, index); + EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(str, expect); + } +}; + +TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { + MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp."; + std::unique_ptr op(new UnicodeCharTokenizerOp()); + std::shared_ptr input = std::make_shared("Hello World!"); + std::shared_ptr output; + Status s = op->Compute(input, &output); + EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(output->Size(), 12); + EXPECT_EQ(output->Rank(), 1); + MS_LOG(INFO) << "Out tensor1: " << output->ToString(); + CheckEqual(output, {0}, "H"); + CheckEqual(output, {1}, "e"); + CheckEqual(output, {2}, "l"); + CheckEqual(output, {3}, "l"); + CheckEqual(output, {4}, "o"); + CheckEqual(output, {5}, " "); + CheckEqual(output, {6}, "W"); + CheckEqual(output, {7}, "o"); + CheckEqual(output, {8}, "r"); + CheckEqual(output, {9}, "l"); + CheckEqual(output, {10}, "d"); + CheckEqual(output, {11}, "!"); + + input = std::make_shared("中国 你好!"); + s = op->Compute(input, &output); + EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(output->Size(), 6); + EXPECT_EQ(output->Rank(), 1); + MS_LOG(INFO) << "Out tensor2: " << output->ToString(); + CheckEqual(output, {0}, "中"); + CheckEqual(output, {1}, "国"); + CheckEqual(output, {2}, " "); + CheckEqual(output, {3}, "你"); + CheckEqual(output, {4}, "好"); + CheckEqual(output, {5}, "!"); + + input = std::make_shared("中"); + s = op->Compute(input, &output); + EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(output->Size(), 1); + EXPECT_EQ(output->Rank(), 1); + MS_LOG(INFO) << "Out tensor3: " << output->ToString(); + CheckEqual(output, {0}, "中"); + + input = std::make_shared("H"); + s = op->Compute(input, &output); + EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(output->Size(), 1); + EXPECT_EQ(output->Rank(), 1); + MS_LOG(INFO) << "Out tensor4: " << output->ToString(); + CheckEqual(output, {0}, "H"); + + input = std::make_shared(" "); + s = op->Compute(input, &output); + EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(output->Size(), 2); + EXPECT_EQ(output->Rank(), 1); + MS_LOG(INFO) << "Out tensor5: " << output->ToString(); + CheckEqual(output, {0}, " "); + CheckEqual(output, {1}, " "); + + input = std::make_shared(""); + s = op->Compute(input, &output); + EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(output->Size(), 1); + EXPECT_EQ(output->Rank(), 1); + MS_LOG(INFO) << "Out tensor6: " << output->ToString(); + CheckEqual(output, {0}, ""); +} diff --git a/tests/ut/data/dataset/testTokenizerData/1.txt b/tests/ut/data/dataset/testTokenizerData/1.txt new file mode 100644 index 00000000000..6df736afe6d --- /dev/null +++ b/tests/ut/data/dataset/testTokenizerData/1.txt @@ -0,0 +1,4 @@ +Welcome to Beijing! +北京欢迎您! +我喜欢English! + diff --git a/tests/ut/python/dataset/test_tokenizer.py b/tests/ut/python/dataset/test_tokenizer.py new file mode 100644 index 00000000000..3cc7e7533d6 --- /dev/null +++ b/tests/ut/python/dataset/test_tokenizer.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing UnicodeCharTokenizer op in DE +""" +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.transforms.text.c_transforms as nlp +import mindspore.dataset.transforms.text.utils as nlp_util + +DATA_FILE = "../data/dataset/testTokenizerData/1.txt" + + +def split_by_unicode_char(input_strs): + """ + Split utf-8 strings to unicode characters + """ + out = [] + for s in input_strs: + out.append([c for c in s]) + return out + + +def test_unicode_char_tokenizer(): + """ + Test UnicodeCharTokenizer + """ + input_strs = ("Welcome to Beijing!", "北京欢迎您!", "我喜欢English!", " ") + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = nlp.UnicodeCharTokenizer() + dataset = dataset.map(operations=tokenizer) + tokens = [] + for i in dataset.create_dict_iterator(): + text = nlp_util.as_text(i['text']).tolist() + tokens.append(text) + logger.info("The out tokens is : {}".format(tokens)) + assert split_by_unicode_char(input_strs) == tokens + + +if __name__ == '__main__': + test_unicode_char_tokenizer()