From f49bacde095f4bfea9ca5bb98d0d1e56f434ab71 Mon Sep 17 00:00:00 2001 From: Guo Zhijian Date: Wed, 15 Sep 2021 05:37:48 -0400 Subject: [PATCH] fix: pyfunc hung when it's a yield func --- mindspore/dataset/core/validator_helpers.py | 18 ------------------ mindspore/dataset/engine/queue.py | 6 ++++-- tests/ut/python/dataset/test_pyfunc.py | 17 +++++++++++++++++ 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index a5934582e73..bea968e2ad4 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -18,7 +18,6 @@ General Validators. import inspect from multiprocessing import cpu_count import os -from pickle import dumps import numpy as np import mindspore._c_dataengine as cde @@ -63,23 +62,6 @@ def is_iterable(obj): return True -def is_serializable(obj): - """ - Helper function to check if object is serializable. - - Args: - obj (any): object to check if serializable. - - Returns: - bool, true if object is serializable. - """ - try: - dumps(obj) - except TypeError: - return False - return True - - def pad_arg_name(arg_name): """ Appends a space to the arg_name (if not empty) diff --git a/mindspore/dataset/engine/queue.py b/mindspore/dataset/engine/queue.py index ec5a93d3ee6..679841d60a9 100644 --- a/mindspore/dataset/engine/queue.py +++ b/mindspore/dataset/engine/queue.py @@ -20,9 +20,10 @@ but it will pass large data through shared memory. import multiprocessing.queues import multiprocessing +import types import numpy as np + from mindspore import log as logger -from ..core.validator_helpers import is_serializable from ..transforms.py_transforms_util import ExceptionHandler @@ -79,7 +80,8 @@ class _SharedQueue(multiprocessing.queues.Queue): raise TypeError("return value of user defined python function in GeneratorDataset or" " map should be numpy array or tuple of numpy array.") for r in data: - if not is_serializable(obj=r): + # the map:pyfunc is a yield generator which can't be serialize + if isinstance(r, types.GeneratorType): raise TypeError("Can not pickle {} object, please verify pyfunc return with numpy array" .format(type(r))) if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem diff --git a/tests/ut/python/dataset/test_pyfunc.py b/tests/ut/python/dataset/test_pyfunc.py index 2c769d46fa6..96ebc3fc5c9 100644 --- a/tests/ut/python/dataset/test_pyfunc.py +++ b/tests/ut/python/dataset/test_pyfunc.py @@ -330,6 +330,22 @@ def skip_test_pyfunc_Exception_multiprocess(): assert "MP Pyfunc Throw" in str(info.value) +def test_func_with_yield_manifest_dataset_01(): + def pass_func(_): + for i in range(10): + yield (np.array([i]),) + + DATA_FILE = "../data/dataset/testManifestData/test.manifest" + data = ds.ManifestDataset(DATA_FILE) + data = data.map(operations=pass_func, input_columns=["image"], num_parallel_workers=1, python_multiprocessing=True) + num_iter = 0 + try: + for _ in data.create_dict_iterator(output_numpy=True): + num_iter += 1 + except RuntimeError as e: + assert "Can not pickle object, " in str(e) + + if __name__ == "__main__": test_case_0() test_case_1() @@ -345,3 +361,4 @@ if __name__ == "__main__": test_pyfunc_implicit_compose() test_pyfunc_exception() skip_test_pyfunc_exception_multiprocess() + test_func_with_yield_manifest_dataset_01()