[feat][assistant][I3J6VJ] add new data operator SVHN

This commit is contained in:
RainWang6188 2021-12-04 08:19:47 +00:00
parent b2b89ad575
commit 340b3b4d5d
6 changed files with 544 additions and 1 deletions

View File

@ -70,7 +70,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \
check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset
check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset, check_svhn_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size, get_auto_offload
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -8962,3 +8962,149 @@ class TedliumDataset(MappableDataset):
def parse(self, children=None):
return cde.TedliumNode(self.dataset_dir, self.release, self.usage, self.extensions, self.sampler)
class _SVHNDataset:
"""
Mainly for loading SVHN Dataset, and return two rows each time.
"""
def __init__(self, dataset_dir, usage):
self.dataset_dir = os.path.realpath(dataset_dir)
self.usage = usage
self.column_names = ["image", "label"]
self.usage_all = ["train", "test", "extra"]
self.data = np.array([], dtype=np.uint8)
self.labels = np.array([], dtype=np.uint32)
if self.usage == "all":
for _usage in self.usage_all:
data, label = self._load_mat(_usage)
self.data = np.concatenate((self.data, data)) if self.data.size else data
self.labels = np.concatenate((self.labels, label)) if self.labels.size else label
else:
self.data, self.labels = self._load_mat(self.usage)
def _load_mat(self, mode):
filename = mode + "_32x32.mat"
mat_data = loadmat(os.path.join(self.dataset_dir, filename))
data = np.transpose(mat_data['X'], [3, 0, 1, 2])
label = mat_data['y'].astype(np.uint32).squeeze()
np.place(label, label == 10, 0)
return data, label
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
class SVHNDataset(GeneratorDataset):
"""
A source dataset for reading and parsing SVHN dataset.
The generated dataset has two columns: :py:obj:`[image, label]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is of a scalar of uint32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Specify the 'train', 'test', 'extra' or 'all' parts of dataset
(default=None, will read all samples).
num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images).
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible
input is required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
Random accessible input is required. When this argument is specified, 'num_samples' reflects the max
sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only
when num_shards is also specified. Random accessible input is required.
Raises:
RuntimeError: If dataset_dir is not valid or does not exist or does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If usage is invalid.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> svhn_dataset_dir = "/path/to/svhn_dataset_directory"
>>> dataset = ds.SVHNDataset(dataset_dir=svhn_dataset_dir, usage="train")
About SVHN dataset:
SVHN dataset consists of 10 digit classes.
SVHN is obtained from house numbers in Google Street View images.
73257 digits for training, 26032 digits for testing, and 531131 additional extra training data.
Here is the original SVHN dataset structure.
You can unzip the dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
svhn_dataset_dir
train_32x32.mat
test_32x32.mat
extra_32x32.mat
Citation:
.. code-block::
@article{
title={Reading Digits in Natural Images with Unsupervised Feature Learning},
author={Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, Andrew Y. Ng},
conference={NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011.},
year={2011},
publisher={NIPS}
url={http://ufldl.stanford.edu/housenumbers}
}
"""
@check_svhn_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=1, shuffle=None,
sampler=None, num_shards=None, shard_id=None):
self.dataset_dir = os.path.realpath(dataset_dir)
self.usage = replace_none(usage, "all")
dataset = _SVHNDataset(self.dataset_dir, self.usage)
super().__init__(dataset, column_names=dataset.column_names, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id)

View File

