mindspore/tests/ut/python/dataset/test_take.py

317 lines
7.8 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
import numpy as np
# In generator dataset: Number of rows is 3, its value is 0, 1, 2
def generator():
for i in range(3):
yield np.array([i]),
# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
def generator_10():
for i in range(10):
yield np.array([i]),
def test_take_01():
"""
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
"""
logger.info("test_take_01")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(1)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 0 == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_02():
"""
Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe
"""
logger.info("test_take_02")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 2 == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_03():
"""
Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof
"""
logger.info("test_take_03")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(3)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_04():
"""
Test take: origin there are 3 row, and take 4 row, this is more than the total rows
"""
logger.info("test_take_04")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(4)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_05():
"""
Test take: there is no repeat op
"""
logger.info("test_take_05")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_06():
"""
Test take: repeat is before take
"""
logger.info("test_take_06")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.repeat(2)
data1 = data1.take(4)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_07():
"""
Test take: take is before batch, that mean take(N), N refer to rows num
"""
logger.info("test_take_07")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
data1 = data1.batch(2)
assert sum([1 for _ in data1]) == 1
def test_take_08():
"""
Test take: take is after batch, that mean take(N), N refer to batches num
"""
logger.info("test_take_08")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.batch(2)
data1 = data1.take(2)
assert sum([1 for _ in data1]) == 2
def test_take_09():
"""
Test take: repeat count is -1, and read the whole dataset, take after repeat
"""
logger.info("test_take_09")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.repeat(2)
data1 = data1.take(-1)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_10():
"""
Test take: repeat count is -1, and read the whole dataset, take before repeat
"""
logger.info("test_take_10")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(-1)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_11():
"""
Test take: batch first, then do repeat and take operation
"""
logger.info("test_take_11")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.batch(2)
data1 = data1.repeat(2)
data1 = data1.take(-1)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 * (i % 2) == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_12():
"""
Test take: take first, then do batch and repeat operation
"""
logger.info("test_take_12")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
data1 = data1.batch(2)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 0 == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_13():
"""
Test take: skip first, then do take, batch and repeat operation
"""
logger.info("test_take_13")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.skip(2)
data1 = data1.take(-1)
data1 = data1.batch(2)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_14():
"""
Test take: take first, then do batch, skip and repeat operation
"""
logger.info("test_take_14")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(-1)
data1 = data1.batch(2)
data1 = data1.skip(1)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_15():
"""
Test take: large amount data, take a part, then do skip operation
"""
logger.info("test_take_15")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(6)
data1 = data1.skip(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert (i + 2) == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_16():
"""
Test take: large amount data, skip a part, then do take operation
"""
logger.info("test_take_16")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.skip(3)
data1 = data1.take(5)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert (i + 3) == d[0][0]
assert sum([1 for _ in data1]) == 5
if __name__ == '__main__':
test_take_01()
test_take_02()
test_take_03()
test_take_04()
test_take_05()
test_take_06()
test_take_07()
test_take_08()
test_take_09()
test_take_10()
test_take_11()
test_take_12()
test_take_13()
test_take_14()
test_take_15()
test_take_16()
logger.info('== test take operation finished ==')