Adding inheritance to user defined sampler
what if we just do nothing trying wrapper approach Fix yolo
This commit is contained in:
parent
37560318ef
commit
ed373f61ac
|
@ -167,7 +167,6 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas
|
||||||
if (py_sampler) {
|
if (py_sampler) {
|
||||||
std::shared_ptr<SamplerObj> sampler_obj;
|
std::shared_ptr<SamplerObj> sampler_obj;
|
||||||
if (!isMindDataset) {
|
if (!isMindDataset) {
|
||||||
// Common Sampler
|
|
||||||
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
|
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
|
||||||
sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
|
sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -44,6 +44,21 @@ valid_detype = [
|
||||||
"uint32", "uint64", "float16", "float32", "float64", "string"
|
"uint32", "uint64", "float16", "float32", "float64", "string"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def is_iterable(obj):
|
||||||
|
"""
|
||||||
|
Helper function to check if object is iterable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (any): object to check if iterable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, true if object iteratable
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
iter(obj)
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def pad_arg_name(arg_name):
|
def pad_arg_name(arg_name):
|
||||||
if arg_name != "":
|
if arg_name != "":
|
||||||
|
|
|
@ -25,7 +25,6 @@ import mindspore._c_dataengine as cde
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
from ..core import validator_helpers as validator
|
from ..core import validator_helpers as validator
|
||||||
|
|
||||||
|
|
||||||
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||||
"""
|
"""
|
||||||
Create sampler based on user input.
|
Create sampler based on user input.
|
||||||
|
@ -57,8 +56,14 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||||
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
||||||
if isinstance(input_sampler, BuiltinSampler):
|
if isinstance(input_sampler, BuiltinSampler):
|
||||||
return input_sampler
|
return input_sampler
|
||||||
|
if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)):
|
||||||
return SubsetSampler(input_sampler, num_samples)
|
return SubsetSampler(input_sampler, num_samples)
|
||||||
|
if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
|
||||||
|
# in this case, the user passed in their own sampler object that's not of type BuiltinSampler
|
||||||
|
return IterSampler(input_sampler, num_samples)
|
||||||
|
if isinstance(input_sampler, int):
|
||||||
|
return SubsetSampler([input_sampler])
|
||||||
|
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
|
||||||
if shuffle is None:
|
if shuffle is None:
|
||||||
if num_shards is not None:
|
if num_shards is not None:
|
||||||
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
||||||
|
@ -621,13 +626,6 @@ class SubsetSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, indices, num_samples=None):
|
def __init__(self, indices, num_samples=None):
|
||||||
def _is_iterable(obj):
|
|
||||||
try:
|
|
||||||
iter(obj)
|
|
||||||
except TypeError:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_sample_ids_as_list(sampler, number_of_samples=None):
|
def _get_sample_ids_as_list(sampler, number_of_samples=None):
|
||||||
if number_of_samples is None:
|
if number_of_samples is None:
|
||||||
return list(sampler)
|
return list(sampler)
|
||||||
|
@ -637,7 +635,7 @@ class SubsetSampler(BuiltinSampler):
|
||||||
|
|
||||||
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
||||||
|
|
||||||
if not isinstance(indices, str) and _is_iterable(indices):
|
if not isinstance(indices, str) and validator.is_iterable(indices):
|
||||||
indices = _get_sample_ids_as_list(indices, num_samples)
|
indices = _get_sample_ids_as_list(indices, num_samples)
|
||||||
elif isinstance(indices, int):
|
elif isinstance(indices, int):
|
||||||
indices = [indices]
|
indices = [indices]
|
||||||
|
@ -731,6 +729,42 @@ class SubsetRandomSampler(SubsetSampler):
|
||||||
return c_sampler
|
return c_sampler
|
||||||
|
|
||||||
|
|
||||||
|
class IterSampler(Sampler):
|
||||||
|
"""
|
||||||
|
User provided an iterable object without inheriting from our Sampler class.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This class exists to allow handshake logic between dataset operators and user defined samplers.
|
||||||
|
By constructing this object we avoid the user having to inherit from our Sampler class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sampler (iterable object): an user defined iterable object.
|
||||||
|
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class MySampler():
|
||||||
|
>>> def __iter__(self):
|
||||||
|
>>> for i in range(99, -1, -1):
|
||||||
|
>>> yield i
|
||||||
|
|
||||||
|
>>> # creates an IterSampler
|
||||||
|
>>> sampler = ds.IterSampler(sampler=MySampler())
|
||||||
|
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
|
||||||
|
... num_parallel_workers=8,
|
||||||
|
... sampler=sampler)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler, num_samples=None):
|
||||||
|
if num_samples is None:
|
||||||
|
num_samples = len(list(sampler))
|
||||||
|
super().__init__(num_samples=num_samples)
|
||||||
|
self.sampler = sampler
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.sampler)
|
||||||
|
|
||||||
|
|
||||||
class WeightedRandomSampler(BuiltinSampler):
|
class WeightedRandomSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
|
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
|
||||||
|
|
|
@ -141,9 +141,23 @@ def test_python_sampler():
|
||||||
assert data[0].asnumpy() == (np.array(i),)
|
assert data[0].asnumpy() == (np.array(i),)
|
||||||
i = i - 1
|
i = i - 1
|
||||||
|
|
||||||
|
# This 2nd case is the one that exhibits the same behavior as the case above without inheritance
|
||||||
|
def test_generator_iter_sampler():
|
||||||
|
class MySampler():
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(99, -1, -1):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler=MySampler())
|
||||||
|
i = 99
|
||||||
|
for data in data1:
|
||||||
|
assert data[0].asnumpy() == (np.array(i),)
|
||||||
|
i = i - 1
|
||||||
|
|
||||||
assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
||||||
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
||||||
test_generator()
|
test_generator()
|
||||||
|
test_generator_iter_sampler()
|
||||||
|
|
||||||
|
|
||||||
def test_sequential_sampler2():
|
def test_sequential_sampler2():
|
||||||
|
|
Loading…
Reference in New Issue