@ -1899,3 +1899,34 @@ def check_tedlium_dataset(method):
return method(self, *args, **kwargs)
return new_method
def check_svhn_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(SVHNDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle']
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "extra", "all"], "usage")
if usage == "all":
for _usage in ["train", "test", "extra"]:
check_file(os.path.join(dataset_dir, _usage + "_32x32.mat"))
else:
check_file(os.path.join(dataset_dir, usage + "_32x32.mat"))
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,366 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test SVHN dataset operators
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pytest
from scipy.io import loadmat
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = "../data/dataset/testSVHNData"
WRONG_DIR = "../data/dataset/testMnistData"
def load_mat(mode, path):
"""
Feature: load_mat.
Description: load .mat file.
Expectation: get .mat of svhn dataset.
"""
filename = mode + "_32x32.mat"
mat_data = loadmat(os.path.realpath(os.path.join(path, filename)))
data = np.transpose(mat_data['X'], [3, 0, 1, 2])
label = mat_data['y'].astype(np.uint32).squeeze()
np.place(label, label == 10, 0)
return data, label
def load_svhn(path, usage):
"""
Feature: load_svhn.
Description: load svhn.
Expectation: get data of svhn dataset.
"""
assert usage in ["train", "test", "extra", "all"]
usage_all = ["train", "test", "extra"]
data = np.array([], dtype=np.uint8)
label = np.array([], dtype=np.uint32)
if usage == "all":
for _usage in usage_all:
current_data, current_label = load_mat(_usage, path)
data = np.concatenate((data, current_data)) if data.size else current_data
label = np.concatenate((label, current_label)) if label.size else current_label
else:
data, label = load_mat(usage, path)
return data, label
def visualize_dataset(images, labels):
"""
Feature: visualize_dataset.
Description: visualize svhn dataset.
Expectation: plot images.
"""
num_samples = len(images)
for i in range(num_samples):
plt.subplot(1, num_samples, i + 1)
plt.imshow(images[i])
plt.title(labels[i])
plt.show()
def test_svhn_content_check():
"""
Feature: test_svhn_content_check.
Description: validate SVHNDataset image readings.
Expectation: get correct number of data and correct content.
"""
logger.info("Test SVHNDataset Op with content check")
train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2, shuffle=False)
images, labels = load_svhn(DATA_DIR, "train")
num_iter = 0
# in this example, each dictionary has keys "image" and "label".
for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 2
test_data = ds.SVHNDataset(DATA_DIR, "test", num_samples=4, shuffle=False)
images, labels = load_svhn(DATA_DIR, "test")
num_iter = 0
# in this example, each dictionary has keys "image" and "label".
for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 4
extra_data = ds.SVHNDataset(DATA_DIR, "extra", num_samples=6, shuffle=False)
images, labels = load_svhn(DATA_DIR, "extra")
num_iter = 0
# in this example, each dictionary has keys "image" and "label".
for i, data in enumerate(extra_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 6
all_data = ds.SVHNDataset(DATA_DIR, "all", num_samples=12, shuffle=False)
images, labels = load_svhn(DATA_DIR, "all")
num_iter = 0
# in this example, each dictionary has keys "image" and "label".
for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data["image"], images[i])
np.testing.assert_array_equal(data["label"], labels[i])
num_iter += 1
assert num_iter == 12
def test_svhn_basic():
"""
Feature: test_svhn_basic.
Description: test basic usage of SVHNDataset.
Expectation: get correct number of data.
"""
logger.info("Test SVHNDataset Op")
# case 1: test loading whole dataset.
default_data = ds.SVHNDataset(DATA_DIR)
num_iter = 0
for _ in default_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 12
all_data = ds.SVHNDataset(DATA_DIR, "all")
num_iter = 0
for _ in all_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 12
# case 2: test num_samples.
train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2)
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 2
# case 3: test repeat.
train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2)
train_data = train_data.repeat(5)
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 10
# case 4: test batch with drop_remainder=False.
train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2)
assert train_data.get_dataset_size() == 2
assert train_data.get_batch_size() == 1
train_data = train_data.batch(batch_size=2) # drop_remainder is default to be False.
assert train_data.get_batch_size() == 2
assert train_data.get_dataset_size() == 1
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 1
# case 5: test batch with drop_remainder=True.
train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2)
assert train_data.get_dataset_size() == 2
assert train_data.get_batch_size() == 1
train_data = train_data.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped.
assert train_data.get_dataset_size() == 1
assert train_data.get_batch_size() == 2
num_iter = 0
for _ in train_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 1
# case 6: test num_parallel_workers>1
shared_mem_flag = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
all_data = ds.SVHNDataset(DATA_DIR, "all", num_parallel_workers=2)
num_iter = 0
for _ in all_data.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 12
ds.config.set_enable_shared_mem(shared_mem_flag)
# case 7: test map method
input_columns = ["image"]
image1, image2 = [], []
dataset = ds.SVHNDataset(DATA_DIR, "all")
for data in dataset.create_dict_iterator(output_numpy=True):
image1.extend(data['image'])
operations = [(lambda x: x + x)]
dataset = dataset.map(input_columns=input_columns, operations=operations)
for data in dataset.create_dict_iterator(output_numpy=True):
image2.extend(data['image'])
assert len(image1) == len(image2)
# case 8: test batch
dataset = ds.SVHNDataset(DATA_DIR, "all")
dataset = dataset.batch(batch_size=3)
num_iter = 0
for data in dataset.create_dict_iterator(output_numpy=True):
num_iter += 1
assert num_iter == 4
def test_svhn_sequential_sampler():
"""
Feature: test_svhn_sequential_sampler.
Description: test usage of SVHNDataset with SequentialSampler.
Expectation: get correct number of data.
"""
logger.info("Test SVHNDataset Op with SequentialSampler")
num_samples = 2
sampler = ds.SequentialSampler(num_samples=num_samples)
train_data_1 = ds.SVHNDataset(DATA_DIR, "train", sampler=sampler)
train_data_2 = ds.SVHNDataset(DATA_DIR, "train", shuffle=False, num_samples=num_samples)
label_list_1, label_list_2 = [], []
num_iter = 0
for item1, item2 in zip(train_data_1.create_dict_iterator(num_epochs=1),
train_data_2.create_dict_iterator(num_epochs=1)):
label_list_1.append(item1["label"].asnumpy())
label_list_2.append(item2["label"].asnumpy())
num_iter += 1
np.testing.assert_array_equal(label_list_1, label_list_2)
assert num_iter == num_samples
def test_svhn_exception():
"""
Feature: test_svhn_exception.
Description: test error cases for SVHNDataset.
Expectation: raise exception.
"""
logger.info("Test error cases for SVHNDataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.SVHNDataset(DATA_DIR, "train", shuffle=False, sampler=ds.SequentialSampler(1))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.SVHNDataset(DATA_DIR, "train", sampler=ds.SequentialSampler(1), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.SVHNDataset(DATA_DIR, "train", num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.SVHNDataset(DATA_DIR, "train", shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.SVHNDataset(DATA_DIR, "train", num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.SVHNDataset(DATA_DIR, "train", num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.SVHNDataset(DATA_DIR, "train", num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.SVHNDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.SVHNDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.SVHNDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.SVHNDataset(DATA_DIR, "train", num_shards=2, shard_id="0")
error_msg_8 = "does not exist or permission denied!"
with pytest.raises(ValueError, match=error_msg_8):
train_data = ds.SVHNDataset(WRONG_DIR, "train")
for _ in train_data.__iter__():
pass
def test_svhn_visualize(plot=False):
"""
Feature: test_svhn_visualize.
Description: visualize SVHNDataset results.
Expectation: get correct number of data and plot them.
"""
logger.info("Test SVHNDataset visualization")
train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert image.shape == (32, 32, 3)
assert image.dtype == np.uint8
assert label.dtype == np.uint32
num_iter += 1
assert num_iter == 2
if plot:
visualize_dataset(image_list, label_list)
def test_svhn_usage():
"""
Feature: test_svhn_usage.
Description: validate SVHNDataset image readings.
Expectation: get correct number of data.
"""
logger.info("Test SVHNDataset usage flag")
def test_config(usage, path=None):
path = DATA_DIR if path is None else path
try:
data = ds.SVHNDataset(path, usage=usage, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("train") == 2
assert test_config("test") == 4
assert test_config("extra") == 6
assert test_config("all") == 12
assert "usage is not within the valid set of ['train', 'test', 'extra', 'all']" in test_config("invalid")
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
data_path = None
# the following tests on the entire datasets.
if data_path is not None:
assert test_config("train", data_path) == 2
assert test_config("test", data_path) == 4
assert test_config("extra", data_path) == 6
assert test_config("all", data_path) == 12
assert ds.SVHNDataset(data_path, usage="train").get_dataset_size() == 2
assert ds.SVHNDataset(data_path, usage="test").get_dataset_size() == 4
assert ds.SVHNDataset(data_path, usage="extra").get_dataset_size() == 6
assert ds.SVHNDataset(data_path, usage="all").get_dataset_size() == 12
if __name__ == '__main__':
test_svhn_content_check()
test_svhn_basic()
test_svhn_sequential_sampler()
test_svhn_exception()
test_svhn_visualize(plot=True)
test_svhn_usage()