!22784 serdes python operations by pyfunc

Merge pull request !22784 from zetongzhao/deserialize
This commit is contained in:
i-robot 2021-09-08 17:18:57 +00:00 committed by Gitee
commit b21da536d3
8 changed files with 152 additions and 66 deletions

View File

@ -222,18 +222,20 @@ Status Serdes::ConstructSampler(nlohmann::json json_obj, std::shared_ptr<Sampler
Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) {
std::vector<std::shared_ptr<TensorOperation>> output;
for (nlohmann::json item : json_obj) {
CHECK_FAIL_RETURN_UNEXPECTED(item.find("is_python_front_end_op") == item.end(),
"python operation is not yet supported");
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_name") != item.end(), "Failed to find tensor_op_name");
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_params") != item.end(), "Failed to find tensor_op_params");
std::string op_name = item["tensor_op_name"];
nlohmann::json op_params = item["tensor_op_params"];
std::shared_ptr<TensorOperation> operation = nullptr;
CHECK_FAIL_RETURN_UNEXPECTED(func_ptr_.find(op_name) != func_ptr_.end(), "Failed to find " + op_name);
RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
output.push_back(operation);
if (item.find("python_module") != item.end()) {
RETURN_IF_NOT_OK(PyFuncOp::from_json(item, result));
} else {
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_name") != item.end(), "Failed to find tensor_op_name");
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_params") != item.end(), "Failed to find tensor_op_params");
std::string op_name = item["tensor_op_name"];
nlohmann::json op_params = item["tensor_op_params"];
std::shared_ptr<TensorOperation> operation = nullptr;
CHECK_FAIL_RETURN_UNEXPECTED(func_ptr_.find(op_name) != func_ptr_.end(), "Failed to find " + op_name);
RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
output.push_back(operation);
*result = output;
}
}
*result = output;
return Status::OK();
}

View File

@ -77,6 +77,7 @@
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/include/dataset/vision.h"
#include "minddata/dataset/kernels/py_func_op.h"
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/affine_ir.h"

View File

