diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index c57afc4a73b..84f966051f9 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -23,6 +23,7 @@ import numpy as np import mindspore._c_dataengine as cde from ..engine import samplers + # POS_INT_MIN is used to limit values from starting from 0 POS_INT_MIN = 1 UINT8_MAX = 255 @@ -289,7 +290,6 @@ def check_sampler_shuffle_shard_options(param_dict): shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') num_samples = param_dict.get('num_samples') - check_sampler(sampler) if sampler is not None: if shuffle is not None: @@ -348,6 +348,7 @@ def check_num_samples(value): raise ValueError( "num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX)) + def validate_dataset_param_value(param_list, param_dict, param_type): for param_name in param_list: if param_dict.get(param_name) is not None: @@ -387,6 +388,7 @@ def check_tensor_op(param, param_name): if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) + def check_sampler(sampler): """ Check if the sampler is of valid input. @@ -419,5 +421,6 @@ def check_sampler(sampler): if not (builtin or base_sampler or list_num): raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers") + def replace_none(value, default): return value if value is not None else default diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 75e44ea77e6..0395d8675ac 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -73,11 +73,11 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) if isinstance(input_sampler, BuiltinSampler): return input_sampler - if _is_iterable(input_sampler): + if not isinstance(input_sampler, str) and _is_iterable(input_sampler): return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples)) if isinstance(input_sampler, int): - return [input_sampler] - raise ValueError('Unsupported sampler object ({})'.format(input_sampler)) + return SubsetSampler([input_sampler]) + raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) if shuffle is None: if num_shards is not None: # If shuffle is not specified, sharding enabled, use distributed random sampler @@ -644,9 +644,9 @@ class SubsetSampler(BuiltinSampler): indices = [indices] for i, item in enumerate(indices): - if not isinstance(item, numbers.Number): - raise TypeError("type of indices element must be number, " - "but got w[{}]: {}, type: {}.".format(i, item, type(item))) + if not isinstance(item, int): + raise TypeError("SubsetSampler: Type of indices element must be int, " + "but got list[{}]: {}, type: {}.".format(i, item, type(item))) if num_samples is not None: if not isinstance(num_samples, int): diff --git a/tests/ut/python/dataset/test_datasets_celeba.py b/tests/ut/python/dataset/test_datasets_celeba.py index 1af9b970d96..4f20d6873fe 100644 --- a/tests/ut/python/dataset/test_datasets_celeba.py +++ b/tests/ut/python/dataset/test_datasets_celeba.py @@ -179,7 +179,7 @@ def test_celeba_sampler_exception(): pass assert False except TypeError as e: - assert "Argument" in str(e) + assert "Unsupported sampler object of type ()" in str(e) if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 3555409ee75..ef71762cf81 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -274,6 +274,26 @@ def test_sampler_list(): dataset_equal(data1, data21 + data22 + data23, 0) + data3 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=1) + dataset_equal(data3, data21, 0) + + def bad_pipeline(sampler, msg): + with pytest.raises(Exception) as info: + data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=sampler) + for _ in data1: + pass + assert msg in str(info.value) + + bad_pipeline(sampler=[1.5, 7], + msg="Type of indices element must be int, but got list[0]: 1.5, type: ") + + bad_pipeline(sampler=["a", "b"], + msg="Type of indices element must be int, but got list[0]: a, type: .") + bad_pipeline(sampler="a", msg="Unsupported sampler object of type ()") + bad_pipeline(sampler="", msg="Unsupported sampler object of type ()") + bad_pipeline(sampler=np.array([1, 2]), + msg="Type of indices element must be int, but got list[0]: 1, type: .") + if __name__ == '__main__': test_sequential_sampler(True)