106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""
|
|
Testing SlidingWindow in mindspore.dataset
|
|
"""
|
|
import numpy as np
|
|
import mindspore.dataset as ds
|
|
import mindspore.dataset.text as text
|
|
|
|
def test_sliding_window_string():
|
|
""" test sliding_window with string type"""
|
|
inputs = [["大", "家", "早", "上", "好"]]
|
|
expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']])
|
|
|
|
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
|
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0))
|
|
|
|
result = []
|
|
for data in dataset.create_dict_iterator(num_epochs=1):
|
|
for i in range(data['text'].shape[0]):
|
|
result.append([])
|
|
for j in range(data['text'].shape[1]):
|
|
result[i].append(data['text'][i][j].decode('utf8'))
|
|
result = np.array(result)
|
|
np.testing.assert_array_equal(result, expect)
|
|
|
|
def test_sliding_window_number():
|
|
inputs = [1]
|
|
expect = np.array([[1]])
|
|
|
|
def gen(nums):
|
|
yield (np.array(nums),)
|
|
|
|
dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"])
|
|
dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(1, -1))
|
|
|
|
for data in dataset.create_dict_iterator(num_epochs=1):
|
|
np.testing.assert_array_equal(data['number'], expect)
|
|
|
|
def test_sliding_window_big_width():
|
|
inputs = [[1, 2, 3, 4, 5]]
|
|
expect = np.array([])
|
|
|
|
dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False)
|
|
dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(30, 0))
|
|
|
|
for data in dataset.create_dict_iterator(num_epochs=1):
|
|
np.testing.assert_array_equal(data['number'], expect)
|
|
|
|
def test_sliding_window_exception():
|
|
try:
|
|
_ = text.SlidingWindow(0, 0)
|
|
assert False
|
|
except ValueError:
|
|
pass
|
|
|
|
try:
|
|
_ = text.SlidingWindow("1", 0)
|
|
assert False
|
|
except TypeError:
|
|
pass
|
|
|
|
try:
|
|
_ = text.SlidingWindow(1, "0")
|
|
assert False
|
|
except TypeError:
|
|
pass
|
|
|
|
try:
|
|
inputs = [[1, 2, 3, 4, 5]]
|
|
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
|
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(3, -100))
|
|
for _ in dataset.create_dict_iterator(num_epochs=1):
|
|
pass
|
|
assert False
|
|
except RuntimeError as e:
|
|
assert "axis supports 0 or -1 only for now." in str(e)
|
|
|
|
try:
|
|
inputs = ["aa", "bb", "cc"]
|
|
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
|
|
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0))
|
|
for _ in dataset.create_dict_iterator(num_epochs=1):
|
|
pass
|
|
assert False
|
|
except RuntimeError as e:
|
|
assert "SlidingWindosOp supports 1D Tensors only for now." in str(e)
|
|
|
|
if __name__ == '__main__':
|
|
test_sliding_window_string()
|
|
test_sliding_window_number()
|
|
test_sliding_window_big_width()
|
|
test_sliding_window_exception()
|