forked from mindspore-Ecosystem/mindspore
User explicit deepcopy
This commit is contained in:
parent
8d55cd9d89
commit
58193bc469
|
@ -3433,6 +3433,7 @@ class GeneratorDataset(MappableDataset):
|
||||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
|
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
|
||||||
self.source = source
|
self.source = source
|
||||||
|
self.prepared_source = None # source to be sent to C++
|
||||||
|
|
||||||
self.python_multiprocessing = python_multiprocessing
|
self.python_multiprocessing = python_multiprocessing
|
||||||
|
|
||||||
|
@ -3463,9 +3464,9 @@ class GeneratorDataset(MappableDataset):
|
||||||
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
||||||
if new_op.num_parallel_workers > 1:
|
if new_op.num_parallel_workers > 1:
|
||||||
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
|
||||||
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
|
||||||
else:
|
else:
|
||||||
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
|
||||||
new_op.sample_fn = sample_fn
|
new_op.sample_fn = sample_fn
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
@ -3476,11 +3477,11 @@ class GeneratorDataset(MappableDataset):
|
||||||
iter(self.source)
|
iter(self.source)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Use generator function if input callable
|
# Use generator function if input callable
|
||||||
new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples))
|
new_op.prepared_source = (lambda: _generator_fn(self.source, new_op.num_samples))
|
||||||
else:
|
else:
|
||||||
# Use iterator function if input is iterable
|
# Use iterator function if input is iterable
|
||||||
# Random accessible input is also iterable
|
# Random accessible input is also iterable
|
||||||
new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples))
|
new_op.prepared_source = (lambda: _iter_fn(self.source, new_op.num_samples))
|
||||||
|
|
||||||
return new_op
|
return new_op
|
||||||
|
|
||||||
|
@ -3492,12 +3493,12 @@ class GeneratorDataset(MappableDataset):
|
||||||
|
|
||||||
def parse(self, children=None):
|
def parse(self, children=None):
|
||||||
if self.schema is None:
|
if self.schema is None:
|
||||||
return cde.GeneratorNode(self.source, self.column_names, self.column_types, self.source_len,
|
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
|
||||||
self.sampler)
|
self.sampler)
|
||||||
schema = self.schema
|
schema = self.schema
|
||||||
if isinstance(schema, Schema):
|
if isinstance(schema, Schema):
|
||||||
schema = self.schema.cpp_schema
|
schema = self.schema.cpp_schema
|
||||||
return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler)
|
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler)
|
||||||
|
|
||||||
|
|
||||||
class TFRecordDataset(SourceDataset):
|
class TFRecordDataset(SourceDataset):
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -745,6 +746,18 @@ def manual_test_generator_keyboard_interrupt():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_deepcopy():
|
||||||
|
"""
|
||||||
|
Test explicit_deepcopy
|
||||||
|
"""
|
||||||
|
logger.info("Test explicit_deepcopy")
|
||||||
|
|
||||||
|
ds1 = ds.NumpySlicesDataset([1, 2], shuffle=False)
|
||||||
|
ds2 = copy.deepcopy(ds1)
|
||||||
|
for d1, d2 in zip(ds1, ds2):
|
||||||
|
assert d1 == d2
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_generator_0()
|
test_generator_0()
|
||||||
test_generator_1()
|
test_generator_1()
|
||||||
|
@ -780,3 +793,4 @@ if __name__ == "__main__":
|
||||||
test_generator_dataset_size_3()
|
test_generator_dataset_size_3()
|
||||||
test_generator_dataset_size_4()
|
test_generator_dataset_size_4()
|
||||||
test_generator_dataset_size_5()
|
test_generator_dataset_size_5()
|
||||||
|
test_explicit_deepcopy()
|
||||||
|
|
Loading…
Reference in New Issue