mindspore/tests/ut/python/dataset/test_datasets_lfw.py

493 lines
16 KiB
Python

# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = "../data/dataset/testLFW"
def test_lfw_basic():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case basic")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_lfw_task():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case basic")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, task="pairs", usage="all")
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image1 is {}".format(item["image1"]))
logger.info("image2 is {}".format(item["image2"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 16
def test_lfw_usage():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case basic")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, usage="test")
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
def test_lfw_image_set():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case basic")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, image_set="funneled")
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_lfw_num_samples():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case numSamples")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, num_samples=4, num_parallel_workers=2)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
random_sampler = ds.RandomSampler(num_samples=2, replacement=True)
data1 = ds.LFWDataset(DATA_DIR, num_parallel_workers=2,
sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 2
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
data1 = ds.LFWDataset(DATA_DIR, num_parallel_workers=2,
sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
def test_lfw_num_shards():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case numShards")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, num_shards=5, shard_id=1)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_lfw_shard_id():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case withShardID")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, num_shards=4, shard_id=1)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_lfw_no_shuffle():
"""
Feature: LFW
Description: Test dataset of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case noShuffle")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_lfw_decode():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case decode")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, decode=True)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_sequential_sampler():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case SequentialSampler")
# define parameters.
repeat_count = 2
# apply dataset operations.
sampler = ds.SequentialSampler(num_samples=3)
data1 = ds.LFWDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
result = []
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
result.append(item["label"])
num_iter += 1
assert num_iter == 6
logger.info("Result: {}".format(result))
def test_random_and_sequentialchained_sampler():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
# Create chained sampler, random and sequential.
sampler = ds.RandomSampler()
child_sampler = ds.SequentialSampler()
sampler.add_child(child_sampler)
# Create LFWDataset with sampler.
data1 = ds.LFWDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(count=3)
# Verify dataset size.
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 12
# Verify number of iterations.
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 12
def test_lfw_rename():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case rename")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, num_samples=4)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
data1 = data1.rename(input_columns=["image"], output_columns="image2")
num_iter = 0
# each data is a dictionary.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image2"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
def test_lfw_zip():
"""
Feature: LFW
Description: Test basic usage of LFW
Expectation: The dataset is as expected
"""
logger.info("Test Case zip")
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.LFWDataset(DATA_DIR, num_samples=3)
data2 = ds.LFWDataset(DATA_DIR, num_samples=3)
data1 = data1.repeat(repeat_count)
# rename dataset2 for no conflict.
data2 = data2.rename(input_columns=["image", "label"],
output_columns=["image1", "label1"])
data3 = ds.zip((data1, data2))
num_iter = 0
# each data is a dictionary.
for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 3
def test_lfw_exception():
"""
Feature: LFW
Description: Test error cases of LFW
Expectation: Throw exception correctly
"""
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.LFWDataset(DATA_DIR, shuffle=False, decode=True, sampler=ds.SequentialSampler(1))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.LFWDataset(DATA_DIR, sampler=ds.SequentialSampler(1), decode=True, num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.LFWDataset(DATA_DIR, decode=True, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.LFWDataset(DATA_DIR, decode=True, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.LFWDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.LFWDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.LFWDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.LFWDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.LFWDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0")
error_msg_8 = "does not exist or is not a directory or permission denied!"
with pytest.raises(ValueError, match=error_msg_8):
all_data = ds.LFWDataset("../data/dataset/testLFW2", decode=True)
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
error_msg_9 = "Input task is not within the valid set of ['people', 'pairs']."
with pytest.raises(ValueError, match=re.escape(error_msg_9)):
all_data = ds.LFWDataset(DATA_DIR, task="all")
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
error_msg_10 = "Input usage is not within the valid set of ['10fold', 'train', 'test', 'all']."
with pytest.raises(ValueError, match=re.escape(error_msg_10)):
all_data = ds.LFWDataset(DATA_DIR, usage="many")
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
error_msg_11 = "Input image_set is not within the valid set of ['original', 'funneled', 'deepfunneled']."
with pytest.raises(ValueError, match=re.escape(error_msg_11)):
all_data = ds.LFWDataset(DATA_DIR, image_set="all")
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
error_msg_12 = "Argument decode with value 123 is not of type [<class 'bool'>], but got <class 'int'>."
with pytest.raises(TypeError, match=re.escape(error_msg_12)):
all_data = ds.LFWDataset(DATA_DIR, decode=123)
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
if __name__ == '__main__':
test_lfw_basic()
test_lfw_task()
test_lfw_usage()
test_lfw_image_set()
test_lfw_num_samples()
test_sequential_sampler()
test_random_and_sequentialchained_sampler()
test_lfw_num_shards()
test_lfw_shard_id()
test_lfw_no_shuffle()
test_lfw_decode()
test_lfw_rename()
test_lfw_zip()
test_lfw_exception()