Adding inheritance to user defined sampler
what if we just do nothing trying wrapper approach Fix yolo
This commit is contained in:
parent
37560318ef
commit
ed373f61ac
|
@ -167,7 +167,6 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas
|
|||
if (py_sampler) {
|
||||
std::shared_ptr<SamplerObj> sampler_obj;
|
||||
if (!isMindDataset) {
|
||||
// Common Sampler
|
||||
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
|
||||
sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
|
||||
} else {
|
||||
|
|
|
@ -44,6 +44,21 @@ valid_detype = [
|
|||
"uint32", "uint64", "float16", "float32", "float64", "string"
|
||||
]
|
||||
|
||||
def is_iterable(obj):
|
||||
"""
|
||||
Helper function to check if object is iterable.
|
||||
|
||||
Args:
|
||||
obj (any): object to check if iterable
|
||||
|
||||
Returns:
|
||||
bool, true if object iteratable
|
||||
"""
|
||||
try:
|
||||
iter(obj)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def pad_arg_name(arg_name):
|
||||
if arg_name != "":
|
||||
|
|
|
@ -25,7 +25,6 @@ import mindspore._c_dataengine as cde
|
|||
import mindspore.dataset as ds
|
||||
from ..core import validator_helpers as validator
|
||||
|
||||
|
||||
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||
"""
|
||||
Create sampler based on user input.
|
||||
|
@ -57,8 +56,14 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|||
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
||||
if isinstance(input_sampler, BuiltinSampler):
|
||||
return input_sampler
|
||||
return SubsetSampler(input_sampler, num_samples)
|
||||
|
||||
if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)):
|
||||
return SubsetSampler(input_sampler, num_samples)
|
||||
if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
|
||||
# in this case, the user passed in their own sampler object that's not of type BuiltinSampler
|
||||
return IterSampler(input_sampler, num_samples)
|
||||
if isinstance(input_sampler, int):
|
||||
return SubsetSampler([input_sampler])
|
||||
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
|
||||
if shuffle is None:
|
||||
if num_shards is not None:
|
||||
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
||||
|
@ -621,13 +626,6 @@ class SubsetSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def __init__(self, indices, num_samples=None):
|
||||
def _is_iterable(obj):
|
||||
try:
|
||||
iter(obj)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_sample_ids_as_list(sampler, number_of_samples=None):
|
||||
if number_of_samples is None:
|
||||
return list(sampler)
|
||||
|
@ -637,7 +635,7 @@ class SubsetSampler(BuiltinSampler):
|
|||
|
||||
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
||||
|
||||
if not isinstance(indices, str) and _is_iterable(indices):
|
||||
if not isinstance(indices, str) and validator.is_iterable(indices):
|
||||
indices = _get_sample_ids_as_list(indices, num_samples)
|
||||
elif isinstance(indices, int):
|
||||
indices = [indices]
|
||||
|
@ -731,6 +729,42 @@ class SubsetRandomSampler(SubsetSampler):
|
|||
return c_sampler
|
||||
|
||||
|
||||
class IterSampler(Sampler):
|
||||
"""
|
||||
User provided an iterable object without inheriting from our Sampler class.
|
||||
|
||||
Note:
|
||||
This class exists to allow handshake logic between dataset operators and user defined samplers.
|
||||
By constructing this object we avoid the user having to inherit from our Sampler class.
|
||||
|
||||
Args:
|
||||
sampler (iterable object): an user defined iterable object.
|
||||
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
||||
|
||||
Examples:
|
||||
>>> class MySampler():
|
||||
>>> def __iter__(self):
|
||||
>>> for i in range(99, -1, -1):
|
||||
>>> yield i
|
||||
|
||||
>>> # creates an IterSampler
|
||||
>>> sampler = ds.IterSampler(sampler=MySampler())
|
||||
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
|
||||
... num_parallel_workers=8,
|
||||
... sampler=sampler)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, sampler, num_samples=None):
|
||||
if num_samples is None:
|
||||
num_samples = len(list(sampler))
|
||||
super().__init__(num_samples=num_samples)
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.sampler)
|
||||
|
||||
|
||||
class WeightedRandomSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
|
||||
|
|
|
@ -141,9 +141,23 @@ def test_python_sampler():
|
|||
assert data[0].asnumpy() == (np.array(i),)
|
||||
i = i - 1
|
||||
|
||||
# This 2nd case is the one that exhibits the same behavior as the case above without inheritance
|
||||
def test_generator_iter_sampler():
|
||||
class MySampler():
|
||||
def __iter__(self):
|
||||
for i in range(99, -1, -1):
|
||||
yield i
|
||||
|
||||
data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler=MySampler())
|
||||
i = 99
|
||||
for data in data1:
|
||||
assert data[0].asnumpy() == (np.array(i),)
|
||||
i = i - 1
|
||||
|
||||
assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
||||
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
||||
test_generator()
|
||||
test_generator_iter_sampler()
|
||||
|
||||
|
||||
def test_sequential_sampler2():
|
||||
|
|
Loading…
Reference in New Issue