From 1185218335656dc73562afd757ed0ac764253a18 Mon Sep 17 00:00:00 2001 From: hesham Date: Fri, 29 Jan 2021 11:23:53 -0500 Subject: [PATCH] Support list of IDs as a sampler --- mindspore/dataset/core/validator_helpers.py | 5 +- mindspore/dataset/engine/datasets.py | 175 ++------------------ mindspore/dataset/engine/samplers.py | 76 +++++++++ tests/ut/python/dataset/test_sampler.py | 11 ++ 4 files changed, 105 insertions(+), 162 deletions(-) diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index 6497bb6021a..fd1e4a1db63 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ import os import numpy as np import mindspore._c_dataengine as cde -from ..engine import samplers # POS_INT_MIN is used to limit values from starting from 0 POS_INT_MIN = 1 @@ -290,8 +289,6 @@ def check_sampler_shuffle_shard_options(param_dict): num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') num_samples = param_dict.get('num_samples') - type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler") - if sampler is not None: if shuffle is not None: raise RuntimeError("sampler and shuffle cannot be specified at the same time.") diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 2718a313090..681ca37d9ef 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -2708,7 +2708,7 @@ class ConcatDataset(Dataset): self.dataset_size = None - self._sampler = _select_sampler(None, sampler, None, None, None) + self._sampler = samplers.select_sampler(None, sampler, None, None, None) cumulative_samples_nums = 0 for index, child in enumerate(self.children): if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None: @@ -2990,65 +2990,6 @@ class RangeDataset(MappableDataset): return self.dataset_size -def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False): - """ - Create sampler based on user input. - - Args: - num_samples (int): Number of samples. - input_sampler (Union[Iterable, Sampler]): Sampler from user. - shuffle (bool): Shuffle. - num_shards (int): Number of shard for sharding. - shard_id (int): Shard ID. - non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). - - Returns: - Sampler, sampler selected based on user input. - """ - if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): - return None - - if input_sampler is not None: - # If the user provided a sampler, then it doesn't matter what the other args are because - # we are being asked specifically to use the given sampler. - # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all - # be None. Consider this example: - # sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle) - # data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1) - # In this case, the user has given different sample-related arguments that contradict each other. - # To prevent this, only allow the user to manually specify the sampler if those arguments are all None - if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler, - samplers.RandomSampler, samplers.SubsetRandomSampler, - samplers.WeightedRandomSampler, samplers.Sampler)) and - (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))): - raise ValueError( - 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' - ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) - return input_sampler - if shuffle is None: - if num_shards is not None: - # If shuffle is not specified, sharding enabled, use distributed random sampler - shuffle = True - return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) - # If shuffle is not specified, sharding disabled, use random sampler - if num_samples is not None: - return samplers.RandomSampler(replacement=True, num_samples=num_samples) - return samplers.RandomSampler(num_samples=num_samples) - if shuffle is True: - if num_shards is not None: - # If shuffle enabled, sharding enabled, use distributed random sampler - return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) - # If shuffle enabled, sharding disabled, use random sampler - if num_samples is not None: - return samplers.RandomSampler(replacement=True, num_samples=num_samples) - return samplers.RandomSampler(num_samples=num_samples) - if num_shards is not None: - # If shuffle disabled, sharding enabled, use distributed sequential sampler - return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) - # If shuffle disabled, sharding disabled, use sequential sampler - return samplers.SequentialSampler(num_samples=num_samples) - - class ImageFolderDataset(MappableDataset): """ A source dataset that reads images from a tree of directories. @@ -3144,7 +3085,7 @@ class ImageFolderDataset(MappableDataset): super().__init__(num_parallel_workers=num_parallel_workers) self.dataset_dir = dataset_dir - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.shuffle_level = shuffle self.extensions = replace_none(extensions, []) @@ -3293,7 +3234,7 @@ class MnistDataset(MappableDataset): self.dataset_dir = dataset_dir self.usage = replace_none(usage, "all") - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.shuffle_level = shuffle self.num_shards = num_shards @@ -3386,7 +3327,7 @@ class MindDataset(MappableDataset): samplers.SequentialSampler)) is False: raise ValueError("The sampler is not supported yet.") - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.padded_sample = padded_sample @@ -3470,27 +3411,6 @@ def _generator_fn(generator, num_samples): yield val -def _py_sampler_fn(sampler, num_samples, dataset): - """ - Generator function wrapper for mappable dataset with Python sampler. - """ - if num_samples is not None: - sampler_iter = iter(sampler) - for _ in range(num_samples): - try: - idx = next(sampler_iter) - except StopIteration: - return - val = dataset[idx] - # convert output tensors to ndarrays - yield tuple([np.array(x, copy=False) for x in val]) - else: - for i in sampler: - val = dataset[i] - # convert output tensors to ndarrays - yield tuple([np.array(x, copy=False) for x in val]) - - def _cpp_sampler_fn(sample_ids, dataset): """ Generator function wrapper for mappable dataset with cpp sampler. @@ -3518,31 +3438,6 @@ def _cpp_sampler_fn_mp(sample_ids, sample_fn): return sample_fn.process(sample_ids) -def _py_sampler_fn_mp(sampler, num_samples, sample_fn): - """ - Multiprocessing generator function wrapper for mappable dataset with Python sampler. - """ - indices = _fetch_py_sampler_indices(sampler, num_samples) - return sample_fn.process(indices) - - -def _fetch_py_sampler_indices(sampler, num_samples): - """ - Indice fetcher for Python sampler. - """ - if num_samples is not None: - sampler_iter = iter(sampler) - ret = [] - for _ in range(num_samples): - try: - val = next(sampler_iter) - ret.append(val) - except StopIteration: - break - return ret - return [i for i in sampler] - - def _fill_worker_indices(workers, indices, idx): """ Worker index queue filler, fill worker index queue in round robin order. @@ -3865,7 +3760,7 @@ class GeneratorDataset(MappableDataset): python_multiprocessing=True): super().__init__(num_parallel_workers=num_parallel_workers) self.source = source - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.num_shards = num_shards self.python_multiprocessing = python_multiprocessing @@ -3912,26 +3807,11 @@ class GeneratorDataset(MappableDataset): if hasattr(self, "__total_batch__"): new_op.__total_batch__ = self.__total_batch__ if new_op.sampler is not None and hasattr(self.source, "__getitem__"): - if isinstance(new_op.sampler, samplers.BuiltinSampler): - if new_op.num_parallel_workers > 1: - sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) - new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) - else: - new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) + if new_op.num_parallel_workers > 1: + sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) + new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) else: - # the sampler provided is not a built-in sampler, it is a list of sample_ids - new_op.sample_ids = new_op.sampler - # since list of sample_ids are not passed to c++, we need to find the proper len here - new_op.source_len = min(self.source_len, len(new_op.sample_ids)) if self.source_len != -1 else len( - new_op.sample_ids) - new_op.source_len = min(self.source_len, - new_op.num_samples) if new_op.num_samples is not None else new_op.source_len - new_op.sampler = None - if new_op.num_parallel_workers > 1: - sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) - new_op.source = (lambda: _py_sampler_fn_mp(new_op.sample_ids, new_op.num_samples, sample_fn)) - else: - new_op.source = (lambda: _py_sampler_fn(new_op.sample_ids, new_op.num_samples, self.source)) + new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) new_op.sample_fn = sample_fn else: try: @@ -4089,13 +3969,6 @@ class TFRecordDataset(SourceDataset): self.shuffle_level = shuffle self.shuffle_files = True - # The TF record dataset does not directly support a sampler. It has provided sampling arguments - # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in - # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. - sampler_shuffle = self.shuffle_files - sampler = None - self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, - non_mappable=True) self.shard_equal_rows = replace_none(shard_equal_rows, False) def get_args(self): @@ -4231,7 +4104,7 @@ class ManifestDataset(MappableDataset): super().__init__(num_parallel_workers=num_parallel_workers) self.dataset_file = dataset_file - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) if class_indexing is not None and not isinstance(class_indexing, dict): raise RuntimeError("class_indexing must be a dictionary.") @@ -4396,7 +4269,7 @@ class Cifar10Dataset(MappableDataset): self.dataset_dir = dataset_dir self.usage = replace_none(usage, "all") - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.num_shards = num_shards self.shard_id = shard_id @@ -4535,7 +4408,7 @@ class Cifar100Dataset(MappableDataset): self.dataset_dir = dataset_dir self.usage = replace_none(usage, "all") - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.num_shards = num_shards self.shard_id = shard_id @@ -4607,8 +4480,6 @@ class RandomDataset(SourceDataset): super().__init__(num_parallel_workers=num_parallel_workers) self.schema = schema self.columns_list = replace_none(columns_list, []) - sampler = None - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True) self.num_samples = num_samples self.total_rows = total_rows @@ -4900,7 +4771,7 @@ class VOCDataset(MappableDataset): self.task = replace_none(task, "Segmentation") self.usage = replace_none(usage, "train") self.class_indexing = class_indexing - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.decode = replace_none(decode, False) self.shuffle_level = shuffle @@ -5092,7 +4963,7 @@ class CocoDataset(MappableDataset): self.dataset_dir = dataset_dir self.annotation_file = annotation_file self.task = replace_none(task, "Detection") - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.decode = replace_none(decode, False) self.shuffle_level = shuffle @@ -5224,7 +5095,7 @@ class CelebADataset(MappableDataset): extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers=num_parallel_workers) self.dataset_dir = dataset_dir - self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_parallel_workers = num_parallel_workers self.decode = replace_none(decode, False) self.extensions = replace_none(extensions, []) @@ -5596,12 +5467,7 @@ class CSVDataset(SourceDataset): self.shuffle_files = True self.cache = cache - # The CSV dataset does not directly support a sampler. It has provided sampling arguments - # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in - # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. - sampler = None - self.sampler = _select_sampler(num_samples, sampler, self.shuffle_files, num_shards, shard_id, - non_mappable=True) + self.num_shards = replace_none(num_shards, 1) self.shard_id = replace_none(shard_id, 0) self.num_samples = replace_none(num_samples, 0) @@ -5715,13 +5581,6 @@ class TextFileDataset(SourceDataset): self.shard_id = replace_none(shard_id, 0) self.cache = cache - # The text file dataset does not directly support a sampler. It has provided sampling arguments - # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in - # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. - sampler_shuffle = self.shuffle_files - sampler = None - self.sampler = _select_sampler(num_samples, sampler, sampler_shuffle, num_shards, shard_id, - non_mappable=True) def get_args(self): args = super().get_args() diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index aeecb2695ae..ac9c00a7951 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -25,6 +25,82 @@ import mindspore._c_dataengine as cde import mindspore.dataset as ds +def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): + """ + Create sampler based on user input. + + Args: + num_samples (int): Number of samples. + input_sampler (Union[Iterable, Sampler]): Sampler from user. + shuffle (bool): Shuffle. + num_shards (int): Number of shard for sharding. + shard_id (int): Shard ID. + + Returns: + Sampler, sampler selected based on user input. + """ + + def _is_iterable(obj): + try: + iter(obj) + except TypeError: + return False + return True + + def _get_sample_ids_as_list(sampler, number_of_samples=None): + if number_of_samples is None: + return list(sampler) + + if isinstance(sampler, list): + return sampler[:number_of_samples] + + return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] + + if input_sampler is not None: + # If the user provided a sampler, then it doesn't matter what the other args are because + # we are being asked specifically to use the given sampler. + # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all + # be None. Consider this example: + # sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle) + # data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1) + # In this case, the user has given different sample-related arguments that contradict each other. + # To prevent this, only allow the user to manually specify the sampler if those arguments are all None + if (isinstance(input_sampler, BuiltinSampler) and + (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))): + raise ValueError( + 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' + ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) + if isinstance(input_sampler, BuiltinSampler): + return input_sampler + if _is_iterable(input_sampler): + return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples)) + if isinstance(input_sampler, int): + return [input_sampler] + raise ValueError('Unsupported sampler object ({})'.format(input_sampler)) + if shuffle is None: + if num_shards is not None: + # If shuffle is not specified, sharding enabled, use distributed random sampler + shuffle = True + return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) + # If shuffle is not specified, sharding disabled, use random sampler + if num_samples is not None: + return RandomSampler(replacement=True, num_samples=num_samples) + return RandomSampler(num_samples=num_samples) + if shuffle is True: + if num_shards is not None: + # If shuffle enabled, sharding enabled, use distributed random sampler + return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) + # If shuffle enabled, sharding disabled, use random sampler + if num_samples is not None: + return RandomSampler(replacement=True, num_samples=num_samples) + return RandomSampler(num_samples=num_samples) + if num_shards is not None: + # If shuffle disabled, sharding enabled, use distributed sequential sampler + return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) + # If shuffle disabled, sharding disabled, use sequential sampler + return SequentialSampler(num_samples=num_samples) + + class BuiltinSampler: """ Base class for BuiltinSampler. diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 83e0ec49d8e..9a2e686f5e2 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -17,6 +17,7 @@ import pytest import mindspore.dataset as ds from mindspore import log as logger +from util import dataset_equal # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] @@ -265,6 +266,15 @@ def test_distributed_sampler_invalid_offset(): assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value) +def test_sampler_list(): + data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=[1, 3, 5]) + data21 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(2).skip(1) + data22 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(4).skip(3) + data23 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(6).skip(5) + + dataset_equal(data1, data21 + data22 + data23, 0) + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) @@ -276,3 +286,4 @@ if __name__ == '__main__': test_sampler_chain() test_add_sampler_invalid_input() test_distributed_sampler_invalid_offset() + test_sampler_list()