mindspore/tests/ut/python/dataset/test_datasets_celeba.py

94 lines
3.4 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore.dataset.transforms.vision import Inter
from mindspore import log as logger
DATA_DIR = "../data/dataset/testCelebAData/"
def test_celeba_dataset_label():
data = ds.CelebADataset(DATA_DIR, decode=True)
expect_labels = [
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
0, 0, 1],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 1]]
count = 0
for item in data.create_dict_iterator():
logger.info("----------image--------")
logger.info(item["image"])
logger.info("----------attr--------")
logger.info(item["attr"])
for index in range(len(expect_labels[count])):
assert (item["attr"][index] == expect_labels[count][index])
count = count + 1
assert (count == 2)
def test_celeba_dataset_op():
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
crop_size = (80, 80)
resize_size = (24, 24)
# define map operations
data = data.repeat(2)
center_crop = vision.CenterCrop(crop_size)
resize_op = vision.Resize(resize_size, Inter.LINEAR) # Bilinear mode
data = data.map(input_columns=["image"], operations=center_crop)
data = data.map(input_columns=["image"], operations=resize_op)
count = 0
for item in data.create_dict_iterator():
logger.info("----------image--------")
logger.info(item["image"])
count = count + 1
assert (count == 4)
def test_celeba_dataset_ext():
ext = [".JPEG"]
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
0, 1, 0, 1, 0, 0, 1],
count = 0
for item in data.create_dict_iterator():
logger.info("----------image--------")
logger.info(item["image"])
logger.info("----------attr--------")
logger.info(item["attr"])
for index in range(len(expect_labels[count])):
assert (item["attr"][index] == expect_labels[count][index])
count = count + 1
assert (count == 1)
def test_celeba_dataset_distribute():
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
count = 0
for item in data.create_dict_iterator():
logger.info("----------image--------")
logger.info(item["image"])
logger.info("----------attr--------")
logger.info(item["attr"])
count = count + 1
assert (count == 1)
if __name__ == '__main__':
test_celeba_dataset_label()
test_celeba_dataset_op()
test_celeba_dataset_ext()
test_celeba_dataset_distribute()