diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index 84f966051f9..178a4bff0a9 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -16,13 +16,11 @@ General Validators. """ import inspect -import numbers from multiprocessing import cpu_count import os import numpy as np import mindspore._c_dataengine as cde -from ..engine import samplers # POS_INT_MIN is used to limit values from starting from 0 POS_INT_MIN = 1 @@ -389,38 +387,5 @@ def check_tensor_op(param, param_name): raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) -def check_sampler(sampler): - """ - Check if the sampler is of valid input. - - Args: - param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler - - Returns: - Exception: TypeError if error - """ - builtin = False - base_sampler = False - list_num = False - if sampler is not None: - if isinstance(sampler, samplers.BuiltinSampler): - builtin = True - elif isinstance(sampler, samplers.Sampler): - base_sampler = True - else: - # check for list of numbers - list_num = True - # subset sampler check - subset_sampler = sampler - if not isinstance(sampler, list): - subset_sampler = [sampler] - - for _, item in enumerate(subset_sampler): - if not isinstance(item, numbers.Number): - list_num = False - if not (builtin or base_sampler or list_num): - raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers") - - def replace_none(value, default): return value if value is not None else default diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 0395d8675ac..39f9651216d 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -41,22 +41,6 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): Sampler, sampler selected based on user input. """ - def _is_iterable(obj): - try: - iter(obj) - except TypeError: - return False - return True - - def _get_sample_ids_as_list(sampler, number_of_samples=None): - if number_of_samples is None: - return list(sampler) - - if isinstance(sampler, list): - return sampler[:number_of_samples] - - return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] - if input_sampler is not None: # If the user provided a sampler, then it doesn't matter what the other args are because # we are being asked specifically to use the given sampler. @@ -73,11 +57,8 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) if isinstance(input_sampler, BuiltinSampler): return input_sampler - if not isinstance(input_sampler, str) and _is_iterable(input_sampler): - return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples)) - if isinstance(input_sampler, int): - return SubsetSampler([input_sampler]) - raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) + return SubsetSampler(input_sampler, num_samples) + if shuffle is None: if num_shards is not None: # If shuffle is not specified, sharding enabled, use distributed random sampler @@ -640,11 +621,31 @@ class SubsetSampler(BuiltinSampler): """ def __init__(self, indices, num_samples=None): - if not isinstance(indices, list): + def _is_iterable(obj): + try: + iter(obj) + except TypeError: + return False + return True + + def _get_sample_ids_as_list(sampler, number_of_samples=None): + if number_of_samples is None: + return list(sampler) + + if isinstance(sampler, list): + return sampler[:number_of_samples] + + return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] + + if not isinstance(indices, str) and _is_iterable(indices): + indices = _get_sample_ids_as_list(indices, num_samples) + elif isinstance(indices, int): indices = [indices] + else: + raise TypeError('Unsupported sampler object of type ({})'.format(type(indices))) for i, item in enumerate(indices): - if not isinstance(item, int): + if not isinstance(item, (int, np.integer)): raise TypeError("SubsetSampler: Type of indices element must be int, " "but got list[{}]: {}, type: {}.".format(i, item, type(item))) diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index ef71762cf81..031f5d93f3b 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -177,13 +177,23 @@ def test_subset_sampler(): def pipeline(): sampler = ds.SubsetSampler(indices, num_samples) data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler) + data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples) dataset_size = data.get_dataset_size() - return [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size + dataset_size2 = data.get_dataset_size() + res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size + res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2 + return res1, res2 if exception_msg is None: - res, size = pipeline() + res, res2 = pipeline() + res, size = res + res2, size2 = res2 + if not isinstance(indices, list): + indices = list(indices) assert indices[:num_samples] == res assert len(indices[:num_samples]) == size + assert indices[:num_samples] == res2 + assert len(indices[:num_samples]) == size2 else: with pytest.raises(Exception) as error_info: pipeline() @@ -205,6 +215,8 @@ def test_subset_sampler(): test_config([0, 9, 3, 2], num_samples=2) test_config([0, 9, 3, 2], num_samples=5) + test_config(np.array([1, 2, 3])) + test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]") test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]") test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") @@ -212,6 +224,9 @@ def test_subset_sampler(): # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset test_config([0, 9, 3, 2], num_samples=-1, exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)") + test_config(np.array([[1], [5]]), num_samples=10, + exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1]," + " type: .") def test_sampler_chain(): @@ -291,8 +306,8 @@ def test_sampler_list(): msg="Type of indices element must be int, but got list[0]: a, type: .") bad_pipeline(sampler="a", msg="Unsupported sampler object of type ()") bad_pipeline(sampler="", msg="Unsupported sampler object of type ()") - bad_pipeline(sampler=np.array([1, 2]), - msg="Type of indices element must be int, but got list[0]: 1, type: .") + bad_pipeline(sampler=np.array([[1, 2]]), + msg="Type of indices element must be int, but got list[0]: [1 2], type: .") if __name__ == '__main__':