forked from mindspore-Ecosystem/mindspore
[feat][assistant][I5EWI8] add new data operator Truncate
This commit is contained in:
parent
3b5e68cca7
commit
6b42fd45f1
|
@ -0,0 +1,14 @@
|
||||||
|
mindspore.dataset.text.Truncate
|
||||||
|
===============================
|
||||||
|
|
||||||
|
.. py:class:: mindspore.dataset.text.Truncate(max_seq_len)
|
||||||
|
|
||||||
|
截断输入序列,使其不超过最大长度。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **max_seq_len** (int) - 最大截断长度。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - 如果 `max_seq_len` 的类型不是int。
|
||||||
|
- **ValueError** - 如果 `max_seq_len` 的值小于或等于0。
|
||||||
|
- **RuntimeError** - 如果输入张量的数据类型不是bool、int、float、double或者str。
|
|
@ -293,6 +293,7 @@ API样例中常用的导入模块如下:
|
||||||
mindspore.dataset.text.SlidingWindow
|
mindspore.dataset.text.SlidingWindow
|
||||||
mindspore.dataset.text.ToNumber
|
mindspore.dataset.text.ToNumber
|
||||||
mindspore.dataset.text.ToVectors
|
mindspore.dataset.text.ToVectors
|
||||||
|
mindspore.dataset.text.Truncate
|
||||||
mindspore.dataset.text.TruncateSequencePair
|
mindspore.dataset.text.TruncateSequencePair
|
||||||
mindspore.dataset.text.UnicodeCharTokenizer
|
mindspore.dataset.text.UnicodeCharTokenizer
|
||||||
mindspore.dataset.text.UnicodeScriptTokenizer
|
mindspore.dataset.text.UnicodeScriptTokenizer
|
||||||
|
|
|
@ -171,6 +171,7 @@ Transforms
|
||||||
mindspore.dataset.text.SlidingWindow
|
mindspore.dataset.text.SlidingWindow
|
||||||
mindspore.dataset.text.ToNumber
|
mindspore.dataset.text.ToNumber
|
||||||
mindspore.dataset.text.ToVectors
|
mindspore.dataset.text.ToVectors
|
||||||
|
mindspore.dataset.text.Truncate
|
||||||
mindspore.dataset.text.TruncateSequencePair
|
mindspore.dataset.text.TruncateSequencePair
|
||||||
mindspore.dataset.text.UnicodeCharTokenizer
|
mindspore.dataset.text.UnicodeCharTokenizer
|
||||||
mindspore.dataset.text.UnicodeScriptTokenizer
|
mindspore.dataset.text.UnicodeScriptTokenizer
|
||||||
|
|
|
@ -241,6 +241,16 @@ PYBIND_REGISTER(
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(TruncateOperation, 1, ([](const py::module *m) {
|
||||||
|
(void)py::class_<text::TruncateOperation, TensorOperation, std::shared_ptr<text::TruncateOperation>>(
|
||||||
|
*m, "TruncateOperation")
|
||||||
|
.def(py::init([](int32_t max_seq_len) {
|
||||||
|
auto truncate = std::make_shared<text::TruncateOperation>(max_seq_len);
|
||||||
|
THROW_IF_ERROR(truncate->ValidateParams());
|
||||||
|
return truncate;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(TruncateSequencePairOperation, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(TruncateSequencePairOperation, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<text::TruncateSequencePairOperation, TensorOperation,
|
(void)py::class_<text::TruncateSequencePairOperation, TensorOperation,
|
||||||
std::shared_ptr<text::TruncateSequencePairOperation>>(
|
std::shared_ptr<text::TruncateSequencePairOperation>>(
|
||||||
|
|
|
@ -394,6 +394,16 @@ std::shared_ptr<TensorOperation> ToVectors::Parse() {
|
||||||
return std::make_shared<ToVectorsOperation>(data_->vectors_, data_->unk_init_, data_->lower_case_backup_);
|
return std::make_shared<ToVectorsOperation>(data_->vectors_, data_->unk_init_, data_->lower_case_backup_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Truncate
|
||||||
|
struct Truncate::Data {
|
||||||
|
explicit Data(int32_t max_seq_len) : max_seq_len_(max_seq_len) {}
|
||||||
|
int32_t max_seq_len_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Truncate::Truncate(int32_t max_seq_len) : data_(std::make_shared<Data>(max_seq_len)) {}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> Truncate::Parse() { return std::make_shared<TruncateOperation>(data_->max_seq_len_); }
|
||||||
|
|
||||||
// TruncateSequencePair
|
// TruncateSequencePair
|
||||||
struct TruncateSequencePair::Data {
|
struct TruncateSequencePair::Data {
|
||||||
explicit Data(int32_t max_length) : max_length_(max_length) {}
|
explicit Data(int32_t max_length) : max_length_(max_length) {}
|
||||||
|
|
|
@ -894,6 +894,35 @@ class DATASET_API ToVectors final : public TensorTransform {
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief Truncate the input sequence so that it does not exceed the maximum length.
|
||||||
|
class DATASET_API Truncate final : public TensorTransform {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
/// \param[in] max_seq_len Maximum allowable length.
|
||||||
|
/// \par Example
|
||||||
|
/// \code
|
||||||
|
/// /* Define operations */
|
||||||
|
/// auto truncate_op = text::Truncate(5);
|
||||||
|
///
|
||||||
|
/// /* dataset is an instance of Dataset object */
|
||||||
|
/// dataset = dataset->Map({truncate_op}, // operations
|
||||||
|
/// {"text"}); // input columns
|
||||||
|
/// \endcode
|
||||||
|
explicit Truncate(int32_t max_seq_len);
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~Truncate() = default;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||||
|
/// \return Shared pointer to the TensorOperation object.
|
||||||
|
std::shared_ptr<TensorOperation> Parse() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Data;
|
||||||
|
std::shared_ptr<Data> data_;
|
||||||
|
};
|
||||||
|
|
||||||
/// \brief Truncate a pair of rank-1 tensors such that the total length is less than max_length.
|
/// \brief Truncate a pair of rank-1 tensors such that the total length is less than max_length.
|
||||||
class DATASET_API TruncateSequencePair final : public TensorTransform {
|
class DATASET_API TruncateSequencePair final : public TensorTransform {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -147,6 +147,7 @@ constexpr char kRegexReplaceOp[] = "RegexReplaceOp";
|
||||||
constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
|
constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
|
||||||
constexpr char kToNumberOp[] = "ToNumberOp";
|
constexpr char kToNumberOp[] = "ToNumberOp";
|
||||||
constexpr char kToVectorsOp[] = "ToVectorsOp";
|
constexpr char kToVectorsOp[] = "ToVectorsOp";
|
||||||
|
constexpr char kTruncateOp[] = "TruncateOp";
|
||||||
constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp";
|
constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp";
|
||||||
constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp";
|
constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp";
|
||||||
constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp";
|
constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp";
|
||||||
|
|
|
@ -35,6 +35,7 @@
|
||||||
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||||
#include "minddata/dataset/text/kernels/to_number_op.h"
|
#include "minddata/dataset/text/kernels/to_number_op.h"
|
||||||
#include "minddata/dataset/text/kernels/to_vectors_op.h"
|
#include "minddata/dataset/text/kernels/to_vectors_op.h"
|
||||||
|
#include "minddata/dataset/text/kernels/truncate_op.h"
|
||||||
#include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
|
#include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
|
||||||
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
|
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||||
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
|
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
|
||||||
|
@ -47,6 +48,7 @@
|
||||||
#include "minddata/dataset/util/path.h"
|
#include "minddata/dataset/util/path.h"
|
||||||
#include "minddata/dataset/util/validators.h"
|
#include "minddata/dataset/util/validators.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/audio/ir/validators.h"
|
||||||
#include "minddata/dataset/text/ir/validators.h"
|
#include "minddata/dataset/text/ir/validators.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -583,6 +585,27 @@ std::shared_ptr<TensorOp> ToVectorsOperation::Build() {
|
||||||
return tensor_op;
|
return tensor_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TruncateOperation
|
||||||
|
TruncateOperation::TruncateOperation(int32_t max_seq_len) : max_seq_len_(max_seq_len) {}
|
||||||
|
|
||||||
|
Status TruncateOperation::ValidateParams() {
|
||||||
|
RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("Truncate", "max_seq_len", max_seq_len_));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> TruncateOperation::Build() {
|
||||||
|
std::shared_ptr<TruncateOp> tensor_op = std::make_shared<TruncateOp>(max_seq_len_);
|
||||||
|
return tensor_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TruncateOperation::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["max_seq_len"] = max_seq_len_;
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// TruncateSequencePairOperation
|
// TruncateSequencePairOperation
|
||||||
TruncateSequencePairOperation::TruncateSequencePairOperation(int32_t max_length) : max_length_(max_length) {}
|
TruncateSequencePairOperation::TruncateSequencePairOperation(int32_t max_length) : max_length_(max_length) {}
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,7 @@ constexpr char kSentencepieceTokenizerOperation[] = "SentencepieceTokenizer";
|
||||||
constexpr char kSlidingWindowOperation[] = "SlidingWindow";
|
constexpr char kSlidingWindowOperation[] = "SlidingWindow";
|
||||||
constexpr char kToNumberOperation[] = "ToNumber";
|
constexpr char kToNumberOperation[] = "ToNumber";
|
||||||
constexpr char kToVectorsOperation[] = "ToVectors";
|
constexpr char kToVectorsOperation[] = "ToVectors";
|
||||||
|
constexpr char kTruncateOperation[] = "Truncate";
|
||||||
constexpr char kTruncateSequencePairOperation[] = "TruncateSequencePair";
|
constexpr char kTruncateSequencePairOperation[] = "TruncateSequencePair";
|
||||||
constexpr char kUnicodeCharTokenizerOperation[] = "UnicodeCharTokenizer";
|
constexpr char kUnicodeCharTokenizerOperation[] = "UnicodeCharTokenizer";
|
||||||
constexpr char kUnicodeScriptTokenizerOperation[] = "UnicodeScriptTokenizer";
|
constexpr char kUnicodeScriptTokenizerOperation[] = "UnicodeScriptTokenizer";
|
||||||
|
@ -353,6 +354,24 @@ class ToVectorsOperation : public TensorOperation {
|
||||||
bool lower_case_backup_;
|
bool lower_case_backup_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TruncateOperation : public TensorOperation {
|
||||||
|
public:
|
||||||
|
explicit TruncateOperation(int32_t max_seq_len);
|
||||||
|
|
||||||
|
~TruncateOperation() = default;
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> Build() override;
|
||||||
|
|
||||||
|
Status ValidateParams() override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kTruncateOperation; }
|
||||||
|
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t max_seq_len_;
|
||||||
|
};
|
||||||
|
|
||||||
class TruncateSequencePairOperation : public TensorOperation {
|
class TruncateSequencePairOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit TruncateSequencePairOperation(int32_t max_length);
|
explicit TruncateSequencePairOperation(int32_t max_length);
|
||||||
|
|
|
@ -22,6 +22,7 @@ add_library(text-kernels OBJECT
|
||||||
ngram_op.cc
|
ngram_op.cc
|
||||||
sliding_window_op.cc
|
sliding_window_op.cc
|
||||||
wordpiece_tokenizer_op.cc
|
wordpiece_tokenizer_op.cc
|
||||||
|
truncate_op.cc
|
||||||
truncate_sequence_pair_op.cc
|
truncate_sequence_pair_op.cc
|
||||||
to_number_op.cc
|
to_number_op.cc
|
||||||
to_vectors_op.cc
|
to_vectors_op.cc
|
||||||
|
|
|
@ -98,5 +98,16 @@ Status AddToken(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Truncate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int max_seq_len) {
|
||||||
|
if (input->shape().Rank() == 1) {
|
||||||
|
return input->Slice(output, {SliceOption(Slice(max_seq_len))});
|
||||||
|
} else {
|
||||||
|
int dim = input->shape()[0];
|
||||||
|
Slice slice_dim = Slice(dim);
|
||||||
|
Slice slice_len = Slice(max_seq_len);
|
||||||
|
return input->Slice(output, {SliceOption(slice_dim), SliceOption(slice_len)});
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -55,6 +55,11 @@ Status AppendOffsetsHelper(const std::vector<uint32_t> &offsets_start, const std
|
||||||
/// \return Status return code.
|
/// \return Status return code.
|
||||||
Status AddToken(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::string &token,
|
Status AddToken(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::string &token,
|
||||||
bool begin);
|
bool begin);
|
||||||
|
|
||||||
|
/// \brief Truncate the input sequence so that it does not exceed the maximum length.
|
||||||
|
/// \param[in] max_seq_len Maximum allowable length.
|
||||||
|
/// \param[out] output Output Tensor.
|
||||||
|
Status Truncate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int max_seq_len);
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/text/kernels/truncate_op.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/data/slice_op.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
Status TruncateOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
constexpr int kMaxSeqRank = 2;
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1 || input->shape().Rank() == kMaxSeqRank,
|
||||||
|
"Truncate: the input tensor should be of dimension 1 or 2.");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||||
|
input->type() == DataType::DE_STRING || input->type().IsNumeric(),
|
||||||
|
"Truncate: Truncate: the input tensor should be in type of [bool, int, float, double, string].");
|
||||||
|
return Truncate(input, output, max_seq_len_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TruncateOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||||
|
constexpr int kMaxSeqRank = 2;
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs[0].Rank() == 1 || inputs[0].Rank() == kMaxSeqRank,
|
||||||
|
"Truncate: the input tensor should be of dimension 1 or 2.");
|
||||||
|
if (inputs[0].Rank() == 1) {
|
||||||
|
outputs.clear();
|
||||||
|
auto shape = inputs[0].AsVector();
|
||||||
|
int length = shape[0];
|
||||||
|
shape[0] = std::min(length, max_seq_len_);
|
||||||
|
outputs.emplace_back(TensorShape{shape});
|
||||||
|
} else {
|
||||||
|
outputs.clear();
|
||||||
|
auto shape = inputs[0].AsVector();
|
||||||
|
int length = shape[1];
|
||||||
|
shape[1] = std::min(length, max_seq_len_);
|
||||||
|
outputs.emplace_back(TensorShape{shape});
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TRUNCATE_OP_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TRUNCATE_OP_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class TruncateOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
explicit TruncateOp(int32_t max_seq_len) : max_seq_len_(max_seq_len) {}
|
||||||
|
|
||||||
|
~TruncateOp() override = default;
|
||||||
|
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kTruncateOp; }
|
||||||
|
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t max_seq_len_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TRUNCATE_OP_H_
|
|
@ -73,7 +73,7 @@ import platform
|
||||||
from . import transforms
|
from . import transforms
|
||||||
from . import utils
|
from . import utils
|
||||||
from .transforms import AddToken, JiebaTokenizer, Lookup, Ngram, PythonTokenizer, SentencePieceTokenizer, \
|
from .transforms import AddToken, JiebaTokenizer, Lookup, Ngram, PythonTokenizer, SentencePieceTokenizer, \
|
||||||
SlidingWindow, ToNumber, ToVectors, TruncateSequencePair, UnicodeCharTokenizer, WordpieceTokenizer
|
SlidingWindow, ToNumber, ToVectors, Truncate, TruncateSequencePair, UnicodeCharTokenizer, WordpieceTokenizer
|
||||||
from .utils import CharNGram, FastText, GloVe, JiebaMode, NormalizeForm, SentencePieceModel, SentencePieceVocab, \
|
from .utils import CharNGram, FastText, GloVe, JiebaMode, NormalizeForm, SentencePieceModel, SentencePieceVocab, \
|
||||||
SPieceTokenizerLoadType, SPieceTokenizerOutType, Vectors, Vocab, to_bytes, to_str
|
SPieceTokenizerLoadType, SPieceTokenizerOutType, Vectors, Vocab, to_bytes, to_str
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ from .validators import check_add_token, check_lookup, check_jieba_add_dict, che
|
||||||
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \
|
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \
|
||||||
check_wordpiece_tokenizer, check_regex_replace, check_regex_tokenizer, check_basic_tokenizer, check_ngram, \
|
check_wordpiece_tokenizer, check_regex_replace, check_regex_tokenizer, check_basic_tokenizer, check_ngram, \
|
||||||
check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow, \
|
check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow, \
|
||||||
check_sentence_piece_tokenizer
|
check_sentence_piece_tokenizer, check_truncate
|
||||||
from ..core.datatypes import mstype_to_detype
|
from ..core.datatypes import mstype_to_detype
|
||||||
from ..core.validator_helpers import replace_none
|
from ..core.validator_helpers import replace_none
|
||||||
from ..transforms.py_transforms_util import Implementation
|
from ..transforms.py_transforms_util import Implementation
|
||||||
|
@ -620,6 +620,46 @@ class ToVectors(TextTensorOperation):
|
||||||
return cde.ToVectorsOperation(self.vectors, self.unk_init, self.lower_case_backup)
|
return cde.ToVectorsOperation(self.vectors, self.unk_init, self.lower_case_backup)
|
||||||
|
|
||||||
|
|
||||||
|
class Truncate(TextTensorOperation):
|
||||||
|
"""
|
||||||
|
Truncate the input sequence so that it does not exceed the maximum length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_seq_len (int): Maximum allowable length.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `max_length_len` is not of type int.
|
||||||
|
ValueError: If value of `max_length_len` is not greater than or equal to 0.
|
||||||
|
RuntimeError: If the input tensor is not of dtype bool, int, float, double or str.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> dataset = ds.NumpySlicesDataset(data=[['a', 'b', 'c', 'd', 'e']], column_names=["text"], shuffle=False)
|
||||||
|
>>> # Data before
|
||||||
|
>>> # | col1 |
|
||||||
|
>>> # +---------------------------+
|
||||||
|
>>> # | ['a', 'b', 'c', 'd', 'e'] |
|
||||||
|
>>> # +---------------------------+
|
||||||
|
>>> truncate = text.Truncate(4)
|
||||||
|
>>> dataset = dataset.map(operations=truncate, input_columns=["text"])
|
||||||
|
>>> # Data after
|
||||||
|
>>> # | col1 |
|
||||||
|
>>> # +------------------------+
|
||||||
|
>>> # | ['a', 'b', 'c', 'd'] |
|
||||||
|
>>> # +------------------------+
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_truncate
|
||||||
|
def __init__(self, max_seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
return cde.TruncateOperation(self.max_seq_len)
|
||||||
|
|
||||||
|
|
||||||
class TruncateSequencePair(TextTensorOperation):
|
class TruncateSequencePair(TextTensorOperation):
|
||||||
"""
|
"""
|
||||||
Truncate a pair of rank-1 tensors such that the total length is less than max_length.
|
Truncate a pair of rank-1 tensors such that the total length is less than max_length.
|
||||||
|
|
|
@ -456,6 +456,18 @@ def check_ngram(method):
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_truncate(method):
|
||||||
|
"""Wrapper method to check the parameters of number of truncate."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[max_seq_len], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
check_pos_int32(max_seq_len, "max_seq_len")
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_pair_truncate(method):
|
def check_pair_truncate(method):
|
||||||
"""Wrapper method to check the parameters of number of pair truncate."""
|
"""Wrapper method to check the parameters of number of pair truncate."""
|
||||||
|
|
||||||
|
|
|
@ -5295,3 +5295,85 @@ TEST_F(MindDataTestPipeline, TestAddTokenPipelineSuccess) {
|
||||||
// Manually terminate the pipeline
|
// Manually terminate the pipeline
|
||||||
iter->Stop();
|
iter->Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Feature: Truncate
|
||||||
|
/// Description: Test Truncate basic usage max_seq_len less length
|
||||||
|
/// Expectation: Output is equal to the expected output
|
||||||
|
TEST_F(MindDataTestPipeline, TestTruncateSuccess1D) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTruncateSuccess1D.";
|
||||||
|
// Testing basic Truncate
|
||||||
|
|
||||||
|
// Create a TextFile dataset
|
||||||
|
std::string data_file = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||||
|
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create white_tokenizer operation on ds
|
||||||
|
std::shared_ptr<TensorTransform> white_tokenizer = std::make_shared<text::WhitespaceTokenizer>();
|
||||||
|
EXPECT_NE(white_tokenizer, nullptr);
|
||||||
|
|
||||||
|
// Create a truncate operation on ds
|
||||||
|
std::shared_ptr<TensorTransform> truncate = std::make_shared<text::Truncate>(3);
|
||||||
|
EXPECT_NE(truncate, nullptr);
|
||||||
|
|
||||||
|
// Create Map operation on ds
|
||||||
|
ds = ds->Map({white_tokenizer, truncate}, {"text"});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
std::vector<std::vector<std::string>> expected = {
|
||||||
|
{"This", "is", "a"}, {"Be", "happy", "every"}, {"Good", "luck", "to"}};
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto ind = row["text"];
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> de_expected_tensor;
|
||||||
|
ASSERT_OK(Tensor::CreateFromVector(expected[i], &de_expected_tensor));
|
||||||
|
mindspore::MSTensor expected_tensor =
|
||||||
|
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_tensor));
|
||||||
|
EXPECT_MSTENSOR_EQ(ind, expected_tensor);
|
||||||
|
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Truncate
|
||||||
|
/// Description: Test the incorrect parameter of Truncate interface
|
||||||
|
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||||
|
TEST_F(MindDataTestPipeline, TestTruncateFail) {
|
||||||
|
// Testing the incorrect parameter of Truncate interface.
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTruncateFail.";
|
||||||
|
|
||||||
|
// Create a TextFile dataset
|
||||||
|
std::string data_file = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||||
|
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Testing the parameter max_seq_len less than 0
|
||||||
|
// Create a truncate operation on ds
|
||||||
|
std::shared_ptr<TensorTransform> truncate = std::make_shared<text::Truncate>(-1);
|
||||||
|
EXPECT_NE(truncate, nullptr);
|
||||||
|
|
||||||
|
// Create a Map operation on ds
|
||||||
|
ds = ds->Map({truncate});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: invalid Truncate input (The parameter max_seq_len must be greater than 0)
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
|
@ -3038,3 +3038,18 @@ TEST_F(MindDataTestExecute, TestLFCCWrongArgsDctType) {
|
||||||
Status status = trans(input_ms, &input_ms);
|
Status status = trans(input_ms, &input_ms);
|
||||||
EXPECT_FALSE(status.IsOk());
|
EXPECT_FALSE(status.IsOk());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Feature: Truncate
|
||||||
|
/// Description: Test basic usage of Truncate op
|
||||||
|
/// Expectation: The data is processed successfully
|
||||||
|
TEST_F(MindDataTestExecute, TestTruncateOpStr) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTruncateOpStr.";
|
||||||
|
std::shared_ptr<Tensor> input;
|
||||||
|
Tensor::CreateFromVector(std::vector<std::string>({"hello", "hhx", "hyx", "world", "this", "is"}), &input);
|
||||||
|
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||||
|
std::shared_ptr<TensorTransform> truncate_op = std::make_shared<text::Truncate>(3);
|
||||||
|
// apply Truncate
|
||||||
|
mindspore::dataset::Execute trans({truncate_op});
|
||||||
|
Status status = trans(input_ms, &input_ms);
|
||||||
|
EXPECT_TRUE(status.IsOk());
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing Truncate Python API
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import log as logger
|
||||||
|
import mindspore.dataset.text as text
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_max_len_1d():
|
||||||
|
"""
|
||||||
|
Feature: Truncate op
|
||||||
|
Description: Test Truncate op using 1D str as the input
|
||||||
|
Expectation: Output is equal to the expected output
|
||||||
|
"""
|
||||||
|
truncate = text.Truncate(3)
|
||||||
|
input1 = ["1", "2", "3", "4", "5"]
|
||||||
|
result1 = truncate(input1)
|
||||||
|
expect1 = (["1", "2", "3"])
|
||||||
|
assert np.array_equal(result1, expect1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_max_len_2d():
|
||||||
|
"""
|
||||||
|
Feature: Truncate op
|
||||||
|
Description: Test Truncate op using 2D int as the input
|
||||||
|
Expectation: Output is equal to the expected output
|
||||||
|
"""
|
||||||
|
truncate = text.Truncate(3)
|
||||||
|
input1 = [[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]]
|
||||||
|
result1 = truncate(input1)
|
||||||
|
expect1 = ([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
|
||||||
|
assert np.array_equal(result1, expect1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_input_error():
|
||||||
|
"""
|
||||||
|
Feature:Truncate Op
|
||||||
|
Description: Test input Error
|
||||||
|
Expectation: Throw ValueError, TypeError or RuntimeError exception
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_ = text.Truncate(-1)
|
||||||
|
except ValueError as error:
|
||||||
|
logger.info("Got an exception in Truncate: {}".format(str(error)))
|
||||||
|
assert "Input max_seq_len is not within the required interval of [1, 2147483647]." in str(
|
||||||
|
error)
|
||||||
|
try:
|
||||||
|
_ = text.Truncate('a')
|
||||||
|
except TypeError as error:
|
||||||
|
logger.info("Got an exception in Truncate: {}".format(str(error)))
|
||||||
|
assert "Argument max_seq_len with value a is not of type [<class 'int'>], but got <class 'str'>." in str(
|
||||||
|
error)
|
||||||
|
try:
|
||||||
|
truncate = text.Truncate(2)
|
||||||
|
input1 = [b'1', b'2', b'3', b'4', b'5']
|
||||||
|
truncate(input1)
|
||||||
|
except RuntimeError as error:
|
||||||
|
logger.info("Got an exception in Truncate: {}".format(str(error)))
|
||||||
|
assert "Truncate: Truncate: the input tensor should be in type of [bool, int, float, double, string]." in str(
|
||||||
|
error)
|
||||||
|
try:
|
||||||
|
truncate = text.Truncate(2)
|
||||||
|
input1 = [[[1, 2, 3]]]
|
||||||
|
truncate(input1)
|
||||||
|
except RuntimeError as error:
|
||||||
|
logger.info("Got an exception in Truncate: {}".format(str(error)))
|
||||||
|
assert "Truncate: the input tensor should be of dimension 1 or 2."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_truncate_max_len_1d()
|
||||||
|
test_truncate_max_len_2d()
|
||||||
|
test_truncate_input_error()
|
Loading…
Reference in New Issue