forked from mindspore-Ecosystem/mindspore
!23264 python serdes bug fix
Merge pull request !23264 from zetongzhao/deserialize
This commit is contained in:
commit
6dcfe64501
|
@ -15,7 +15,9 @@
|
|||
"""
|
||||
Built-in py_transforms_utils functions.
|
||||
"""
|
||||
import json
|
||||
import random
|
||||
from types import FunctionType
|
||||
import numpy as np
|
||||
|
||||
from ..core.py_util_helpers import is_numpy, ExceptionHandler
|
||||
|
@ -187,4 +189,9 @@ class FuncWrapper:
|
|||
return result
|
||||
|
||||
def to_json(self):
|
||||
if isinstance(self.transform, FunctionType):
|
||||
json_obj = {}
|
||||
json_obj["tensor_op_name"] = self.transform.__name__
|
||||
json_obj["python_module"] = self.__class__.__module__
|
||||
return json.dumps(json_obj)
|
||||
return self.transform.to_json()
|
||||
|
|
|
@ -396,6 +396,17 @@ def test_serdes_pyvision(remove_json_files=True):
|
|||
data1 = data1.map(operations=py.Compose(transforms1), input_columns=["image"])
|
||||
data1 = data1.map(operations=py.RandomApply(transforms2), input_columns=["image"])
|
||||
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
||||
data2 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||
data2 = data2.map(operations=(lambda x, y, z: (
|
||||
np.array(x).flatten().reshape(10, 39),
|
||||
np.array(y).flatten().reshape(10, 39),
|
||||
np.array(z).flatten().reshape(10, 1)
|
||||
)))
|
||||
ds.serialize(data2, "pyvision_dataset_pipeline.json")
|
||||
assert validate_jsonfile("pyvision_dataset_pipeline.json") is True
|
||||
|
||||
if remove_json_files:
|
||||
delete_json_files()
|
||||
|
||||
|
||||
def test_serdes_uniform_augment(remove_json_files=True):
|
||||
|
|
Loading…
Reference in New Issue