Adding inheritance to user defined sampler

what if we just do nothing

trying wrapper approach

Fix yolo
This commit is contained in:
Eric 2021-03-05 17:38:58 -05:00
parent 37560318ef
commit ed373f61ac
4 changed files with 74 additions and 12 deletions

View File

@ -167,7 +167,6 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas
if (py_sampler) { if (py_sampler) {
std::shared_ptr<SamplerObj> sampler_obj; std::shared_ptr<SamplerObj> sampler_obj;
if (!isMindDataset) { if (!isMindDataset) {
// Common Sampler
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse"); auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>(); sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
} else { } else {

View File

@ -44,6 +44,21 @@ valid_detype = [
"uint32", "uint64", "float16", "float32", "float64", "string" "uint32", "uint64", "float16", "float32", "float64", "string"
] ]
def is_iterable(obj):
"""
Helper function to check if object is iterable.
Args:
obj (any): object to check if iterable
Returns:
bool, true if object iteratable
"""
try:
iter(obj)
except TypeError:
return False
return True
def pad_arg_name(arg_name): def pad_arg_name(arg_name):
if arg_name != "": if arg_name != "":

View File

@ -25,7 +25,6 @@ import mindspore._c_dataengine as cde
import mindspore.dataset as ds import mindspore.dataset as ds
from ..core import validator_helpers as validator from ..core import validator_helpers as validator
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
""" """
Create sampler based on user input. Create sampler based on user input.
@ -57,8 +56,14 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler): if isinstance(input_sampler, BuiltinSampler):
return input_sampler return input_sampler
if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)):
return SubsetSampler(input_sampler, num_samples) return SubsetSampler(input_sampler, num_samples)
if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
# in this case, the user passed in their own sampler object that's not of type BuiltinSampler
return IterSampler(input_sampler, num_samples)
if isinstance(input_sampler, int):
return SubsetSampler([input_sampler])
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
if shuffle is None: if shuffle is None:
if num_shards is not None: if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler # If shuffle is not specified, sharding enabled, use distributed random sampler
@ -621,13 +626,6 @@ class SubsetSampler(BuiltinSampler):
""" """
def __init__(self, indices, num_samples=None): def __init__(self, indices, num_samples=None):
def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True
def _get_sample_ids_as_list(sampler, number_of_samples=None): def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None: if number_of_samples is None:
return list(sampler) return list(sampler)
@ -637,7 +635,7 @@ class SubsetSampler(BuiltinSampler):
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if not isinstance(indices, str) and _is_iterable(indices): if not isinstance(indices, str) and validator.is_iterable(indices):
indices = _get_sample_ids_as_list(indices, num_samples) indices = _get_sample_ids_as_list(indices, num_samples)
elif isinstance(indices, int): elif isinstance(indices, int):
indices = [indices] indices = [indices]
@ -731,6 +729,42 @@ class SubsetRandomSampler(SubsetSampler):
return c_sampler return c_sampler
class IterSampler(Sampler):
"""
User provided an iterable object without inheriting from our Sampler class.
Note:
This class exists to allow handshake logic between dataset operators and user defined samplers.
By constructing this object we avoid the user having to inherit from our Sampler class.
Args:
sampler (iterable object): an user defined iterable object.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
Examples:
>>> class MySampler():
>>> def __iter__(self):
>>> for i in range(99, -1, -1):
>>> yield i
>>> # creates an IterSampler
>>> sampler = ds.IterSampler(sampler=MySampler())
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, sampler, num_samples=None):
if num_samples is None:
num_samples = len(list(sampler))
super().__init__(num_samples=num_samples)
self.sampler = sampler
def __iter__(self):
return iter(self.sampler)
class WeightedRandomSampler(BuiltinSampler): class WeightedRandomSampler(BuiltinSampler):
""" """
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).

View File

@ -141,9 +141,23 @@ def test_python_sampler():
assert data[0].asnumpy() == (np.array(i),) assert data[0].asnumpy() == (np.array(i),)
i = i - 1 i = i - 1
# This 2nd case is the one that exhibits the same behavior as the case above without inheritance
def test_generator_iter_sampler():
class MySampler():
def __iter__(self):
for i in range(99, -1, -1):
yield i
data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler=MySampler())
i = 99
for data in data1:
assert data[0].asnumpy() == (np.array(i),)
i = i - 1
assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
test_generator() test_generator()
test_generator_iter_sampler()
def test_sequential_sampler2(): def test_sequential_sampler2():