forked from mindspore-Ecosystem/mindspore
!858 Fix gpu issue
Merge pull request !858 from xiefangqi/md_fix_gpu_issue
This commit is contained in:
commit
05676676e9
|
@ -17,6 +17,7 @@
|
|||
from abc import abstractmethod
|
||||
import copy
|
||||
import weakref
|
||||
from importlib import import_module
|
||||
|
||||
from mindspore._c_dataengine import DEPipeline
|
||||
from mindspore._c_dataengine import OpName
|
||||
|
@ -24,14 +25,29 @@ from mindspore._c_dataengine import OpName
|
|||
from mindspore import log as logger
|
||||
from . import datasets as de
|
||||
|
||||
try:
|
||||
context = import_module("mindspore.context")
|
||||
except ModuleNotFoundError:
|
||||
context = None
|
||||
|
||||
ITERATORS_LIST = list()
|
||||
|
||||
|
||||
def _cleanup():
|
||||
"""Release all the Iterator."""
|
||||
for itr_ref in ITERATORS_LIST:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
if context:
|
||||
device_type = context.get_context("device_target")
|
||||
if device_type == "GPU":
|
||||
itr_ref.release()
|
||||
else:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
else:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
|
||||
|
||||
def alter_tree(node):
|
||||
|
@ -85,7 +101,14 @@ class Iterator:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
if context:
|
||||
device_type = context.get_context("device_target")
|
||||
if device_type == "GPU":
|
||||
ITERATORS_LIST.append(self)
|
||||
else:
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
else:
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
# create a copy of tree and work on it.
|
||||
self.dataset = copy.deepcopy(dataset)
|
||||
self.dataset = alter_tree(self.dataset)
|
||||
|
|
Loading…
Reference in New Issue