diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc index b1a8d0981c0..45e25d17a0c 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc @@ -121,9 +121,6 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) { Status PyFuncOp::to_json(nlohmann::json *out_json) { nlohmann::json args; - auto package = pybind11::module::import("pickle"); - auto module = package.attr("dumps"); - args["tensor_op_params"] = module(py_func_ptr_, 0).cast(); args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast(); args["is_python_front_end_op"] = true; *out_json = args; diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index f52b39238e2..74abf3c465d 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -17,7 +17,6 @@ Functions to support dataset serialize and deserialize. """ import json import os -import pickle import sys import mindspore.common.dtype as mstype @@ -30,6 +29,9 @@ def serialize(dataset, json_filepath=""): """ Serialize dataset pipeline into a json file. + Currently some python objects are not supported to be serialized. + For python function serialization of map operator, de.serialize will only return its function name. + Args: dataset (Dataset): the starting node. json_filepath (str): a filepath where a serialized json file will be generated. @@ -56,6 +58,8 @@ def deserialize(input_dict=None, json_filepath=None): """ Construct a de pipeline from a json file produced by de.serialize(). + Currently python function deserialization of map operator are not supported. + Args: input_dict (dict): a Python dictionary containing a serialized dataset graph json_filepath (str): a path to the json file. @@ -349,42 +353,42 @@ def construct_tensor_ops(operations): op_params = op.get('tensor_op_params') if op.get('is_python_front_end_op'): # check if it's a py_transform op - result.append(pickle.loads(op_params.encode())) + raise NotImplementedError("python function is not yet supported by de.deserialize().") + + if op_name == "HwcToChw": op_name = "HWC2CHW" + if op_name == "UniformAug": op_name = "UniformAugment" + op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"] + op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"] + + if hasattr(op_module_vis, op_name): + op_class = getattr(op_module_vis, op_name, None) + elif hasattr(op_module_trans, op_name[:-2]): + op_name = op_name[:-2] # to remove op from the back of the name + op_class = getattr(op_module_trans, op_name, None) else: - if op_name == "HwcToChw": op_name = "HWC2CHW" - if op_name == "UniformAug": op_name = "UniformAugment" - op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"] - op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"] + raise RuntimeError(op_name + " is not yet supported by deserialize().") - if hasattr(op_module_vis, op_name): - op_class = getattr(op_module_vis, op_name, None) - elif hasattr(op_module_trans, op_name[:-2]): - op_name = op_name[:-2] # to remove op from the back of the name - op_class = getattr(op_module_trans, op_name, None) - else: - raise RuntimeError(op_name + " is not yet supported by deserialize().") + if op_params is None: # If no parameter is specified, call it directly + result.append(op_class()) + else: + # Input parameter type cast + for key, val in op_params.items(): + if key in ['center', 'fill_value']: + op_params[key] = tuple(val) + elif key in ['interpolation', 'resample']: + op_params[key] = Inter(to_interpolation_mode(val)) + elif key in ['padding_mode']: + op_params[key] = Border(to_border_mode(val)) + elif key in ['data_type']: + op_params[key] = to_mstype(val) + elif key in ['image_batch_format']: + op_params[key] = to_image_batch_format(val) + elif key in ['policy']: + op_params[key] = to_policy(val) + elif key in ['transform', 'transforms']: + op_params[key] = construct_tensor_ops(val) - if op_params is None: # If no parameter is specified, call it directly - result.append(op_class()) - else: - # Input parameter type cast - for key, val in op_params.items(): - if key in ['center', 'fill_value']: - op_params[key] = tuple(val) - elif key in ['interpolation', 'resample']: - op_params[key] = Inter(to_interpolation_mode(val)) - elif key in ['padding_mode']: - op_params[key] = Border(to_border_mode(val)) - elif key in ['data_type']: - op_params[key] = to_mstype(val) - elif key in ['image_batch_format']: - op_params[key] = to_image_batch_format(val) - elif key in ['policy']: - op_params[key] = to_policy(val) - elif key in ['transform', 'transforms']: - op_params[key] = construct_tensor_ops(val) - - result.append(op_class(**op_params)) + result.append(op_class(**op_params)) return result diff --git a/tests/ut/python/dataset/test_serdes_dataset.py b/tests/ut/python/dataset/test_serdes_dataset.py index 8e0bdd0971f..e68065b6be7 100644 --- a/tests/ut/python/dataset/test_serdes_dataset.py +++ b/tests/ut/python/dataset/test_serdes_dataset.py @@ -375,7 +375,13 @@ def test_serdes_pyvision(remove_json_files=True): py_vision.ToTensor() ] data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"]) - util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files) + # Current python function derialization will be failed for pickle, so we disable this testcase + # as an exception testcase. + try: + util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files) + assert False + except NotImplementedError as e: + assert "python function is not yet supported" in str(e) def test_serdes_uniform_augment(remove_json_files=True):