diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 62c2898880e..1508a76f0b8 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3433,6 +3433,7 @@ class GeneratorDataset(MappableDataset): super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) self.source = source + self.prepared_source = None # source to be sent to C++ self.python_multiprocessing = python_multiprocessing @@ -3463,9 +3464,9 @@ class GeneratorDataset(MappableDataset): if new_op.sampler is not None and hasattr(self.source, "__getitem__"): if new_op.num_parallel_workers > 1: sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) - new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) + new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) else: - new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) + new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) new_op.sample_fn = sample_fn else: try: @@ -3476,11 +3477,11 @@ class GeneratorDataset(MappableDataset): iter(self.source) except TypeError: # Use generator function if input callable - new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples)) + new_op.prepared_source = (lambda: _generator_fn(self.source, new_op.num_samples)) else: # Use iterator function if input is iterable # Random accessible input is also iterable - new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples)) + new_op.prepared_source = (lambda: _iter_fn(self.source, new_op.num_samples)) return new_op @@ -3492,12 +3493,12 @@ class GeneratorDataset(MappableDataset): def parse(self, children=None): if self.schema is None: - return cde.GeneratorNode(self.source, self.column_names, self.column_types, self.source_len, + return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len, self.sampler) schema = self.schema if isinstance(schema, Schema): schema = self.schema.cpp_schema - return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler) + return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler) class TFRecordDataset(SourceDataset): diff --git a/tests/ut/python/dataset/test_datasets_generator.py b/tests/ut/python/dataset/test_datasets_generator.py index 057b5b718cc..189c1e19ef4 100644 --- a/tests/ut/python/dataset/test_datasets_generator.py +++ b/tests/ut/python/dataset/test_datasets_generator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import copy import numpy as np import pytest @@ -745,6 +746,18 @@ def manual_test_generator_keyboard_interrupt(): pass +def test_explicit_deepcopy(): + """ + Test explicit_deepcopy + """ + logger.info("Test explicit_deepcopy") + + ds1 = ds.NumpySlicesDataset([1, 2], shuffle=False) + ds2 = copy.deepcopy(ds1) + for d1, d2 in zip(ds1, ds2): + assert d1 == d2 + + if __name__ == "__main__": test_generator_0() test_generator_1() @@ -780,3 +793,4 @@ if __name__ == "__main__": test_generator_dataset_size_3() test_generator_dataset_size_4() test_generator_dataset_size_5() + test_explicit_deepcopy()