forked from mindspore-Ecosystem/mindspore
add SlidingWindow Op
This commit is contained in:
parent
d89cedb980
commit
4136892a3e
|
@ -77,6 +77,7 @@
|
||||||
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
|
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
|
||||||
#include "minddata/dataset/text/kernels/lookup_op.h"
|
#include "minddata/dataset/text/kernels/lookup_op.h"
|
||||||
#include "minddata/dataset/text/kernels/ngram_op.h"
|
#include "minddata/dataset/text/kernels/ngram_op.h"
|
||||||
|
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||||
#include "minddata/dataset/text/kernels/to_number_op.h"
|
#include "minddata/dataset/text/kernels/to_number_op.h"
|
||||||
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
|
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||||
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
|
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
|
||||||
|
@ -640,6 +641,9 @@ void bindTokenizerOps(py::module *m) {
|
||||||
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||||
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
|
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
|
||||||
py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
|
py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
|
||||||
|
(void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>(
|
||||||
|
*m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.")
|
||||||
|
.def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void bindDependIcuTokenizerOps(py::module *m) {
|
void bindDependIcuTokenizerOps(py::module *m) {
|
||||||
|
|
|
@ -120,6 +120,7 @@ constexpr char kCaseFoldOp[] = "CaseFoldOp";
|
||||||
constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp";
|
constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp";
|
||||||
constexpr char kLookupOp[] = "LookupOp";
|
constexpr char kLookupOp[] = "LookupOp";
|
||||||
constexpr char kNgramOp[] = "NgramOp";
|
constexpr char kNgramOp[] = "NgramOp";
|
||||||
|
constexpr char kSlidingWindowOp[] = "SlidingWindowOp";
|
||||||
constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op";
|
constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op";
|
||||||
constexpr char kRegexReplaceOp[] = "RegexReplaceOp";
|
constexpr char kRegexReplaceOp[] = "RegexReplaceOp";
|
||||||
constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
|
constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
|
||||||
|
|
|
@ -12,10 +12,12 @@ if (NOT (CMAKE_SYSTEM_NAME MATCHES "Windows"))
|
||||||
whitespace_tokenizer_op.cc)
|
whitespace_tokenizer_op.cc)
|
||||||
endif()
|
endif()
|
||||||
add_library(text-kernels OBJECT
|
add_library(text-kernels OBJECT
|
||||||
|
data_utils.cc
|
||||||
lookup_op.cc
|
lookup_op.cc
|
||||||
jieba_tokenizer_op.cc
|
jieba_tokenizer_op.cc
|
||||||
unicode_char_tokenizer_op.cc
|
unicode_char_tokenizer_op.cc
|
||||||
ngram_op.cc
|
ngram_op.cc
|
||||||
|
sliding_window_op.cc
|
||||||
wordpiece_tokenizer_op.cc
|
wordpiece_tokenizer_op.cc
|
||||||
truncate_sequence_pair_op.cc
|
truncate_sequence_pair_op.cc
|
||||||
to_number_op.cc
|
to_number_op.cc
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/pybind_support.h"
|
||||||
|
#include "minddata/dataset/kernels/data/type_cast_op.h"
|
||||||
|
#include "minddata/dataset/kernels/data/slice_op.h"
|
||||||
|
#include "minddata/dataset/kernels/data/concatenate_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape,
|
||||||
|
uint32_t width, int32_t axis) {
|
||||||
|
// if the data row has fewer items than width, the corresponding result row will be empty
|
||||||
|
if (out_shape.Size() == 0) {
|
||||||
|
MS_LOG(WARNING) << "The data row has fewer items than width, the result will be empty.";
|
||||||
|
if (input->type().value() == DataType::DE_STRING) {
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, std::vector<std::string>{}, TensorShape({0})));
|
||||||
|
} else {
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, TensorShape({0}), input->type()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
axis = Tensor::HandleNeg(axis, input->shape().Size());
|
||||||
|
int32_t axis_end = input->shape()[axis];
|
||||||
|
std::shared_ptr<Tensor> tmp;
|
||||||
|
auto concatenate_op = std::make_unique<ConcatenateOp>(axis, nullptr, nullptr);
|
||||||
|
|
||||||
|
// Slice on specified axis and concatenate on new axis
|
||||||
|
for (int32_t i = 0; i + width <= axis_end; i++) {
|
||||||
|
auto slice_op = std::make_unique<SliceOp>(Slice(i, i + width, 1));
|
||||||
|
slice_op->Compute(input, &tmp);
|
||||||
|
if (i == 0) {
|
||||||
|
*output = tmp;
|
||||||
|
} else {
|
||||||
|
TensorRow in({*output, tmp});
|
||||||
|
TensorRow out_row;
|
||||||
|
concatenate_op->Compute(in, &out_row);
|
||||||
|
*output = out_row[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(*output)->Reshape(out_shape);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
||||||
|
#define DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
#include "minddata/dataset/core/constants.h"
|
||||||
|
#include "minddata/dataset/core/data_type.h"
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/core/cv_tensor.h"
|
||||||
|
#include "minddata/dataset/core/tensor_shape.h"
|
||||||
|
#include "minddata/dataset/core/tensor_row.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
/// \brief Helper method that perform sliding window on input tensor.
|
||||||
|
/// \param[in] input - Input tensor.
|
||||||
|
/// \param[in] out_shape - Output shape of output tensor.
|
||||||
|
/// \param[in] width - The axis along which sliding window is computed.
|
||||||
|
/// \param[in] axis - The width of the window.
|
||||||
|
/// \param[out] output - Output tensor
|
||||||
|
/// \return Status return code
|
||||||
|
Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape,
|
||||||
|
uint32_t width, int32_t axis);
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
Status SlidingWindowOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SlidingWindosOp supports 1D Tensors only for now.");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(axis_ == 0 || axis_ == -1, "axis supports 0 or -1 only for now.");
|
||||||
|
|
||||||
|
std::vector<TensorShape> input_shape = {input->shape()};
|
||||||
|
std::vector<TensorShape> output_shape = {TensorShape({})};
|
||||||
|
RETURN_IF_NOT_OK(OutputShape(input_shape, output_shape));
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(SlidingWindowHelper(input, output, output_shape[0], width_, axis_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SlidingWindowOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n");
|
||||||
|
int32_t axis = Tensor::HandleNeg(axis_, inputs[0].Size());
|
||||||
|
TensorShape input_shape = inputs[0];
|
||||||
|
std::vector<dsize_t> output_shape_initializer;
|
||||||
|
|
||||||
|
// if a data row has fewer items than width, the corresponding result row will be empty.
|
||||||
|
if (input_shape[axis] >= width_) {
|
||||||
|
for (int32_t idx = 0; idx < input_shape.Size(); ++idx) {
|
||||||
|
if (idx != axis) {
|
||||||
|
output_shape_initializer.push_back(input_shape[idx]);
|
||||||
|
} else {
|
||||||
|
output_shape_initializer.push_back(input_shape[idx] - (width_ - 1));
|
||||||
|
output_shape_initializer.push_back(width_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs.pop_back();
|
||||||
|
outputs.emplace_back(TensorShape(output_shape_initializer));
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
|
||||||
|
#define DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class SlidingWindowOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor of SlidingWindowOp.
|
||||||
|
/// \param[in] width - The axis along which sliding window is computed.
|
||||||
|
/// \param[in] axis - The width of the window.
|
||||||
|
/// \return Status return code
|
||||||
|
explicit SlidingWindowOp(uint32_t width, int32_t axis = 0) : width_(width), axis_(axis) {}
|
||||||
|
|
||||||
|
/// \brief Destructor of SlidingWindowOp.
|
||||||
|
~SlidingWindowOp() override = default;
|
||||||
|
|
||||||
|
/// \brief Perform sliding window to tensor.
|
||||||
|
/// \param[in] input - Input tensor of Op.
|
||||||
|
/// \param[out] output - output tensor of Op.
|
||||||
|
/// \return Status return code
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
/// \brief Calculate tensor shape for output tensor.
|
||||||
|
/// \param[in] inputs - Input tensor shapes.
|
||||||
|
/// \param[out] outputs - Output tensor shapes.
|
||||||
|
/// \return Status return code
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
/// \brief Print args for debugging.
|
||||||
|
/// \param[in] out - std::ostream &out.
|
||||||
|
void Print(std::ostream &out) const override { out << "SliceWindowOp"; }
|
||||||
|
|
||||||
|
/// \brief Print name of op.
|
||||||
|
std::string Name() const override { return kSlidingWindowOp; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint32_t width_; // The width of the window. Must be an integer and greater than zero.
|
||||||
|
int32_t axis_; // The axis along which sliding window is computed, only support 0/-1 for now.
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
|
|
@ -19,13 +19,13 @@ utils provides some general methods for nlp text processing.
|
||||||
"""
|
"""
|
||||||
import platform
|
import platform
|
||||||
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \
|
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \
|
||||||
ToNumber
|
ToNumber, SlidingWindow
|
||||||
from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm
|
from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram",
|
"Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram",
|
||||||
"to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber",
|
"to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber",
|
||||||
"PythonTokenizer"
|
"PythonTokenizer", "SlidingWindow"
|
||||||
]
|
]
|
||||||
|
|
||||||
if platform.system().lower() != 'windows':
|
if platform.system().lower() != 'windows':
|
||||||
|
|
|
@ -54,7 +54,7 @@ from .utils import JiebaMode, NormalizeForm, to_str
|
||||||
from .validators import check_lookup, check_jieba_add_dict, \
|
from .validators import check_lookup, check_jieba_add_dict, \
|
||||||
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\
|
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\
|
||||||
check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\
|
check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\
|
||||||
check_to_number, check_bert_tokenizer, check_python_tokenizer
|
check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow
|
||||||
from ..core.datatypes import mstype_to_detype
|
from ..core.datatypes import mstype_to_detype
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,6 +72,34 @@ class Lookup(cde.LookupOp):
|
||||||
def __init__(self, vocab, unknown_token=None):
|
def __init__(self, vocab, unknown_token=None):
|
||||||
super().__init__(vocab, unknown_token)
|
super().__init__(vocab, unknown_token)
|
||||||
|
|
||||||
|
class SlidingWindow(cde.SlidingWindowOp):
|
||||||
|
"""
|
||||||
|
TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis
|
||||||
|
is a slice of data starting at the corresponding position, with a specified width.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width (int): The width of the window. Must be an integer and greater than zero.
|
||||||
|
axis (int, optional): The axis along which sliding window is computed (default=0).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Data before
|
||||||
|
>>> # | col1 |
|
||||||
|
>>> # +-------------+
|
||||||
|
>>> # | [1,2,3,4,5] |
|
||||||
|
>>> # +-------------+
|
||||||
|
>>> data = data.map(operations=SlidingWindow(3, 0))
|
||||||
|
>>> # Data after
|
||||||
|
>>> # | col1 |
|
||||||
|
>>> # +-------------+
|
||||||
|
>>> # | [[1,2,3], |
|
||||||
|
>>> # | [2,3,4], |
|
||||||
|
>>> # | [3,4,5]] |
|
||||||
|
>>> # +--------------+
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_slidingwindow
|
||||||
|
def __init__(self, width, axis=0):
|
||||||
|
super().__init__(width=width, axis=axis)
|
||||||
|
|
||||||
class Ngram(cde.NgramOp):
|
class Ngram(cde.NgramOp):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde
|
||||||
from mindspore._c_expression import typing
|
from mindspore._c_expression import typing
|
||||||
|
|
||||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
|
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
|
||||||
INT32_MAX, check_value, check_positive
|
INT32_MAX, check_value, check_positive, check_pos_int32
|
||||||
|
|
||||||
|
|
||||||
def check_unique_list_of_words(words, arg_name):
|
def check_unique_list_of_words(words, arg_name):
|
||||||
|
@ -328,6 +328,17 @@ def check_from_dataset(method):
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
def check_slidingwindow(method):
|
||||||
|
"""A wrapper that wrap a parameter checker to the original function(sliding window operation)."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[width, axis], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
check_pos_int32(width, "width")
|
||||||
|
type_check(axis, (int,), "axis")
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
||||||
def check_ngram(method):
|
def check_ngram(method):
|
||||||
"""A wrapper that wraps a parameter checker to the original function."""
|
"""A wrapper that wraps a parameter checker to the original function."""
|
||||||
|
|
|
@ -92,6 +92,7 @@ SET(DE_UT_SRCS
|
||||||
perf_data_test.cc
|
perf_data_test.cc
|
||||||
c_api_test.cc
|
c_api_test.cc
|
||||||
tensor_op_fusion_pass_test.cc
|
tensor_op_fusion_pass_test.cc
|
||||||
|
sliding_window_op_test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "minddata/dataset/text/kernels/sliding_window_op.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
|
||||||
|
class MindDataTestSlidingWindowOp : public UT::Common {
|
||||||
|
protected:
|
||||||
|
MindDataTestSlidingWindowOp() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestSlidingWindowOp, Compute) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->Compute.";
|
||||||
|
std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"};
|
||||||
|
TensorShape shape({static_cast<dsize_t>(strings.size())});
|
||||||
|
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
|
||||||
|
std::shared_ptr<Tensor> output;
|
||||||
|
|
||||||
|
std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0));
|
||||||
|
Status s = op->Compute(input, &output);
|
||||||
|
|
||||||
|
std::vector<std::string> out = {"one", "two", "three", "two", "three", "four", "three", "four", "five",
|
||||||
|
"four", "five", "six", "five", "six", "seven", "six", "seven", "eight"};
|
||||||
|
std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(out, TensorShape({6, 3}));
|
||||||
|
|
||||||
|
ASSERT_TRUE(output->shape() == expected->shape());
|
||||||
|
ASSERT_TRUE(output->type() == expected->type());
|
||||||
|
MS_LOG(DEBUG) << *output << std::endl;
|
||||||
|
MS_LOG(DEBUG) << *expected << std::endl;
|
||||||
|
ASSERT_TRUE(*output == *expected);
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "MindDataTestSlidingWindowOp end.";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestSlidingWindowOp, OutputShape) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->OutputShape.";
|
||||||
|
std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"};
|
||||||
|
TensorShape shape({static_cast<dsize_t>(strings.size())});
|
||||||
|
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
|
||||||
|
std::vector<TensorShape> input_shape = {input->shape()};
|
||||||
|
std::vector<TensorShape> output_shape = {TensorShape({})};
|
||||||
|
|
||||||
|
std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0));
|
||||||
|
Status s = op->OutputShape(input_shape, output_shape);
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "input_shape" << input_shape[0];
|
||||||
|
MS_LOG(DEBUG) << "output_shape" << output_shape[0];
|
||||||
|
ASSERT_TRUE(output_shape[0] == TensorShape({6, 3}));
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "MindDataTestSlidingWindowOp end.";
|
||||||
|
}
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing SlidingWindow in mindspore.dataset
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.text as text
|
||||||
|
|
||||||
|
def test_sliding_window_string():
|
||||||
|
""" test sliding_window with string type"""
|
||||||
|
inputs = [["大", "家", "早", "上", "好"]]
|
||||||
|
expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']])
|
||||||
|
|
||||||
|
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
||||||
|
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0))
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
for i in range(data['text'].shape[0]):
|
||||||
|
result.append([])
|
||||||
|
for j in range(data['text'].shape[1]):
|
||||||
|
result[i].append(data['text'][i][j].decode('utf8'))
|
||||||
|
result = np.array(result)
|
||||||
|
np.testing.assert_array_equal(result, expect)
|
||||||
|
|
||||||
|
def test_sliding_window_number():
|
||||||
|
inputs = [1]
|
||||||
|
expect = np.array([[1]])
|
||||||
|
|
||||||
|
def gen(nums):
|
||||||
|
yield (np.array(nums),)
|
||||||
|
|
||||||
|
dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"])
|
||||||
|
dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(1, -1))
|
||||||
|
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
np.testing.assert_array_equal(data['number'], expect)
|
||||||
|
|
||||||
|
def test_sliding_window_big_width():
|
||||||
|
inputs = [[1, 2, 3, 4, 5]]
|
||||||
|
expect = np.array([])
|
||||||
|
|
||||||
|
dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False)
|
||||||
|
dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(30, 0))
|
||||||
|
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
np.testing.assert_array_equal(data['number'], expect)
|
||||||
|
|
||||||
|
def test_sliding_window_exception():
|
||||||
|
try:
|
||||||
|
_ = text.SlidingWindow(0, 0)
|
||||||
|
assert False
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = text.SlidingWindow("1", 0)
|
||||||
|
assert False
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = text.SlidingWindow(1, "0")
|
||||||
|
assert False
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
inputs = [[1, 2, 3, 4, 5]]
|
||||||
|
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
||||||
|
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(3, -100))
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert False
|
||||||
|
except RuntimeError as e:
|
||||||
|
assert "axis supports 0 or -1 only for now." in str(e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
inputs = ["aa", "bb", "cc"]
|
||||||
|
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
||||||
|
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0))
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert False
|
||||||
|
except RuntimeError as e:
|
||||||
|
assert "SlidingWindosOp supports 1D Tensors only for now." in str(e)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_sliding_window_string()
|
||||||
|
test_sliding_window_number()
|
||||||
|
test_sliding_window_big_width()
|
||||||
|
test_sliding_window_exception()
|
Loading…
Reference in New Issue