Fix sampler supporting numpy input

This commit is contained in:
hesham 2021-03-01 12:28:43 -05:00
parent 692d158f5c
commit 616d1a1f3d
3 changed files with 43 additions and 62 deletions

View File

@ -16,13 +16,11 @@
General Validators.
"""
import inspect
import numbers
from multiprocessing import cpu_count
import os
import numpy as np
import mindspore._c_dataengine as cde
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
@ -389,38 +387,5 @@ def check_tensor_op(param, param_name):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
def check_sampler(sampler):
"""
Check if the sampler is of valid input.
Args:
param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler
Returns:
Exception: TypeError if error
"""
builtin = False
base_sampler = False
list_num = False
if sampler is not None:
if isinstance(sampler, samplers.BuiltinSampler):
builtin = True
elif isinstance(sampler, samplers.Sampler):
base_sampler = True
else:
# check for list of numbers
list_num = True
# subset sampler check
subset_sampler = sampler
if not isinstance(sampler, list):
subset_sampler = [sampler]
for _, item in enumerate(subset_sampler):
if not isinstance(item, numbers.Number):
list_num = False
if not (builtin or base_sampler or list_num):
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
def replace_none(value, default):
return value if value is not None else default

View File

@ -41,22 +41,6 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
Sampler, sampler selected based on user input.
"""
def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True
def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)
if isinstance(sampler, list):
return sampler[:number_of_samples]
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
@ -73,11 +57,8 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler):
return input_sampler
if not isinstance(input_sampler, str) and _is_iterable(input_sampler):
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
if isinstance(input_sampler, int):
return SubsetSampler([input_sampler])
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
return SubsetSampler(input_sampler, num_samples)
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
@ -640,11 +621,31 @@ class SubsetSampler(BuiltinSampler):
"""
def __init__(self, indices, num_samples=None):
if not isinstance(indices, list):
def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True
def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)
if isinstance(sampler, list):
return sampler[:number_of_samples]
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if not isinstance(indices, str) and _is_iterable(indices):
indices = _get_sample_ids_as_list(indices, num_samples)
elif isinstance(indices, int):
indices = [indices]
else:
raise TypeError('Unsupported sampler object of type ({})'.format(type(indices)))
for i, item in enumerate(indices):
if not isinstance(item, int):
if not isinstance(item, (int, np.integer)):
raise TypeError("SubsetSampler: Type of indices element must be int, "
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))

View File

@ -177,13 +177,23 @@ def test_subset_sampler():
def pipeline():
sampler = ds.SubsetSampler(indices, num_samples)
data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler)
data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples)
dataset_size = data.get_dataset_size()
return [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
dataset_size2 = data.get_dataset_size()
res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2
return res1, res2
if exception_msg is None:
res, size = pipeline()
res, res2 = pipeline()
res, size = res
res2, size2 = res2
if not isinstance(indices, list):
indices = list(indices)
assert indices[:num_samples] == res
assert len(indices[:num_samples]) == size
assert indices[:num_samples] == res2
assert len(indices[:num_samples]) == size2
else:
with pytest.raises(Exception) as error_info:
pipeline()
@ -205,6 +215,8 @@ def test_subset_sampler():
test_config([0, 9, 3, 2], num_samples=2)
test_config([0, 9, 3, 2], num_samples=5)
test_config(np.array([1, 2, 3]))
test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]")
test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]")
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
@ -212,6 +224,9 @@ def test_subset_sampler():
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
test_config([0, 9, 3, 2], num_samples=-1,
exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)")
test_config(np.array([[1], [5]]), num_samples=10,
exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1],"
" type: <class 'numpy.ndarray'>.")
def test_sampler_chain():
@ -291,8 +306,8 @@ def test_sampler_list():
msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler=np.array([1, 2]),
msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.")
bad_pipeline(sampler=np.array([[1, 2]]),
msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")
if __name__ == '__main__':