From fb6c7ba2e127be6632ddb8f0497df54bfaedb59a Mon Sep 17 00:00:00 2001 From: hesham Date: Thu, 16 Apr 2020 00:07:28 -0400 Subject: [PATCH] Fix two problem when we create multiple instances of the same dataset (2 for-loops) -- Iterator list is keeping all created iterators wihtout cleaning them up -- alter tree modifies the original. --- mindspore/ccsrc/dataset/api/python_bindings.cc | 6 ++++-- mindspore/dataset/engine/iterators.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index e2675ee217c..d9e0ccbba87 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -225,11 +225,13 @@ void bindTensor(py::module *m) { (void)py::class_(*m, "DataType") .def(py::init()) .def(py::self == py::self) - .def("__str__", &DataType::ToString); + .def("__str__", &DataType::ToString) + .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); } void bindTensorOps1(py::module *m) { - (void)py::class_>(*m, "TensorOp"); + (void)py::class_>(*m, "TensorOp") + .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); (void)py::class_>( *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 268a66c0cf0..69dd9ce0a97 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -15,6 +15,8 @@ """Built-in iterators. """ from abc import abstractmethod +import copy +import weakref from mindspore._c_dataengine import DEPipeline from mindspore._c_dataengine import OpName @@ -27,7 +29,9 @@ ITERATORS_LIST = list() def _cleanup(): for itr in ITERATORS_LIST: - itr.release() + iter_ref = itr() + if itr is not None: + iter_ref.release() def alter_tree(node): @@ -73,8 +77,10 @@ class Iterator: """ def __init__(self, dataset): - ITERATORS_LIST.append(self) - self.dataset = alter_tree(dataset) + ITERATORS_LIST.append(weakref.ref(self)) + # create a copy of tree and work on it. + self.dataset = copy.deepcopy(dataset) + self.dataset = alter_tree(self.dataset) if not self.__is_tree(): raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") self.depipeline = DEPipeline()