From 4136892a3ec2e92f0bc2f744ba87fa29f61e8f6d Mon Sep 17 00:00:00 2001 From: YangLuo Date: Wed, 1 Jul 2020 14:34:57 +0800 Subject: [PATCH] add SlidingWindow Op --- .../minddata/dataset/api/python_bindings.cc | 4 + .../minddata/dataset/kernels/tensor_op.h | 1 + .../dataset/text/kernels/CMakeLists.txt | 2 + .../dataset/text/kernels/data_utils.cc | 66 +++++++++++ .../dataset/text/kernels/data_utils.h | 43 +++++++ .../dataset/text/kernels/sliding_window_op.cc | 57 ++++++++++ .../dataset/text/kernels/sliding_window_op.h | 68 ++++++++++++ mindspore/dataset/text/__init__.py | 4 +- mindspore/dataset/text/transforms.py | 30 ++++- mindspore/dataset/text/validators.py | 13 ++- tests/ut/cpp/dataset/CMakeLists.txt | 1 + .../ut/cpp/dataset/sliding_window_op_test.cc | 69 ++++++++++++ .../ut/python/dataset/test_sliding_window.py | 105 ++++++++++++++++++ 13 files changed, 459 insertions(+), 4 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc create mode 100644 mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.h create mode 100644 mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.h create mode 100644 tests/ut/cpp/dataset/sliding_window_op_test.cc create mode 100644 tests/ut/python/dataset/test_sliding_window.py diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc index 145291ec3be..36741637d13 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -77,6 +77,7 @@ #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" #include "minddata/dataset/text/kernels/lookup_op.h" #include "minddata/dataset/text/kernels/ngram_op.h" +#include "minddata/dataset/text/kernels/sliding_window_op.h" #include "minddata/dataset/text/kernels/to_number_op.h" #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" @@ -640,6 +641,9 @@ void bindTokenizerOps(py::module *m) { py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); + (void)py::class_>( + *m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.") + .def(py::init(), py::arg("width"), py::arg("axis")); } void bindDependIcuTokenizerOps(py::module *m) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 3bcba4b4630..d4f5abc4b69 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -120,6 +120,7 @@ constexpr char kCaseFoldOp[] = "CaseFoldOp"; constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; constexpr char kLookupOp[] = "LookupOp"; constexpr char kNgramOp[] = "NgramOp"; +constexpr char kSlidingWindowOp[] = "SlidingWindowOp"; constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt index 449bb93d8b9..a932a2089eb 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt @@ -12,10 +12,12 @@ if (NOT (CMAKE_SYSTEM_NAME MATCHES "Windows")) whitespace_tokenizer_op.cc) endif() add_library(text-kernels OBJECT + data_utils.cc lookup_op.cc jieba_tokenizer_op.cc unicode_char_tokenizer_op.cc ngram_op.cc + sliding_window_op.cc wordpiece_tokenizer_op.cc truncate_sequence_pair_op.cc to_number_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc new file mode 100644 index 00000000000..74b1d930775 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/data_utils.h" + +#include +#include +#include +#include + +#include "minddata/dataset/core/pybind_support.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/data/slice_op.h" +#include "minddata/dataset/kernels/data/concatenate_op.h" + +namespace mindspore { +namespace dataset { +Status SlidingWindowHelper(const std::shared_ptr &input, std::shared_ptr *output, TensorShape out_shape, + uint32_t width, int32_t axis) { + // if the data row has fewer items than width, the corresponding result row will be empty + if (out_shape.Size() == 0) { + MS_LOG(WARNING) << "The data row has fewer items than width, the result will be empty."; + if (input->type().value() == DataType::DE_STRING) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, std::vector{}, TensorShape({0}))); + } else { + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, TensorShape({0}), input->type())); + } + return Status::OK(); + } + + axis = Tensor::HandleNeg(axis, input->shape().Size()); + int32_t axis_end = input->shape()[axis]; + std::shared_ptr tmp; + auto concatenate_op = std::make_unique(axis, nullptr, nullptr); + + // Slice on specified axis and concatenate on new axis + for (int32_t i = 0; i + width <= axis_end; i++) { + auto slice_op = std::make_unique(Slice(i, i + width, 1)); + slice_op->Compute(input, &tmp); + if (i == 0) { + *output = tmp; + } else { + TensorRow in({*output, tmp}); + TensorRow out_row; + concatenate_op->Compute(in, &out_row); + *output = out_row[0]; + } + } + (*output)->Reshape(out_shape); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.h b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.h new file mode 100644 index 00000000000..2af69cd3d6f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_TEXT_DATA_UTILS_H_ +#define DATASET_KERNELS_TEXT_DATA_UTILS_H_ + +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { +/// \brief Helper method that perform sliding window on input tensor. +/// \param[in] input - Input tensor. +/// \param[in] out_shape - Output shape of output tensor. +/// \param[in] width - The axis along which sliding window is computed. +/// \param[in] axis - The width of the window. +/// \param[out] output - Output tensor +/// \return Status return code +Status SlidingWindowHelper(const std::shared_ptr &input, std::shared_ptr *output, TensorShape out_shape, + uint32_t width, int32_t axis); +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_TEXT_DATA_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.cc new file mode 100644 index 00000000000..f857f1ab966 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/sliding_window_op.h" + +namespace mindspore { +namespace dataset { +Status SlidingWindowOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SlidingWindosOp supports 1D Tensors only for now."); + CHECK_FAIL_RETURN_UNEXPECTED(axis_ == 0 || axis_ == -1, "axis supports 0 or -1 only for now."); + + std::vector input_shape = {input->shape()}; + std::vector output_shape = {TensorShape({})}; + RETURN_IF_NOT_OK(OutputShape(input_shape, output_shape)); + + RETURN_IF_NOT_OK(SlidingWindowHelper(input, output, output_shape[0], width_, axis_)); + return Status::OK(); +} + +Status SlidingWindowOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n"); + int32_t axis = Tensor::HandleNeg(axis_, inputs[0].Size()); + TensorShape input_shape = inputs[0]; + std::vector output_shape_initializer; + + // if a data row has fewer items than width, the corresponding result row will be empty. + if (input_shape[axis] >= width_) { + for (int32_t idx = 0; idx < input_shape.Size(); ++idx) { + if (idx != axis) { + output_shape_initializer.push_back(input_shape[idx]); + } else { + output_shape_initializer.push_back(input_shape[idx] - (width_ - 1)); + output_shape_initializer.push_back(width_); + } + } + } + + outputs.pop_back(); + outputs.emplace_back(TensorShape(output_shape_initializer)); + CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n"); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.h new file mode 100644 index 00000000000..a9340d12bd0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ +#define DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/kernels/data_utils.h" + +namespace mindspore { +namespace dataset { + +class SlidingWindowOp : public TensorOp { + public: + /// \brief Constructor of SlidingWindowOp. + /// \param[in] width - The axis along which sliding window is computed. + /// \param[in] axis - The width of the window. + /// \return Status return code + explicit SlidingWindowOp(uint32_t width, int32_t axis = 0) : width_(width), axis_(axis) {} + + /// \brief Destructor of SlidingWindowOp. + ~SlidingWindowOp() override = default; + + /// \brief Perform sliding window to tensor. + /// \param[in] input - Input tensor of Op. + /// \param[out] output - output tensor of Op. + /// \return Status return code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + /// \brief Calculate tensor shape for output tensor. + /// \param[in] inputs - Input tensor shapes. + /// \param[out] outputs - Output tensor shapes. + /// \return Status return code + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + /// \brief Print args for debugging. + /// \param[in] out - std::ostream &out. + void Print(std::ostream &out) const override { out << "SliceWindowOp"; } + + /// \brief Print name of op. + std::string Name() const override { return kSlidingWindowOp; } + + private: + uint32_t width_; // The width of the window. Must be an integer and greater than zero. + int32_t axis_; // The axis along which sliding window is computed, only support 0/-1 for now. +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ diff --git a/mindspore/dataset/text/__init__.py b/mindspore/dataset/text/__init__.py index 04eb90a0b6d..048f345cfab 100644 --- a/mindspore/dataset/text/__init__.py +++ b/mindspore/dataset/text/__init__.py @@ -19,13 +19,13 @@ utils provides some general methods for nlp text processing. """ import platform from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \ - ToNumber + ToNumber, SlidingWindow from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm __all__ = [ "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", "to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber", - "PythonTokenizer" + "PythonTokenizer", "SlidingWindow" ] if platform.system().lower() != 'windows': diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 30fa2b8f429..7f60f05107b 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -54,7 +54,7 @@ from .utils import JiebaMode, NormalizeForm, to_str from .validators import check_lookup, check_jieba_add_dict, \ check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\ check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\ - check_to_number, check_bert_tokenizer, check_python_tokenizer + check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow from ..core.datatypes import mstype_to_detype @@ -72,6 +72,34 @@ class Lookup(cde.LookupOp): def __init__(self, vocab, unknown_token=None): super().__init__(vocab, unknown_token) +class SlidingWindow(cde.SlidingWindowOp): + """ + TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis + is a slice of data starting at the corresponding position, with a specified width. + + Args: + width (int): The width of the window. Must be an integer and greater than zero. + axis (int, optional): The axis along which sliding window is computed (default=0). + + Examples: + >>> # Data before + >>> # | col1 | + >>> # +-------------+ + >>> # | [1,2,3,4,5] | + >>> # +-------------+ + >>> data = data.map(operations=SlidingWindow(3, 0)) + >>> # Data after + >>> # | col1 | + >>> # +-------------+ + >>> # | [[1,2,3], | + >>> # | [2,3,4], | + >>> # | [3,4,5]] | + >>> # +--------------+ + """ + + @check_slidingwindow + def __init__(self, width, axis=0): + super().__init__(width=width, axis=axis) class Ngram(cde.NgramOp): """ diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index b0327f5609c..71f48a1238a 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde from mindspore._c_expression import typing from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ - INT32_MAX, check_value, check_positive + INT32_MAX, check_value, check_positive, check_pos_int32 def check_unique_list_of_words(words, arg_name): @@ -328,6 +328,17 @@ def check_from_dataset(method): return new_method +def check_slidingwindow(method): + """A wrapper that wrap a parameter checker to the original function(sliding window operation).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [width, axis], _ = parse_user_args(method, *args, **kwargs) + check_pos_int32(width, "width") + type_check(axis, (int,), "axis") + return method(self, *args, **kwargs) + + return new_method def check_ngram(method): """A wrapper that wraps a parameter checker to the original function.""" diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 8bbf42a6404..084bd05ab41 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -92,6 +92,7 @@ SET(DE_UT_SRCS perf_data_test.cc c_api_test.cc tensor_op_fusion_pass_test.cc + sliding_window_op_test.cc ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/sliding_window_op_test.cc b/tests/ut/cpp/dataset/sliding_window_op_test.cc new file mode 100644 index 00000000000..7020229d9af --- /dev/null +++ b/tests/ut/cpp/dataset/sliding_window_op_test.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/text/kernels/sliding_window_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestSlidingWindowOp : public UT::Common { + protected: + MindDataTestSlidingWindowOp() {} +}; + +TEST_F(MindDataTestSlidingWindowOp, Compute) { + MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->Compute."; + std::vector strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; + TensorShape shape({static_cast(strings.size())}); + std::shared_ptr input = std::make_shared(strings, shape); + std::shared_ptr output; + + std::unique_ptr op(new SlidingWindowOp(3, 0)); + Status s = op->Compute(input, &output); + + std::vector out = {"one", "two", "three", "two", "three", "four", "three", "four", "five", + "four", "five", "six", "five", "six", "seven", "six", "seven", "eight"}; + std::shared_ptr expected = std::make_shared(out, TensorShape({6, 3})); + + ASSERT_TRUE(output->shape() == expected->shape()); + ASSERT_TRUE(output->type() == expected->type()); + MS_LOG(DEBUG) << *output << std::endl; + MS_LOG(DEBUG) << *expected << std::endl; + ASSERT_TRUE(*output == *expected); + + MS_LOG(INFO) << "MindDataTestSlidingWindowOp end."; +} + +TEST_F(MindDataTestSlidingWindowOp, OutputShape) { + MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->OutputShape."; + std::vector strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; + TensorShape shape({static_cast(strings.size())}); + std::shared_ptr input = std::make_shared(strings, shape); + std::vector input_shape = {input->shape()}; + std::vector output_shape = {TensorShape({})}; + + std::unique_ptr op(new SlidingWindowOp(3, 0)); + Status s = op->OutputShape(input_shape, output_shape); + + MS_LOG(DEBUG) << "input_shape" << input_shape[0]; + MS_LOG(DEBUG) << "output_shape" << output_shape[0]; + ASSERT_TRUE(output_shape[0] == TensorShape({6, 3})); + + MS_LOG(INFO) << "MindDataTestSlidingWindowOp end."; +} diff --git a/tests/ut/python/dataset/test_sliding_window.py b/tests/ut/python/dataset/test_sliding_window.py new file mode 100644 index 00000000000..4fdd7a25c07 --- /dev/null +++ b/tests/ut/python/dataset/test_sliding_window.py @@ -0,0 +1,105 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing SlidingWindow in mindspore.dataset +""" +import numpy as np +import mindspore.dataset as ds +import mindspore.dataset.text as text + +def test_sliding_window_string(): + """ test sliding_window with string type""" + inputs = [["大", "家", "早", "上", "好"]] + expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']]) + + dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) + dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0)) + + result = [] + for data in dataset.create_dict_iterator(): + for i in range(data['text'].shape[0]): + result.append([]) + for j in range(data['text'].shape[1]): + result[i].append(data['text'][i][j].decode('utf8')) + result = np.array(result) + np.testing.assert_array_equal(result, expect) + +def test_sliding_window_number(): + inputs = [1] + expect = np.array([[1]]) + + def gen(nums): + yield (np.array(nums),) + + dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"]) + dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(1, -1)) + + for data in dataset.create_dict_iterator(): + np.testing.assert_array_equal(data['number'], expect) + +def test_sliding_window_big_width(): + inputs = [[1, 2, 3, 4, 5]] + expect = np.array([]) + + dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False) + dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(30, 0)) + + for data in dataset.create_dict_iterator(): + np.testing.assert_array_equal(data['number'], expect) + +def test_sliding_window_exception(): + try: + _ = text.SlidingWindow(0, 0) + assert False + except ValueError: + pass + + try: + _ = text.SlidingWindow("1", 0) + assert False + except TypeError: + pass + + try: + _ = text.SlidingWindow(1, "0") + assert False + except TypeError: + pass + + try: + inputs = [[1, 2, 3, 4, 5]] + dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) + dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(3, -100)) + for _ in dataset.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "axis supports 0 or -1 only for now." in str(e) + + try: + inputs = ["aa", "bb", "cc"] + dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) + dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0)) + for _ in dataset.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "SlidingWindosOp supports 1D Tensors only for now." in str(e) + +if __name__ == '__main__': + test_sliding_window_string() + test_sliding_window_number() + test_sliding_window_big_width() + test_sliding_window_exception()