diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 527c468e8ff..6e9229cab63 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -41,6 +41,8 @@ import weakref import platform import psutil import numpy as np +from scipy.io import loadmat +from PIL import Image import mindspore._c_dataengine as cde from mindspore._c_expression import typing @@ -61,7 +63,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_celebadataset, check_minddataset, check_generatordataset, check_sync_wait, check_zip_dataset, \ check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \ check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \ - check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send + check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_sb_dataset from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ get_prefetch_size from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist @@ -5668,3 +5670,190 @@ class PaddedDataset(GeneratorDataset): super().__init__(dataset, column_names=dataset.column_names, num_shards=None, shard_id=None, shuffle=False) self._dataset_size = len(dataset.padded_samples) self.padded_samples = padded_samples + + +class SBDataset(GeneratorDataset): + """ + A source dataset for reading and parsing Semantic Boundaries Dataset. + + The generated dataset has two columns: :py:obj:`[image, task]`. + The tensor of column :py:obj:`image` is of the uint8 type. + The tensor of column :py:obj:`task` contains 20 images of the uint8 type if `task` is `Boundaries` otherwise + contains 1 image of the uint8 type. + + Args: + dataset_dir (str): Path to the root directory that contains the dataset. + task (str, optional): Acceptable tasks include `Boundaries` or `Segmentation` (default=`Boundaries`). + usage (str, optional): Acceptable usages include `train`, `val`, `train_noval` and `all` (default=`all`). + num_samples (int, optional): The number of images to be included in the dataset. + (default=None, all images). + num_parallel_workers (int, optional): Number of workers to read the data + (default=None, number set in the config). + shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected + order behavior shown in the table). + sampler (Sampler, optional): Object used to choose samples from the + dataset (default=None, expected order behavior shown in the table). + num_shards (int, optional): Number of shards that the dataset will be divided + into (default=None). When this argument is specified, `num_samples` reflects + the max sample number of per shard. + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument can only be specified when num_shards is also specified. + + Raises: + RuntimeError: If dataset_dir is not valid or does not contain data files. + RuntimeError: If num_parallel_workers exceeds the max thread numbers. + RuntimeError: If sampler and shuffle are specified at the same time. + RuntimeError: If sampler and sharding are specified at the same time. + RuntimeError: If num_shards is specified but shard_id is None. + RuntimeError: If shard_id is specified but num_shards is None. + ValueError: If dataset_dir is not exist. + ValueError: If task is not in [`Boundaries`, `Segmentation`]. + ValueError: If usage is not in [`train`, `val`, `train_noval`, `all`]. + ValueError: If shard_id is invalid (< 0 or >= num_shards). + + Note: + - This dataset can take in a sampler. `sampler` and `shuffle` are mutually exclusive. + The table below shows what input arguments are allowed and their expected behavior. + + .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` + :widths: 25 25 50 + :header-rows: 1 + + * - Parameter `sampler` + - Parameter `shuffle` + - Expected Order Behavior + * - None + - None + - random order + * - None + - True + - random order + * - None + - False + - sequential order + * - Sampler object + - None + - order defined by sampler + * - Sampler object + - True + - not allowed + * - Sampler object + - False + - not allowed + + Examples: + >>> sb_dataset_dir = "/path/to/sb_dataset_directory" + >>> + >>> # 1) Get all samples from Semantic Boundaries Dataset in sequence + >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, shuffle=False) + >>> + >>> # 2) Randomly select 350 samples from Semantic Boundaries Dataset + >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, num_samples=350, shuffle=True) + >>> + >>> # 3) Get samples from Semantic Boundaries Dataset for shard 0 in a 2-way distributed training + >>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, num_shards=2, shard_id=0) + >>> + >>> # In Semantic Boundaries Dataset, each dictionary has keys "image" and "task" + + About Semantic Boundaries Dataset. + | The Semantic Boundaries Dataset consists of 11355 colour images. There are 8498 images' name in the train.txt, + 2857 images' name in the val.txt and 5623 images' name in the train_noval.txt. The category cls/ + contains the Segmentation and Boundaries results of category-level, the category inst/ catains the + Segmentation and Boundaries results of instance-level. + + | You can unzip the dataset files into the following structure and read by MindSpore's API, + | . + | └── benchmark_RELEASE + | └── dataset + | ├── img + | | ├── 2008_000002.jpg + | | ├── 2008_000003.jpg + | | ├── ... + | ├── cls + | | ├── 2008_000002.mat + | | ├── 2008_000003.mat + | | ├── ... + | ├── inst + | | ├── 2008_000002.mat + | | ├── 2008_000003.mat + | | ├── ... + | ├── train.txt + | └── val.txt + + .. code-block:: + + @InProceedings{BharathICCV2011, + author = "Bharath Hariharan and Pablo Arbelaez and Lubomir Bourdev and + Subhransu Maji and Jitendra Malik", + title = "Semantic Contours from Inverse Detectors", + booktitle = "International Conference on Computer Vision (ICCV)", + year = "2011", + """ + + @check_sb_dataset + def __init__(self, dataset_dir, task='Boundaries', usage='all', num_samples=None, num_parallel_workers=1, + shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None): + dataset = _SBDataset(dataset_dir, task, usage, decode) + super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, + num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, + num_shards=num_shards, shard_id=shard_id) + + +class _SBDataset(): + """ + Dealing with the data file with .mat extension, and return one row in tuple (image, task) each time. + """ + + def __init__(self, dataset_dir, task, usage, decode): + self.column_list = ['image', 'task'] + self.task = task + self.images_path = os.path.join(dataset_dir, 'img') + self.cls_path = os.path.join(dataset_dir, 'cls') + self._loadmat = loadmat + self.categories = 20 + self.decode = replace_none(decode, False) + + if usage == "all": + image_names = [] + for item in ["train", "val"]: + usage_path = os.path.join(dataset_dir, item + '.txt') + if not os.path.exists(usage_path): + raise FileNotFoundError("SBDataset: {0} not found".format(usage_path)) + with open(usage_path, 'r') as f: + image_names += [x.strip() for x in f.readlines()] + else: + usage_path = os.path.join(dataset_dir, usage + '.txt') + if not os.path.exists(usage_path): + raise FileNotFoundError("SBDataset: {0} not found".format(usage_path)) + with open(usage_path, 'r') as f: + image_names = [x.strip() for x in f.readlines()] + + self.images = [os.path.join(self.images_path, i + ".jpg") for i in image_names] + self.clss = [os.path.join(self.cls_path, i + ".mat") for i in image_names] + + if len(self.images) != len(self.clss): + raise ValueError("SBDataset: images count not equal to cls count") + + self._get_data = self._get_boundaries_data if self.task == "Boundaries" else self._get_segmentation_data + self._get_item = self._get_decode_item if self.decode else self._get_undecode_item + + def _get_boundaries_data(self, mat_path): + mat_data = self._loadmat(mat_path) + return np.concatenate([np.expand_dims(mat_data['GTcls'][0][self.task][0][i][0].toarray(), axis=0) + for i in range(self.categories)], axis=0) + + def _get_segmentation_data(self, mat_path): + mat_data = self._loadmat(mat_path) + return Image.fromarray(mat_data['GTcls'][0][self.task][0]) + + def _get_decode_item(self, idx): + return Image.open(self.images[idx]).convert('RGB'), self._get_data(self.clss[idx]) + + def _get_undecode_item(self, idx): + return np.fromfile(self.images[idx], dtype=np.uint8), self._get_data(self.clss[idx]) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + return self._get_item(idx) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index ad892842889..753a449ab05 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1329,3 +1329,34 @@ def check_to_device_send(method): return method(self, *args, **kwargs) return new_method + + +def check_sb_dataset(method): + """A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + nreq_param_bool = ['shuffle', 'decode'] + + dataset_dir = param_dict.get('dataset_dir') + check_dir(dataset_dir) + + usage = param_dict.get('usage') + if usage is not None: + check_valid_str(usage, ["train", "val", "train_noval", "all"], "usage") + + task = param_dict.get('task') + if task is not None: + check_valid_str(task, ["Boundaries", "Segmentation"], "task") + + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + + check_sampler_shuffle_shard_options(param_dict) + + return method(self, *args, **kwargs) + + return new_method diff --git a/tests/ut/data/dataset/testSBData/sbd/cls/000001.mat b/tests/ut/data/dataset/testSBData/sbd/cls/000001.mat new file mode 100644 index 00000000000..3b990edccb1 Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/cls/000001.mat differ diff --git a/tests/ut/data/dataset/testSBData/sbd/cls/000002.mat b/tests/ut/data/dataset/testSBData/sbd/cls/000002.mat new file mode 100644 index 00000000000..ec01c88271d Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/cls/000002.mat differ diff --git a/tests/ut/data/dataset/testSBData/sbd/cls/000003.mat b/tests/ut/data/dataset/testSBData/sbd/cls/000003.mat new file mode 100644 index 00000000000..c90c1221f02 Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/cls/000003.mat differ diff --git a/tests/ut/data/dataset/testSBData/sbd/cls/000004.mat b/tests/ut/data/dataset/testSBData/sbd/cls/000004.mat new file mode 100644 index 00000000000..ec01c88271d Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/cls/000004.mat differ diff --git a/tests/ut/data/dataset/testSBData/sbd/cls/000005.mat b/tests/ut/data/dataset/testSBData/sbd/cls/000005.mat new file mode 100644 index 00000000000..ec01c88271d Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/cls/000005.mat differ diff --git a/tests/ut/data/dataset/testSBData/sbd/cls/000006.mat b/tests/ut/data/dataset/testSBData/sbd/cls/000006.mat new file mode 100644 index 00000000000..3b990edccb1 Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/cls/000006.mat differ diff --git a/tests/ut/data/dataset/testSBData/sbd/img/000001.jpg b/tests/ut/data/dataset/testSBData/sbd/img/000001.jpg new file mode 100644 index 00000000000..95e14358a4e Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/img/000001.jpg differ diff --git a/tests/ut/data/dataset/testSBData/sbd/img/000002.jpg b/tests/ut/data/dataset/testSBData/sbd/img/000002.jpg new file mode 100644 index 00000000000..66bac3d23a6 Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/img/000002.jpg differ diff --git a/tests/ut/data/dataset/testSBData/sbd/img/000003.jpg b/tests/ut/data/dataset/testSBData/sbd/img/000003.jpg new file mode 100644 index 00000000000..1d9cf152348 Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/img/000003.jpg differ diff --git a/tests/ut/data/dataset/testSBData/sbd/img/000004.jpg b/tests/ut/data/dataset/testSBData/sbd/img/000004.jpg new file mode 100644 index 00000000000..15933d38e3a Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/img/000004.jpg differ diff --git a/tests/ut/data/dataset/testSBData/sbd/img/000005.jpg b/tests/ut/data/dataset/testSBData/sbd/img/000005.jpg new file mode 100644 index 00000000000..7fbcd6bbc09 Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/img/000005.jpg differ diff --git a/tests/ut/data/dataset/testSBData/sbd/img/000006.jpg b/tests/ut/data/dataset/testSBData/sbd/img/000006.jpg new file mode 100644 index 00000000000..95e14358a4e Binary files /dev/null and b/tests/ut/data/dataset/testSBData/sbd/img/000006.jpg differ diff --git a/tests/ut/data/dataset/testSBData/sbd/train.txt b/tests/ut/data/dataset/testSBData/sbd/train.txt new file mode 100644 index 00000000000..a8057cda5e4 --- /dev/null +++ b/tests/ut/data/dataset/testSBData/sbd/train.txt @@ -0,0 +1,4 @@ +000001 +000002 +000003 +000004 \ No newline at end of file diff --git a/tests/ut/data/dataset/testSBData/sbd/train_noval.txt b/tests/ut/data/dataset/testSBData/sbd/train_noval.txt new file mode 100644 index 00000000000..36b453a8fbe --- /dev/null +++ b/tests/ut/data/dataset/testSBData/sbd/train_noval.txt @@ -0,0 +1,4 @@ +000001 +000003 +000005 +000006 \ No newline at end of file diff --git a/tests/ut/data/dataset/testSBData/sbd/val.txt b/tests/ut/data/dataset/testSBData/sbd/val.txt new file mode 100644 index 00000000000..30acc75d240 --- /dev/null +++ b/tests/ut/data/dataset/testSBData/sbd/val.txt @@ -0,0 +1,2 @@ +000005 +000006 \ No newline at end of file diff --git a/tests/ut/python/dataset/test_datasets_sbd.py b/tests/ut/python/dataset/test_datasets_sbd.py new file mode 100644 index 00000000000..3801cfa669b --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_sbd.py @@ -0,0 +1,219 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import math + +import matplotlib.pyplot as plt +import numpy as np +import pytest + +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.vision.c_transforms as c_vision + + +DATASET_DIR = "../data/dataset/testSBData/sbd" + + +def visualize_dataset(images, labels, task): + """ + Helper function to visualize the dataset samples + """ + image_num = len(images) + subplot_rows = 1 if task == "Segmentation" else 4 + for i in range(image_num): + plt.imshow(images[i]) + plt.title('Original') + plt.savefig('./sbd_original_{}.jpg'.format(str(i))) + if task == "Segmentation": + plt.imshow(labels[i]) + plt.title(task) + plt.savefig('./sbd_segmentation_{}.jpg'.format(str(i))) + else: + b_num = labels[i].shape[0] + for j in range(b_num): + plt.subplot(subplot_rows, math.ceil(b_num / subplot_rows), j + 1) + plt.imshow(labels[i][j]) + plt.savefig('./sbd_boundaries_{}.jpg'.format(str(i))) + plt.close() + + +def test_sbd_basic01(plot=False): + """ + Validate SBDataset with different usage + """ + task = 'Segmentation' # Boundaries, Segmentation + data = ds.SBDataset(DATASET_DIR, task=task, usage='all', shuffle=False, decode=True) + count = 0 + images_list = [] + task_list = [] + for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): + images_list.append(item['image']) + task_list.append(item['task']) + count = count + 1 + assert count == 6 + if plot: + visualize_dataset(images_list, task_list, task) + + data2 = ds.SBDataset(DATASET_DIR, task=task, usage='train', shuffle=False, decode=False) + count = 0 + for item in data2.create_dict_iterator(num_epochs=1, output_numpy=True): + count = count + 1 + assert count == 4 + + data3 = ds.SBDataset(DATASET_DIR, task=task, usage='val', shuffle=False, decode=False) + count = 0 + for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True): + count = count + 1 + assert count == 2 + + +def test_sbd_basic02(): + """ + Validate SBDataset with repeat and batch operation + """ + # Boundaries, Segmentation + # case 1: test num_samples + data1 = ds.SBDataset(DATASET_DIR, task='Boundaries', usage='train', num_samples=3, shuffle=False) + num_iter1 = 0 + for _ in data1.create_dict_iterator(num_epochs=1): + num_iter1 += 1 + assert num_iter1 == 3 + + # case 2: test repeat + data2 = ds.SBDataset(DATASET_DIR, task='Boundaries', usage='train', num_samples=4, shuffle=False) + data2 = data2.repeat(5) + num_iter2 = 0 + for _ in data2.create_dict_iterator(num_epochs=1): + num_iter2 += 1 + assert num_iter2 == 20 + + # case 3: test batch with drop_remainder=False + data3 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, decode=True) + resize_op = c_vision.Resize((100, 100)) + data3 = data3.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1) + data3 = data3.map(operations=resize_op, input_columns=["task"], num_parallel_workers=1) + assert data3.get_dataset_size() == 4 + assert data3.get_batch_size() == 1 + data3 = data3.batch(batch_size=3) # drop_remainder is default to be False + assert data3.get_dataset_size() == 2 + assert data3.get_batch_size() == 3 + num_iter3 = 0 + for _ in data3.create_dict_iterator(num_epochs=1): + num_iter3 += 1 + assert num_iter3 == 2 + + # case 4: test batch with drop_remainder=True + data4 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, decode=True) + resize_op = c_vision.Resize((100, 100)) + data4 = data4.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1) + data4 = data4.map(operations=resize_op, input_columns=["task"], num_parallel_workers=1) + assert data4.get_dataset_size() == 4 + assert data4.get_batch_size() == 1 + data4 = data4.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped + assert data4.get_dataset_size() == 1 + assert data4.get_batch_size() == 3 + num_iter4 = 0 + for _ in data4.create_dict_iterator(num_epochs=1): + num_iter4 += 1 + assert num_iter4 == 1 + + +def test_sbd_sequential_sampler(): + """ + Test SBDataset with SequentialSampler + """ + logger.info("Test SBDataset Op with SequentialSampler") + num_samples = 5 + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='all', sampler=sampler) + data2 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='all', shuffle=False, num_samples=num_samples) + num_iter = 0 + for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), + data2.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(item1["task"], item2["task"]) + num_iter += 1 + assert num_iter == num_samples + + +def test_sbd_exception(): + """ + Validate SBDataset with error parameters + """ + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id=0, + sampler=ds.PKSampler(3)) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id="0") + + +def test_sbd_usage(): + """ + Validate SBDataset image readings + """ + def test_config(usage): + try: + data = ds.SBDataset(DATASET_DIR, task='Segmentation', usage=usage) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("train") == 4 + assert test_config("train_noval") == 4 + assert test_config("val") == 2 + assert test_config("all") == 6 + assert "usage is not within the valid set of ['train', 'val', 'train_noval', 'all']" in test_config("invalid") + assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) + + +if __name__ == "__main__": + test_sbd_basic01() + test_sbd_basic02() + test_sbd_sequential_sampler() + test_sbd_exception() + test_sbd_usage()