mindspore/tests/ut/python/dataset/test_skip.py

239 lines
5.8 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_tf_skip():
"""
a simple skip operation.
"""
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
resize_height, resize_width = 32, 32
decode_op = vision.Decode()
resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
data1 = data1.skip(2)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 1
def generator_md():
"""
create a dataset with [0, 1, 2, 3, 4]
"""
for i in range(5):
yield (np.array([i]),)
def test_generator_skip():
ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4)
# Here ds1 should be [3, 4]
ds1 = ds1.skip(3)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [3, 4]
def test_skip_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be []
ds1 = ds1.skip(7)
buf = []
for data in ds1:
buf.append(data[0][0])
assert buf == []
def test_skip_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3, 4]
ds1 = ds1.skip(0)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 5
assert buf == [0, 1, 2, 3, 4]
def test_skip_repeat_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
ds1 = ds1.repeat(2)
# Here ds1 should be [3, 4, 0, 1, 2, 3, 4]
ds1 = ds1.skip(3)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 7
assert buf == [3, 4, 0, 1, 2, 3, 4]
def test_skip_repeat_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [3, 4]
ds1 = ds1.skip(3)
# Here ds1 should be [3, 4, 3, 4]
ds1 = ds1.repeat(2)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 4
assert buf == [3, 4, 3, 4]
def test_skip_repeat_3():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
ds1 = ds1.repeat(2)
# Here ds1 should be [3, 4]
ds1 = ds1.skip(8)
# Here ds1 should be [3, 4, 3, 4, 3, 4]
ds1 = ds1.repeat(3)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 6
assert buf == [3, 4, 3, 4, 3, 4]
def test_skip_take_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3]
ds1 = ds1.take(4)
# Here ds1 should be [2, 3]
ds1 = ds1.skip(2)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [2, 3]
def test_skip_take_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [2, 3, 4]
ds1 = ds1.skip(2)
# Here ds1 should be [2, 3]
ds1 = ds1.take(2)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [2, 3]
def generator_1d():
for i in range(64):
yield (np.array([i]),)
def test_skip_filter_1():
dataset = ds.GeneratorDataset(generator_1d, ['data'])
dataset = dataset.skip(5)
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
buf = []
for item in dataset:
buf.append(item[0][0])
assert buf == [5, 6, 7, 8, 9, 10]
def test_skip_filter_2():
dataset = ds.GeneratorDataset(generator_1d, ['data'])
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
dataset = dataset.skip(5)
buf = []
for item in dataset:
buf.append(item[0][0])
assert buf == [5, 6, 7, 8, 9, 10]
def test_skip_exception_1():
data1 = ds.GeneratorDataset(generator_md, ["data"])
try:
data1 = data1.skip(count=-1)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter += 1
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Skip count must be positive integer or 0." in str(e)
def test_skip_exception_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
with pytest.raises(ValueError) as e:
ds1 = ds1.skip(-2)
assert "Input count is not within the required interval" in str(e.value)
if __name__ == "__main__":
test_tf_skip()
test_generator_skip()
test_skip_1()
test_skip_2()
test_skip_repeat_1()
test_skip_repeat_2()
test_skip_repeat_3()
test_skip_take_1()
test_skip_take_2()
test_skip_filter_1()
test_skip_filter_2()
test_skip_exception_1()
test_skip_exception_2()