[feat][assistant][I5EWI8] add new data operator Truncate
This commit is contained in:
parent
3b5e68cca7
commit
6b42fd45f1
|
@ -0,0 +1,14 @@
|
|||
mindspore.dataset.text.Truncate
|
||||
===============================
|
||||
|
||||
.. py:class:: mindspore.dataset.text.Truncate(max_seq_len)
|
||||
|
||||
截断输入序列,使其不超过最大长度。
|
||||
|
||||
参数:
|
||||
- **max_seq_len** (int) - 最大截断长度。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `max_seq_len` 的类型不是int。
|
||||
- **ValueError** - 如果 `max_seq_len` 的值小于或等于0。
|
||||
- **RuntimeError** - 如果输入张量的数据类型不是bool、int、float、double或者str。
|
|
@ -293,6 +293,7 @@ API样例中常用的导入模块如下:
|
|||
mindspore.dataset.text.SlidingWindow
|
||||
mindspore.dataset.text.ToNumber
|
||||
mindspore.dataset.text.ToVectors
|
||||
mindspore.dataset.text.Truncate
|
||||
mindspore.dataset.text.TruncateSequencePair
|
||||
mindspore.dataset.text.UnicodeCharTokenizer
|
||||
mindspore.dataset.text.UnicodeScriptTokenizer
|
||||
|
|
|
@ -171,6 +171,7 @@ Transforms
|
|||
mindspore.dataset.text.SlidingWindow
|
||||
mindspore.dataset.text.ToNumber
|
||||
mindspore.dataset.text.ToVectors
|
||||
mindspore.dataset.text.Truncate
|
||||
mindspore.dataset.text.TruncateSequencePair
|
||||
mindspore.dataset.text.UnicodeCharTokenizer
|
||||
mindspore.dataset.text.UnicodeScriptTokenizer
|
||||
|
|
|
@ -241,6 +241,16 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TruncateOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<text::TruncateOperation, TensorOperation, std::shared_ptr<text::TruncateOperation>>(
|
||||
*m, "TruncateOperation")
|
||||
.def(py::init([](int32_t max_seq_len) {
|
||||
auto truncate = std::make_shared<text::TruncateOperation>(max_seq_len);
|
||||
THROW_IF_ERROR(truncate->ValidateParams());
|
||||
return truncate;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TruncateSequencePairOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<text::TruncateSequencePairOperation, TensorOperation,
|
||||
std::shared_ptr<text::TruncateSequencePairOperation>>(
|
||||
|
|
|
@ -394,6 +394,16 @@ std::shared_ptr<TensorOperation> ToVectors::Parse() {
|
|||
return std::make_shared<ToVectorsOperation>(data_->vectors_, data_->unk_init_, data_->lower_case_backup_);
|
||||
}
|
||||
|
||||
// Truncate
|
||||
struct Truncate::Data {
|
||||
explicit Data(int32_t max_seq_len) : max_seq_len_(max_seq_len) {}
|
||||
int32_t max_seq_len_;
|
||||
};
|
||||
|
||||
Truncate::Truncate(int32_t max_seq_len) : data_(std::make_shared<Data>(max_seq_len)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> Truncate::Parse() { return std::make_shared<TruncateOperation>(data_->max_seq_len_); }
|
||||
|
||||
// TruncateSequencePair
|
||||
struct TruncateSequencePair::Data {
|
||||
explicit Data(int32_t max_length) : max_length_(max_length) {}
|
||||
|
|
|
@ -894,6 +894,35 @@ class DATASET_API ToVectors final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Truncate the input sequence so that it does not exceed the maximum length.
|
||||
class DATASET_API Truncate final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] max_seq_len Maximum allowable length.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define operations */
|
||||
/// auto truncate_op = text::Truncate(5);
|
||||
///
|
||||
/// /* dataset is an instance of Dataset object */
|
||||
/// dataset = dataset->Map({truncate_op}, // operations
|
||||
/// {"text"}); // input columns
|
||||
/// \endcode
|
||||
explicit Truncate(int32_t max_seq_len);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Truncate() = default;
|
||||
|
||||
protected:
|
||||
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to the TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Truncate a pair of rank-1 tensors such that the total length is less than max_length.
|
||||
class DATASET_API TruncateSequencePair final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -147,6 +147,7 @@ constexpr char kRegexReplaceOp[] = "RegexReplaceOp";
|
|||
constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
|
||||
constexpr char kToNumberOp[] = "ToNumberOp";
|
||||
constexpr char kToVectorsOp[] = "ToVectorsOp";
|
||||
constexpr char kTruncateOp[] = "TruncateOp";
|
||||
constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp";
|
||||
constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp";
|
||||
constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp";
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||
#include "minddata/dataset/text/kernels/to_number_op.h"
|
||||
#include "minddata/dataset/text/kernels/to_vectors_op.h"
|
||||
#include "minddata/dataset/text/kernels/truncate_op.h"
|
||||
#include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
|
||||
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
|
||||
|
@ -47,6 +48,7 @@
|
|||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/validators.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/text/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -583,6 +585,27 @@ std::shared_ptr<TensorOp> ToVectorsOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
// TruncateOperation
|
||||
TruncateOperation::TruncateOperation(int32_t max_seq_len) : max_seq_len_(max_seq_len) {}
|
||||
|
||||
Status TruncateOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("Truncate", "max_seq_len", max_seq_len_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> TruncateOperation::Build() {
|
||||
std::shared_ptr<TruncateOp> tensor_op = std::make_shared<TruncateOp>(max_seq_len_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status TruncateOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["max_seq_len"] = max_seq_len_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TruncateSequencePairOperation
|
||||
TruncateSequencePairOperation::TruncateSequencePairOperation(int32_t max_length) : max_length_(max_length) {}
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ constexpr char kSentencepieceTokenizerOperation[] = "SentencepieceTokenizer";
|
|||
constexpr char kSlidingWindowOperation[] = "SlidingWindow";
|
||||
constexpr char kToNumberOperation[] = "ToNumber";
|
||||
constexpr char kToVectorsOperation[] = "ToVectors";
|
||||
constexpr char kTruncateOperation[] = "Truncate";
|
||||
constexpr char kTruncateSequencePairOperation[] = "TruncateSequencePair";
|
||||
constexpr char kUnicodeCharTokenizerOperation[] = "UnicodeCharTokenizer";
|
||||
constexpr char kUnicodeScriptTokenizerOperation[] = "UnicodeScriptTokenizer";
|
||||
|
@ -353,6 +354,24 @@ class ToVectorsOperation : public TensorOperation {
|
|||
bool lower_case_backup_;
|
||||
};
|
||||
|
||||
class TruncateOperation : public TensorOperation {
|
||||
public:
|
||||
explicit TruncateOperation(int32_t max_seq_len);
|
||||
|
||||
~TruncateOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kTruncateOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int32_t max_seq_len_;
|
||||
};
|
||||
|
||||
class TruncateSequencePairOperation : public TensorOperation {
|
||||
public:
|
||||
explicit TruncateSequencePairOperation(int32_t max_length);
|
||||
|
|
|
@ -22,6 +22,7 @@ add_library(text-kernels OBJECT
|
|||
ngram_op.cc
|
||||
sliding_window_op.cc
|
||||
wordpiece_tokenizer_op.cc
|
||||
truncate_op.cc
|
||||
truncate_sequence_pair_op.cc
|
||||
to_number_op.cc
|
||||
to_vectors_op.cc
|
||||
|
|
|
@ -98,5 +98,16 @@ Status AddToken(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Truncate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int max_seq_len) {
|
||||
if (input->shape().Rank() == 1) {
|
||||
return input->Slice(output, {SliceOption(Slice(max_seq_len))});
|
||||
} else {
|
||||
int dim = input->shape()[0];
|
||||
Slice slice_dim = Slice(dim);
|
||||
Slice slice_len = Slice(max_seq_len);
|
||||
return input->Slice(output, {SliceOption(slice_dim), SliceOption(slice_len)});
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,6 +55,11 @@ Status AppendOffsetsHelper(const std::vector<uint32_t> &offsets_start, const std
|
|||
/// \return Status return code.
|
||||
Status AddToken(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::string &token,
|
||||
bool begin);
|
||||
|
||||
/// \brief Truncate the input sequence so that it does not exceed the maximum length.
|
||||
/// \param[in] max_seq_len Maximum allowable length.
|
||||
/// \param[out] output Output Tensor.
|
||||
Status Truncate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int max_seq_len);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/text/kernels/truncate_op.h"
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/data/slice_op.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status TruncateOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
constexpr int kMaxSeqRank = 2;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1 || input->shape().Rank() == kMaxSeqRank,
|
||||
"Truncate: the input tensor should be of dimension 1 or 2.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type() == DataType::DE_STRING || input->type().IsNumeric(),
|
||||
"Truncate: Truncate: the input tensor should be in type of [bool, int, float, double, string].");
|
||||
return Truncate(input, output, max_seq_len_);
|
||||
}
|
||||
|
||||
Status TruncateOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||
constexpr int kMaxSeqRank = 2;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(inputs[0].Rank() == 1 || inputs[0].Rank() == kMaxSeqRank,
|
||||
"Truncate: the input tensor should be of dimension 1 or 2.");
|
||||
if (inputs[0].Rank() == 1) {
|
||||
outputs.clear();
|
||||
auto shape = inputs[0].AsVector();
|
||||
int length = shape[0];
|
||||
shape[0] = std::min(length, max_seq_len_);
|
||||
outputs.emplace_back(TensorShape{shape});
|
||||
} else {
|
||||
outputs.clear();
|
||||
auto shape = inputs[0].AsVector();
|
||||
int length = shape[1];
|
||||
shape[1] = std::min(length, max_seq_len_);
|
||||
outputs.emplace_back(TensorShape{shape});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TRUNCATE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TRUNCATE_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class TruncateOp : public TensorOp {
|
||||
public:
|
||||
explicit TruncateOp(int32_t max_seq_len) : max_seq_len_(max_seq_len) {}
|
||||
|
||||
~TruncateOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kTruncateOp; }
|
||||
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
private:
|
||||
int32_t max_seq_len_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TRUNCATE_OP_H_
|
|
@ -73,7 +73,7 @@ import platform
|
|||
from . import transforms
|
||||
from . import utils
|
||||
from .transforms import AddToken, JiebaTokenizer, Lookup, Ngram, PythonTokenizer, SentencePieceTokenizer, \
|
||||
SlidingWindow, ToNumber, ToVectors, TruncateSequencePair, UnicodeCharTokenizer, WordpieceTokenizer
|
||||
SlidingWindow, ToNumber, ToVectors, Truncate, TruncateSequencePair, UnicodeCharTokenizer, WordpieceTokenizer
|
||||
from .utils import CharNGram, FastText, GloVe, JiebaMode, NormalizeForm, SentencePieceModel, SentencePieceVocab, \
|
||||
SPieceTokenizerLoadType, SPieceTokenizerOutType, Vectors, Vocab, to_bytes, to_str
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ from .validators import check_add_token, check_lookup, check_jieba_add_dict, che
|
|||
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \
|
||||
check_wordpiece_tokenizer, check_regex_replace, check_regex_tokenizer, check_basic_tokenizer, check_ngram, \
|
||||
check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow, \
|
||||
check_sentence_piece_tokenizer
|
||||
check_sentence_piece_tokenizer, check_truncate
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
from ..core.validator_helpers import replace_none
|
||||
from ..transforms.py_transforms_util import Implementation
|
||||
|
@ -620,6 +620,46 @@ class ToVectors(TextTensorOperation):
|
|||
return cde.ToVectorsOperation(self.vectors, self.unk_init, self.lower_case_backup)
|
||||
|
||||
|
||||
class Truncate(TextTensorOperation):
|
||||
"""
|
||||
Truncate the input sequence so that it does not exceed the maximum length.
|
||||
|
||||
Args:
|
||||
max_seq_len (int): Maximum allowable length.
|
||||
|
||||
Raises:
|
||||
TypeError: If `max_length_len` is not of type int.
|
||||
ValueError: If value of `max_length_len` is not greater than or equal to 0.
|
||||
RuntimeError: If the input tensor is not of dtype bool, int, float, double or str.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> dataset = ds.NumpySlicesDataset(data=[['a', 'b', 'c', 'd', 'e']], column_names=["text"], shuffle=False)
|
||||
>>> # Data before
|
||||
>>> # | col1 |
|
||||
>>> # +---------------------------+
|
||||
>>> # | ['a', 'b', 'c', 'd', 'e'] |
|
||||
>>> # +---------------------------+
|
||||
>>> truncate = text.Truncate(4)
|
||||
>>> dataset = dataset.map(operations=truncate, input_columns=["text"])
|
||||
>>> # Data after
|
||||
>>> # | col1 |
|
||||
>>> # +------------------------+
|
||||
>>> # | ['a', 'b', 'c', 'd'] |
|
||||
>>> # +------------------------+
|
||||
"""
|
||||
|
||||
@check_truncate
|
||||
def __init__(self, max_seq_len):
|
||||
super().__init__()
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def parse(self):
|
||||
return cde.TruncateOperation(self.max_seq_len)
|
||||
|
||||
|
||||
class TruncateSequencePair(TextTensorOperation):
|
||||
"""
|
||||
Truncate a pair of rank-1 tensors such that the total length is less than max_length.
|
||||
|
|
|
@ -456,6 +456,18 @@ def check_ngram(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_truncate(method):
|
||||
"""Wrapper method to check the parameters of number of truncate."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[max_seq_len], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_pos_int32(max_seq_len, "max_seq_len")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_pair_truncate(method):
|
||||
"""Wrapper method to check the parameters of number of pair truncate."""
|
||||
|
||||
|
|
|
@ -5295,3 +5295,85 @@ TEST_F(MindDataTestPipeline, TestAddTokenPipelineSuccess) {
|
|||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Truncate
|
||||
/// Description: Test Truncate basic usage max_seq_len less length
|
||||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestPipeline, TestTruncateSuccess1D) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTruncateSuccess1D.";
|
||||
// Testing basic Truncate
|
||||
|
||||
// Create a TextFile dataset
|
||||
std::string data_file = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create white_tokenizer operation on ds
|
||||
std::shared_ptr<TensorTransform> white_tokenizer = std::make_shared<text::WhitespaceTokenizer>();
|
||||
EXPECT_NE(white_tokenizer, nullptr);
|
||||
|
||||
// Create a truncate operation on ds
|
||||
std::shared_ptr<TensorTransform> truncate = std::make_shared<text::Truncate>(3);
|
||||
EXPECT_NE(truncate, nullptr);
|
||||
|
||||
// Create Map operation on ds
|
||||
ds = ds->Map({white_tokenizer, truncate}, {"text"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<std::string>> expected = {
|
||||
{"This", "is", "a"}, {"Be", "happy", "every"}, {"Good", "luck", "to"}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto ind = row["text"];
|
||||
|
||||
std::shared_ptr<Tensor> de_expected_tensor;
|
||||
ASSERT_OK(Tensor::CreateFromVector(expected[i], &de_expected_tensor));
|
||||
mindspore::MSTensor expected_tensor =
|
||||
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_tensor));
|
||||
EXPECT_MSTENSOR_EQ(ind, expected_tensor);
|
||||
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Truncate
|
||||
/// Description: Test the incorrect parameter of Truncate interface
|
||||
/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr
|
||||
TEST_F(MindDataTestPipeline, TestTruncateFail) {
|
||||
// Testing the incorrect parameter of Truncate interface.
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTruncateFail.";
|
||||
|
||||
// Create a TextFile dataset
|
||||
std::string data_file = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Testing the parameter max_seq_len less than 0
|
||||
// Create a truncate operation on ds
|
||||
std::shared_ptr<TensorTransform> truncate = std::make_shared<text::Truncate>(-1);
|
||||
EXPECT_NE(truncate, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({truncate});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid Truncate input (The parameter max_seq_len must be greater than 0)
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
|
|
@ -3038,3 +3038,18 @@ TEST_F(MindDataTestExecute, TestLFCCWrongArgsDctType) {
|
|||
Status status = trans(input_ms, &input_ms);
|
||||
EXPECT_FALSE(status.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: Truncate
|
||||
/// Description: Test basic usage of Truncate op
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestExecute, TestTruncateOpStr) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTruncateOpStr.";
|
||||
std::shared_ptr<Tensor> input;
|
||||
Tensor::CreateFromVector(std::vector<std::string>({"hello", "hhx", "hyx", "world", "this", "is"}), &input);
|
||||
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> truncate_op = std::make_shared<text::Truncate>(3);
|
||||
// apply Truncate
|
||||
mindspore::dataset::Execute trans({truncate_op});
|
||||
Status status = trans(input_ms, &input_ms);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing Truncate Python API
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.text as text
|
||||
|
||||
|
||||
def test_truncate_max_len_1d():
|
||||
"""
|
||||
Feature: Truncate op
|
||||
Description: Test Truncate op using 1D str as the input
|
||||
Expectation: Output is equal to the expected output
|
||||
"""
|
||||
truncate = text.Truncate(3)
|
||||
input1 = ["1", "2", "3", "4", "5"]
|
||||
result1 = truncate(input1)
|
||||
expect1 = (["1", "2", "3"])
|
||||
assert np.array_equal(result1, expect1)
|
||||
|
||||
|
||||
def test_truncate_max_len_2d():
|
||||
"""
|
||||
Feature: Truncate op
|
||||
Description: Test Truncate op using 2D int as the input
|
||||
Expectation: Output is equal to the expected output
|
||||
"""
|
||||
truncate = text.Truncate(3)
|
||||
input1 = [[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]]
|
||||
result1 = truncate(input1)
|
||||
expect1 = ([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
|
||||
assert np.array_equal(result1, expect1)
|
||||
|
||||
|
||||
def test_truncate_input_error():
|
||||
"""
|
||||
Feature:Truncate Op
|
||||
Description: Test input Error
|
||||
Expectation: Throw ValueError, TypeError or RuntimeError exception
|
||||
"""
|
||||
try:
|
||||
_ = text.Truncate(-1)
|
||||
except ValueError as error:
|
||||
logger.info("Got an exception in Truncate: {}".format(str(error)))
|
||||
assert "Input max_seq_len is not within the required interval of [1, 2147483647]." in str(
|
||||
error)
|
||||
try:
|
||||
_ = text.Truncate('a')
|
||||
except TypeError as error:
|
||||
logger.info("Got an exception in Truncate: {}".format(str(error)))
|
||||
assert "Argument max_seq_len with value a is not of type [<class 'int'>], but got <class 'str'>." in str(
|
||||
error)
|
||||
try:
|
||||
truncate = text.Truncate(2)
|
||||
input1 = [b'1', b'2', b'3', b'4', b'5']
|
||||
truncate(input1)
|
||||
except RuntimeError as error:
|
||||
logger.info("Got an exception in Truncate: {}".format(str(error)))
|
||||
assert "Truncate: Truncate: the input tensor should be in type of [bool, int, float, double, string]." in str(
|
||||
error)
|
||||
try:
|
||||
truncate = text.Truncate(2)
|
||||
input1 = [[[1, 2, 3]]]
|
||||
truncate(input1)
|
||||
except RuntimeError as error:
|
||||
logger.info("Got an exception in Truncate: {}".format(str(error)))
|
||||
assert "Truncate: the input tensor should be of dimension 1 or 2."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_truncate_max_len_1d()
|
||||
test_truncate_max_len_2d()
|
||||
test_truncate_input_error()
|
Loading…
Reference in New Issue