!12558 Fix sampler error messages

From: @hfarahat
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-26 00:10:17 +08:00 committed by Gitee
commit c8defc2a8a
4 changed files with 31 additions and 8 deletions

View File

@ -23,6 +23,7 @@ import numpy as np
import mindspore._c_dataengine as cde
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
UINT8_MAX = 255
@ -289,7 +290,6 @@ def check_sampler_shuffle_shard_options(param_dict):
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
num_samples = param_dict.get('num_samples')
check_sampler(sampler)
if sampler is not None:
if shuffle is not None:
@ -348,6 +348,7 @@ def check_num_samples(value):
raise ValueError(
"num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX))
def validate_dataset_param_value(param_list, param_dict, param_type):
for param_name in param_list:
if param_dict.get(param_name) is not None:
@ -387,6 +388,7 @@ def check_tensor_op(param, param_name):
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
def check_sampler(sampler):
"""
Check if the sampler is of valid input.
@ -419,5 +421,6 @@ def check_sampler(sampler):
if not (builtin or base_sampler or list_num):
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
def replace_none(value, default):
return value if value is not None else default

View File

@ -73,11 +73,11 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler):
return input_sampler
if _is_iterable(input_sampler):
if not isinstance(input_sampler, str) and _is_iterable(input_sampler):
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
if isinstance(input_sampler, int):
return [input_sampler]
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
return SubsetSampler([input_sampler])
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
@ -644,9 +644,9 @@ class SubsetSampler(BuiltinSampler):
indices = [indices]
for i, item in enumerate(indices):
if not isinstance(item, numbers.Number):
raise TypeError("type of indices element must be number, "
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
if not isinstance(item, int):
raise TypeError("SubsetSampler: Type of indices element must be int, "
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
if num_samples is not None:
if not isinstance(num_samples, int):

View File

@ -179,7 +179,7 @@ def test_celeba_sampler_exception():
pass
assert False
except TypeError as e:
assert "Argument" in str(e)
assert "Unsupported sampler object of type (<class 'str'>)" in str(e)
if __name__ == '__main__':

View File

@ -274,6 +274,26 @@ def test_sampler_list():
dataset_equal(data1, data21 + data22 + data23, 0)
data3 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=1)
dataset_equal(data3, data21, 0)
def bad_pipeline(sampler, msg):
with pytest.raises(Exception) as info:
data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=sampler)
for _ in data1:
pass
assert msg in str(info.value)
bad_pipeline(sampler=[1.5, 7],
msg="Type of indices element must be int, but got list[0]: 1.5, type: <class 'float'>")
bad_pipeline(sampler=["a", "b"],
msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler=np.array([1, 2]),
msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.")
if __name__ == '__main__':
test_sequential_sampler(True)