[feat][assistant][I5EWHZ] add new data operator AddToken

This commit is contained in:
xyuuuyx 2022-11-24 21:32:54 +08:00
parent 98d9e95256
commit d1459b200e
20 changed files with 456 additions and 3 deletions

View File

@ -0,0 +1,14 @@
mindspore.dataset.text.AddToken
===============================
.. py:class:: mindspore.dataset.text.AddToken(token, begin=True)
将分词(token)添加到序列的开头或结尾处。
参数:
- **token** (str) - 待添加的分词(token)。
- **begin** (bool, 可选) - 是否在序列的开头或结尾插入分词(token)。默认值True。
异常:
- **TypeError** - 如果 `token` 的类型不为str。
- **TypeError** - 如果 `begin` 的类型不为bool。

View File

@ -277,6 +277,7 @@ API样例中常用的导入模块如下
:nosignatures:
:template: classtemplate.rst
mindspore.dataset.text.AddToken
mindspore.dataset.text.BasicTokenizer
mindspore.dataset.text.BertTokenizer
mindspore.dataset.text.CaseFold

View File

@ -155,6 +155,7 @@ Transforms
:nosignatures:
:template: classtemplate.rst
mindspore.dataset.text.AddToken
mindspore.dataset.text.BasicTokenizer
mindspore.dataset.text.BertTokenizer
mindspore.dataset.text.CaseFold

View File

