forked from mindspore-Ecosystem/mindspore
!22784 serdes python operations by pyfunc
Merge pull request !22784 from zetongzhao/deserialize
This commit is contained in:
commit
b21da536d3
|
@ -222,18 +222,20 @@ Status Serdes::ConstructSampler(nlohmann::json json_obj, std::shared_ptr<Sampler
|
|||
Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> output;
|
||||
for (nlohmann::json item : json_obj) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.find("is_python_front_end_op") == item.end(),
|
||||
"python operation is not yet supported");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_name") != item.end(), "Failed to find tensor_op_name");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_params") != item.end(), "Failed to find tensor_op_params");
|
||||
std::string op_name = item["tensor_op_name"];
|
||||
nlohmann::json op_params = item["tensor_op_params"];
|
||||
std::shared_ptr<TensorOperation> operation = nullptr;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_ptr_.find(op_name) != func_ptr_.end(), "Failed to find " + op_name);
|
||||
RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
|
||||
output.push_back(operation);
|
||||
if (item.find("python_module") != item.end()) {
|
||||
RETURN_IF_NOT_OK(PyFuncOp::from_json(item, result));
|
||||
} else {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_name") != item.end(), "Failed to find tensor_op_name");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_params") != item.end(), "Failed to find tensor_op_params");
|
||||
std::string op_name = item["tensor_op_name"];
|
||||
nlohmann::json op_params = item["tensor_op_params"];
|
||||
std::shared_ptr<TensorOperation> operation = nullptr;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_ptr_.find(op_name) != func_ptr_.end(), "Failed to find " + op_name);
|
||||
RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
|
||||
output.push_back(operation);
|
||||
*result = output;
|
||||
}
|
||||
}
|
||||
*result = output;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -77,6 +77,7 @@
|
|||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/include/dataset/vision.h"
|
||||
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/affine_ir.h"
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -124,12 +125,29 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
|
|||
|
||||
Status PyFuncOp::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
|
||||
args["is_python_front_end_op"] = true;
|
||||
if (py_func_ptr_.attr("to_json")) {
|
||||
args = nlohmann::json::parse(py_func_ptr_.attr("to_json")().cast<std::string>());
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PyFuncOp::from_json(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> output;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("tensor_op_name") != json_obj.end(), "Failed to find tensor_op_name");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("tensor_op_params") != json_obj.end(), "Failed to find tensor_op_params");
|
||||
std::string op_name = json_obj["tensor_op_name"];
|
||||
nlohmann::json op_params = json_obj["tensor_op_params"];
|
||||
std::string python_module = json_obj["python_module"];
|
||||
std::shared_ptr<TensorOperation> operation = nullptr;
|
||||
py::function py_func =
|
||||
py::module::import(python_module.c_str()).attr(op_name.c_str()).attr("from_json")(op_params.dump());
|
||||
operation = std::make_shared<transforms::PreBuiltOperation>(std::make_shared<PyFuncOp>(py_func));
|
||||
output.push_back(operation);
|
||||
*result = output;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool PyFuncOp::IsRandom() {
|
||||
bool random = true;
|
||||
if (py::hasattr(py_func_ptr_, "random") && py::reinterpret_borrow<py::bool_>(py_func_ptr_.attr("random")) == false)
|
||||
|
|
|
@ -51,6 +51,8 @@ class PyFuncOp : public TensorOp {
|
|||
std::string Name() const override { return kPyFuncOp; }
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::vector<std::shared_ptr<TensorOperation>> *result);
|
||||
|
||||
/// \brief Check whether this pyfunc op is deterministic
|
||||
/// \return True if this pyfunc op is random
|
||||
bool IsRandom();
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
The module transforms.py_transform is implemented based on Python. It provides common
|
||||
operations including OneHotOp.
|
||||
"""
|
||||
import json
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from .validators import check_one_hot_op, check_compose_list, check_random_apply, check_transforms_list, \
|
||||
check_compose_call
|
||||
from . import py_transforms_util as util
|
||||
|
@ -31,7 +35,58 @@ def not_random(function):
|
|||
return function
|
||||
|
||||
|
||||
class OneHotOp:
|
||||
class PyTensorOperation:
|
||||
"""
|
||||
Base Python Tensor Operations class
|
||||
"""
|
||||
|
||||
def to_json(self):
|
||||
"""
|
||||
Base to_json for Python tensor operations class
|
||||
"""
|
||||
json_obj = {}
|
||||
json_trans = {}
|
||||
if "transforms" in self.__dict__.keys():
|
||||
# operations which have transforms as input, need to call _to_json() for each transform to serialize
|
||||
json_list = []
|
||||
for transform in self.transforms:
|
||||
json_list.append(json.loads(transform.to_json()))
|
||||
json_trans["transforms"] = json_list
|
||||
self.__dict__.pop("transforms")
|
||||
if "output_type" in self.__dict__.keys():
|
||||
json_trans["output_type"] = np.dtype(
|
||||
self.__dict__["output_type"]).name
|
||||
self.__dict__.pop("output_type")
|
||||
json_obj["tensor_op_params"] = self.__dict__
|
||||
# append transforms to the tensor_op_params of the operation
|
||||
json_obj["tensor_op_params"].update(json_trans)
|
||||
json_obj["tensor_op_name"] = self.__class__.__name__
|
||||
json_obj["python_module"] = self.__class__.__module__
|
||||
return json.dumps(json_obj)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_string):
|
||||
"""
|
||||
Base from_json for Python tensor operations class
|
||||
"""
|
||||
json_obj = json.loads(json_string)
|
||||
new_op = cls.__new__(cls)
|
||||
new_op.__dict__ = json_obj
|
||||
if "transforms" in json_obj.keys():
|
||||
# operations which have transforms as input, need to call _from_json() for each transform to deseriallize
|
||||
transforms = []
|
||||
for json_op in json_obj["transforms"]:
|
||||
transforms.append(getattr(
|
||||
sys.modules[json_op["python_module"]], json_op["tensor_op_name"]).from_json(
|
||||
json.dumps(json_op["tensor_op_params"])))
|
||||
new_op.transforms = transforms
|
||||
if "output_type" in json_obj.keys():
|
||||
output_type = np.dtype(json_obj["output_type"])
|
||||
new_op.output_type = output_type
|
||||
return new_op
|
||||
|
||||
|
||||
class OneHotOp(PyTensorOperation):
|
||||
"""
|
||||
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.
|
||||
|
||||
|
@ -67,7 +122,7 @@ class OneHotOp:
|
|||
return util.one_hot_encoding(label, self.num_classes, self.smoothing_rate)
|
||||
|
||||
|
||||
class Compose:
|
||||
class Compose(PyTensorOperation):
|
||||
"""
|
||||
Compose a list of transforms.
|
||||
|
||||
|
@ -170,7 +225,7 @@ class Compose:
|
|||
return new_ops
|
||||
|
||||
|
||||
class RandomApply:
|
||||
class RandomApply(PyTensorOperation):
|
||||
"""
|
||||
Randomly perform a series of transforms with a given probability.
|
||||
|
||||
|
@ -207,7 +262,7 @@ class RandomApply:
|
|||
return util.random_apply(img, self.transforms, self.prob)
|
||||
|
||||
|
||||
class RandomChoice:
|
||||
class RandomChoice(PyTensorOperation):
|
||||
"""
|
||||
Randomly select one transform from a series of transforms and applies that on the image.
|
||||
|
||||
|
@ -242,7 +297,7 @@ class RandomChoice:
|
|||
return util.random_choice(img, self.transforms)
|
||||
|
||||
|
||||
class RandomOrder:
|
||||
class RandomOrder(PyTensorOperation):
|
||||
"""
|
||||
Perform a series of transforms to the input PIL image in a random order.
|
||||
|
||||
|
|
|
@ -159,3 +159,6 @@ class FuncWrapper:
|
|||
result = ExceptionHandler(where="in map(or batch) worker and execute python function")
|
||||
result.reraise()
|
||||
return result
|
||||
|
||||
def to_json(self):
|
||||
return self.transform.to_json()
|
||||
|
|
|
@ -25,6 +25,7 @@ import random
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||
from . import py_transforms_util as util
|
||||
from .c_transforms import parse_padding
|
||||
from .validators import check_prob, check_center_crop, check_five_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||
|
@ -56,7 +57,7 @@ def not_random(function):
|
|||
return function
|
||||
|
||||
|
||||
class ToTensor:
|
||||
class ToTensor(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Convert the input PIL Image or numpy.ndarray of shape (H, W, C) in the range [0, 255] to numpy.ndarray of
|
||||
shape (C, H, W) in the range [0.0, 1.0] with the desired dtype.
|
||||
|
@ -101,7 +102,7 @@ class ToTensor:
|
|||
return util.to_tensor(img, self.output_type)
|
||||
|
||||
|
||||
class ToType:
|
||||
class ToType(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Convert the input numpy.ndarray image to the desired dtype.
|
||||
|
||||
|
@ -140,7 +141,7 @@ class ToType:
|
|||
return util.to_type(img, self.output_type)
|
||||
|
||||
|
||||
class HWC2CHW:
|
||||
class HWC2CHW(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Transpose the input numpy.ndarray image of shape (H, W, C) to (C, H, W).
|
||||
|
||||
|
@ -173,7 +174,7 @@ class HWC2CHW:
|
|||
return util.hwc_to_chw(img)
|
||||
|
||||
|
||||
class ToPIL:
|
||||
class ToPIL(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Convert the input decoded numpy.ndarray image to PIL Image.
|
||||
|
||||
|
@ -210,7 +211,7 @@ class ToPIL:
|
|||
return util.to_pil(img)
|
||||
|
||||
|
||||
class Decode:
|
||||
class Decode(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Decode the input raw image to PIL Image format in RGB mode.
|
||||
|
||||
|
@ -244,7 +245,7 @@ class Decode:
|
|||
return util.decode(img)
|
||||
|
||||
|
||||
class Normalize:
|
||||
class Normalize(py_transforms.PyTensorOperation):
|
||||
r"""
|
||||
Normalize the input numpy.ndarray image of shape (C, H, W) with the specified mean and standard deviation.
|
||||
|
||||
|
@ -300,7 +301,7 @@ class Normalize:
|
|||
return util.normalize(img, self.mean, self.std)
|
||||
|
||||
|
||||
class NormalizePad:
|
||||
class NormalizePad(py_transforms.PyTensorOperation):
|
||||
r"""
|
||||
Normalize the input numpy.ndarray image of shape (C, H, W) with the specified mean and standard deviation,
|
||||
then pad an extra channel filled with zeros.
|
||||
|
@ -362,7 +363,7 @@ class NormalizePad:
|
|||
return util.normalize(img, self.mean, self.std, pad_channel=True, dtype=self.dtype)
|
||||
|
||||
|
||||
class RandomCrop:
|
||||
class RandomCrop(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Crop the input PIL Image at a random location with the specified size.
|
||||
|
||||
|
@ -431,7 +432,7 @@ class RandomCrop:
|
|||
self.fill_value, self.padding_mode)
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
class RandomHorizontalFlip(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Randomly flip the input image horizontally with a given probability.
|
||||
|
||||
|
@ -465,7 +466,7 @@ class RandomHorizontalFlip:
|
|||
return util.random_horizontal_flip(img, self.prob)
|
||||
|
||||
|
||||
class RandomVerticalFlip:
|
||||
class RandomVerticalFlip(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Randomly flip the input image vertically with a given probability.
|
||||
|
||||
|
@ -499,7 +500,7 @@ class RandomVerticalFlip:
|
|||
return util.random_vertical_flip(img, self.prob)
|
||||
|
||||
|
||||
class Resize:
|
||||
class Resize(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Resize the input PIL image to the given size.
|
||||
|
||||
|
@ -548,7 +549,7 @@ class Resize:
|
|||
return util.resize(img, self.size, self.interpolation)
|
||||
|
||||
|
||||
class RandomResizedCrop:
|
||||
class RandomResizedCrop(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Extract crop from the input image and resize it to a random size and aspect ratio.
|
||||
|
||||
|
@ -606,7 +607,7 @@ class RandomResizedCrop:
|
|||
self.interpolation, self.max_attempts)
|
||||
|
||||
|
||||
class CenterCrop:
|
||||
class CenterCrop(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Crop the central reigion of the input PIL image to the given size.
|
||||
|
||||
|
@ -643,7 +644,7 @@ class CenterCrop:
|
|||
return util.center_crop(img, self.size)
|
||||
|
||||
|
||||
class RandomColorAdjust:
|
||||
class RandomColorAdjust(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Perform a random brightness, contrast, saturation, and hue adjustment on the input PIL image.
|
||||
|
||||
|
@ -691,7 +692,7 @@ class RandomColorAdjust:
|
|||
return util.random_color_adjust(img, self.brightness, self.contrast, self.saturation, self.hue)
|
||||
|
||||
|
||||
class RandomRotation:
|
||||
class RandomRotation(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Rotate the input PIL image by a random angle.
|
||||
|
||||
|
@ -756,7 +757,7 @@ class RandomRotation:
|
|||
return util.random_rotation(img, self.degrees, self.resample, self.expand, self.center, self.fill_value)
|
||||
|
||||
|
||||
class FiveCrop:
|
||||
class FiveCrop(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Generate 5 cropped images (one central image and four corners images).
|
||||
|
||||
|
@ -795,7 +796,7 @@ class FiveCrop:
|
|||
return util.five_crop(img, self.size)
|
||||
|
||||
|
||||
class TenCrop:
|
||||
class TenCrop(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Generate 10 cropped images (first 5 images from FiveCrop, second 5 images from their flipped version
|
||||
as per input flag to flip vertically or horizontally).
|
||||
|
@ -841,7 +842,7 @@ class TenCrop:
|
|||
return util.ten_crop(img, self.size, self.use_vertical_flip)
|
||||
|
||||
|
||||
class Grayscale:
|
||||
class Grayscale(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Convert the input PIL image to grayscale image.
|
||||
|
||||
|
@ -877,7 +878,7 @@ class Grayscale:
|
|||
return util.grayscale(img, num_output_channels=self.num_output_channels)
|
||||
|
||||
|
||||
class RandomGrayscale:
|
||||
class RandomGrayscale(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Randomly convert the input image into grayscale image with a given probability.
|
||||
|
||||
|
@ -920,7 +921,7 @@ class RandomGrayscale:
|
|||
return img
|
||||
|
||||
|
||||
class Pad:
|
||||
class Pad(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Pad the input PIL image according to padding parameters.
|
||||
|
||||
|
@ -981,7 +982,7 @@ class Pad:
|
|||
return util.pad(img, self.padding, self.fill_value, self.padding_mode)
|
||||
|
||||
|
||||
class RandomPerspective:
|
||||
class RandomPerspective(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Randomly apply perspective transformation to the input PIL image with a given probability.
|
||||
|
||||
|
@ -1026,12 +1027,13 @@ class RandomPerspective:
|
|||
if not is_pil(img):
|
||||
raise ValueError("Input image should be a Pillow image.")
|
||||
if self.prob > random.random():
|
||||
start_points, end_points = util.get_perspective_params(img, self.distortion_scale)
|
||||
start_points, end_points = util.get_perspective_params(
|
||||
img, self.distortion_scale)
|
||||
return util.perspective(img, start_points, end_points, self.interpolation)
|
||||
return img
|
||||
|
||||
|
||||
class RandomErasing:
|
||||
class RandomErasing(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Erase the pixels, within a selected rectangle region, to the given value.
|
||||
|
||||
|
@ -1090,7 +1092,7 @@ class RandomErasing:
|
|||
return np_img
|
||||
|
||||
|
||||
class Cutout:
|
||||
class Cutout(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Randomly cut (mask) out a given number of square patches from the input NumPy image array of shape (C, H, W).
|
||||
|
||||
|
@ -1128,9 +1130,11 @@ class Cutout:
|
|||
np_img (numpy.ndarray), NumPy image array with square patches cut out.
|
||||
"""
|
||||
if not isinstance(np_img, np.ndarray):
|
||||
raise TypeError("img should be NumPy array. Got {}.".format(type(np_img)))
|
||||
raise TypeError(
|
||||
"img should be NumPy array. Got {}.".format(type(np_img)))
|
||||
if np_img.ndim != 3:
|
||||
raise TypeError('img dimension should be 3. Got {}.'.format(np_img.ndim))
|
||||
raise TypeError(
|
||||
'img dimension should be 3. Got {}.'.format(np_img.ndim))
|
||||
|
||||
_, image_h, image_w = np_img.shape
|
||||
scale = (self.length * self.length) / (image_h * image_w)
|
||||
|
@ -1143,7 +1147,7 @@ class Cutout:
|
|||
return np_img
|
||||
|
||||
|
||||
class LinearTransformation:
|
||||
class LinearTransformation(py_transforms.PyTensorOperation):
|
||||
r"""
|
||||
Apply linear transformation to the input NumPy image array, given a square transformation matrix and
|
||||
a mean vector.
|
||||
|
@ -1191,7 +1195,7 @@ class LinearTransformation:
|
|||
return util.linear_transform(np_img, self.transformation_matrix, self.mean_vector)
|
||||
|
||||
|
||||
class RandomAffine:
|
||||
class RandomAffine(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Apply Random affine transformation to the input PIL image.
|
||||
|
||||
|
@ -1292,7 +1296,7 @@ class RandomAffine:
|
|||
self.fill_value)
|
||||
|
||||
|
||||
class MixUp:
|
||||
class MixUp(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Apply mix up transformation to the input image and label. Make one input data combined with others.
|
||||
|
||||
|
@ -1338,7 +1342,7 @@ class MixUp:
|
|||
return util.mix_up_muti(self, self.batch_size, image, label, self.alpha)
|
||||
|
||||
|
||||
class RgbToBgr:
|
||||
class RgbToBgr(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Convert a NumPy RGB image or a batch of NumPy RGB images to BGR images.
|
||||
|
||||
|
@ -1376,7 +1380,7 @@ class RgbToBgr:
|
|||
return util.rgb_to_bgrs(rgb_imgs, self.is_hwc)
|
||||
|
||||
|
||||
class RgbToHsv:
|
||||
class RgbToHsv(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Convert a NumPy RGB image or a batch of NumPy RGB images to HSV images.
|
||||
|
||||
|
@ -1414,7 +1418,7 @@ class RgbToHsv:
|
|||
return util.rgb_to_hsvs(rgb_imgs, self.is_hwc)
|
||||
|
||||
|
||||
class HsvToRgb:
|
||||
class HsvToRgb(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Convert a NumPy HSV image or one batch NumPy HSV images to RGB images.
|
||||
|
||||
|
@ -1452,7 +1456,7 @@ class HsvToRgb:
|
|||
return util.hsv_to_rgbs(hsv_imgs, self.is_hwc)
|
||||
|
||||
|
||||
class RandomColor:
|
||||
class RandomColor(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Adjust the color of the input PIL image by a random degree.
|
||||
|
||||
|
@ -1488,7 +1492,7 @@ class RandomColor:
|
|||
return util.random_color(img, self.degrees)
|
||||
|
||||
|
||||
class RandomSharpness:
|
||||
class RandomSharpness(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Adjust the sharpness of the input PIL image by a fixed or random degree. Degree of 0.0 gives a blurred image,
|
||||
degree of 1.0 gives the original image, and degree of 2.0 gives a sharpened image.
|
||||
|
@ -1525,7 +1529,7 @@ class RandomSharpness:
|
|||
return util.random_sharpness(img, self.degrees)
|
||||
|
||||
|
||||
class AdjustGamma:
|
||||
class AdjustGamma(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Adjust gamma of the input PIL image.
|
||||
|
||||
|
@ -1563,7 +1567,7 @@ class AdjustGamma:
|
|||
return util.adjust_gamma(img, self.gamma, self.gain)
|
||||
|
||||
|
||||
class AutoContrast:
|
||||
class AutoContrast(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Automatically maximize the contrast of the input PIL image.
|
||||
|
||||
|
@ -1602,7 +1606,7 @@ class AutoContrast:
|
|||
return util.auto_contrast(img, self.cutoff, self.ignore)
|
||||
|
||||
|
||||
class Invert:
|
||||
class Invert(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Invert colors of input PIL image.
|
||||
|
||||
|
@ -1633,7 +1637,7 @@ class Invert:
|
|||
return util.invert_color(img)
|
||||
|
||||
|
||||
class Equalize:
|
||||
class Equalize(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Equalize the histogram of input PIL image.
|
||||
|
||||
|
@ -1665,7 +1669,7 @@ class Equalize:
|
|||
return util.equalize(img)
|
||||
|
||||
|
||||
class UniformAugment:
|
||||
class UniformAugment(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Uniformly select and apply a number of transforms sequentially from
|
||||
a list of transforms. Randomly assign a probability to each transform for
|
||||
|
|
|
@ -382,19 +382,20 @@ def test_serdes_pyvision(remove_json_files=True):
|
|||
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||
transforms = [
|
||||
transforms1 = [
|
||||
py_vision.Decode(),
|
||||
py_vision.CenterCrop([32, 32]),
|
||||
py_vision.ToTensor()
|
||||
]
|
||||
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
|
||||
# Current python function derialization will be failed for pickle, so we disable this testcase
|
||||
# as an exception testcase.
|
||||
try:
|
||||
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "python operation is not yet supported" in str(e)
|
||||
transforms2 = [
|
||||
py_vision.RandomColorAdjust(),
|
||||
py_vision.FiveCrop(1),
|
||||
py_vision.Grayscale(),
|
||||
py.OneHotOp(1)
|
||||
]
|
||||
data1 = data1.map(operations=py.Compose(transforms1), input_columns=["image"])
|
||||
data1 = data1.map(operations=py.RandomApply(transforms2), input_columns=["image"])
|
||||
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
||||
|
||||
|
||||
def test_serdes_uniform_augment(remove_json_files=True):
|
||||
|
|
Loading…
Reference in New Issue