forked from mindspore-Ecosystem/mindspore
!9076 [MD] Support vision c_transform python eager: Resize, Rescale, Normalize, HWC2CHW, Pad
From: @luoyang42 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
f0671ce513
|
@ -6,6 +6,7 @@ if (ENABLE_PYTHON)
|
|||
python/pybind_conversion.cc
|
||||
python/bindings/dataset/include/datasets_bindings.cc
|
||||
python/bindings/dataset/include/iterator_bindings.cc
|
||||
python/bindings/dataset/include/execute_binding.cc
|
||||
python/bindings/dataset/include/schema_bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/include/execute.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
|
||||
(void)py::class_<Execute, std::shared_ptr<Execute>>(*m, "Execute")
|
||||
.def(py::init([](py::object operation) {
|
||||
auto execute = std::make_shared<Execute>(toTensorOperation(operation));
|
||||
return execute;
|
||||
}))
|
||||
.def("__call__", [](Execute &self, std::shared_ptr<Tensor> in) {
|
||||
std::shared_ptr<Tensor> out = self(in);
|
||||
return out;
|
||||
});
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -104,6 +104,18 @@ std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<p
|
|||
return vector;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOperation> toTensorOperation(py::handle operation) {
|
||||
std::shared_ptr<TensorOperation> op;
|
||||
std::shared_ptr<TensorOp> tensor_op;
|
||||
if (py::isinstance<TensorOp>(operation)) {
|
||||
tensor_op = operation.cast<std::shared_ptr<TensorOp>>();
|
||||
} else {
|
||||
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op."); }());
|
||||
}
|
||||
op = std::make_shared<transforms::PreBuiltOperation>(tensor_op);
|
||||
return op;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets) {
|
||||
std::vector<std::shared_ptr<DatasetNode>> vector;
|
||||
vector.push_back(self);
|
||||
|
|
|
@ -59,6 +59,8 @@ std::vector<std::pair<int, int>> toPairVector(const py::list list);
|
|||
|
||||
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations);
|
||||
|
||||
std::shared_ptr<TensorOperation> toTensorOperation(py::handle operation);
|
||||
|
||||
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets);
|
||||
|
||||
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false);
|
||||
|
|
|
@ -2218,7 +2218,7 @@ class MapDataset(Dataset):
|
|||
# wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations
|
||||
new_ops, start_ind, end_ind = [], 0, 0
|
||||
for i, op in enumerate(operations):
|
||||
if not callable(op):
|
||||
if str(op).find("c_transform") >= 0:
|
||||
# reset counts
|
||||
if start_ind != end_ind:
|
||||
new_ops.append(py_transforms.Compose(operations[start_ind:end_ind]))
|
||||
|
|
|
@ -36,11 +36,11 @@ def compose(transforms, *args):
|
|||
Compose a list of transforms and apply on the image.
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray): An image in Numpy ndarray.
|
||||
img (numpy.ndarray): An image in NumPy ndarray.
|
||||
transforms (list): A list of transform Class objects to be composed.
|
||||
|
||||
Returns:
|
||||
img (numpy.ndarray), An augmented image in Numpy ndarray.
|
||||
img (numpy.ndarray), An augmented image in NumPy ndarray.
|
||||
"""
|
||||
if all_numpy(args):
|
||||
for transform in transforms:
|
||||
|
@ -49,8 +49,8 @@ def compose(transforms, *args):
|
|||
|
||||
if all_numpy(args):
|
||||
return args
|
||||
raise TypeError('args should be Numpy ndarray. Got {}. Append ToTensor() to transforms.'.format(type(args)))
|
||||
raise TypeError('args should be Numpy ndarray. Got {}.'.format(type(args)))
|
||||
raise TypeError('args should be NumPy ndarray. Got {}. Append ToTensor() to transforms.'.format(type(args)))
|
||||
raise TypeError('args should be NumPy ndarray. Got {}.'.format(type(args)))
|
||||
|
||||
|
||||
def one_hot_encoding(label, num_classes, epsilon):
|
||||
|
|
|
@ -44,6 +44,8 @@ Examples:
|
|||
>>> data1 = data1.map(operations=onehot_op, input_columns="label")
|
||||
"""
|
||||
import numbers
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .utils import Inter, Border, ImageBatchFormat
|
||||
|
@ -280,6 +282,22 @@ class Normalize(cde.NormalizeOp):
|
|||
self.std = std
|
||||
super().__init__(*mean, *std)
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (NumPy or PIL image): Image array to be normalized.
|
||||
|
||||
Returns:
|
||||
img (NumPy), Normalized Image array.
|
||||
"""
|
||||
if not isinstance(img, (np.ndarray, Image.Image)):
|
||||
raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img)))
|
||||
normalize = cde.Execute(cde.NormalizeOp(*self.mean, *self.std))
|
||||
img = normalize(cde.Tensor(np.asarray(img)))
|
||||
return img.as_array()
|
||||
|
||||
|
||||
class RandomAffine(cde.RandomAffineOp):
|
||||
"""
|
||||
|
@ -676,13 +694,29 @@ class Resize(cde.ResizeOp):
|
|||
|
||||
@check_resize_interpolation
|
||||
def __init__(self, size, interpolation=Inter.LINEAR):
|
||||
if isinstance(size, int):
|
||||
size = (size, 0)
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
interpoltn = DE_C_INTER_MODE[interpolation]
|
||||
if isinstance(size, int):
|
||||
size = (size, 0)
|
||||
super().__init__(*size, interpoltn)
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (NumPy or PIL image): Image to be resized.
|
||||
|
||||
Returns:
|
||||
img (NumPy), Resized image.
|
||||
"""
|
||||
if not isinstance(img, (np.ndarray, Image.Image)):
|
||||
raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img)))
|
||||
resize = cde.Execute(cde.ResizeOp(*self.size, DE_C_INTER_MODE[self.interpolation]))
|
||||
img = resize(cde.Tensor(np.asarray(img)))
|
||||
return img.as_array()
|
||||
|
||||
|
||||
class ResizeWithBBox(cde.ResizeWithBBoxOp):
|
||||
"""
|
||||
|
@ -995,6 +1029,22 @@ class Rescale(cde.RescaleOp):
|
|||
self.shift = shift
|
||||
super().__init__(rescale, shift)
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (NumPy or PIL image): Image to be rescaled.
|
||||
|
||||
Returns:
|
||||
img (NumPy), Rescaled image.
|
||||
"""
|
||||
if not isinstance(img, (np.ndarray, Image.Image)):
|
||||
raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img)))
|
||||
rescale = cde.Execute(cde.RescaleOp(self.rescale, self.shift))
|
||||
img = rescale(cde.Tensor(np.asarray(img)))
|
||||
return img.as_array()
|
||||
|
||||
|
||||
class RandomResize(cde.RandomResizeOp):
|
||||
"""
|
||||
|
@ -1067,6 +1117,22 @@ class HWC2CHW(cde.ChannelSwapOp):
|
|||
>>> data1 = data1.map(operations=transforms_list, input_columns=["image"])
|
||||
"""
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (NumPy or PIL image): Image array, of shape (H, W, C), to have channels swapped.
|
||||
|
||||
Returns:
|
||||
img (NumPy), Image array, of shape (C, H, W), with channels swapped.
|
||||
"""
|
||||
if not isinstance(img, (np.ndarray, Image.Image)):
|
||||
raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img)))
|
||||
hwc2chw = cde.Execute(cde.ChannelSwapOp())
|
||||
img = hwc2chw(cde.Tensor(np.asarray(img)))
|
||||
return img.as_array()
|
||||
|
||||
|
||||
class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp):
|
||||
"""
|
||||
|
@ -1156,13 +1222,28 @@ class Pad(cde.PadOp):
|
|||
padding = parse_padding(padding)
|
||||
if isinstance(fill_value, int):
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
padding_mode = DE_C_BORDER_TYPE[padding_mode]
|
||||
|
||||
self.padding = padding
|
||||
self.fill_value = fill_value
|
||||
self.padding_mode = padding_mode
|
||||
padding_mode = DE_C_BORDER_TYPE[padding_mode]
|
||||
super().__init__(*padding, padding_mode, *fill_value)
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (NumPy or PIL image): Image to be padded.
|
||||
|
||||
Returns:
|
||||
img (NumPy), Padded image.
|
||||
"""
|
||||
if not isinstance(img, (np.ndarray, Image.Image)):
|
||||
raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img)))
|
||||
pad = cde.Execute(cde.PadOp(*self.padding, DE_C_BORDER_TYPE[self.padding_mode], *self.fill_value))
|
||||
img = pad(cde.Tensor(np.asarray(img)))
|
||||
return img.as_array()
|
||||
|
||||
|
||||
class UniformAugment(cde.UniformAugOp):
|
||||
"""
|
||||
|
|
|
@ -235,15 +235,15 @@ def test_py_transforms_with_c_vision():
|
|||
return res
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
test_config(py_transforms.RandomApply([c_vision.Resize(200)]))
|
||||
test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)]))
|
||||
assert "transforms[0] is not callable." in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
test_config(py_transforms.RandomChoice([c_vision.Resize(200)]))
|
||||
test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)]))
|
||||
assert "transforms[0] is not callable." in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
test_config(py_transforms.RandomOrder([np.array, c_vision.Resize(200)]))
|
||||
test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)]))
|
||||
assert "transforms[1] is not callable." in str(error_info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
|
||||
def test_eager_resize():
|
||||
img = cv2.imread("../data/dataset/apple.jpg")
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
img = C.Resize(size=(32, 32))(img)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
assert img.shape == (32, 32, 3)
|
||||
|
||||
def test_eager_rescale():
|
||||
img = cv2.imread("../data/dataset/apple.jpg")
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
pixel = img[0][0][0]
|
||||
|
||||
rescale_factor = 0.5
|
||||
img = C.Rescale(rescale=rescale_factor, shift=0)(img)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
pixel_rescaled = img[0][0][0]
|
||||
|
||||
assert pixel*rescale_factor == pixel_rescaled
|
||||
|
||||
def test_eager_normalize():
|
||||
img = Image.open("../data/dataset/apple.jpg").convert("RGB")
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size))
|
||||
pixel = img.getpixel((0, 0))[0]
|
||||
|
||||
mean_vec = [100, 100, 100]
|
||||
std_vec = [2, 2, 2]
|
||||
img = C.Normalize(mean=mean_vec, std=std_vec)(img)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
pixel_normalized = img[0][0][0]
|
||||
|
||||
assert (pixel - mean_vec[0]) / std_vec[0] == pixel_normalized
|
||||
|
||||
def test_eager_HWC2CHW():
|
||||
img = cv2.imread("../data/dataset/apple.jpg")
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
channel = img.shape
|
||||
|
||||
img = C.HWC2CHW()(img)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
channel_swaped = img.shape
|
||||
|
||||
assert channel == (channel_swaped[1], channel_swaped[2], channel_swaped[0])
|
||||
|
||||
def test_eager_pad():
|
||||
img = Image.open("../data/dataset/apple.jpg").convert("RGB")
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size))
|
||||
|
||||
img = C.Resize(size=(32, 32))(img)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size))
|
||||
size = img.shape
|
||||
|
||||
pad = 4
|
||||
img = C.Pad(padding=pad)(img)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size))
|
||||
size_padded = img.shape
|
||||
|
||||
assert size_padded == (size[0] + 2 * pad, size[1] + 2 * pad, size[2])
|
||||
|
||||
def test_eager_exceptions():
|
||||
try:
|
||||
img = cv2.imread("../data/dataset/apple.jpg")
|
||||
img = C.Resize(size=(-32, 32))(img)
|
||||
assert False
|
||||
except ValueError as e:
|
||||
assert "not within the required interval" in str(e)
|
||||
|
||||
try:
|
||||
img = "../data/dataset/apple.jpg"
|
||||
img = C.Pad(padding=4)(img)
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Input should be NumPy or PIL image" in str(e)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_eager_resize()
|
||||
test_eager_rescale()
|
||||
test_eager_normalize()
|
||||
test_eager_HWC2CHW()
|
||||
test_eager_pad()
|
||||
test_eager_exceptions()
|
||||
|
Loading…
Reference in New Issue