forked from mindspore-Ecosystem/mindspore
fix: pyfunc hung when it's a yield func
This commit is contained in:
parent
e00b51ce35
commit
f49bacde09
|
@ -18,7 +18,6 @@ General Validators.
|
|||
import inspect
|
||||
from multiprocessing import cpu_count
|
||||
import os
|
||||
from pickle import dumps
|
||||
import numpy as np
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
|
@ -63,23 +62,6 @@ def is_iterable(obj):
|
|||
return True
|
||||
|
||||
|
||||
def is_serializable(obj):
|
||||
"""
|
||||
Helper function to check if object is serializable.
|
||||
|
||||
Args:
|
||||
obj (any): object to check if serializable.
|
||||
|
||||
Returns:
|
||||
bool, true if object is serializable.
|
||||
"""
|
||||
try:
|
||||
dumps(obj)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def pad_arg_name(arg_name):
|
||||
"""
|
||||
Appends a space to the arg_name (if not empty)
|
||||
|
|
|
@ -20,9 +20,10 @@ but it will pass large data through shared memory.
|
|||
|
||||
import multiprocessing.queues
|
||||
import multiprocessing
|
||||
import types
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
from ..core.validator_helpers import is_serializable
|
||||
from ..transforms.py_transforms_util import ExceptionHandler
|
||||
|
||||
|
||||
|
@ -79,7 +80,8 @@ class _SharedQueue(multiprocessing.queues.Queue):
|
|||
raise TypeError("return value of user defined python function in GeneratorDataset or"
|
||||
" map should be numpy array or tuple of numpy array.")
|
||||
for r in data:
|
||||
if not is_serializable(obj=r):
|
||||
# the map:pyfunc is a yield generator which can't be serialize
|
||||
if isinstance(r, types.GeneratorType):
|
||||
raise TypeError("Can not pickle {} object, please verify pyfunc return with numpy array"
|
||||
.format(type(r)))
|
||||
if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem
|
||||
|
|
|
@ -330,6 +330,22 @@ def skip_test_pyfunc_Exception_multiprocess():
|
|||
assert "MP Pyfunc Throw" in str(info.value)
|
||||
|
||||
|
||||
def test_func_with_yield_manifest_dataset_01():
|
||||
def pass_func(_):
|
||||
for i in range(10):
|
||||
yield (np.array([i]),)
|
||||
|
||||
DATA_FILE = "../data/dataset/testManifestData/test.manifest"
|
||||
data = ds.ManifestDataset(DATA_FILE)
|
||||
data = data.map(operations=pass_func, input_columns=["image"], num_parallel_workers=1, python_multiprocessing=True)
|
||||
num_iter = 0
|
||||
try:
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
num_iter += 1
|
||||
except RuntimeError as e:
|
||||
assert "Can not pickle <class 'generator'> object, " in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_case_0()
|
||||
test_case_1()
|
||||
|
@ -345,3 +361,4 @@ if __name__ == "__main__":
|
|||
test_pyfunc_implicit_compose()
|
||||
test_pyfunc_exception()
|
||||
skip_test_pyfunc_exception_multiprocess()
|
||||
test_func_with_yield_manifest_dataset_01()
|
||||
|
|
Loading…
Reference in New Issue