forked from mindspore-Ecosystem/mindspore
fix numpyslice issue
This commit is contained in:
parent
7cb567ebbe
commit
cb9c6fad86
|
@ -3219,33 +3219,9 @@ class GeneratorDataset(MappableDataset):
|
||||||
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
||||||
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
|
self.source = source
|
||||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||||
if self.sampler is not None and hasattr(source, "__getitem__"):
|
self.num_samples = num_samples
|
||||||
if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
|
||||||
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
|
||||||
samplers.WeightedRandomSampler, samplers.Sampler)):
|
|
||||||
sampler_instance = self.sampler.create()
|
|
||||||
sampler_instance.set_num_rows(len(source))
|
|
||||||
sampler_instance.initialize()
|
|
||||||
if num_parallel_workers > 1:
|
|
||||||
self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers))
|
|
||||||
else:
|
|
||||||
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
|
|
||||||
else:
|
|
||||||
if num_parallel_workers > 1:
|
|
||||||
self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers))
|
|
||||||
else:
|
|
||||||
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
iter(source)
|
|
||||||
except TypeError:
|
|
||||||
# Use generator function if input callable
|
|
||||||
self.source = (lambda: _generator_fn(source, num_samples))
|
|
||||||
else:
|
|
||||||
# Use iterator function if input is iterable
|
|
||||||
# Random accessible input is also iterable
|
|
||||||
self.source = (lambda: _iter_fn(source, num_samples))
|
|
||||||
|
|
||||||
if column_names is not None and not isinstance(column_names, list):
|
if column_names is not None and not isinstance(column_names, list):
|
||||||
column_names = [column_names]
|
column_names = [column_names]
|
||||||
|
@ -3310,9 +3286,35 @@ class GeneratorDataset(MappableDataset):
|
||||||
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
|
||||||
new_op.column_types = copy.deepcopy(self.column_types, memodict)
|
new_op.column_types = copy.deepcopy(self.column_types, memodict)
|
||||||
new_op.column_names = copy.deepcopy(self.column_names, memodict)
|
new_op.column_names = copy.deepcopy(self.column_names, memodict)
|
||||||
|
new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
|
||||||
|
|
||||||
new_op.source = self.source
|
new_op.sampler = copy.deepcopy(self.sampler)
|
||||||
new_op.sampler = self.sampler
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
||||||
|
if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||||
|
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||||
|
samplers.WeightedRandomSampler, samplers.Sampler)):
|
||||||
|
sampler_instance = new_op.sampler.create()
|
||||||
|
sampler_instance.set_num_rows(len(self.source))
|
||||||
|
sampler_instance.initialize()
|
||||||
|
if new_op.num_parallel_workers > 1:
|
||||||
|
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers))
|
||||||
|
else:
|
||||||
|
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
|
||||||
|
else:
|
||||||
|
if new_op.num_parallel_workers > 1:
|
||||||
|
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers))
|
||||||
|
else:
|
||||||
|
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
iter(self.source)
|
||||||
|
except TypeError:
|
||||||
|
# Use generator function if input callable
|
||||||
|
new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples))
|
||||||
|
else:
|
||||||
|
# Use iterator function if input is iterable
|
||||||
|
# Random accessible input is also iterable
|
||||||
|
new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples))
|
||||||
|
|
||||||
return new_op
|
return new_op
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue