forked from mindspore-Ecosystem/mindspore
add SlidingWindow Op
This commit is contained in:
parent
d89cedb980
commit
4136892a3e
|
@ -77,6 +77,7 @@
|
|||
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/lookup_op.h"
|
||||
#include "minddata/dataset/text/kernels/ngram_op.h"
|
||||
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||
#include "minddata/dataset/text/kernels/to_number_op.h"
|
||||
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
|
||||
|
@ -640,6 +641,9 @@ void bindTokenizerOps(py::module *m) {
|
|||
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
|
||||
py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
|
||||
(void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>(
|
||||
*m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.")
|
||||
.def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis"));
|
||||
}
|
||||
|
||||
void bindDependIcuTokenizerOps(py::module *m) {
|
||||
|
|
|
@ -120,6 +120,7 @@ constexpr char kCaseFoldOp[] = "CaseFoldOp";
|
|||
constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp";
|
||||
constexpr char kLookupOp[] = "LookupOp";
|
||||
constexpr char kNgramOp[] = "NgramOp";
|
||||
constexpr char kSlidingWindowOp[] = "SlidingWindowOp";
|
||||
constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op";
|
||||
constexpr char kRegexReplaceOp[] = "RegexReplaceOp";
|
||||
constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
|
||||
|
|
|
@ -12,10 +12,12 @@ if (NOT (CMAKE_SYSTEM_NAME MATCHES "Windows"))
|
|||
whitespace_tokenizer_op.cc)
|
||||
endif()
|
||||
add_library(text-kernels OBJECT
|
||||
data_utils.cc
|
||||
lookup_op.cc
|
||||
jieba_tokenizer_op.cc
|
||||
unicode_char_tokenizer_op.cc
|
||||
ngram_op.cc
|
||||
sliding_window_op.cc
|
||||
wordpiece_tokenizer_op.cc
|
||||
truncate_sequence_pair_op.cc
|
||||
to_number_op.cc
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/pybind_support.h"
|
||||
#include "minddata/dataset/kernels/data/type_cast_op.h"
|
||||
#include "minddata/dataset/kernels/data/slice_op.h"
|
||||
#include "minddata/dataset/kernels/data/concatenate_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape,
|
||||
uint32_t width, int32_t axis) {
|
||||
// if the data row has fewer items than width, the corresponding result row will be empty
|
||||
if (out_shape.Size() == 0) {
|
||||
MS_LOG(WARNING) << "The data row has fewer items than width, the result will be empty.";
|
||||
if (input->type().value() == DataType::DE_STRING) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, std::vector<std::string>{}, TensorShape({0})));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, TensorShape({0}), input->type()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
axis = Tensor::HandleNeg(axis, input->shape().Size());
|
||||
int32_t axis_end = input->shape()[axis];
|
||||
std::shared_ptr<Tensor> tmp;
|
||||
auto concatenate_op = std::make_unique<ConcatenateOp>(axis, nullptr, nullptr);
|
||||
|
||||
// Slice on specified axis and concatenate on new axis
|
||||
for (int32_t i = 0; i + width <= axis_end; i++) {
|
||||
auto slice_op = std::make_unique<SliceOp>(Slice(i, i + width, 1));
|
||||
slice_op->Compute(input, &tmp);
|
||||
if (i == 0) {
|
||||
*output = tmp;
|
||||
} else {
|
||||
TensorRow in({*output, tmp});
|
||||
TensorRow out_row;
|
||||
concatenate_op->Compute(in, &out_row);
|
||||
*output = out_row[0];
|
||||
}
|
||||
}
|
||||
(*output)->Reshape(out_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
||||
#define DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief Helper method that perform sliding window on input tensor.
|
||||
/// \param[in] input - Input tensor.
|
||||
/// \param[in] out_shape - Output shape of output tensor.
|
||||
/// \param[in] width - The axis along which sliding window is computed.
|
||||
/// \param[in] axis - The width of the window.
|
||||
/// \param[out] output - Output tensor
|
||||
/// \return Status return code
|
||||
Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape,
|
||||
uint32_t width, int32_t axis);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status SlidingWindowOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SlidingWindosOp supports 1D Tensors only for now.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(axis_ == 0 || axis_ == -1, "axis supports 0 or -1 only for now.");
|
||||
|
||||
std::vector<TensorShape> input_shape = {input->shape()};
|
||||
std::vector<TensorShape> output_shape = {TensorShape({})};
|
||||
RETURN_IF_NOT_OK(OutputShape(input_shape, output_shape));
|
||||
|
||||
RETURN_IF_NOT_OK(SlidingWindowHelper(input, output, output_shape[0], width_, axis_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SlidingWindowOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n");
|
||||
int32_t axis = Tensor::HandleNeg(axis_, inputs[0].Size());
|
||||
TensorShape input_shape = inputs[0];
|
||||
std::vector<dsize_t> output_shape_initializer;
|
||||
|
||||
// if a data row has fewer items than width, the corresponding result row will be empty.
|
||||
if (input_shape[axis] >= width_) {
|
||||
for (int32_t idx = 0; idx < input_shape.Size(); ++idx) {
|
||||
if (idx != axis) {
|
||||
output_shape_initializer.push_back(input_shape[idx]);
|
||||
} else {
|
||||
output_shape_initializer.push_back(input_shape[idx] - (width_ - 1));
|
||||
output_shape_initializer.push_back(width_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outputs.pop_back();
|
||||
outputs.emplace_back(TensorShape(output_shape_initializer));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n");
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
|
||||
#define DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class SlidingWindowOp : public TensorOp {
|
||||
public:
|
||||
/// \brief Constructor of SlidingWindowOp.
|
||||
/// \param[in] width - The axis along which sliding window is computed.
|
||||
/// \param[in] axis - The width of the window.
|
||||
/// \return Status return code
|
||||
explicit SlidingWindowOp(uint32_t width, int32_t axis = 0) : width_(width), axis_(axis) {}
|
||||
|
||||
/// \brief Destructor of SlidingWindowOp.
|
||||
~SlidingWindowOp() override = default;
|
||||
|
||||
/// \brief Perform sliding window to tensor.
|
||||
/// \param[in] input - Input tensor of Op.
|
||||
/// \param[out] output - output tensor of Op.
|
||||
/// \return Status return code
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
/// \brief Calculate tensor shape for output tensor.
|
||||
/// \param[in] inputs - Input tensor shapes.
|
||||
/// \param[out] outputs - Output tensor shapes.
|
||||
/// \return Status return code
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
/// \brief Print args for debugging.
|
||||
/// \param[in] out - std::ostream &out.
|
||||
void Print(std::ostream &out) const override { out << "SliceWindowOp"; }
|
||||
|
||||
/// \brief Print name of op.
|
||||
std::string Name() const override { return kSlidingWindowOp; }
|
||||
|
||||
private:
|
||||
uint32_t width_; // The width of the window. Must be an integer and greater than zero.
|
||||
int32_t axis_; // The axis along which sliding window is computed, only support 0/-1 for now.
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
|
|
@ -19,13 +19,13 @@ utils provides some general methods for nlp text processing.
|
|||
"""
|
||||
import platform
|
||||
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \
|
||||
ToNumber
|
||||
ToNumber, SlidingWindow
|
||||
from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm
|
||||
|
||||
__all__ = [
|
||||
"Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram",
|
||||
"to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber",
|
||||
"PythonTokenizer"
|
||||
"PythonTokenizer", "SlidingWindow"
|
||||
]
|
||||
|
||||
if platform.system().lower() != 'windows':
|
||||
|
|
|
@ -54,7 +54,7 @@ from .utils import JiebaMode, NormalizeForm, to_str
|
|||
from .validators import check_lookup, check_jieba_add_dict, \
|
||||
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\
|
||||
check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\
|
||||
check_to_number, check_bert_tokenizer, check_python_tokenizer
|
||||
check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
|
||||
|
||||
|
@ -72,6 +72,34 @@ class Lookup(cde.LookupOp):
|
|||
def __init__(self, vocab, unknown_token=None):
|
||||
super().__init__(vocab, unknown_token)
|
||||
|
||||
class SlidingWindow(cde.SlidingWindowOp):
|
||||
"""
|
||||
TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis
|
||||
is a slice of data starting at the corresponding position, with a specified width.
|
||||
|
||||
Args:
|
||||
width (int): The width of the window. Must be an integer and greater than zero.
|
||||
axis (int, optional): The axis along which sliding window is computed (default=0).
|
||||
|
||||
Examples:
|
||||
>>> # Data before
|
||||
>>> # | col1 |
|
||||
>>> # +-------------+
|
||||
>>> # | [1,2,3,4,5] |
|
||||
>>> # +-------------+
|
||||
>>> data = data.map(operations=SlidingWindow(3, 0))
|
||||
>>> # Data after
|
||||
>>> # | col1 |
|
||||
>>> # +-------------+
|
||||
>>> # | [[1,2,3], |
|
||||
>>> # | [2,3,4], |
|
||||
>>> # | [3,4,5]] |
|
||||
>>> # +--------------+
|
||||
"""
|
||||
|
||||
@check_slidingwindow
|
||||
def __init__(self, width, axis=0):
|
||||
super().__init__(width=width, axis=axis)
|
||||
|
||||
class Ngram(cde.NgramOp):
|
||||
"""
|
||||
|
|
|
@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde
|
|||
from mindspore._c_expression import typing
|
||||
|
||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
|
||||
INT32_MAX, check_value, check_positive
|
||||
INT32_MAX, check_value, check_positive, check_pos_int32
|
||||
|
||||
|
||||
def check_unique_list_of_words(words, arg_name):
|
||||
|
@ -328,6 +328,17 @@ def check_from_dataset(method):
|
|||
|
||||
return new_method
|
||||
|
||||
def check_slidingwindow(method):
|
||||
"""A wrapper that wrap a parameter checker to the original function(sliding window operation)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[width, axis], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_pos_int32(width, "width")
|
||||
type_check(axis, (int,), "axis")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
def check_ngram(method):
|
||||
"""A wrapper that wraps a parameter checker to the original function."""
|
||||
|
|
|
@ -92,6 +92,7 @@ SET(DE_UT_SRCS
|
|||
perf_data_test.cc
|
||||
c_api_test.cc
|
||||
tensor_op_fusion_pass_test.cc
|
||||
sliding_window_op_test.cc
|
||||
)
|
||||
|
||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestSlidingWindowOp : public UT::Common {
|
||||
protected:
|
||||
MindDataTestSlidingWindowOp() {}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestSlidingWindowOp, Compute) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->Compute.";
|
||||
std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"};
|
||||
TensorShape shape({static_cast<dsize_t>(strings.size())});
|
||||
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
|
||||
std::shared_ptr<Tensor> output;
|
||||
|
||||
std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0));
|
||||
Status s = op->Compute(input, &output);
|
||||
|
||||
std::vector<std::string> out = {"one", "two", "three", "two", "three", "four", "three", "four", "five",
|
||||
"four", "five", "six", "five", "six", "seven", "six", "seven", "eight"};
|
||||
std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(out, TensorShape({6, 3}));
|
||||
|
||||
ASSERT_TRUE(output->shape() == expected->shape());
|
||||
ASSERT_TRUE(output->type() == expected->type());
|
||||
MS_LOG(DEBUG) << *output << std::endl;
|
||||
MS_LOG(DEBUG) << *expected << std::endl;
|
||||
ASSERT_TRUE(*output == *expected);
|
||||
|
||||
MS_LOG(INFO) << "MindDataTestSlidingWindowOp end.";
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestSlidingWindowOp, OutputShape) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->OutputShape.";
|
||||
std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"};
|
||||
TensorShape shape({static_cast<dsize_t>(strings.size())});
|
||||
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
|
||||
std::vector<TensorShape> input_shape = {input->shape()};
|
||||
std::vector<TensorShape> output_shape = {TensorShape({})};
|
||||
|
||||
std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0));
|
||||
Status s = op->OutputShape(input_shape, output_shape);
|
||||
|
||||
MS_LOG(DEBUG) << "input_shape" << input_shape[0];
|
||||
MS_LOG(DEBUG) << "output_shape" << output_shape[0];
|
||||
ASSERT_TRUE(output_shape[0] == TensorShape({6, 3}));
|
||||
|
||||
MS_LOG(INFO) << "MindDataTestSlidingWindowOp end.";
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing SlidingWindow in mindspore.dataset
|
||||
"""
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
|
||||
def test_sliding_window_string():
|
||||
""" test sliding_window with string type"""
|
||||
inputs = [["大", "家", "早", "上", "好"]]
|
||||
expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']])
|
||||
|
||||
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
||||
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0))
|
||||
|
||||
result = []
|
||||
for data in dataset.create_dict_iterator():
|
||||
for i in range(data['text'].shape[0]):
|
||||
result.append([])
|
||||
for j in range(data['text'].shape[1]):
|
||||
result[i].append(data['text'][i][j].decode('utf8'))
|
||||
result = np.array(result)
|
||||
np.testing.assert_array_equal(result, expect)
|
||||
|
||||
def test_sliding_window_number():
|
||||
inputs = [1]
|
||||
expect = np.array([[1]])
|
||||
|
||||
def gen(nums):
|
||||
yield (np.array(nums),)
|
||||
|
||||
dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"])
|
||||
dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(1, -1))
|
||||
|
||||
for data in dataset.create_dict_iterator():
|
||||
np.testing.assert_array_equal(data['number'], expect)
|
||||
|
||||
def test_sliding_window_big_width():
|
||||
inputs = [[1, 2, 3, 4, 5]]
|
||||
expect = np.array([])
|
||||
|
||||
dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False)
|
||||
dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(30, 0))
|
||||
|
||||
for data in dataset.create_dict_iterator():
|
||||
np.testing.assert_array_equal(data['number'], expect)
|
||||
|
||||
def test_sliding_window_exception():
|
||||
try:
|
||||
_ = text.SlidingWindow(0, 0)
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
_ = text.SlidingWindow("1", 0)
|
||||
assert False
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
_ = text.SlidingWindow(1, "0")
|
||||
assert False
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
inputs = [[1, 2, 3, 4, 5]]
|
||||
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
||||
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(3, -100))
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "axis supports 0 or -1 only for now." in str(e)
|
||||
|
||||
try:
|
||||
inputs = ["aa", "bb", "cc"]
|
||||
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
||||
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0))
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "SlidingWindosOp supports 1D Tensors only for now." in str(e)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sliding_window_string()
|
||||
test_sliding_window_number()
|
||||
test_sliding_window_big_width()
|
||||
test_sliding_window_exception()
|
Loading…
Reference in New Issue