forked from mindspore-Ecosystem/mindspore
fix pyfunc pickle issue
This commit is contained in:
parent
c9ce0d371a
commit
41f3e02e87
|
@ -121,9 +121,6 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
|
||||||
|
|
||||||
Status PyFuncOp::to_json(nlohmann::json *out_json) {
|
Status PyFuncOp::to_json(nlohmann::json *out_json) {
|
||||||
nlohmann::json args;
|
nlohmann::json args;
|
||||||
auto package = pybind11::module::import("pickle");
|
|
||||||
auto module = package.attr("dumps");
|
|
||||||
args["tensor_op_params"] = module(py_func_ptr_, 0).cast<std::string>();
|
|
||||||
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
|
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
|
||||||
args["is_python_front_end_op"] = true;
|
args["is_python_front_end_op"] = true;
|
||||||
*out_json = args;
|
*out_json = args;
|
||||||
|
|
|
@ -17,7 +17,6 @@ Functions to support dataset serialize and deserialize.
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
@ -30,6 +29,9 @@ def serialize(dataset, json_filepath=""):
|
||||||
"""
|
"""
|
||||||
Serialize dataset pipeline into a json file.
|
Serialize dataset pipeline into a json file.
|
||||||
|
|
||||||
|
Currently some python objects are not supported to be serialized.
|
||||||
|
For python function serialization of map operator, de.serialize will only return its function name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (Dataset): the starting node.
|
dataset (Dataset): the starting node.
|
||||||
json_filepath (str): a filepath where a serialized json file will be generated.
|
json_filepath (str): a filepath where a serialized json file will be generated.
|
||||||
|
@ -56,6 +58,8 @@ def deserialize(input_dict=None, json_filepath=None):
|
||||||
"""
|
"""
|
||||||
Construct a de pipeline from a json file produced by de.serialize().
|
Construct a de pipeline from a json file produced by de.serialize().
|
||||||
|
|
||||||
|
Currently python function deserialization of map operator are not supported.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dict (dict): a Python dictionary containing a serialized dataset graph
|
input_dict (dict): a Python dictionary containing a serialized dataset graph
|
||||||
json_filepath (str): a path to the json file.
|
json_filepath (str): a path to the json file.
|
||||||
|
@ -349,8 +353,8 @@ def construct_tensor_ops(operations):
|
||||||
op_params = op.get('tensor_op_params')
|
op_params = op.get('tensor_op_params')
|
||||||
|
|
||||||
if op.get('is_python_front_end_op'): # check if it's a py_transform op
|
if op.get('is_python_front_end_op'): # check if it's a py_transform op
|
||||||
result.append(pickle.loads(op_params.encode()))
|
raise NotImplementedError("python function is not yet supported by de.deserialize().")
|
||||||
else:
|
|
||||||
if op_name == "HwcToChw": op_name = "HWC2CHW"
|
if op_name == "HwcToChw": op_name = "HWC2CHW"
|
||||||
if op_name == "UniformAug": op_name = "UniformAugment"
|
if op_name == "UniformAug": op_name = "UniformAugment"
|
||||||
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
|
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
|
||||||
|
|
|
@ -375,7 +375,13 @@ def test_serdes_pyvision(remove_json_files=True):
|
||||||
py_vision.ToTensor()
|
py_vision.ToTensor()
|
||||||
]
|
]
|
||||||
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
|
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
|
||||||
|
# Current python function derialization will be failed for pickle, so we disable this testcase
|
||||||
|
# as an exception testcase.
|
||||||
|
try:
|
||||||
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
||||||
|
assert False
|
||||||
|
except NotImplementedError as e:
|
||||||
|
assert "python function is not yet supported" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_serdes_uniform_augment(remove_json_files=True):
|
def test_serdes_uniform_augment(remove_json_files=True):
|
||||||
|
|
Loading…
Reference in New Issue