!23264 python serdes bug fix

Merge pull request !23264 from zetongzhao/deserialize
This commit is contained in:
i-robot 2021-09-13 15:53:05 +00:00 committed by Gitee
commit 6dcfe64501
2 changed files with 18 additions and 0 deletions

View File

@ -15,7 +15,9 @@
"""
Built-in py_transforms_utils functions.
"""
import json
import random
from types import FunctionType
import numpy as np
from ..core.py_util_helpers import is_numpy, ExceptionHandler
@ -187,4 +189,9 @@ class FuncWrapper:
return result
def to_json(self):
if isinstance(self.transform, FunctionType):
json_obj = {}
json_obj["tensor_op_name"] = self.transform.__name__
json_obj["python_module"] = self.__class__.__module__
return json.dumps(json_obj)
return self.transform.to_json()

View File

@ -396,6 +396,17 @@ def test_serdes_pyvision(remove_json_files=True):
data1 = data1.map(operations=py.Compose(transforms1), input_columns=["image"])
data1 = data1.map(operations=py.RandomApply(transforms2), input_columns=["image"])
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
data2 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
data2 = data2.map(operations=(lambda x, y, z: (
np.array(x).flatten().reshape(10, 39),
np.array(y).flatten().reshape(10, 39),
np.array(z).flatten().reshape(10, 1)
)))
ds.serialize(data2, "pyvision_dataset_pipeline.json")
assert validate_jsonfile("pyvision_dataset_pipeline.json") is True
if remove_json_files:
delete_json_files()
def test_serdes_uniform_augment(remove_json_files=True):