[feat][assistant][I3J6VN] add new data operator flowers102

This commit is contained in:
ckczzj 2021-06-02 11:39:16 +08:00
parent 3a50cb8432
commit 9d9f33be88
17 changed files with 613 additions and 1 deletions

View File

@ -64,7 +64,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset
check_sb_dataset, check_flowers102dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -5418,6 +5418,232 @@ class CSVDataset(SourceDataset):
self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id)
class _Flowers102Dataset:
"""
Mainly for loading Flowers102 Dataset, and return one row each time.
"""
def __init__(self, dataset_dir, task, usage, decode):
self.dataset_dir = os.path.realpath(dataset_dir)
self.task = task
self.usage = usage
self.decode = decode
if self.task == "Classification":
self.column_names = ["image", "label"]
else:
self.column_names = ["image", "segmentation", "label"]
labels_path = os.path.join(self.dataset_dir, "imagelabels.mat")
setid_path = os.path.join(self.dataset_dir, "setid.mat")
# minus one to transform 1~102 to 0 ~ 101
self.labels = (loadmat(labels_path)["labels"][0] - 1).astype(np.uint32)
self.setid = loadmat(setid_path)
if self.usage == 'train':
self.indices = self.setid["trnid"][0].tolist()
elif self.usage == 'test':
self.indices = self.setid["tstid"][0].tolist()
elif self.usage == 'valid':
self.indices = self.setid["valid"][0].tolist()
elif self.usage == 'all':
self.indices = self.setid["trnid"][0].tolist()
self.indices += self.setid["tstid"][0].tolist()
self.indices += self.setid["valid"][0].tolist()
else:
raise ValueError("Input usage is not within the valid set of ['train', 'valid', 'test', 'all'].")
def __getitem__(self, index):
# range: 1 ~ 8189
image_path = os.path.join(self.dataset_dir, "jpg", "image_" + str(self.indices[index]).zfill(5) + ".jpg")
if not os.path.exists(image_path):
raise RuntimeError("Can not find image file: " + image_path)
if self.decode is True:
image = np.asarray(Image.open(image_path).convert("RGB"))
else:
image = np.fromfile(image_path, dtype=np.uint8)
label = self.labels[self.indices[index] - 1]
if self.task == "Segmentation":
segmentation_path = \
os.path.join(self.dataset_dir, "segmim", "segmim_" + str(self.indices[index]).zfill(5) + ".jpg")
if not os.path.exists(segmentation_path):
raise RuntimeError("Can not find segmentation file: " + segmentation_path)
if self.decode is True:
segmentation = np.asarray(Image.open(segmentation_path).convert("RGB"))
else:
segmentation = np.fromfile(segmentation_path, dtype=np.uint8)
return image, segmentation, label
return image, label
def __len__(self):
return len(self.indices)
class Flowers102Dataset(GeneratorDataset):
"""
A source dataset for reading and parsing Flowers102 dataset.
The generated dataset has two columns :py:obj:`[image, label]` or three :py:obj:`[image, segmentation, label]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`segmentation` is of the uint8 type.
The tensor of column :py:obj:`label` is a scalar or a tensor of the uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
task (str): Specify the 'Classification' or 'Segmentation' task (default='Classification').
usage (str): Specify the 'train', 'valid', 'test' part or 'all' parts of dataset
(default='all', will read all samples).
num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images).
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
decode (bool, optional): Whether or not to decode the images and segmentations after reading (default=False).
sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible
input is required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
Random accessible input is required. When this argument is specified, 'num_samples' reflects the max
sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only
when num_shards is also specified. Random accessible input is required.
Raises:
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> flowers102_dataset_dir = "/path/to/flowers102_dataset_directory"
>>> dataset = ds.Flowers102Dataset(dataset_dir=flowers102_dataset_dir,
... task="Classification",
... usage="all",
... decode=True)
About Flowers102 dataset:
Flowers102 dataset consists of 102 flower categories.
The flowers commonly occur in the United Kingdom.
Each class consists of between 40 and 258 images.
Here is the original Flowers102 dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
flowes102_dataset_dir
imagelabels.mat
setid.mat
jpg
image_00001.jpg
image_00002.jpg
...
segmim
segmim_00001.jpg
segmim_00002.jpg
...
Citation:
.. code-block::
@InProceedings{Nilsback08,
author = "Maria-Elena Nilsback and Andrew Zisserman",
title = "Automated Flower Classification over a Large Number of Classes",
booktitle = "Indian Conference on Computer Vision, Graphics and Image Processing",
month = "Dec",
year = "2008",
}
"""
@check_flowers102dataset
def __init__(self, dataset_dir, task="Classification", usage="all", num_samples=None, num_parallel_workers=1,
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
self.dataset_dir = os.path.realpath(dataset_dir)
self.task = replace_none(task, "Classification")
self.usage = replace_none(usage, "all")
self.decode = replace_none(decode, False)
dataset = _Flowers102Dataset(self.dataset_dir, self.task, self.usage, self.decode)
super().__init__(dataset, column_names=dataset.column_names, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id)
def get_class_indexing(self):
"""
Get the class index.
Returns:
dict, a str-to-int mapping from label name to index.
"""
class_names = [
"pink primrose", "hard-leaved pocket orchid", "canterbury bells",
"sweet pea", "english marigold", "tiger lily", "moon orchid",
"bird of paradise", "monkshood", "globe thistle", "snapdragon",
"colt's foot", "king protea", "spear thistle", "yellow iris",
"globe-flower", "purple coneflower", "peruvian lily", "balloon flower",
"giant white arum lily", "fire lily", "pincushion flower", "fritillary",
"red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers",
"stemless gentian", "artichoke", "sweet william", "carnation",
"garden phlox", "love in the mist", "mexican aster", "alpine sea holly",
"ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip",
"lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia",
"bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy",
"common dandelion", "petunia", "wild pansy", "primula", "sunflower",
"pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia",
"pink-yellow dahlia?", "cautleya spicata", "japanese anemone",
"black-eyed susan", "silverbush", "californian poppy", "osteospermum",
"spring crocus", "bearded iris", "windflower", "tree poppy", "gazania",
"azalea", "water lily", "rose", "thorn apple", "morning glory",
"passion flower", "lotus", "toad lily", "anthurium", "frangipani",
"clematis", "hibiscus", "columbine", "desert-rose", "tree mallow",
"magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum",
"bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow",
"mexican petunia", "bromelia", "blanket flower", "trumpet creeper",
"blackberry lily"
]
class_dict = {}
for i, class_name in enumerate(class_names):
class_dict[class_name] = i
return class_dict
class TextFileDataset(SourceDataset):
"""
A source dataset that reads and parses datasets stored on disk in text format.

View File

@ -953,6 +953,44 @@ def check_csvdataset(method):
return new_method
def check_flowers102dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(Flowers102Dataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
check_dir(os.path.join(dataset_dir, "jpg"))
check_file(os.path.join(dataset_dir, "imagelabels.mat"))
check_file(os.path.join(dataset_dir, "setid.mat"))
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
task = param_dict.get('task')
if task is not None:
check_valid_str(task, ["Classification", "Segmentation"], "task")
if task == "Segmentation":
check_dir(os.path.join(dataset_dir, "segmim"))
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_textfiledataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""

Binary file not shown.

After

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

View File

@ -0,0 +1,348 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Flowers102 dataset operators
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pytest
from PIL import Image
from scipy.io import loadmat
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testFlowers102Dataset"
WRONG_DIR = "../data/dataset/testMnistData"
def load_flowers102(path, usage):
"""
load Flowers102 data
"""
assert usage in ["train", "valid", "test", "all"]
imagelabels = (loadmat(os.path.join(path, "imagelabels.mat"))["labels"][0] - 1).astype(np.uint32)
split = loadmat(os.path.join(path, "setid.mat"))
if usage == 'train':
indices = split["trnid"][0].tolist()
elif usage == 'test':
indices = split["tstid"][0].tolist()
elif usage == 'valid':
indices = split["valid"][0].tolist()
elif usage == 'all':
indices = split["trnid"][0].tolist()
indices += split["tstid"][0].tolist()
indices += split["valid"][0].tolist()
image_paths = [os.path.join(path, "jpg", "image_" + str(index).zfill(5) + ".jpg") for index in indices]
segmentation_paths = [os.path.join(path, "segmim", "segmim_" + str(index).zfill(5) + ".jpg") for index in indices]
images = [np.asarray(Image.open(path).convert("RGB")) for path in image_paths]
segmentations = [np.asarray(Image.open(path).convert("RGB")) for path in segmentation_paths]
labels = [imagelabels[index - 1] for index in indices]
return images, segmentations, labels
def visualize_dataset(images, labels):
"""
Helper function to visualize the dataset samples
"""
num_samples = len(images)
for i in range(num_samples):
plt.subplot(1, num_samples, i + 1)
plt.imshow(images[i].squeeze())
plt.title(labels[i])
plt.show()
def test_flowers102_content_check():
"""
Validate Flowers102Dataset image readings
"""
logger.info("Test Flowers102Dataset Op with content check")
all_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="all",
num_samples=6, decode=True, shuffle=False)
images, segmentations, labels = load_flowers102(DATA_DIR, "all")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["segmentation"], segmentations[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 6
train_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="train",
num_samples=2, decode=True, shuffle=False)
images, segmentations, labels = load_flowers102(DATA_DIR, "train")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["segmentation"], segmentations[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 2
test_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="test",
num_samples=2, decode=True, shuffle=False)
images, segmentations, labels = load_flowers102(DATA_DIR, "test")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["segmentation"], segmentations[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 2
val_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="valid",
num_samples=2, decode=True, shuffle=False)
images, segmentations, labels = load_flowers102(DATA_DIR, "valid")
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
for i, data in enumerate(val_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["segmentation"], segmentations[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 2
def test_flowers102_basic():
"""
Validate Flowers102Dataset
"""
logger.info("Test Flowers102Dataset Op")
# case 1: test decode
all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, shuffle=False)
all_data_1 = all_data.map(operations=[c_vision.Decode()], input_columns=["image"], num_parallel_workers=1)
all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shuffle=False)
num_iter = 0
for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(item1["label"], item2["label"])
num_iter += 1
assert num_iter == 6
# case 2: test num_samples
all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4)
num_iter = 0
for _ in all_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 4
# case 3: test repeat
all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4)
all_data = all_data.repeat(5)
num_iter = 0
for _ in all_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 20
# case 3: test get_dataset_size, resize and batch
all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4)
all_data = all_data.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224))], input_columns=["image"],
num_parallel_workers=1)
assert all_data.get_dataset_size() == 4
assert all_data.get_batch_size() == 1
all_data = all_data.batch(batch_size=3) # drop_remainder is default to be False
assert all_data.get_batch_size() == 3
assert all_data.get_dataset_size() == 2
num_iter = 0
for _ in all_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 2
# case 4: test get_class_indexing
all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4)
class_indexing = all_data.get_class_indexing()
assert class_indexing["pink primrose"] == 0
assert class_indexing["blackberry lily"] == 101
def test_flowers102_sequential_sampler():
"""
Test Flowers102Dataset with SequentialSampler
"""
logger.info("Test Flowers102Dataset Op with SequentialSampler")
num_samples = 4
sampler = ds.SequentialSampler(num_samples=num_samples)
all_data_1 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all",
decode=True, sampler=sampler)
all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all",
decode=True, shuffle=False, num_samples=num_samples)
label_list_1, label_list_2 = [], []
num_iter = 0
for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1),
all_data_2.create_dict_iterator(num_epochs=1)):
label_list_1.append(item1["label"].asnumpy())
label_list_2.append(item2["label"].asnumpy())
num_iter += 1
np.testing.assert_array_equal(label_list_1, label_list_2)
assert num_iter == num_samples
def test_flowers102_exception():
"""
Test error cases for Flowers102Dataset
"""
logger.info("Test error cases for Flowers102Dataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", shuffle=False,
decode=True, sampler=ds.SequentialSampler(1))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", sampler=ds.SequentialSampler(1),
decode=True, num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id="0")
error_msg_8 = "does not exist or is not a directory or permission denied!"
with pytest.raises(ValueError, match=error_msg_8):
all_data = ds.Flowers102Dataset(WRONG_DIR, task="Classification", usage="all", decode=True)
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
error_msg_9 = "is not of type"
with pytest.raises(TypeError, match=error_msg_9):
all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=123)
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
def test_flowers102_visualize(plot=False):
"""
Visualize Flowers102Dataset results
"""
logger.info("Test Flowers102Dataset visualization")
all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", num_samples=4,
decode=True, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in all_data.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert len(image.shape) == 3
assert image.shape[-1] == 3
assert image.dtype == np.uint8
assert label.dtype == np.uint32
num_iter += 1
assert num_iter == 4
if plot:
visualize_dataset(image_list, label_list)
def test_flowers102_usage():
"""
Validate Flowers102Dataset usage
"""
logger.info("Test Flowers102Dataset usage flag")
def test_config(usage):
try:
data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage=usage, decode=True, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("all") == 6
assert test_config("train") == 2
assert test_config("test") == 2
assert test_config("valid") == 2
assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid")
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
def test_flowers102_task():
"""
Validate Flowers102Dataset task
"""
logger.info("Test Flowers102Dataset task flag")
def test_config(task):
try:
data = ds.Flowers102Dataset(DATA_DIR, task=task, usage="all", decode=True, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("Classification") == 6
assert test_config("Segmentation") == 6
assert "Input task is not within the valid set of ['Classification', 'Segmentation']" in test_config("invalid")
assert "Argument task with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
if __name__ == '__main__':
test_flowers102_content_check()
test_flowers102_basic()
test_flowers102_sequential_sampler()
test_flowers102_exception()
test_flowers102_visualize(plot=True)
test_flowers102_usage()
test_flowers102_task()