@ -23,6 +23,16 @@
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(AddTokenOperation, 1, ([](const py::module *m) {
(void)py::class_<text::AddTokenOperation, TensorOperation, std::shared_ptr<text::AddTokenOperation>>(
*m, "AddTokenOperation")
.def(py::init([](const std::string &token, bool begin) {
auto add_token = std::make_shared<text::AddTokenOperation>(token, begin);
THROW_IF_ERROR(add_token->ValidateParams());
return add_token;
}));
}));
#ifdef ENABLE_ICU4C
PYBIND_REGISTER(

View File

@ -36,6 +36,19 @@ constexpr size_t kMaxLoggedRows = 10;
// FUNCTIONS TO CREATE TEXT OPERATIONS
// (In alphabetical order)
// AddToken
struct AddToken::Data {
Data(const std::string &token, bool begin) : token_(token), begin_(begin) {}
std::string token_;
bool begin_;
};
AddToken::AddToken(const std::string &token, bool begin) : data_(std::make_shared<Data>(token, begin)) {}
std::shared_ptr<TensorOperation> AddToken::Parse() {
return std::make_shared<AddTokenOperation>(data_->token_, data_->begin_);
}
#ifndef _WIN32
// BasicTokenizer
struct BasicTokenizer::Data {

View File

@ -234,6 +234,36 @@ class SentencePieceVocab {
// Transform operations for text
namespace text {
/// \brief Add token to beginning or end of sequence.
class DATASET_API AddToken final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] token The token to be added.
/// \param[in] begin Whether to insert token at start or end of sequence. Default: true.
/// \par Example
/// \code
/// /* Define operations */
/// auto add_token_op = text::AddToken(token='TOKEN', begin=True);
///
/// /* dataset is an instance of Dataset object */
/// dataset = dataset->Map({add_token_op}, // operations
/// {"text"}); // input columns
/// \endcode
explicit AddToken(const std::string &token, bool begin = true);
/// \brief Destructor.
~AddToken() override = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
#ifndef _WIN32
/// \brief Tokenize a scalar tensor of UTF-8 string by specific rules.
/// \note BasicTokenizer is not supported on the Windows platform yet.

View File

@ -133,6 +133,7 @@ constexpr char kVerticalFlipOp[] = "VerticalFlipOp";
constexpr char kDvppDecodeVideoOp[] = "DvppDecodeVideoOp";
// text
constexpr char kAddTokenOp[] = "AddTokenOp";
constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp";
constexpr char kBertTokenizerOp[] = "BertTokenizerOp";
constexpr char kCaseFoldOp[] = "CaseFoldOp";

View File

@ -16,6 +16,7 @@
#include "minddata/dataset/text/ir/kernels/text_ir.h"
#include "minddata/dataset/text/kernels/add_token_op.h"
#ifndef _WIN32
#include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
#include "minddata/dataset/text/kernels/bert_tokenizer_op.h"
@ -56,6 +57,34 @@ namespace text {
// (In alphabetical order)
// AddToken
AddTokenOperation::AddTokenOperation(const std::string &token, bool begin) : token_(token), begin_(begin) {}
AddTokenOperation::~AddTokenOperation() = default;
std::shared_ptr<TensorOp> AddTokenOperation::Build() {
std::shared_ptr<AddTokenOp> tensor_op = std::make_shared<AddTokenOp>(token_, begin_);
return tensor_op;
}
Status AddTokenOperation::ValidateParams() {
if (token_.empty()) {
std::string err_msg = "AddToken: Parameter token is not provided.";
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
std::string AddTokenOperation::Name() const { return kAddTokenOperation; }
Status AddTokenOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["token"] = token_;
args["begin"] = begin_;
*out_json = args;
return Status::OK();
}
#ifndef _WIN32
// BasicTokenizerOperation
BasicTokenizerOperation::BasicTokenizerOperation(bool lower_case, bool keep_whitespace,

View File

@ -35,6 +35,7 @@ class Vocab;
namespace text {
constexpr int kStatusSum = 4;
// Char arrays storing name of corresponding classes (in alphabetical order)
constexpr char kAddTokenOperation[] = "AddToken";
constexpr char kBasicTokenizerOperation[] = "BasicTokenizer";
constexpr char kBertTokenizerOperation[] = "BertTokenizer";
constexpr char kCaseFoldOperation[] = "CaseFold";
@ -57,6 +58,28 @@ constexpr char kWordpieceTokenizerOperation[] = "WordpieceTokenizer";
/* ####################################### Derived TensorOperation classes ################################# */
class AddTokenOperation : public TensorOperation {
public:
/// \brief Constructor.
/// \param[in] token The token to be added.
/// \param[in] begin Whether to insert token at start or end of sequence.
AddTokenOperation(const std::string &token, bool begin);
~AddTokenOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
private:
std::string token_;
bool begin_;
};
#ifndef _WIN32
class BasicTokenizerOperation : public TensorOperation {
public:

View File

@ -13,6 +13,7 @@ if(NOT (CMAKE_SYSTEM_NAME MATCHES "Windows"))
whitespace_tokenizer_op.cc)
endif()
add_library(text-kernels OBJECT
add_token_op.cc
data_utils.cc
lookup_op.cc
jieba_tokenizer_op.cc

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/text/kernels/add_token_op.h"
#include <vector>
#include "minddata/dataset/text/kernels/data_utils.h"
namespace mindspore {
namespace dataset {
Status AddTokenOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
CHECK_FAIL_RETURN_UNEXPECTED(input->Rank() == 1 || input->Rank() == 2,
"AddToken: input tensor rank should be 1 or 2.");
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "AddToken: input tensor type should be string.");
IO_CHECK(input, output);
return AddToken(input, output, token_, begin_);
}
Status AddTokenOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
TensorShape input_shape = inputs[0];
std::vector<dsize_t> output_shape_vector = input_shape.AsVector();
output_shape_vector[input_shape.Size() == 1 ? 0 : 1] += 1;
TensorShape out = TensorShape(output_shape_vector);
outputs.emplace_back(out);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_ADD_TOKEN_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_ADD_TOKEN_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
class AddTokenOp : public TensorOp {
public:
/// \brief Constructor.
/// \param[in] token The token to be added.
/// \param[in] begin Whether to insert token at start or end of sequence.
AddTokenOp(const std::string &token, bool begin) : token_(token), begin_(begin) {}
~AddTokenOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kAddTokenOp; }
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
private:
const std::string token_;
bool begin_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_ADD_TOKEN_OP_H_

View File

@ -22,6 +22,7 @@
#include "minddata/dataset/core/pybind_support.h"
#include "minddata/dataset/kernels/data/slice_op.h"
#include "minddata/dataset/kernels/data/concatenate_op.h"
#include "minddata/dataset/kernels/data/data_utils.h"
namespace mindspore {
namespace dataset {
@ -65,5 +66,37 @@ Status AppendOffsetsHelper(const std::vector<uint32_t> &offsets_start, const std
output->push_back(offsets_limit_tensor);
return Status::OK();
}
Status AddToken(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::string &token,
bool begin) {
if (input->Rank() == 1) {
std::shared_ptr<Tensor> append;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(std::vector<std::string>({token}), &append));
TensorRow in({input}), out;
RETURN_IF_NOT_OK(Concatenate(in, &out, 0, begin ? append : nullptr, begin ? nullptr : append));
*output = out[0];
} else {
std::vector<std::string> output_vector;
int dim = input->shape()[0];
int shape = input->shape()[-1];
int count = 0;
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
if (count >= shape) {
count = 0;
}
if (begin && count == 0) {
output_vector.emplace_back(token);
}
output_vector.emplace_back(*it);
if (!begin && count == shape - 1) {
output_vector.emplace_back(token);
}
count++;
}
shape++;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(output_vector, TensorShape({dim, shape}), output));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -46,6 +46,15 @@ Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr
/// \return Status return code
Status AppendOffsetsHelper(const std::vector<uint32_t> &offsets_start, const std::vector<uint32_t> &offsets_limit,
TensorRow *output);
/// \brief Helper method that add token on input tensor.
/// \param[in] input Input tensor.
/// \param[in] token The token to be added.
/// \param[in] begin Whether to insert token at start or end of sequence.
/// \param[out] output Output tensor.
/// \return Status return code.
Status AddToken(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::string &token,
bool begin);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_

View File

@ -70,8 +70,8 @@ import platform
from . import transforms
from . import utils
from .transforms import JiebaTokenizer, Lookup, Ngram, PythonTokenizer, SentencePieceTokenizer, SlidingWindow, \
ToNumber, ToVectors, TruncateSequencePair, UnicodeCharTokenizer, WordpieceTokenizer
from .transforms import AddToken, JiebaTokenizer, Lookup, Ngram, PythonTokenizer, SentencePieceTokenizer, \
SlidingWindow, ToNumber, ToVectors, TruncateSequencePair, UnicodeCharTokenizer, WordpieceTokenizer
from .utils import CharNGram, FastText, GloVe, JiebaMode, NormalizeForm, SentencePieceModel, SentencePieceVocab, \
SPieceTokenizerLoadType, SPieceTokenizerOutType, Vectors, Vocab, to_bytes, to_str

View File

@ -49,7 +49,7 @@ import mindspore._c_dataengine as cde
from mindspore.common import dtype as mstype
from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType, SentencePieceVocab
from .validators import check_lookup, check_jieba_add_dict, check_to_vectors, \
from .validators import check_add_token, check_lookup, check_jieba_add_dict, check_to_vectors, \
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \
check_wordpiece_tokenizer, check_regex_replace, check_regex_tokenizer, check_basic_tokenizer, check_ngram, \
check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow, \
@ -91,6 +91,47 @@ DE_C_INTER_SENTENCEPIECE_OUTTYPE = {
}
class AddToken(TextTensorOperation):
"""
Add token to beginning or end of sequence.
Args:
token (str): The token to be added.
begin (bool, optional): Whether to insert token at start or end of sequence. Default: True.
Raises:
TypeError: If `token` is not of type string.
TypeError: If `begin` is not of type bool.
Supported Platforms:
``CPU``
Examples:
>>> dataset = ds.NumpySlicesDataset(data={"text": [['a', 'b', 'c', 'd', 'e']]})
>>> # Data before
>>> # | text |
>>> # +---------------------------+
>>> # | ['a', 'b', 'c', 'd', 'e'] |
>>> # +---------------------------+
>>> add_token_op = text.AddToken(token='TOKEN', begin=True)
>>> dataset = dataset.map(operations=add_token_op)
>>> # Data after
>>> # | text |
>>> # +---------------------------+
>>> # | ['TOKEN', 'a', 'b', 'c', 'd', 'e'] |
>>> # +---------------------------+
"""
@check_add_token
def __init__(self, token, begin=True):
super().__init__()
self.token = token
self.begin = begin
def parse(self):
return cde.AddTokenOperation(self.token, self.begin)
class JiebaTokenizer(TextTensorOperation):
"""
Tokenize Chinese string into words based on dictionary.

View File

@ -27,6 +27,19 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
INT32_MAX, check_value, check_positive, check_pos_int32, check_filename, check_non_negative_int32
def check_add_token(method):
"""Wrapper method to check the parameters of add token."""
@wraps(method)
def new_method(self, *args, **kwargs):
[token, begin], _ = parse_user_args(method, *args, **kwargs)
type_check(token, (str,), "token")
type_check(begin, (bool,), "begin")
return method(self, *args, **kwargs)
return new_method
def check_unique_list_of_words(words, arg_name):
"""Check that words is a list and each element is a str without any duplication"""

View File

@ -5239,3 +5239,59 @@ TEST_F(MindDataTestPipeline, TestCharNGramsWithNotExistFile) {
Status s = CharNGram::BuildFromFile(&char_n_gram, vectors_dir);
EXPECT_NE(s, Status::OK());
}
/// Feature: AddToken op
/// Description: Test input 1d of AddToken op successfully
/// Expectation: Output is equal to the expected output
TEST_F(MindDataTestPipeline, TestAddTokenPipelineSuccess) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAddTokenPipelineSuccess.";
// Create a TextFile dataset
std::string data_file = datasets_root_path_ + "/testTextFileDataset/1.txt";
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create Take operation on ds
ds = ds->Take(1);
EXPECT_NE(ds, nullptr);
// Create white_tokenizer operation on ds
std::shared_ptr<TensorTransform> white_tokenizer = std::make_shared<text::WhitespaceTokenizer>();
EXPECT_NE(white_tokenizer, nullptr);
// Create add_token operation on ds
std::shared_ptr<TensorTransform> add_token = std::make_shared<text::AddToken>("TOKEN", true);
EXPECT_NE(add_token, nullptr);
// Create Map operation on ds
ds = ds->Map({white_tokenizer, add_token}, {"text"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<std::string> expected = {"TOKEN", "This", "is", "a", "text", "file."};
std::shared_ptr<Tensor> de_expected_tensor;
ASSERT_OK(Tensor::CreateFromVector(expected, &de_expected_tensor));
mindspore::MSTensor expected_tensor =
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_tensor));
uint64_t i = 0;
while (row.size() != 0) {
auto ind = row["text"];
EXPECT_MSTENSOR_EQ(ind, expected_tensor);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 1);
// Manually terminate the pipeline
iter->Stop();
}

View File

@ -2984,3 +2984,19 @@ TEST_F(MindDataTestExecute, TestPerspective) {
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}
/// Feature: AddToken op
/// Description: Test basic usage of AddToken op
/// Expectation: The data is processed successfully
TEST_F(MindDataTestExecute, TestAddToken) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestAddToken.";
std::vector<std::string> input_vectors = {"a", "b", "c", "d", "e"};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(input_vectors, &input));
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> add_token_op = std::make_shared<text::AddToken>("Token", true);
// apply AddToken
mindspore::dataset::Execute trans({add_token_op});
Status status = trans(input_ms, &input_ms);
EXPECT_TRUE(status.IsOk());
}

View File

@ -0,0 +1,68 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing AddToken op
"""
import numpy as np
import mindspore.dataset.text as text
def test_add_token_at_begin():
"""
Feature: AddToken op
Description: Test AddToken with begin = True
Expectation: Output is equal to the expected output
"""
input_one_dimension = ['a', 'b', 'c', 'd', 'e']
expected = ['TOKEN', 'a', 'b', 'c', 'd', 'e']
out = text.AddToken(token='TOKEN', begin=True)
result = out(input_one_dimension)
assert np.array_equal(result, np.array(expected))
def test_add_token_at_end():
"""
Feature: AddToken op
Description: Test AddToken with begin = False
Expectation: Output is equal to the expected output
"""
input_one_dimension = ['a', 'b', 'c', 'd', 'e']
expected = ['a', 'b', 'c', 'd', 'e', 'TOKEN']
out = text.AddToken(token='TOKEN', begin=False)
result = out(input_one_dimension)
assert np.array_equal(result, np.array(expected))
def test_add_token_fail():
"""
Feature: AddToken op
Description: fail to test AddToken
Expectation: TypeError is raised as expected
"""
try:
_ = text.AddToken(token=1.0, begin=True)
except TypeError as error:
assert "Argument token with value 1.0 is not of type [<class 'str'>], but got <class 'float'>." in str(error)
try:
_ = text.AddToken(token='TOKEN', begin=12.3)
except TypeError as error:
assert "Argument begin with value 12.3 is not of type [<class 'bool'>], but got <class 'float'>." in str(error)
if __name__ == "__main__":
test_add_token_at_begin()
test_add_token_at_end()
test_add_token_fail()