[MD] RandomResizedCrop - handle Interpolation and image input type
mismatch
This commit is contained in:
parent
666baf94ee
commit
b77e8b9f08
|
@ -1480,6 +1480,8 @@ class RandomAffine(TensorOperation, PyTensorOperation):
|
|||
self.fill_value = fill_value
|
||||
|
||||
def parse(self):
|
||||
if self.c_resample is None:
|
||||
raise TypeError("Current Interpolation is not supported with NumPy input.")
|
||||
return cde.RandomAffineOperation(self.degrees, self.translate, self.scale, self.shear,
|
||||
self.c_resample, self.fill_value)
|
||||
|
||||
|
@ -1493,7 +1495,8 @@ class RandomAffine(TensorOperation, PyTensorOperation):
|
|||
Returns:
|
||||
PIL Image, randomly affine transformed image.
|
||||
"""
|
||||
|
||||
if self.py_resample is None:
|
||||
raise TypeError("Current Interpolation is not supported with PIL input.")
|
||||
return util.random_affine(img,
|
||||
self.degrees,
|
||||
self.translate,
|
||||
|
@ -2397,6 +2400,8 @@ class RandomResizedCrop(TensorOperation, PyTensorOperation):
|
|||
- Inter.PILCUBIC, means interpolation method is bicubic interpolation like implemented in pillow, input
|
||||
should be in 3 channels format.
|
||||
|
||||
- Inter.ANTIALIAS, means the interpolation method is antialias interpolation.
|
||||
|
||||
max_attempts (int, optional): The maximum number of attempts to propose a valid
|
||||
crop_area (default=10). If exceeded, fall back to use center_crop instead.
|
||||
|
||||
|
@ -2446,6 +2451,8 @@ class RandomResizedCrop(TensorOperation, PyTensorOperation):
|
|||
self.max_attempts = max_attempts
|
||||
|
||||
def parse(self):
|
||||
if self.c_interpolation is None:
|
||||
raise TypeError("Current Interpolation is not supported with NumPy input.")
|
||||
return cde.RandomResizedCropOperation(self.size, self.scale, self.ratio, self.c_interpolation,
|
||||
self.max_attempts)
|
||||
|
||||
|
@ -2459,6 +2466,8 @@ class RandomResizedCrop(TensorOperation, PyTensorOperation):
|
|||
Returns:
|
||||
PIL Image, randomly cropped and resized image.
|
||||
"""
|
||||
if self.py_interpolation is None:
|
||||
raise TypeError("Current Interpolation is not supported with PIL input.")
|
||||
return util.random_resize_crop(img, self.size, self.scale, self.ratio,
|
||||
self.py_interpolation, self.max_attempts)
|
||||
|
||||
|
@ -2701,6 +2710,8 @@ class RandomRotation(TensorOperation, PyTensorOperation):
|
|||
self.fill_value = fill_value
|
||||
|
||||
def parse(self):
|
||||
if self.c_resample is None:
|
||||
raise TypeError("Current Interpolation is not supported with NumPy input.")
|
||||
return cde.RandomRotationOperation(self.degrees, self.c_resample, self.expand, self.c_center,
|
||||
self.fill_value)
|
||||
|
||||
|
@ -2714,6 +2725,8 @@ class RandomRotation(TensorOperation, PyTensorOperation):
|
|||
Returns:
|
||||
PIL Image, randomly rotated image.
|
||||
"""
|
||||
if self.py_resample is None:
|
||||
raise TypeError("Current Interpolation is not supported with PIL input.")
|
||||
return util.random_rotation(img, self.degrees, self.py_resample, self.expand, self.py_center, self.fill_value)
|
||||
|
||||
|
||||
|
@ -3011,6 +3024,8 @@ class Resize(TensorOperation, PyTensorOperation):
|
|||
self.random = False
|
||||
|
||||
def parse(self):
|
||||
if self.c_interpolation is None:
|
||||
raise TypeError("Current Interpolation is not supported with NumPy input.")
|
||||
return cde.ResizeOperation(self.c_size, self.c_interpolation)
|
||||
|
||||
def execute_py(self, img):
|
||||
|
@ -3023,6 +3038,8 @@ class Resize(TensorOperation, PyTensorOperation):
|
|||
Returns:
|
||||
PIL Image, resized image.
|
||||
"""
|
||||
if self.py_interpolation is None:
|
||||
raise TypeError("Current Interpolation is not supported with PIL input.")
|
||||
return util.resize(img, self.py_size, self.py_interpolation)
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@ Testing RandomCropAndResize op in DE
|
|||
"""
|
||||
import numpy as np
|
||||
import cv2
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset.transforms as ops
|
||||
import mindspore.dataset.vision as vision
|
||||
|
@ -33,13 +35,13 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
|||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_random_crop_and_resize_callable():
|
||||
def test_random_crop_and_resize_callable_numpy():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize is callable
|
||||
Description: Test RandomCropAndResize is callable with NumPy input
|
||||
Expectation: Passes the shape equality test
|
||||
"""
|
||||
logger.info("test_random_crop_and_resize_callable")
|
||||
logger.info("test_random_crop_and_resize_callable_numpy")
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
|
@ -48,11 +50,31 @@ def test_random_crop_and_resize_callable():
|
|||
assert img.shape == (2268, 4032, 3)
|
||||
|
||||
# test one tensor
|
||||
random_crop_and_resize_op1 = vision.RandomResizedCrop((256, 512), (2, 2), (1, 3))
|
||||
random_crop_and_resize_op1 = vision.RandomResizedCrop(size=(256, 512), scale=(2, 2), ratio=(1, 3),
|
||||
interpolation=Inter.AREA)
|
||||
img1 = random_crop_and_resize_op1(img)
|
||||
assert img1.shape == (256, 512, 3)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_callable_pil():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize is callable with PIL input
|
||||
Expectation: Passes the shape equality test
|
||||
"""
|
||||
logger.info("test_random_crop_and_resize_callable_pil")
|
||||
|
||||
img = Image.open("../data/dataset/apple.jpg").convert("RGB")
|
||||
|
||||
assert img.size == (4032, 2268)
|
||||
|
||||
# test one tensor
|
||||
random_crop_and_resize_op1 = vision.RandomResizedCrop(size=(256, 512), scale=(2, 2), ratio=(1, 3),
|
||||
interpolation=Inter.ANTIALIAS)
|
||||
img1 = random_crop_and_resize_op1(img)
|
||||
assert img1.size == (512, 256)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_op_c(plot=False):
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
|
@ -135,13 +157,14 @@ def test_random_crop_and_resize_op_py(plot=False):
|
|||
if plot:
|
||||
visualize_list(original_images, crop_and_resize_images)
|
||||
|
||||
def test_random_crop_and_resize_op_py_ANTIALIAS():
|
||||
|
||||
def test_random_crop_and_resize_op_py_antialias():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize with Python transformations where image interpolation mode is Inter.ANTIALIAS
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
logger.info("test_random_crop_and_resize_op_py_ANTIALIAS")
|
||||
logger.info("test_random_crop_and_resize_op_py_antialias")
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
# With these inputs we expect the code to crop the whole image
|
||||
|
@ -157,6 +180,7 @@ def test_random_crop_and_resize_op_py_ANTIALIAS():
|
|||
num_iter += 1
|
||||
logger.info("use RandomResizedCrop by Inter.ANTIALIAS process {} images.".format(num_iter))
|
||||
|
||||
|
||||
def test_random_crop_and_resize_01():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
|
@ -424,6 +448,7 @@ def test_random_crop_and_resize_06():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Argument scale[1] with value 2 is not of type [<class 'float'>, <class 'int'>]" in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_07():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
|
@ -452,11 +477,40 @@ def test_random_crop_and_resize_07():
|
|||
num_iter += 1
|
||||
|
||||
|
||||
def test_random_crop_and_resize_eager_error_01():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize in eager mode with PIL input and C++ only interpolation AREA
|
||||
Expectation: Correct error is thrown as expected
|
||||
"""
|
||||
img = Image.open("../data/dataset/apple.jpg").convert("RGB")
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
random_crop_and_resize_op = vision.RandomResizedCrop(size=(100, 200), scale=[1.0, 2.0],
|
||||
interpolation=Inter.AREA)
|
||||
_ = random_crop_and_resize_op(img)
|
||||
assert "Current Interpolation is not supported with PIL input." in str(error_info.value)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_eager_error_02():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize in eager mode with NumPy input and Python only interpolation ANTIALIAS
|
||||
Expectation: Correct error is thrown as expected
|
||||
"""
|
||||
img = np.random.randint(0, 1, (100, 100, 3)).astype(np.uint8)
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
random_crop_and_resize_op = vision.RandomResizedCrop(size=(100, 200), scale=[1.0, 2.0],
|
||||
interpolation=Inter.ANTIALIAS)
|
||||
_ = random_crop_and_resize_op(img)
|
||||
assert "Current Interpolation is not supported with NumPy input." in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_crop_and_resize_callable()
|
||||
test_random_crop_and_resize_callable_numpy()
|
||||
test_random_crop_and_resize_callable_pil()
|
||||
test_random_crop_and_resize_op_c(True)
|
||||
test_random_crop_and_resize_op_py(True)
|
||||
test_random_crop_and_resize_op_py_ANTIALIAS()
|
||||
test_random_crop_and_resize_op_py_antialias()
|
||||
test_random_crop_and_resize_01()
|
||||
test_random_crop_and_resize_02()
|
||||
test_random_crop_and_resize_03()
|
||||
|
@ -467,3 +521,5 @@ if __name__ == "__main__":
|
|||
test_random_crop_and_resize_06()
|
||||
test_random_crop_and_resize_comp(True)
|
||||
test_random_crop_and_resize_07()
|
||||
test_random_crop_and_resize_eager_error_01()
|
||||
test_random_crop_and_resize_eager_error_02()
|
||||
|
|
|
@ -18,6 +18,7 @@ Testing ToTensor op in DE
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
|
||||
|
@ -178,7 +179,7 @@ def test_to_tensor_float64_eager():
|
|||
Expectation: Test runs successfully and results are verified
|
||||
"""
|
||||
|
||||
def test_config(my_output_type, output_dtype):
|
||||
def test_config(my_output_type, output_dtype, result_output_type=None):
|
||||
image = np.random.randn(128, 128, 3).astype(np.float64)
|
||||
op = vision.ToTensor(output_type=my_output_type)
|
||||
out = op(image)
|
||||
|
@ -186,7 +187,10 @@ def test_to_tensor_float64_eager():
|
|||
assert out.dtype == output_dtype
|
||||
|
||||
image = image / 255
|
||||
image = image.astype(my_output_type)
|
||||
if result_output_type is None:
|
||||
image = image.astype(my_output_type)
|
||||
else:
|
||||
image = image.astype(result_output_type)
|
||||
image = np.transpose(image, (2, 0, 1))
|
||||
|
||||
np.testing.assert_almost_equal(out, image, 5)
|
||||
|
@ -204,6 +208,19 @@ def test_to_tensor_float64_eager():
|
|||
test_config(np.uint64, "uint64")
|
||||
test_config(np.bool, "bool")
|
||||
|
||||
test_config(mstype.float16, "float16", np.float16)
|
||||
test_config(mstype.float32, "float32", np.float32)
|
||||
test_config(mstype.float64, "float64", np.float64)
|
||||
test_config(mstype.int8, "int8", np.int8)
|
||||
test_config(mstype.int16, "int16", np.int16)
|
||||
test_config(mstype.int32, "int32", np.int32)
|
||||
test_config(mstype.int64, "int64", np.int64)
|
||||
test_config(mstype.uint8, "uint8", np.uint8)
|
||||
test_config(mstype.uint16, "uint16", np.uint16)
|
||||
test_config(mstype.uint32, "uint32", np.uint32)
|
||||
test_config(mstype.uint64, "uint64", np.uint64)
|
||||
test_config(mstype.bool_, "bool", np.bool)
|
||||
|
||||
|
||||
def test_to_tensor_int32_eager():
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing RandomCropAndResize op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.dataset.vision.utils import Inter
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_random_crop_and_resize_callable_numpy():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize C++ op is callable with NumPy input
|
||||
Expectation: Passes the shape equality test
|
||||
"""
|
||||
logger.info("test_random_crop_and_resize_callable_numpy")
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
decode_op = c_vision.Decode()
|
||||
img = decode_op(img)
|
||||
assert img.shape == (2268, 4032, 3)
|
||||
|
||||
# test one tensor
|
||||
random_crop_and_resize_op1 = c_vision.RandomResizedCrop(size=(256, 512), scale=(2, 2), ratio=(1, 3),
|
||||
interpolation=Inter.AREA)
|
||||
img1 = random_crop_and_resize_op1(img)
|
||||
assert img1.shape == (256, 512, 3)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_callable_pil():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize Python op is callable with PIL input
|
||||
Expectation: Passes the shape equality test
|
||||
"""
|
||||
logger.info("test_random_crop_and_resize_callable_pil")
|
||||
|
||||
img = Image.open("../data/dataset/apple.jpg").convert("RGB")
|
||||
|
||||
assert img.size == (4032, 2268)
|
||||
|
||||
# test one tensor
|
||||
random_crop_and_resize_op1 = py_vision.RandomResizedCrop(size=(256, 512), scale=(2, 2), ratio=(1, 3),
|
||||
interpolation=Inter.ANTIALIAS)
|
||||
img1 = random_crop_and_resize_op1(img)
|
||||
assert img1.size == (512, 256)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_op_py_antialias():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize with Python transformations where image interpolation mode
|
||||
is Inter.ANTIALIAS
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
logger.info("test_random_crop_and_resize_op_py_antialias")
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
# With these inputs we expect the code to crop the whole image
|
||||
transforms1 = [
|
||||
py_vision.Decode(),
|
||||
py_vision.RandomResizedCrop((256, 512), (2, 2), (1, 3), Inter.ANTIALIAS),
|
||||
py_vision.ToTensor()
|
||||
]
|
||||
data1 = data1.map(operations=transforms1, input_columns=["image"])
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
logger.info("use RandomResizedCrop by Inter.ANTIALIAS process {} images.".format(num_iter))
|
||||
|
||||
|
||||
def test_random_crop_and_resize_eager_error_02():
|
||||
"""
|
||||
Feature: RandomCropAndResize op
|
||||
Description: Test RandomCropAndResize Python op in eager mode with NumPy input and
|
||||
Python only interpolation ANTIALIAS
|
||||
Expectation: Correct error is thrown as expected
|
||||
"""
|
||||
img = np.random.randint(0, 1, (100, 100, 3)).astype(np.uint8)
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
random_crop_and_resize_op = py_vision.RandomResizedCrop(size=(100, 200), scale=[1.0, 2.0],
|
||||
interpolation=Inter.ANTIALIAS)
|
||||
_ = random_crop_and_resize_op(img)
|
||||
assert "img should be PIL image. Got <class 'numpy.ndarray'>." in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_crop_and_resize_callable_numpy()
|
||||
test_random_crop_and_resize_callable_pil()
|
||||
test_random_crop_and_resize_op_py_antialias()
|
||||
test_random_crop_and_resize_eager_error_02()
|
Loading…
Reference in New Issue