fix pyfunc pickle issue

This commit is contained in:
xiefangqi 2021-03-10 10:41:03 +08:00
parent c9ce0d371a
commit 41f3e02e87
3 changed files with 45 additions and 38 deletions

View File

@ -121,9 +121,6 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
Status PyFuncOp::to_json(nlohmann::json *out_json) { Status PyFuncOp::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
auto package = pybind11::module::import("pickle");
auto module = package.attr("dumps");
args["tensor_op_params"] = module(py_func_ptr_, 0).cast<std::string>();
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>(); args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
args["is_python_front_end_op"] = true; args["is_python_front_end_op"] = true;
*out_json = args; *out_json = args;

View File

@ -17,7 +17,6 @@ Functions to support dataset serialize and deserialize.
""" """
import json import json
import os import os
import pickle
import sys import sys
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@ -30,6 +29,9 @@ def serialize(dataset, json_filepath=""):
""" """
Serialize dataset pipeline into a json file. Serialize dataset pipeline into a json file.
Currently some python objects are not supported to be serialized.
For python function serialization of map operator, de.serialize will only return its function name.
Args: Args:
dataset (Dataset): the starting node. dataset (Dataset): the starting node.
json_filepath (str): a filepath where a serialized json file will be generated. json_filepath (str): a filepath where a serialized json file will be generated.
@ -56,6 +58,8 @@ def deserialize(input_dict=None, json_filepath=None):
""" """
Construct a de pipeline from a json file produced by de.serialize(). Construct a de pipeline from a json file produced by de.serialize().
Currently python function deserialization of map operator are not supported.
Args: Args:
input_dict (dict): a Python dictionary containing a serialized dataset graph input_dict (dict): a Python dictionary containing a serialized dataset graph
json_filepath (str): a path to the json file. json_filepath (str): a path to the json file.
@ -349,42 +353,42 @@ def construct_tensor_ops(operations):
op_params = op.get('tensor_op_params') op_params = op.get('tensor_op_params')
if op.get('is_python_front_end_op'): # check if it's a py_transform op if op.get('is_python_front_end_op'): # check if it's a py_transform op
result.append(pickle.loads(op_params.encode())) raise NotImplementedError("python function is not yet supported by de.deserialize().")
if op_name == "HwcToChw": op_name = "HWC2CHW"
if op_name == "UniformAug": op_name = "UniformAugment"
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name, None)
elif hasattr(op_module_trans, op_name[:-2]):
op_name = op_name[:-2] # to remove op from the back of the name
op_class = getattr(op_module_trans, op_name, None)
else: else:
if op_name == "HwcToChw": op_name = "HWC2CHW" raise RuntimeError(op_name + " is not yet supported by deserialize().")
if op_name == "UniformAug": op_name = "UniformAugment"
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
if hasattr(op_module_vis, op_name): if op_params is None: # If no parameter is specified, call it directly
op_class = getattr(op_module_vis, op_name, None) result.append(op_class())
elif hasattr(op_module_trans, op_name[:-2]): else:
op_name = op_name[:-2] # to remove op from the back of the name # Input parameter type cast
op_class = getattr(op_module_trans, op_name, None) for key, val in op_params.items():
else: if key in ['center', 'fill_value']:
raise RuntimeError(op_name + " is not yet supported by deserialize().") op_params[key] = tuple(val)
elif key in ['interpolation', 'resample']:
op_params[key] = Inter(to_interpolation_mode(val))
elif key in ['padding_mode']:
op_params[key] = Border(to_border_mode(val))
elif key in ['data_type']:
op_params[key] = to_mstype(val)
elif key in ['image_batch_format']:
op_params[key] = to_image_batch_format(val)
elif key in ['policy']:
op_params[key] = to_policy(val)
elif key in ['transform', 'transforms']:
op_params[key] = construct_tensor_ops(val)
if op_params is None: # If no parameter is specified, call it directly result.append(op_class(**op_params))
result.append(op_class())
else:
# Input parameter type cast
for key, val in op_params.items():
if key in ['center', 'fill_value']:
op_params[key] = tuple(val)
elif key in ['interpolation', 'resample']:
op_params[key] = Inter(to_interpolation_mode(val))
elif key in ['padding_mode']:
op_params[key] = Border(to_border_mode(val))
elif key in ['data_type']:
op_params[key] = to_mstype(val)
elif key in ['image_batch_format']:
op_params[key] = to_image_batch_format(val)
elif key in ['policy']:
op_params[key] = to_policy(val)
elif key in ['transform', 'transforms']:
op_params[key] = construct_tensor_ops(val)
result.append(op_class(**op_params))
return result return result

View File

@ -375,7 +375,13 @@ def test_serdes_pyvision(remove_json_files=True):
py_vision.ToTensor() py_vision.ToTensor()
] ]
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"]) data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files) # Current python function derialization will be failed for pickle, so we disable this testcase
# as an exception testcase.
try:
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
assert False
except NotImplementedError as e:
assert "python function is not yet supported" in str(e)
def test_serdes_uniform_augment(remove_json_files=True): def test_serdes_uniform_augment(remove_json_files=True):