mindspore/tests/ut/python/dataset/test_apply.py

246 lines
6.7 KiB
Python

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testPK/data"
# Generate 1d int numpy array from 0 - 64
def generator_1d():
for i in range(64):
yield (np.array([i]),)
def test_apply_generator_case():
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data2 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds_):
ds_ = ds_.repeat(2)
return ds_.batch(4)
data1 = data1.apply(dataset_fn)
data2 = data2.repeat(2)
data2 = data2.batch(4)
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
np.testing.assert_array_equal(item1["data"], item2["data"])
def test_apply_imagefolder_case():
# apply dataset map operations
data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
data2 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
decode_op = vision.Decode()
normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0])
def dataset_fn(ds_):
ds_ = ds_.map(operations=decode_op)
ds_ = ds_.map(operations=normalize_op)
ds_ = ds_.repeat(2)
return ds_
data1 = data1.apply(dataset_fn)
data2 = data2.map(operations=decode_op)
data2 = data2.map(operations=normalize_op)
data2 = data2.repeat(2)
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
np.testing.assert_array_equal(item1["image"], item2["image"])
def test_apply_flow_case_0(id_=0):
# apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds_):
if id_ == 0:
ds_ = ds_.batch(4)
elif id_ == 1:
ds_ = ds_.repeat(2)
elif id_ == 2:
ds_ = ds_.batch(4)
ds_ = ds_.repeat(2)
else:
ds_ = ds_.shuffle(buffer_size=4)
return ds_
data1 = data1.apply(dataset_fn)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter = num_iter + 1
if id_ == 0:
assert num_iter == 16
elif id_ == 1:
assert num_iter == 128
elif id_ == 2:
assert num_iter == 32
else:
assert num_iter == 64
def test_apply_flow_case_1(id_=1):
# apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds_):
if id_ == 0:
ds_ = ds_.batch(4)
elif id_ == 1:
ds_ = ds_.repeat(2)
elif id_ == 2:
ds_ = ds_.batch(4)
ds_ = ds_.repeat(2)
else:
ds_ = ds_.shuffle(buffer_size=4)
return ds_
data1 = data1.apply(dataset_fn)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter = num_iter + 1
if id_ == 0:
assert num_iter == 16
elif id_ == 1:
assert num_iter == 128
elif id_ == 2:
assert num_iter == 32
else:
assert num_iter == 64
def test_apply_flow_case_2(id_=2):
# apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds_):
if id_ == 0:
ds_ = ds_.batch(4)
elif id_ == 1:
ds_ = ds_.repeat(2)
elif id_ == 2:
ds_ = ds_.batch(4)
ds_ = ds_.repeat(2)
else:
ds_ = ds_.shuffle(buffer_size=4)
return ds_
data1 = data1.apply(dataset_fn)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter = num_iter + 1
if id_ == 0:
assert num_iter == 16
elif id_ == 1:
assert num_iter == 128
elif id_ == 2:
assert num_iter == 32
else:
assert num_iter == 64
def test_apply_flow_case_3(id_=3):
# apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds_):
if id_ == 0:
ds_ = ds_.batch(4)
elif id_ == 1:
ds_ = ds_.repeat(2)
elif id_ == 2:
ds_ = ds_.batch(4)
ds_ = ds_.repeat(2)
else:
ds_ = ds_.shuffle(buffer_size=4)
return ds_
data1 = data1.apply(dataset_fn)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter = num_iter + 1
if id_ == 0:
assert num_iter == 16
elif id_ == 1:
assert num_iter == 128
elif id_ == 2:
assert num_iter == 32
else:
assert num_iter == 64
def test_apply_exception_case():
# apply exception operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds_):
ds_ = ds_.repeat(2)
return ds_.batch(4)
def exception_fn():
return np.array([[0], [1], [3], [4], [5]])
try:
data1 = data1.apply("123")
for _ in data1.create_dict_iterator(num_epochs=1):
pass
assert False
except TypeError:
pass
try:
data1 = data1.apply(exception_fn)
for _ in data1.create_dict_iterator(num_epochs=1):
pass
assert False
except TypeError:
pass
try:
data2 = data1.apply(dataset_fn)
_ = data1.apply(dataset_fn)
for _, _ in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
pass
assert False
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
if __name__ == '__main__':
logger.info("Running test_apply.py test_apply_generator_case() function")
test_apply_generator_case()
logger.info("Running test_apply.py test_apply_imagefolder_case() function")
test_apply_imagefolder_case()
logger.info("Running test_apply.py test_apply_flow_case(id) function")
test_apply_flow_case_0()
test_apply_flow_case_1()
test_apply_flow_case_2()
test_apply_flow_case_3()
logger.info("Running test_apply.py test_apply_exception_case() function")
test_apply_exception_case()