forked from mindspore-Ecosystem/mindspore
Fix two problem when we create multiple instances of the same dataset (2 for-loops)
-- Iterator list is keeping all created iterators wihtout cleaning them up -- alter tree modifies the original.
This commit is contained in:
parent
5ed799d7b2
commit
fb6c7ba2e1
|
@ -225,11 +225,13 @@ void bindTensor(py::module *m) {
|
|||
(void)py::class_<DataType>(*m, "DataType")
|
||||
.def(py::init<std::string>())
|
||||
.def(py::self == py::self)
|
||||
.def("__str__", &DataType::ToString);
|
||||
.def("__str__", &DataType::ToString)
|
||||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
||||
}
|
||||
|
||||
void bindTensorOps1(py::module *m) {
|
||||
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp");
|
||||
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
|
||||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
||||
|
||||
(void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>(
|
||||
*m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.")
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
"""Built-in iterators.
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
import copy
|
||||
import weakref
|
||||
|
||||
from mindspore._c_dataengine import DEPipeline
|
||||
from mindspore._c_dataengine import OpName
|
||||
|
@ -27,7 +29,9 @@ ITERATORS_LIST = list()
|
|||
|
||||
def _cleanup():
|
||||
for itr in ITERATORS_LIST:
|
||||
itr.release()
|
||||
iter_ref = itr()
|
||||
if itr is not None:
|
||||
iter_ref.release()
|
||||
|
||||
|
||||
def alter_tree(node):
|
||||
|
@ -73,8 +77,10 @@ class Iterator:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
ITERATORS_LIST.append(self)
|
||||
self.dataset = alter_tree(dataset)
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
# create a copy of tree and work on it.
|
||||
self.dataset = copy.deepcopy(dataset)
|
||||
self.dataset = alter_tree(self.dataset)
|
||||
if not self.__is_tree():
|
||||
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
|
||||
self.depipeline = DEPipeline()
|
||||
|
|
Loading…
Reference in New Issue