forked from mindspore-Ecosystem/mindspore
!9472 Add hook and prevent c transforms from being called in python multithreading
From: @ezphlow Reviewed-by: Signed-off-by:
This commit is contained in:
commit
31b3ddcac4
|
@ -1921,6 +1921,8 @@ class BatchDataset(Dataset):
|
|||
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
|
||||
new_op.hook = copy.deepcopy(self.hook, memodict)
|
||||
new_op.pad_info = copy.deepcopy(self.pad_info, memodict)
|
||||
if hasattr(self, "__total_batch__"):
|
||||
new_op.__total_batch__ = self.__total_batch__
|
||||
return new_op
|
||||
|
||||
# Iterator bootstrap will be called on iterator construction.
|
||||
|
@ -1939,6 +1941,7 @@ class BatchDataset(Dataset):
|
|||
idx = 0
|
||||
# Wrap per_batch_map into _PythonCallable
|
||||
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool)
|
||||
self.hook = _ExceptHookHandler()
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'process_pool') and self.process_pool is not None:
|
||||
|
@ -2205,15 +2208,12 @@ class _PythonCallable:
|
|||
|
||||
|
||||
class _ExceptHookHandler:
|
||||
def __init__(self, pool):
|
||||
self.__pool = pool
|
||||
def __init__(self):
|
||||
sys.excepthook = self.__handler_exception
|
||||
|
||||
def __handler_exception(self, type, value, tb):
|
||||
logger.error("Uncaught exception: ", exc_info=(type, value, tb))
|
||||
if self.__pool is not None:
|
||||
_set_iterator_cleanup()
|
||||
self.__pool.terminate()
|
||||
_set_iterator_cleanup()
|
||||
|
||||
|
||||
class MapDataset(Dataset):
|
||||
|
@ -2350,7 +2350,8 @@ class MapDataset(Dataset):
|
|||
|
||||
# Pass #1, look for Python callables and build list
|
||||
for op in self.operations:
|
||||
if callable(op):
|
||||
# our c transforms is now callable and should not be run in python multithreading
|
||||
if callable(op) and str(op).find("c_transform") < 0:
|
||||
callable_list.append(op)
|
||||
|
||||
if callable_list:
|
||||
|
@ -2362,7 +2363,8 @@ class MapDataset(Dataset):
|
|||
# Pass #2
|
||||
idx = 0
|
||||
for op in self.operations:
|
||||
if callable(op):
|
||||
# our c transforms is now callable and should not be run in python multithreading
|
||||
if callable(op) and str(op).find("c_transform") < 0:
|
||||
# Wrap Python callable into _PythonCallable
|
||||
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
|
||||
idx += 1
|
||||
|
@ -2370,6 +2372,7 @@ class MapDataset(Dataset):
|
|||
# CPP ops remain the same
|
||||
iter_specific_operations.append(op)
|
||||
self.operations = iter_specific_operations
|
||||
self.hook = _ExceptHookHandler()
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'process_pool') and self.process_pool is not None:
|
||||
|
|
Loading…
Reference in New Issue