@ -20,6 +20,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -124,12 +125,29 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
Status PyFuncOp::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
args["is_python_front_end_op"] = true;
if (py_func_ptr_.attr("to_json")) {
args = nlohmann::json::parse(py_func_ptr_.attr("to_json")().cast<std::string>());
}
*out_json = args;
return Status::OK();
}
Status PyFuncOp::from_json(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) {
std::vector<std::shared_ptr<TensorOperation>> output;
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("tensor_op_name") != json_obj.end(), "Failed to find tensor_op_name");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("tensor_op_params") != json_obj.end(), "Failed to find tensor_op_params");
std::string op_name = json_obj["tensor_op_name"];
nlohmann::json op_params = json_obj["tensor_op_params"];
std::string python_module = json_obj["python_module"];
std::shared_ptr<TensorOperation> operation = nullptr;
py::function py_func =
py::module::import(python_module.c_str()).attr(op_name.c_str()).attr("from_json")(op_params.dump());
operation = std::make_shared<transforms::PreBuiltOperation>(std::make_shared<PyFuncOp>(py_func));
output.push_back(operation);
*result = output;
return Status::OK();
}
bool PyFuncOp::IsRandom() {
bool random = true;
if (py::hasattr(py_func_ptr_, "random") && py::reinterpret_borrow<py::bool_>(py_func_ptr_.attr("random")) == false)

View File

@ -51,6 +51,8 @@ class PyFuncOp : public TensorOp {
std::string Name() const override { return kPyFuncOp; }
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::vector<std::shared_ptr<TensorOperation>> *result);
/// \brief Check whether this pyfunc op is deterministic
/// \return True if this pyfunc op is random
bool IsRandom();

View File

@ -16,6 +16,10 @@
The module transforms.py_transform is implemented based on Python. It provides common
operations including OneHotOp.
"""
import json
import sys
import numpy as np
from .validators import check_one_hot_op, check_compose_list, check_random_apply, check_transforms_list, \
check_compose_call
from . import py_transforms_util as util
@ -31,7 +35,58 @@ def not_random(function):
return function
class OneHotOp:
class PyTensorOperation:
"""
Base Python Tensor Operations class
"""
def to_json(self):
"""
Base to_json for Python tensor operations class
"""
json_obj = {}
json_trans = {}
if "transforms" in self.__dict__.keys():
# operations which have transforms as input, need to call _to_json() for each transform to serialize
json_list = []
for transform in self.transforms:
json_list.append(json.loads(transform.to_json()))
json_trans["transforms"] = json_list
self.__dict__.pop("transforms")
if "output_type" in self.__dict__.keys():
json_trans["output_type"] = np.dtype(
self.__dict__["output_type"]).name
self.__dict__.pop("output_type")
json_obj["tensor_op_params"] = self.__dict__
# append transforms to the tensor_op_params of the operation
json_obj["tensor_op_params"].update(json_trans)
json_obj["tensor_op_name"] = self.__class__.__name__
json_obj["python_module"] = self.__class__.__module__
return json.dumps(json_obj)
@classmethod
def from_json(cls, json_string):
"""
Base from_json for Python tensor operations class
"""
json_obj = json.loads(json_string)
new_op = cls.__new__(cls)
new_op.__dict__ = json_obj
if "transforms" in json_obj.keys():
# operations which have transforms as input, need to call _from_json() for each transform to deseriallize
transforms = []
for json_op in json_obj["transforms"]:
transforms.append(getattr(
sys.modules[json_op["python_module"]], json_op["tensor_op_name"]).from_json(
json.dumps(json_op["tensor_op_params"])))
new_op.transforms = transforms
if "output_type" in json_obj.keys():
output_type = np.dtype(json_obj["output_type"])
new_op.output_type = output_type
return new_op
class OneHotOp(PyTensorOperation):
"""
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.
@ -67,7 +122,7 @@ class OneHotOp:
return util.one_hot_encoding(label, self.num_classes, self.smoothing_rate)
class Compose:
class Compose(PyTensorOperation):
"""
Compose a list of transforms.
@ -170,7 +225,7 @@ class Compose:
return new_ops
class RandomApply:
class RandomApply(PyTensorOperation):
"""
Randomly perform a series of transforms with a given probability.
@ -207,7 +262,7 @@ class RandomApply:
return util.random_apply(img, self.transforms, self.prob)
class RandomChoice:
class RandomChoice(PyTensorOperation):
"""
Randomly select one transform from a series of transforms and applies that on the image.
@ -242,7 +297,7 @@ class RandomChoice:
return util.random_choice(img, self.transforms)
class RandomOrder:
class RandomOrder(PyTensorOperation):
"""
Perform a series of transforms to the input PIL image in a random order.

View File

@ -159,3 +159,6 @@ class FuncWrapper:
result = ExceptionHandler(where="in map(or batch) worker and execute python function")
result.reraise()
return result
def to_json(self):
return self.transform.to_json()

View File

@ -25,6 +25,7 @@ import random
import numpy as np
from PIL import Image
import mindspore.dataset.transforms.py_transforms as py_transforms
from . import py_transforms_util as util
from .c_transforms import parse_padding
from .validators import check_prob, check_center_crop, check_five_crop, check_resize_interpolation, check_random_resize_crop, \
@ -56,7 +57,7 @@ def not_random(function):
return function
class ToTensor:
class ToTensor(py_transforms.PyTensorOperation):
"""
Convert the input PIL Image or numpy.ndarray of shape (H, W, C) in the range [0, 255] to numpy.ndarray of
shape (C, H, W) in the range [0.0, 1.0] with the desired dtype.
@ -101,7 +102,7 @@ class ToTensor:
return util.to_tensor(img, self.output_type)
class ToType:
class ToType(py_transforms.PyTensorOperation):
"""
Convert the input numpy.ndarray image to the desired dtype.
@ -140,7 +141,7 @@ class ToType:
return util.to_type(img, self.output_type)
class HWC2CHW:
class HWC2CHW(py_transforms.PyTensorOperation):
"""
Transpose the input numpy.ndarray image of shape (H, W, C) to (C, H, W).
@ -173,7 +174,7 @@ class HWC2CHW:
return util.hwc_to_chw(img)
class ToPIL:
class ToPIL(py_transforms.PyTensorOperation):
"""
Convert the input decoded numpy.ndarray image to PIL Image.
@ -210,7 +211,7 @@ class ToPIL:
return util.to_pil(img)
class Decode:
class Decode(py_transforms.PyTensorOperation):
"""
Decode the input raw image to PIL Image format in RGB mode.
@ -244,7 +245,7 @@ class Decode:
return util.decode(img)
class Normalize:
class Normalize(py_transforms.PyTensorOperation):
r"""
Normalize the input numpy.ndarray image of shape (C, H, W) with the specified mean and standard deviation.
@ -300,7 +301,7 @@ class Normalize:
return util.normalize(img, self.mean, self.std)
class NormalizePad:
class NormalizePad(py_transforms.PyTensorOperation):
r"""
Normalize the input numpy.ndarray image of shape (C, H, W) with the specified mean and standard deviation,
then pad an extra channel filled with zeros.
@ -362,7 +363,7 @@ class NormalizePad:
return util.normalize(img, self.mean, self.std, pad_channel=True, dtype=self.dtype)
class RandomCrop:
class RandomCrop(py_transforms.PyTensorOperation):
"""
Crop the input PIL Image at a random location with the specified size.
@ -431,7 +432,7 @@ class RandomCrop:
self.fill_value, self.padding_mode)
class RandomHorizontalFlip:
class RandomHorizontalFlip(py_transforms.PyTensorOperation):
"""
Randomly flip the input image horizontally with a given probability.
@ -465,7 +466,7 @@ class RandomHorizontalFlip:
return util.random_horizontal_flip(img, self.prob)
class RandomVerticalFlip:
class RandomVerticalFlip(py_transforms.PyTensorOperation):
"""
Randomly flip the input image vertically with a given probability.
@ -499,7 +500,7 @@ class RandomVerticalFlip:
return util.random_vertical_flip(img, self.prob)
class Resize:
class Resize(py_transforms.PyTensorOperation):
"""
Resize the input PIL image to the given size.
@ -548,7 +549,7 @@ class Resize:
return util.resize(img, self.size, self.interpolation)
class RandomResizedCrop:
class RandomResizedCrop(py_transforms.PyTensorOperation):
"""
Extract crop from the input image and resize it to a random size and aspect ratio.
@ -606,7 +607,7 @@ class RandomResizedCrop:
self.interpolation, self.max_attempts)
class CenterCrop:
class CenterCrop(py_transforms.PyTensorOperation):
"""
Crop the central reigion of the input PIL image to the given size.
@ -643,7 +644,7 @@ class CenterCrop:
return util.center_crop(img, self.size)
class RandomColorAdjust:
class RandomColorAdjust(py_transforms.PyTensorOperation):
"""
Perform a random brightness, contrast, saturation, and hue adjustment on the input PIL image.
@ -691,7 +692,7 @@ class RandomColorAdjust:
return util.random_color_adjust(img, self.brightness, self.contrast, self.saturation, self.hue)
class RandomRotation:
class RandomRotation(py_transforms.PyTensorOperation):
"""
Rotate the input PIL image by a random angle.
@ -756,7 +757,7 @@ class RandomRotation:
return util.random_rotation(img, self.degrees, self.resample, self.expand, self.center, self.fill_value)
class FiveCrop:
class FiveCrop(py_transforms.PyTensorOperation):
"""
Generate 5 cropped images (one central image and four corners images).
@ -795,7 +796,7 @@ class FiveCrop:
return util.five_crop(img, self.size)
class TenCrop:
class TenCrop(py_transforms.PyTensorOperation):
"""
Generate 10 cropped images (first 5 images from FiveCrop, second 5 images from their flipped version
as per input flag to flip vertically or horizontally).
@ -841,7 +842,7 @@ class TenCrop:
return util.ten_crop(img, self.size, self.use_vertical_flip)
class Grayscale:
class Grayscale(py_transforms.PyTensorOperation):
"""
Convert the input PIL image to grayscale image.
@ -877,7 +878,7 @@ class Grayscale:
return util.grayscale(img, num_output_channels=self.num_output_channels)
class RandomGrayscale:
class RandomGrayscale(py_transforms.PyTensorOperation):
"""
Randomly convert the input image into grayscale image with a given probability.
@ -920,7 +921,7 @@ class RandomGrayscale:
return img
class Pad:
class Pad(py_transforms.PyTensorOperation):
"""
Pad the input PIL image according to padding parameters.
@ -981,7 +982,7 @@ class Pad:
return util.pad(img, self.padding, self.fill_value, self.padding_mode)
class RandomPerspective:
class RandomPerspective(py_transforms.PyTensorOperation):
"""
Randomly apply perspective transformation to the input PIL image with a given probability.
@ -1026,12 +1027,13 @@ class RandomPerspective:
if not is_pil(img):
raise ValueError("Input image should be a Pillow image.")
if self.prob > random.random():
start_points, end_points = util.get_perspective_params(img, self.distortion_scale)
start_points, end_points = util.get_perspective_params(
img, self.distortion_scale)
return util.perspective(img, start_points, end_points, self.interpolation)
return img
class RandomErasing:
class RandomErasing(py_transforms.PyTensorOperation):
"""
Erase the pixels, within a selected rectangle region, to the given value.
@ -1090,7 +1092,7 @@ class RandomErasing:
return np_img
class Cutout:
class Cutout(py_transforms.PyTensorOperation):
"""
Randomly cut (mask) out a given number of square patches from the input NumPy image array of shape (C, H, W).
@ -1128,9 +1130,11 @@ class Cutout:
np_img (numpy.ndarray), NumPy image array with square patches cut out.
"""
if not isinstance(np_img, np.ndarray):
raise TypeError("img should be NumPy array. Got {}.".format(type(np_img)))
raise TypeError(
"img should be NumPy array. Got {}.".format(type(np_img)))
if np_img.ndim != 3:
raise TypeError('img dimension should be 3. Got {}.'.format(np_img.ndim))
raise TypeError(
'img dimension should be 3. Got {}.'.format(np_img.ndim))
_, image_h, image_w = np_img.shape
scale = (self.length * self.length) / (image_h * image_w)
@ -1143,7 +1147,7 @@ class Cutout:
return np_img
class LinearTransformation:
class LinearTransformation(py_transforms.PyTensorOperation):
r"""
Apply linear transformation to the input NumPy image array, given a square transformation matrix and
a mean vector.
@ -1191,7 +1195,7 @@ class LinearTransformation:
return util.linear_transform(np_img, self.transformation_matrix, self.mean_vector)
class RandomAffine:
class RandomAffine(py_transforms.PyTensorOperation):
"""
Apply Random affine transformation to the input PIL image.
@ -1292,7 +1296,7 @@ class RandomAffine:
self.fill_value)
class MixUp:
class MixUp(py_transforms.PyTensorOperation):
"""
Apply mix up transformation to the input image and label. Make one input data combined with others.
@ -1338,7 +1342,7 @@ class MixUp:
return util.mix_up_muti(self, self.batch_size, image, label, self.alpha)
class RgbToBgr:
class RgbToBgr(py_transforms.PyTensorOperation):
"""
Convert a NumPy RGB image or a batch of NumPy RGB images to BGR images.
@ -1376,7 +1380,7 @@ class RgbToBgr:
return util.rgb_to_bgrs(rgb_imgs, self.is_hwc)
class RgbToHsv:
class RgbToHsv(py_transforms.PyTensorOperation):
"""
Convert a NumPy RGB image or a batch of NumPy RGB images to HSV images.
@ -1414,7 +1418,7 @@ class RgbToHsv:
return util.rgb_to_hsvs(rgb_imgs, self.is_hwc)
class HsvToRgb:
class HsvToRgb(py_transforms.PyTensorOperation):
"""
Convert a NumPy HSV image or one batch NumPy HSV images to RGB images.
@ -1452,7 +1456,7 @@ class HsvToRgb:
return util.hsv_to_rgbs(hsv_imgs, self.is_hwc)
class RandomColor:
class RandomColor(py_transforms.PyTensorOperation):
"""
Adjust the color of the input PIL image by a random degree.
@ -1488,7 +1492,7 @@ class RandomColor:
return util.random_color(img, self.degrees)
class RandomSharpness:
class RandomSharpness(py_transforms.PyTensorOperation):
"""
Adjust the sharpness of the input PIL image by a fixed or random degree. Degree of 0.0 gives a blurred image,
degree of 1.0 gives the original image, and degree of 2.0 gives a sharpened image.
@ -1525,7 +1529,7 @@ class RandomSharpness:
return util.random_sharpness(img, self.degrees)
class AdjustGamma:
class AdjustGamma(py_transforms.PyTensorOperation):
"""
Adjust gamma of the input PIL image.
@ -1563,7 +1567,7 @@ class AdjustGamma:
return util.adjust_gamma(img, self.gamma, self.gain)
class AutoContrast:
class AutoContrast(py_transforms.PyTensorOperation):
"""
Automatically maximize the contrast of the input PIL image.
@ -1602,7 +1606,7 @@ class AutoContrast:
return util.auto_contrast(img, self.cutoff, self.ignore)
class Invert:
class Invert(py_transforms.PyTensorOperation):
"""
Invert colors of input PIL image.
@ -1633,7 +1637,7 @@ class Invert:
return util.invert_color(img)
class Equalize:
class Equalize(py_transforms.PyTensorOperation):
"""
Equalize the histogram of input PIL image.
@ -1665,7 +1669,7 @@ class Equalize:
return util.equalize(img)
class UniformAugment:
class UniformAugment(py_transforms.PyTensorOperation):
"""
Uniformly select and apply a number of transforms sequentially from
a list of transforms. Randomly assign a probability to each transform for

View File

@ -382,19 +382,20 @@ def test_serdes_pyvision(remove_json_files=True):
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
transforms = [
transforms1 = [
py_vision.Decode(),
py_vision.CenterCrop([32, 32]),
py_vision.ToTensor()
]
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
# Current python function derialization will be failed for pickle, so we disable this testcase
# as an exception testcase.
try:
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
assert False
except RuntimeError as e:
assert "python operation is not yet supported" in str(e)
transforms2 = [
py_vision.RandomColorAdjust(),
py_vision.FiveCrop(1),
py_vision.Grayscale(),
py.OneHotOp(1)
]
data1 = data1.map(operations=py.Compose(transforms1), input_columns=["image"])
data1 = data1.map(operations=py.RandomApply(transforms2), input_columns=["image"])
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
def test_serdes_uniform_augment(remove_json_files=True):