forked from mindspore-Ecosystem/mindspore
!12558 Fix sampler error messages
From: @hfarahat Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c8defc2a8a
|
@ -23,6 +23,7 @@ import numpy as np
|
|||
|
||||
import mindspore._c_dataengine as cde
|
||||
from ..engine import samplers
|
||||
|
||||
# POS_INT_MIN is used to limit values from starting from 0
|
||||
POS_INT_MIN = 1
|
||||
UINT8_MAX = 255
|
||||
|
@ -289,7 +290,6 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
|
||||
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||
num_samples = param_dict.get('num_samples')
|
||||
check_sampler(sampler)
|
||||
|
||||
if sampler is not None:
|
||||
if shuffle is not None:
|
||||
|
@ -348,6 +348,7 @@ def check_num_samples(value):
|
|||
raise ValueError(
|
||||
"num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX))
|
||||
|
||||
|
||||
def validate_dataset_param_value(param_list, param_dict, param_type):
|
||||
for param_name in param_list:
|
||||
if param_dict.get(param_name) is not None:
|
||||
|
@ -387,6 +388,7 @@ def check_tensor_op(param, param_name):
|
|||
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
|
||||
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
||||
|
||||
|
||||
def check_sampler(sampler):
|
||||
"""
|
||||
Check if the sampler is of valid input.
|
||||
|
@ -419,5 +421,6 @@ def check_sampler(sampler):
|
|||
if not (builtin or base_sampler or list_num):
|
||||
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
|
||||
|
||||
|
||||
def replace_none(value, default):
|
||||
return value if value is not None else default
|
||||
|
|
|
@ -73,11 +73,11 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|||
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
||||
if isinstance(input_sampler, BuiltinSampler):
|
||||
return input_sampler
|
||||
if _is_iterable(input_sampler):
|
||||
if not isinstance(input_sampler, str) and _is_iterable(input_sampler):
|
||||
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
|
||||
if isinstance(input_sampler, int):
|
||||
return [input_sampler]
|
||||
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
|
||||
return SubsetSampler([input_sampler])
|
||||
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
|
||||
if shuffle is None:
|
||||
if num_shards is not None:
|
||||
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
||||
|
@ -644,9 +644,9 @@ class SubsetSampler(BuiltinSampler):
|
|||
indices = [indices]
|
||||
|
||||
for i, item in enumerate(indices):
|
||||
if not isinstance(item, numbers.Number):
|
||||
raise TypeError("type of indices element must be number, "
|
||||
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
|
||||
if not isinstance(item, int):
|
||||
raise TypeError("SubsetSampler: Type of indices element must be int, "
|
||||
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
|
||||
|
||||
if num_samples is not None:
|
||||
if not isinstance(num_samples, int):
|
||||
|
|
|
@ -179,7 +179,7 @@ def test_celeba_sampler_exception():
|
|||
pass
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument" in str(e)
|
||||
assert "Unsupported sampler object of type (<class 'str'>)" in str(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -274,6 +274,26 @@ def test_sampler_list():
|
|||
|
||||
dataset_equal(data1, data21 + data22 + data23, 0)
|
||||
|
||||
data3 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=1)
|
||||
dataset_equal(data3, data21, 0)
|
||||
|
||||
def bad_pipeline(sampler, msg):
|
||||
with pytest.raises(Exception) as info:
|
||||
data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=sampler)
|
||||
for _ in data1:
|
||||
pass
|
||||
assert msg in str(info.value)
|
||||
|
||||
bad_pipeline(sampler=[1.5, 7],
|
||||
msg="Type of indices element must be int, but got list[0]: 1.5, type: <class 'float'>")
|
||||
|
||||
bad_pipeline(sampler=["a", "b"],
|
||||
msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
|
||||
bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
|
||||
bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
|
||||
bad_pipeline(sampler=np.array([1, 2]),
|
||||
msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
|
|
Loading…
Reference in New Issue