fix pyfunc pickle issue

This commit is contained in:
xiefangqi 2021-03-10 10:41:03 +08:00
parent c9ce0d371a
commit 41f3e02e87
3 changed files with 45 additions and 38 deletions

View File

@ -121,9 +121,6 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
Status PyFuncOp::to_json(nlohmann::json *out_json) {
nlohmann::json args;
auto package = pybind11::module::import("pickle");
auto module = package.attr("dumps");
args["tensor_op_params"] = module(py_func_ptr_, 0).cast<std::string>();
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
args["is_python_front_end_op"] = true;
*out_json = args;

View File

@ -17,7 +17,6 @@ Functions to support dataset serialize and deserialize.
"""
import json
import os
import pickle
import sys
import mindspore.common.dtype as mstype
@ -30,6 +29,9 @@ def serialize(dataset, json_filepath=""):
"""
Serialize dataset pipeline into a json file.
Currently some python objects are not supported to be serialized.
For python function serialization of map operator, de.serialize will only return its function name.
Args:
dataset (Dataset): the starting node.
json_filepath (str): a filepath where a serialized json file will be generated.
@ -56,6 +58,8 @@ def deserialize(input_dict=None, json_filepath=None):
"""
Construct a de pipeline from a json file produced by de.serialize().
Currently python function deserialization of map operator are not supported.
Args:
input_dict (dict): a Python dictionary containing a serialized dataset graph
json_filepath (str): a path to the json file.
@ -349,42 +353,42 @@ def construct_tensor_ops(operations):
op_params = op.get('tensor_op_params')
if op.get('is_python_front_end_op'): # check if it's a py_transform op
result.append(pickle.loads(op_params.encode()))
raise NotImplementedError("python function is not yet supported by de.deserialize().")
if op_name == "HwcToChw": op_name = "HWC2CHW"
if op_name == "UniformAug": op_name = "UniformAugment"
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name, None)
elif hasattr(op_module_trans, op_name[:-2]):
op_name = op_name[:-2] # to remove op from the back of the name
op_class = getattr(op_module_trans, op_name, None)
else:
if op_name == "HwcToChw": op_name = "HWC2CHW"
if op_name == "UniformAug": op_name = "UniformAugment"
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
raise RuntimeError(op_name + " is not yet supported by deserialize().")
if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name, None)
elif hasattr(op_module_trans, op_name[:-2]):
op_name = op_name[:-2] # to remove op from the back of the name
op_class = getattr(op_module_trans, op_name, None)
else:
raise RuntimeError(op_name + " is not yet supported by deserialize().")
if op_params is None: # If no parameter is specified, call it directly
result.append(op_class())
else:
# Input parameter type cast
for key, val in op_params.items():
if key in ['center', 'fill_value']:
op_params[key] = tuple(val)
elif key in ['interpolation', 'resample']:
op_params[key] = Inter(to_interpolation_mode(val))
elif key in ['padding_mode']:
op_params[key] = Border(to_border_mode(val))
elif key in ['data_type']:
op_params[key] = to_mstype(val)
elif key in ['image_batch_format']:
op_params[key] = to_image_batch_format(val)
elif key in ['policy']:
op_params[key] = to_policy(val)
elif key in ['transform', 'transforms']:
op_params[key] = construct_tensor_ops(val)
if op_params is None: # If no parameter is specified, call it directly
result.append(op_class())
else:
# Input parameter type cast
for key, val in op_params.items():
if key in ['center', 'fill_value']:
op_params[key] = tuple(val)
elif key in ['interpolation', 'resample']:
op_params[key] = Inter(to_interpolation_mode(val))
elif key in ['padding_mode']:
op_params[key] = Border(to_border_mode(val))
elif key in ['data_type']:
op_params[key] = to_mstype(val)
elif key in ['image_batch_format']:
op_params[key] = to_image_batch_format(val)
elif key in ['policy']:
op_params[key] = to_policy(val)
elif key in ['transform', 'transforms']:
op_params[key] = construct_tensor_ops(val)
result.append(op_class(**op_params))
result.append(op_class(**op_params))
return result

View File

@ -375,7 +375,13 @@ def test_serdes_pyvision(remove_json_files=True):
py_vision.ToTensor()
]
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
# Current python function derialization will be failed for pickle, so we disable this testcase
# as an exception testcase.
try:
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
assert False
except NotImplementedError as e:
assert "python function is not yet supported" in str(e)
def test_serdes_uniform_augment(remove_json_files=True):