!9472 Add hook and prevent c transforms from being called in python multithreading

From: @ezphlow
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-07 10:17:58 +08:00 committed by Gitee
commit 31b3ddcac4
1 changed files with 10 additions and 7 deletions

View File

@ -1921,6 +1921,8 @@ class BatchDataset(Dataset):
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.hook = copy.deepcopy(self.hook, memodict)
new_op.pad_info = copy.deepcopy(self.pad_info, memodict)
if hasattr(self, "__total_batch__"):
new_op.__total_batch__ = self.__total_batch__
return new_op
# Iterator bootstrap will be called on iterator construction.
@ -1939,6 +1941,7 @@ class BatchDataset(Dataset):
idx = 0
# Wrap per_batch_map into _PythonCallable
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool)
self.hook = _ExceptHookHandler()
def __del__(self):
if hasattr(self, 'process_pool') and self.process_pool is not None:
@ -2205,15 +2208,12 @@ class _PythonCallable:
class _ExceptHookHandler:
def __init__(self, pool):
self.__pool = pool
def __init__(self):
sys.excepthook = self.__handler_exception
def __handler_exception(self, type, value, tb):
logger.error("Uncaught exception: ", exc_info=(type, value, tb))
if self.__pool is not None:
_set_iterator_cleanup()
self.__pool.terminate()
_set_iterator_cleanup()
class MapDataset(Dataset):
@ -2350,7 +2350,8 @@ class MapDataset(Dataset):
# Pass #1, look for Python callables and build list
for op in self.operations:
if callable(op):
# our c transforms is now callable and should not be run in python multithreading
if callable(op) and str(op).find("c_transform") < 0:
callable_list.append(op)
if callable_list:
@ -2362,7 +2363,8 @@ class MapDataset(Dataset):
# Pass #2
idx = 0
for op in self.operations:
if callable(op):
# our c transforms is now callable and should not be run in python multithreading
if callable(op) and str(op).find("c_transform") < 0:
# Wrap Python callable into _PythonCallable
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
idx += 1
@ -2370,6 +2372,7 @@ class MapDataset(Dataset):
# CPP ops remain the same
iter_specific_operations.append(op)
self.operations = iter_specific_operations
self.hook = _ExceptHookHandler()
def __del__(self):
if hasattr(self, 'process_pool') and self.process_pool is not None: