forked from mindspore-Ecosystem/mindspore
Fix sampler supporting numpy input
This commit is contained in:
parent
692d158f5c
commit
616d1a1f3d
|
@ -16,13 +16,11 @@
|
|||
General Validators.
|
||||
"""
|
||||
import inspect
|
||||
import numbers
|
||||
from multiprocessing import cpu_count
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
from ..engine import samplers
|
||||
|
||||
# POS_INT_MIN is used to limit values from starting from 0
|
||||
POS_INT_MIN = 1
|
||||
|
@ -389,38 +387,5 @@ def check_tensor_op(param, param_name):
|
|||
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
||||
|
||||
|
||||
def check_sampler(sampler):
|
||||
"""
|
||||
Check if the sampler is of valid input.
|
||||
|
||||
Args:
|
||||
param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler
|
||||
|
||||
Returns:
|
||||
Exception: TypeError if error
|
||||
"""
|
||||
builtin = False
|
||||
base_sampler = False
|
||||
list_num = False
|
||||
if sampler is not None:
|
||||
if isinstance(sampler, samplers.BuiltinSampler):
|
||||
builtin = True
|
||||
elif isinstance(sampler, samplers.Sampler):
|
||||
base_sampler = True
|
||||
else:
|
||||
# check for list of numbers
|
||||
list_num = True
|
||||
# subset sampler check
|
||||
subset_sampler = sampler
|
||||
if not isinstance(sampler, list):
|
||||
subset_sampler = [sampler]
|
||||
|
||||
for _, item in enumerate(subset_sampler):
|
||||
if not isinstance(item, numbers.Number):
|
||||
list_num = False
|
||||
if not (builtin or base_sampler or list_num):
|
||||
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
|
||||
|
||||
|
||||
def replace_none(value, default):
|
||||
return value if value is not None else default
|
||||
|
|
|
@ -41,22 +41,6 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|||
Sampler, sampler selected based on user input.
|
||||
"""
|
||||
|
||||
def _is_iterable(obj):
|
||||
try:
|
||||
iter(obj)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_sample_ids_as_list(sampler, number_of_samples=None):
|
||||
if number_of_samples is None:
|
||||
return list(sampler)
|
||||
|
||||
if isinstance(sampler, list):
|
||||
return sampler[:number_of_samples]
|
||||
|
||||
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
||||
|
||||
if input_sampler is not None:
|
||||
# If the user provided a sampler, then it doesn't matter what the other args are because
|
||||
# we are being asked specifically to use the given sampler.
|
||||
|
@ -73,11 +57,8 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|||
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
||||
if isinstance(input_sampler, BuiltinSampler):
|
||||
return input_sampler
|
||||
if not isinstance(input_sampler, str) and _is_iterable(input_sampler):
|
||||
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
|
||||
if isinstance(input_sampler, int):
|
||||
return SubsetSampler([input_sampler])
|
||||
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
|
||||
return SubsetSampler(input_sampler, num_samples)
|
||||
|
||||
if shuffle is None:
|
||||
if num_shards is not None:
|
||||
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
||||
|
@ -640,11 +621,31 @@ class SubsetSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def __init__(self, indices, num_samples=None):
|
||||
if not isinstance(indices, list):
|
||||
def _is_iterable(obj):
|
||||
try:
|
||||
iter(obj)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_sample_ids_as_list(sampler, number_of_samples=None):
|
||||
if number_of_samples is None:
|
||||
return list(sampler)
|
||||
|
||||
if isinstance(sampler, list):
|
||||
return sampler[:number_of_samples]
|
||||
|
||||
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
||||
|
||||
if not isinstance(indices, str) and _is_iterable(indices):
|
||||
indices = _get_sample_ids_as_list(indices, num_samples)
|
||||
elif isinstance(indices, int):
|
||||
indices = [indices]
|
||||
else:
|
||||
raise TypeError('Unsupported sampler object of type ({})'.format(type(indices)))
|
||||
|
||||
for i, item in enumerate(indices):
|
||||
if not isinstance(item, int):
|
||||
if not isinstance(item, (int, np.integer)):
|
||||
raise TypeError("SubsetSampler: Type of indices element must be int, "
|
||||
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
|
||||
|
||||
|
|
|
@ -177,13 +177,23 @@ def test_subset_sampler():
|
|||
def pipeline():
|
||||
sampler = ds.SubsetSampler(indices, num_samples)
|
||||
data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler)
|
||||
data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples)
|
||||
dataset_size = data.get_dataset_size()
|
||||
return [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
|
||||
dataset_size2 = data.get_dataset_size()
|
||||
res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
|
||||
res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2
|
||||
return res1, res2
|
||||
|
||||
if exception_msg is None:
|
||||
res, size = pipeline()
|
||||
res, res2 = pipeline()
|
||||
res, size = res
|
||||
res2, size2 = res2
|
||||
if not isinstance(indices, list):
|
||||
indices = list(indices)
|
||||
assert indices[:num_samples] == res
|
||||
assert len(indices[:num_samples]) == size
|
||||
assert indices[:num_samples] == res2
|
||||
assert len(indices[:num_samples]) == size2
|
||||
else:
|
||||
with pytest.raises(Exception) as error_info:
|
||||
pipeline()
|
||||
|
@ -205,6 +215,8 @@ def test_subset_sampler():
|
|||
test_config([0, 9, 3, 2], num_samples=2)
|
||||
test_config([0, 9, 3, 2], num_samples=5)
|
||||
|
||||
test_config(np.array([1, 2, 3]))
|
||||
|
||||
test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]")
|
||||
test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]")
|
||||
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
|
||||
|
@ -212,6 +224,9 @@ def test_subset_sampler():
|
|||
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
|
||||
test_config([0, 9, 3, 2], num_samples=-1,
|
||||
exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)")
|
||||
test_config(np.array([[1], [5]]), num_samples=10,
|
||||
exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1],"
|
||||
" type: <class 'numpy.ndarray'>.")
|
||||
|
||||
|
||||
def test_sampler_chain():
|
||||
|
@ -291,8 +306,8 @@ def test_sampler_list():
|
|||
msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
|
||||
bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
|
||||
bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
|
||||
bad_pipeline(sampler=np.array([1, 2]),
|
||||
msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.")
|
||||
bad_pipeline(sampler=np.array([[1, 2]]),
|
||||
msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue