fix: pyfunc hung when it's a yield func

This commit is contained in:
Guo Zhijian 2021-09-15 05:37:48 -04:00 committed by jonyguo
parent e00b51ce35
commit f49bacde09
3 changed files with 21 additions and 20 deletions

View File

@ -18,7 +18,6 @@ General Validators.
import inspect
from multiprocessing import cpu_count
import os
from pickle import dumps
import numpy as np
import mindspore._c_dataengine as cde
@ -63,23 +62,6 @@ def is_iterable(obj):
return True
def is_serializable(obj):
"""
Helper function to check if object is serializable.
Args:
obj (any): object to check if serializable.
Returns:
bool, true if object is serializable.
"""
try:
dumps(obj)
except TypeError:
return False
return True
def pad_arg_name(arg_name):
"""
Appends a space to the arg_name (if not empty)

View File

@ -20,9 +20,10 @@ but it will pass large data through shared memory.
import multiprocessing.queues
import multiprocessing
import types
import numpy as np
from mindspore import log as logger
from ..core.validator_helpers import is_serializable
from ..transforms.py_transforms_util import ExceptionHandler
@ -79,7 +80,8 @@ class _SharedQueue(multiprocessing.queues.Queue):
raise TypeError("return value of user defined python function in GeneratorDataset or"
" map should be numpy array or tuple of numpy array.")
for r in data:
if not is_serializable(obj=r):
# the map:pyfunc is a yield generator which can't be serialize
if isinstance(r, types.GeneratorType):
raise TypeError("Can not pickle {} object, please verify pyfunc return with numpy array"
.format(type(r)))
if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem

View File

@ -330,6 +330,22 @@ def skip_test_pyfunc_Exception_multiprocess():
assert "MP Pyfunc Throw" in str(info.value)
def test_func_with_yield_manifest_dataset_01():
def pass_func(_):
for i in range(10):
yield (np.array([i]),)
DATA_FILE = "../data/dataset/testManifestData/test.manifest"
data = ds.ManifestDataset(DATA_FILE)
data = data.map(operations=pass_func, input_columns=["image"], num_parallel_workers=1, python_multiprocessing=True)
num_iter = 0
try:
for _ in data.create_dict_iterator(output_numpy=True):
num_iter += 1
except RuntimeError as e:
assert "Can not pickle <class 'generator'> object, " in str(e)
if __name__ == "__main__":
test_case_0()
test_case_1()
@ -345,3 +361,4 @@ if __name__ == "__main__":
test_pyfunc_implicit_compose()
test_pyfunc_exception()
skip_test_pyfunc_exception_multiprocess()
test_func_with_yield_manifest_dataset_01()