mindspore/tests/ut/python/dataset/test_take.py

364 lines
8.8 KiB
Python
Raw Normal View History

2020-04-11 18:59:37 +08:00
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2020-05-18 16:42:35 +08:00
import numpy as np
2020-04-11 18:59:37 +08:00
import mindspore.dataset as ds
from mindspore import log as logger
# In generator dataset: Number of rows is 3, its value is 0, 1, 2
def generator():
for i in range(3):
2020-05-22 14:16:07 +08:00
yield (np.array([i]),)
2020-04-11 18:59:37 +08:00
# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
def generator_10():
for i in range(10):
2020-05-22 14:16:07 +08:00
yield (np.array([i]),)
2020-04-11 18:59:37 +08:00
2020-04-29 13:41:36 +08:00
def filter_func_ge(data):
if data > 3:
return False
return True
2020-04-11 18:59:37 +08:00
def test_take_01():
"""
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
"""
logger.info("test_take_01")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(1)
data1 = data1.repeat(2)
2020-05-26 16:17:53 +08:00
# Here i refers to index, d refers to data element
2020-05-22 14:16:07 +08:00
for _, d in enumerate(data1):
assert d[0][0] == 0
2020-04-11 18:59:37 +08:00
assert sum([1 for _ in data1]) == 2
def test_take_02():
"""
Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe
"""
logger.info("test_take_02")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
data1 = data1.repeat(2)
2020-05-26 16:17:53 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert i % 2 == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_03():
"""
Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof
"""
logger.info("test_take_03")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(3)
data1 = data1.repeat(2)
2020-05-26 16:17:53 +08:00
# Here i refers to index, d refers to data elements
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_04():
"""
Test take: origin there are 3 row, and take 4 row, this is more than the total rows
"""
logger.info("test_take_04")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(4)
data1 = data1.repeat(2)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_05():
"""
Test take: there is no repeat op
"""
logger.info("test_take_05")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert i == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_06():
"""
Test take: repeat is before take
"""
logger.info("test_take_06")
data1 = ds.GeneratorDataset(generator, ["data"])
2020-05-18 10:31:46 +08:00
2020-04-11 18:59:37 +08:00
data1 = data1.repeat(2)
data1 = data1.take(4)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_07():
"""
Test take: take is before batch, that mean take(N), N refer to rows num
"""
logger.info("test_take_07")
data1 = ds.GeneratorDataset(generator, ["data"])
2020-05-18 10:31:46 +08:00
2020-04-11 18:59:37 +08:00
data1 = data1.take(2)
data1 = data1.batch(2)
assert sum([1 for _ in data1]) == 1
def test_take_08():
"""
Test take: take is after batch, that mean take(N), N refer to batches num
"""
logger.info("test_take_08")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.batch(2)
data1 = data1.take(2)
assert sum([1 for _ in data1]) == 2
def test_take_09():
"""
Test take: repeat count is -1, and read the whole dataset, take after repeat
"""
logger.info("test_take_09")
data1 = ds.GeneratorDataset(generator, ["data"])
2020-05-18 10:31:46 +08:00
2020-04-11 18:59:37 +08:00
data1 = data1.repeat(2)
data1 = data1.take(-1)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_10():
"""
Test take: repeat count is -1, and read the whole dataset, take before repeat
"""
logger.info("test_take_10")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(-1)
data1 = data1.repeat(2)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_11():
"""
Test take: batch first, then do repeat and take operation
"""
logger.info("test_take_11")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.batch(2)
data1 = data1.repeat(2)
data1 = data1.take(-1)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert 2 * (i % 2) == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_12():
"""
Test take: take first, then do batch and repeat operation
"""
logger.info("test_take_12")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
data1 = data1.batch(2)
data1 = data1.repeat(2)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
for _, d in enumerate(data1):
assert d[0][0] == 0
2020-04-11 18:59:37 +08:00
assert sum([1 for _ in data1]) == 2
def test_take_13():
"""
Test take: skip first, then do take, batch and repeat operation
"""
logger.info("test_take_13")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.skip(2)
data1 = data1.take(-1)
data1 = data1.batch(2)
data1 = data1.repeat(2)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
for _, d in enumerate(data1):
assert d[0][0] == 2
2020-04-11 18:59:37 +08:00
assert sum([1 for _ in data1]) == 2
def test_take_14():
"""
Test take: take first, then do batch, skip and repeat operation
"""
logger.info("test_take_14")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(-1)
data1 = data1.batch(2)
data1 = data1.skip(1)
data1 = data1.repeat(2)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
for _, d in enumerate(data1):
assert d[0][0] == 2
2020-04-11 18:59:37 +08:00
assert sum([1 for _ in data1]) == 2
def test_take_15():
"""
Test take: large amount data, take a part, then do skip operation
"""
logger.info("test_take_15")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(6)
data1 = data1.skip(2)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert (i + 2) == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_16():
"""
Test take: large amount data, skip a part, then do take operation
"""
logger.info("test_take_16")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.skip(3)
data1 = data1.take(5)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-11 18:59:37 +08:00
for i, d in enumerate(data1):
assert (i + 3) == d[0][0]
assert sum([1 for _ in data1]) == 5
2020-04-29 13:41:36 +08:00
def test_take_17():
"""
Test take: take first, then do fiter operation
"""
logger.info("test_take_17")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
2020-04-29 13:41:36 +08:00
for i, d in enumerate(data1):
assert i == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_18():
"""
Test take: take first, then do fiter, skip, batch and repeat operation
"""
logger.info("test_take_18")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
data1 = data1.skip(2)
data1 = data1.batch(2)
data1 = data1.repeat(2)
2020-05-22 14:16:07 +08:00
# Here i refers to index, d refers to data element
for _, d in enumerate(data1):
assert d[0][0] == 2
2020-04-29 13:41:36 +08:00
assert sum([1 for _ in data1]) == 2
2020-04-11 18:59:37 +08:00
if __name__ == '__main__':
test_take_01()
test_take_02()
test_take_03()
test_take_04()
test_take_05()
test_take_06()
test_take_07()
test_take_08()
test_take_09()
test_take_10()
test_take_11()
test_take_12()
test_take_13()
test_take_14()
test_take_15()
test_take_16()
2020-04-29 13:41:36 +08:00
test_take_17()
test_take_18()
2020-05-18 10:31:46 +08:00
logger.info('== test take operation finished ==')