forked from mindspore-Ecosystem/mindspore
fix pyfunc pickle issue
This commit is contained in:
parent
c9ce0d371a
commit
41f3e02e87
|
@ -121,9 +121,6 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
|
|||
|
||||
Status PyFuncOp::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
auto package = pybind11::module::import("pickle");
|
||||
auto module = package.attr("dumps");
|
||||
args["tensor_op_params"] = module(py_func_ptr_, 0).cast<std::string>();
|
||||
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
|
||||
args["is_python_front_end_op"] = true;
|
||||
*out_json = args;
|
||||
|
|
|
@ -17,7 +17,6 @@ Functions to support dataset serialize and deserialize.
|
|||
"""
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -30,6 +29,9 @@ def serialize(dataset, json_filepath=""):
|
|||
"""
|
||||
Serialize dataset pipeline into a json file.
|
||||
|
||||
Currently some python objects are not supported to be serialized.
|
||||
For python function serialization of map operator, de.serialize will only return its function name.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): the starting node.
|
||||
json_filepath (str): a filepath where a serialized json file will be generated.
|
||||
|
@ -56,6 +58,8 @@ def deserialize(input_dict=None, json_filepath=None):
|
|||
"""
|
||||
Construct a de pipeline from a json file produced by de.serialize().
|
||||
|
||||
Currently python function deserialization of map operator are not supported.
|
||||
|
||||
Args:
|
||||
input_dict (dict): a Python dictionary containing a serialized dataset graph
|
||||
json_filepath (str): a path to the json file.
|
||||
|
@ -349,42 +353,42 @@ def construct_tensor_ops(operations):
|
|||
op_params = op.get('tensor_op_params')
|
||||
|
||||
if op.get('is_python_front_end_op'): # check if it's a py_transform op
|
||||
result.append(pickle.loads(op_params.encode()))
|
||||
raise NotImplementedError("python function is not yet supported by de.deserialize().")
|
||||
|
||||
if op_name == "HwcToChw": op_name = "HWC2CHW"
|
||||
if op_name == "UniformAug": op_name = "UniformAugment"
|
||||
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
|
||||
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
|
||||
|
||||
if hasattr(op_module_vis, op_name):
|
||||
op_class = getattr(op_module_vis, op_name, None)
|
||||
elif hasattr(op_module_trans, op_name[:-2]):
|
||||
op_name = op_name[:-2] # to remove op from the back of the name
|
||||
op_class = getattr(op_module_trans, op_name, None)
|
||||
else:
|
||||
if op_name == "HwcToChw": op_name = "HWC2CHW"
|
||||
if op_name == "UniformAug": op_name = "UniformAugment"
|
||||
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
|
||||
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
|
||||
raise RuntimeError(op_name + " is not yet supported by deserialize().")
|
||||
|
||||
if hasattr(op_module_vis, op_name):
|
||||
op_class = getattr(op_module_vis, op_name, None)
|
||||
elif hasattr(op_module_trans, op_name[:-2]):
|
||||
op_name = op_name[:-2] # to remove op from the back of the name
|
||||
op_class = getattr(op_module_trans, op_name, None)
|
||||
else:
|
||||
raise RuntimeError(op_name + " is not yet supported by deserialize().")
|
||||
if op_params is None: # If no parameter is specified, call it directly
|
||||
result.append(op_class())
|
||||
else:
|
||||
# Input parameter type cast
|
||||
for key, val in op_params.items():
|
||||
if key in ['center', 'fill_value']:
|
||||
op_params[key] = tuple(val)
|
||||
elif key in ['interpolation', 'resample']:
|
||||
op_params[key] = Inter(to_interpolation_mode(val))
|
||||
elif key in ['padding_mode']:
|
||||
op_params[key] = Border(to_border_mode(val))
|
||||
elif key in ['data_type']:
|
||||
op_params[key] = to_mstype(val)
|
||||
elif key in ['image_batch_format']:
|
||||
op_params[key] = to_image_batch_format(val)
|
||||
elif key in ['policy']:
|
||||
op_params[key] = to_policy(val)
|
||||
elif key in ['transform', 'transforms']:
|
||||
op_params[key] = construct_tensor_ops(val)
|
||||
|
||||
if op_params is None: # If no parameter is specified, call it directly
|
||||
result.append(op_class())
|
||||
else:
|
||||
# Input parameter type cast
|
||||
for key, val in op_params.items():
|
||||
if key in ['center', 'fill_value']:
|
||||
op_params[key] = tuple(val)
|
||||
elif key in ['interpolation', 'resample']:
|
||||
op_params[key] = Inter(to_interpolation_mode(val))
|
||||
elif key in ['padding_mode']:
|
||||
op_params[key] = Border(to_border_mode(val))
|
||||
elif key in ['data_type']:
|
||||
op_params[key] = to_mstype(val)
|
||||
elif key in ['image_batch_format']:
|
||||
op_params[key] = to_image_batch_format(val)
|
||||
elif key in ['policy']:
|
||||
op_params[key] = to_policy(val)
|
||||
elif key in ['transform', 'transforms']:
|
||||
op_params[key] = construct_tensor_ops(val)
|
||||
|
||||
result.append(op_class(**op_params))
|
||||
result.append(op_class(**op_params))
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -375,7 +375,13 @@ def test_serdes_pyvision(remove_json_files=True):
|
|||
py_vision.ToTensor()
|
||||
]
|
||||
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
|
||||
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
||||
# Current python function derialization will be failed for pickle, so we disable this testcase
|
||||
# as an exception testcase.
|
||||
try:
|
||||
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
||||
assert False
|
||||
except NotImplementedError as e:
|
||||
assert "python function is not yet supported" in str(e)
|
||||
|
||||
|
||||
def test_serdes_uniform_augment(remove_json_files=True):
|
||||
|
|
Loading…
Reference in New Issue