forked from mindspore-Ecosystem/mindspore
Add UnicodeCharTokenizer for nlp
This commit is contained in:
parent
93e7c97a96
commit
451c20a6f5
|
@ -61,7 +61,7 @@ set(submodules
|
|||
$<TARGET_OBJECTS:kernels>
|
||||
$<TARGET_OBJECTS:kernels-image>
|
||||
$<TARGET_OBJECTS:kernels-data>
|
||||
$<TARGET_OBJECTS:kernels-nlp>
|
||||
$<TARGET_OBJECTS:kernels-text>
|
||||
$<TARGET_OBJECTS:APItoPython>
|
||||
$<TARGET_OBJECTS:engine-datasetops-source>
|
||||
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include "dataset/kernels/image/uniform_aug_op.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/kernels/text/jieba_tokenizer_op.h"
|
||||
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
|
@ -407,12 +408,16 @@ void bindTensorOps4(py::module *m) {
|
|||
py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
|
||||
}
|
||||
|
||||
void bindTensorOps6(py::module *m) {
|
||||
void bindTensorOps5(py::module *m) {
|
||||
(void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "")
|
||||
.def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"),
|
||||
py::arg("mode") = JiebaMode::kMix)
|
||||
.def("add_word",
|
||||
[](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); });
|
||||
|
||||
(void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>(
|
||||
*m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
|
||||
.def(py::init<>());
|
||||
}
|
||||
|
||||
void bindSamplerOps(py::module *m) {
|
||||
|
@ -534,7 +539,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
bindTensorOps2(&m);
|
||||
bindTensorOps3(&m);
|
||||
bindTensorOps4(&m);
|
||||
bindTensorOps6(&m);
|
||||
bindTensorOps5(&m);
|
||||
bindSamplerOps(&m);
|
||||
bindDatasetOps(&m);
|
||||
bindInfoObjects(&m);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(kernels-nlp OBJECT
|
||||
add_library(kernels-text OBJECT
|
||||
jieba_tokenizer_op.cc
|
||||
unicode_char_tokenizer_op.cc
|
||||
)
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_NLP_JIEBA_OP_H_
|
||||
#define DATASET_ENGINE_NLP_JIEBA_OP_H_
|
||||
#ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_
|
||||
#define DATASET_ENGINE_TEXT_JIEBA_OP_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
@ -61,4 +61,4 @@ class JiebaTokenizerOp : public TensorOp {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_NLP_JIEBA_OP_H_
|
||||
#endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "cppjieba/Unicode.hpp"
|
||||
|
||||
using cppjieba::DecodeRunesInString;
|
||||
using cppjieba::RuneStrArray;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status UnicodeCharTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
|
||||
RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor");
|
||||
}
|
||||
std::string_view str;
|
||||
RETURN_IF_NOT_OK(input->GetItemAt(&str, {}));
|
||||
|
||||
RuneStrArray runes;
|
||||
if (!DecodeRunesInString(str.data(), str.size(), runes)) {
|
||||
RETURN_STATUS_UNEXPECTED("Decode utf8 string failed.");
|
||||
}
|
||||
std::vector<std::string> splits(runes.size());
|
||||
for (size_t i = 0; i < runes.size(); i++) {
|
||||
splits[i] = str.substr(runes[i].offset, runes[i].len);
|
||||
}
|
||||
if (splits.empty()) {
|
||||
splits.emplace_back("");
|
||||
}
|
||||
*output = std::make_shared<Tensor>(splits, TensorShape({(dsize_t)splits.size()}));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_
|
||||
#define DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_
|
||||
#include <memory>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class UnicodeCharTokenizerOp : public TensorOp {
|
||||
public:
|
||||
UnicodeCharTokenizerOp() {}
|
||||
|
||||
~UnicodeCharTokenizerOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_KERNELS_TEXT_UNICODE_CHAR_TOKENIZER_OP_H_
|
|
@ -284,10 +284,10 @@ class Dataset:
|
|||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> import mindspore.dataset.transforms.nlp.utils as nlp
|
||||
>>> import mindspore.dataset.transforms.text.utils as text
|
||||
>>> # declare a function which returns a Dataset object
|
||||
>>> def flat_map_func(x):
|
||||
>>> data_dir = nlp.as_text(x[0])
|
||||
>>> data_dir = text.as_text(x[0])
|
||||
>>> d = ds.ImageFolderDatasetV2(data_dir)
|
||||
>>> return d
|
||||
>>> # data is a Dataset object
|
||||
|
|
|
@ -18,3 +18,4 @@ image augmentation module which is developed with c++ opencv. Py_transforms
|
|||
provide more kinds of image augmentations which is developed with python PIL.
|
||||
"""
|
||||
from .utils import as_text, JiebaMode
|
||||
from . import c_transforms
|
||||
|
|
|
@ -123,3 +123,9 @@ class JiebaTokenizer(cde.JiebaTokenizerOp):
|
|||
if not os.path.exists(model_path):
|
||||
raise ValueError(
|
||||
" jieba mode file {} is not exist".format(model_path))
|
||||
|
||||
|
||||
class UnicodeCharTokenizer(cde.UnicodeCharTokenizerOp):
|
||||
"""
|
||||
Tokenize a scalar tensor of UTF-8 string to Unicode characters.
|
||||
"""
|
||||
|
|
|
@ -33,9 +33,7 @@ def as_text(array, encoding='utf8'):
|
|||
if not isinstance(array, np.ndarray):
|
||||
raise ValueError('input should be a numpy array')
|
||||
|
||||
def decode(x):
|
||||
return x.decode(encoding)
|
||||
decode = np.vectorize(decode)
|
||||
decode = np.vectorize(lambda x: x.decode(encoding))
|
||||
return decode(array)
|
||||
|
||||
|
||||
|
|
|
@ -69,6 +69,7 @@ SET(DE_UT_SRCS
|
|||
filter_op_test.cc
|
||||
concat_op_test.cc
|
||||
jieba_tokenizer_op_test.cc
|
||||
tokenizer_op_test.cc
|
||||
)
|
||||
|
||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
|
||||
class MindDataTestTokenizerOp : public UT::Common {
|
||||
public:
|
||||
void CheckEqual(const std::shared_ptr<Tensor> &o,
|
||||
const std::vector<dsize_t> &index,
|
||||
const std::string &expect) {
|
||||
std::string_view str;
|
||||
Status s = o->GetItemAt(&str, index);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_EQ(str, expect);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) {
|
||||
MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp.";
|
||||
std::unique_ptr<UnicodeCharTokenizerOp> op(new UnicodeCharTokenizerOp());
|
||||
std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Hello World!");
|
||||
std::shared_ptr<Tensor> output;
|
||||
Status s = op->Compute(input, &output);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_EQ(output->Size(), 12);
|
||||
EXPECT_EQ(output->Rank(), 1);
|
||||
MS_LOG(INFO) << "Out tensor1: " << output->ToString();
|
||||
CheckEqual(output, {0}, "H");
|
||||
CheckEqual(output, {1}, "e");
|
||||
CheckEqual(output, {2}, "l");
|
||||
CheckEqual(output, {3}, "l");
|
||||
CheckEqual(output, {4}, "o");
|
||||
CheckEqual(output, {5}, " ");
|
||||
CheckEqual(output, {6}, "W");
|
||||
CheckEqual(output, {7}, "o");
|
||||
CheckEqual(output, {8}, "r");
|
||||
CheckEqual(output, {9}, "l");
|
||||
CheckEqual(output, {10}, "d");
|
||||
CheckEqual(output, {11}, "!");
|
||||
|
||||
input = std::make_shared<Tensor>("中国 你好!");
|
||||
s = op->Compute(input, &output);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_EQ(output->Size(), 6);
|
||||
EXPECT_EQ(output->Rank(), 1);
|
||||
MS_LOG(INFO) << "Out tensor2: " << output->ToString();
|
||||
CheckEqual(output, {0}, "中");
|
||||
CheckEqual(output, {1}, "国");
|
||||
CheckEqual(output, {2}, " ");
|
||||
CheckEqual(output, {3}, "你");
|
||||
CheckEqual(output, {4}, "好");
|
||||
CheckEqual(output, {5}, "!");
|
||||
|
||||
input = std::make_shared<Tensor>("中");
|
||||
s = op->Compute(input, &output);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_EQ(output->Size(), 1);
|
||||
EXPECT_EQ(output->Rank(), 1);
|
||||
MS_LOG(INFO) << "Out tensor3: " << output->ToString();
|
||||
CheckEqual(output, {0}, "中");
|
||||
|
||||
input = std::make_shared<Tensor>("H");
|
||||
s = op->Compute(input, &output);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_EQ(output->Size(), 1);
|
||||
EXPECT_EQ(output->Rank(), 1);
|
||||
MS_LOG(INFO) << "Out tensor4: " << output->ToString();
|
||||
CheckEqual(output, {0}, "H");
|
||||
|
||||
input = std::make_shared<Tensor>(" ");
|
||||
s = op->Compute(input, &output);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_EQ(output->Size(), 2);
|
||||
EXPECT_EQ(output->Rank(), 1);
|
||||
MS_LOG(INFO) << "Out tensor5: " << output->ToString();
|
||||
CheckEqual(output, {0}, " ");
|
||||
CheckEqual(output, {1}, " ");
|
||||
|
||||
input = std::make_shared<Tensor>("");
|
||||
s = op->Compute(input, &output);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_EQ(output->Size(), 1);
|
||||
EXPECT_EQ(output->Rank(), 1);
|
||||
MS_LOG(INFO) << "Out tensor6: " << output->ToString();
|
||||
CheckEqual(output, {0}, "");
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
Welcome to Beijing!
|
||||
北京欢迎您!
|
||||
我喜欢English!
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing UnicodeCharTokenizer op in DE
|
||||
"""
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.transforms.text.c_transforms as nlp
|
||||
import mindspore.dataset.transforms.text.utils as nlp_util
|
||||
|
||||
DATA_FILE = "../data/dataset/testTokenizerData/1.txt"
|
||||
|
||||
|
||||
def split_by_unicode_char(input_strs):
|
||||
"""
|
||||
Split utf-8 strings to unicode characters
|
||||
"""
|
||||
out = []
|
||||
for s in input_strs:
|
||||
out.append([c for c in s])
|
||||
return out
|
||||
|
||||
|
||||
def test_unicode_char_tokenizer():
|
||||
"""
|
||||
Test UnicodeCharTokenizer
|
||||
"""
|
||||
input_strs = ("Welcome to Beijing!", "北京欢迎您!", "我喜欢English!", " ")
|
||||
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
tokenizer = nlp.UnicodeCharTokenizer()
|
||||
dataset = dataset.map(operations=tokenizer)
|
||||
tokens = []
|
||||
for i in dataset.create_dict_iterator():
|
||||
text = nlp_util.as_text(i['text']).tolist()
|
||||
tokens.append(text)
|
||||
logger.info("The out tokens is : {}".format(tokens))
|
||||
assert split_by_unicode_char(input_strs) == tokens
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_unicode_char_tokenizer()
|
Loading…
Reference in New Issue