[MD] Unified UniformAugment op: allow Pyfunc in transforms list
This commit is contained in:
parent
4c9252c21b
commit
a61c709d00
|
@ -40,12 +40,12 @@ std::string UniformAugOperation::Name() const { return kUniformAugOperation; }
|
||||||
|
|
||||||
Status UniformAugOperation::ValidateParams() {
|
Status UniformAugOperation::ValidateParams() {
|
||||||
// transforms
|
// transforms
|
||||||
RETURN_IF_NOT_OK(ValidateVectorTransforms("UniformAug", transforms_));
|
RETURN_IF_NOT_OK(ValidateVectorTransforms("UniformAugment", transforms_));
|
||||||
// num_ops
|
// num_ops
|
||||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("UniformAug", "num_ops", num_ops_));
|
RETURN_IF_NOT_OK(ValidateIntScalarPositive("UniformAugment", "num_ops", num_ops_));
|
||||||
if (num_ops_ > transforms_.size()) {
|
if (num_ops_ > transforms_.size()) {
|
||||||
std::string err_msg =
|
std::string err_msg =
|
||||||
"UniformAug: num_ops must be less than or equal to transforms size, but got: " + std::to_string(num_ops_);
|
"UniformAugment: num_ops must be less than or equal to transforms size, but got: " + std::to_string(num_ops_);
|
||||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -33,7 +33,7 @@ namespace dataset {
|
||||||
|
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
constexpr char kUniformAugOperation[] = "UniformAug";
|
constexpr char kUniformAugOperation[] = "UniformAugment";
|
||||||
|
|
||||||
class UniformAugOperation : public TensorOperation {
|
class UniformAugOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -118,7 +118,7 @@ constexpr char kSlicePatchesOp[] = "SlicePatchesOp";
|
||||||
constexpr char kSolarizeOp[] = "SolarizeOp";
|
constexpr char kSolarizeOp[] = "SolarizeOp";
|
||||||
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
|
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
|
||||||
constexpr char kToTensorOp[] = "ToTensorOp";
|
constexpr char kToTensorOp[] = "ToTensorOp";
|
||||||
constexpr char kUniformAugOp[] = "UniformAugOp";
|
constexpr char kUniformAugOp[] = "UniformAugmentOp";
|
||||||
constexpr char kVerticalFlipOp[] = "VerticalFlipOp";
|
constexpr char kVerticalFlipOp[] = "VerticalFlipOp";
|
||||||
|
|
||||||
// video
|
// video
|
||||||
|
|
|
@ -69,7 +69,7 @@ from .validators import check_adjust_gamma, check_alpha, check_auto_augment, che
|
||||||
check_random_erasing, check_random_perspective, check_random_resize_crop, check_random_rotation, \
|
check_random_erasing, check_random_perspective, check_random_resize_crop, check_random_rotation, \
|
||||||
check_random_select_subpolicy_op, check_random_solarize, check_range, check_rescale, check_resize, \
|
check_random_select_subpolicy_op, check_random_solarize, check_range, check_rescale, check_resize, \
|
||||||
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_ten_crop, \
|
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_ten_crop, \
|
||||||
check_uniform_augment_cpp, check_to_tensor, FLOAT_MAX_INTEGER
|
check_uniform_augment, check_to_tensor, FLOAT_MAX_INTEGER
|
||||||
from ..core.datatypes import mstype_to_detype, nptype_to_detype
|
from ..core.datatypes import mstype_to_detype, nptype_to_detype
|
||||||
from ..transforms.py_transforms_util import Implementation
|
from ..transforms.py_transforms_util import Implementation
|
||||||
from ..transforms.transforms import CompoundOperation, PyTensorOperation, TensorOperation, TypeCast
|
from ..transforms.transforms import CompoundOperation, PyTensorOperation, TensorOperation, TypeCast
|
||||||
|
@ -3541,7 +3541,7 @@ class UniformAugment(CompoundOperation):
|
||||||
... input_columns="image")
|
... input_columns="image")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_uniform_augment_cpp
|
@check_uniform_augment
|
||||||
def __init__(self, transforms, num_ops=2):
|
def __init__(self, transforms, num_ops=2):
|
||||||
super().__init__(transforms)
|
super().__init__(transforms)
|
||||||
self.num_ops = num_ops
|
self.num_ops = num_ops
|
||||||
|
|
|
@ -24,7 +24,8 @@ from mindspore.dataset.core.validator_helpers import check_value, check_uint8, F
|
||||||
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
|
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
|
||||||
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
|
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
|
||||||
check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32, check_non_negative_int32, \
|
check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32, check_non_negative_int32, \
|
||||||
check_pos_int32, deprecator_factory
|
check_pos_int32, check_tensor_op, deprecator_factory
|
||||||
|
from mindspore.dataset.transforms.validators import check_transform_op_type
|
||||||
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy
|
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy
|
||||||
|
|
||||||
|
|
||||||
|
@ -932,6 +933,30 @@ def check_uniform_augment_cpp(method):
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_uniform_augment(method):
|
||||||
|
"""Wrapper method to check the parameters of UniformAugment Unified op."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
type_check(num_ops, (int,), "num_ops")
|
||||||
|
check_positive(num_ops, "num_ops")
|
||||||
|
|
||||||
|
if num_ops > len(transforms):
|
||||||
|
raise ValueError("num_ops is greater than transforms list size.")
|
||||||
|
|
||||||
|
type_check(transforms, (list, tuple,), "transforms list")
|
||||||
|
if not transforms:
|
||||||
|
raise ValueError("transforms list can not be empty.")
|
||||||
|
for ind, op in enumerate(transforms):
|
||||||
|
check_tensor_op(op, "transforms[{0}]".format(ind))
|
||||||
|
check_transform_op_type(ind, op)
|
||||||
|
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_bounding_box_augment_cpp(method):
|
def check_bounding_box_augment_cpp(method):
|
||||||
"""Wrapper method to check the parameters of BoundingBoxAugment C++ op."""
|
"""Wrapper method to check the parameters of BoundingBoxAugment C++ op."""
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
Testing UniformAugment in DE
|
Testing UniformAugment in DE
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.transforms
|
import mindspore.dataset.transforms
|
||||||
|
@ -29,7 +30,7 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||||
def test_uniform_augment_callable(num_ops=2):
|
def test_uniform_augment_callable(num_ops=2):
|
||||||
"""
|
"""
|
||||||
Feature: UniformAugment
|
Feature: UniformAugment
|
||||||
Description: Test UniformAugment under normal test case
|
Description: Test UniformAugment under execute mode
|
||||||
Expectation: Output's shape is the same as expected output's shape
|
Expectation: Output's shape is the same as expected output's shape
|
||||||
"""
|
"""
|
||||||
logger.info("test_uniform_augment_callable")
|
logger.info("test_uniform_augment_callable")
|
||||||
|
@ -40,11 +41,75 @@ def test_uniform_augment_callable(num_ops=2):
|
||||||
img = decode_op(img)
|
img = decode_op(img)
|
||||||
assert img.shape == (2268, 4032, 3)
|
assert img.shape == (2268, 4032, 3)
|
||||||
|
|
||||||
transforms_ua = [vision.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32]),
|
transforms_ua = [vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||||
vision.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32])]
|
vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||||
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
img = uni_aug(img)
|
img = uni_aug(img)
|
||||||
assert img.shape == (2268, 4032, 3) or img.shape == (400, 400, 3)
|
assert img.shape == (2268, 4032, 3) or img.shape == (200, 400, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_augment_callable_pil(num_ops=2):
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test UniformAugment under execute mode, with PIL input.
|
||||||
|
Expectation: Output's shape is the same as expected output's shape
|
||||||
|
"""
|
||||||
|
logger.info("test_uniform_augment_callable")
|
||||||
|
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||||
|
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||||
|
|
||||||
|
decode_op = vision.Decode(to_pil=True)
|
||||||
|
img = decode_op(img)
|
||||||
|
assert img.size == (4032, 2268)
|
||||||
|
|
||||||
|
transforms_ua = [vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||||
|
vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||||
|
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
img = uni_aug(img)
|
||||||
|
assert img.size == (4032, 2268) or img.size == (400, 200)
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_augment_callable_pil_pyfunc(num_ops=3):
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test UniformAugment under execute mode, with PIL input. Include pyfunc in transforms list.
|
||||||
|
Expectation: Output's shape is the same as expected output's shape
|
||||||
|
"""
|
||||||
|
logger.info("test_uniform_augment_callable")
|
||||||
|
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||||
|
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||||
|
|
||||||
|
decode_op = vision.Decode(to_pil=True)
|
||||||
|
img = decode_op(img)
|
||||||
|
assert img.size == (4032, 2268)
|
||||||
|
|
||||||
|
transforms_ua = [vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||||
|
lambda x: x,
|
||||||
|
vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||||
|
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
img = uni_aug(img)
|
||||||
|
assert img.size == (4032, 2268) or img.size == (400, 200)
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_augment_callable_tuple(num_ops=2):
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test UniformAugment under execute mode. Use tuple for transforms list argument.
|
||||||
|
Expectation: Output's shape is the same as expected output's shape
|
||||||
|
"""
|
||||||
|
logger.info("test_uniform_augment_callable")
|
||||||
|
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||||
|
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||||
|
|
||||||
|
decode_op = vision.Decode()
|
||||||
|
img = decode_op(img)
|
||||||
|
assert img.shape == (2268, 4032, 3)
|
||||||
|
|
||||||
|
transforms_ua = (vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||||
|
vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]))
|
||||||
|
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
img = uni_aug(img)
|
||||||
|
assert img.shape == (2268, 4032, 3) or img.shape == (200, 400, 3)
|
||||||
|
|
||||||
|
|
||||||
def test_uniform_augment(plot=False, num_ops=2):
|
def test_uniform_augment(plot=False, num_ops=2):
|
||||||
|
@ -74,7 +139,7 @@ def test_uniform_augment(plot=False, num_ops=2):
|
||||||
np.transpose(image.asnumpy(), (0, 2, 3, 1)),
|
np.transpose(image.asnumpy(), (0, 2, 3, 1)),
|
||||||
axis=0)
|
axis=0)
|
||||||
|
|
||||||
# UniformAugment Images
|
# UniformAugment Images
|
||||||
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
|
||||||
transform_list = [vision.RandomRotation(45),
|
transform_list = [vision.RandomRotation(45),
|
||||||
|
@ -188,12 +253,10 @@ def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
||||||
vision.RandomColorAdjust(),
|
vision.RandomColorAdjust(),
|
||||||
vision.RandomRotation(degrees=45)]
|
vision.RandomRotation(degrees=45)]
|
||||||
|
|
||||||
try:
|
with pytest.raises(ValueError) as error_info:
|
||||||
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(error_info)))
|
||||||
except Exception as e:
|
assert "num_ops is greater than transforms list size" in str(error_info)
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
|
||||||
assert "num_ops" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
||||||
|
@ -202,7 +265,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
||||||
Description: Test UniformAugment using invalid non-positive num_ops
|
Description: Test UniformAugment using invalid non-positive num_ops
|
||||||
Expectation: Exception is raised as expected
|
Expectation: Exception is raised as expected
|
||||||
"""
|
"""
|
||||||
logger.info("Test CPP UniformAugment invalid non-positive num_ops exception")
|
logger.info("Test UniformAugment invalid non-positive num_ops exception")
|
||||||
|
|
||||||
transforms_ua = [vision.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
transforms_ua = [vision.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||||
vision.RandomHorizontalFlip(),
|
vision.RandomHorizontalFlip(),
|
||||||
|
@ -210,12 +273,10 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
||||||
vision.RandomColorAdjust(),
|
vision.RandomColorAdjust(),
|
||||||
vision.RandomRotation(degrees=45)]
|
vision.RandomRotation(degrees=45)]
|
||||||
|
|
||||||
try:
|
with pytest.raises(ValueError) as error_info:
|
||||||
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(error_info)))
|
||||||
except Exception as e:
|
assert "Input num_ops must be greater than 0" in str(error_info)
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
|
||||||
assert "Input num_ops must be greater than 0" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
||||||
|
@ -224,7 +285,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
||||||
Description: Test UniformAugment using invalid float num_ops
|
Description: Test UniformAugment using invalid float num_ops
|
||||||
Expectation: Exception is raised as expected
|
Expectation: Exception is raised as expected
|
||||||
"""
|
"""
|
||||||
logger.info("Test CPP UniformAugment invalid float num_ops exception")
|
logger.info("Test UniformAugment invalid float num_ops exception")
|
||||||
|
|
||||||
transforms_ua = [vision.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
transforms_ua = [vision.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||||
vision.RandomHorizontalFlip(),
|
vision.RandomHorizontalFlip(),
|
||||||
|
@ -232,12 +293,10 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
||||||
vision.RandomColorAdjust(),
|
vision.RandomColorAdjust(),
|
||||||
vision.RandomRotation(degrees=45)]
|
vision.RandomRotation(degrees=45)]
|
||||||
|
|
||||||
try:
|
with pytest.raises(TypeError) as error_info:
|
||||||
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(error_info)))
|
||||||
except Exception as e:
|
assert "Argument num_ops with value 2.5 is not of type [<class 'int'>]" in str(error_info)
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
|
||||||
assert "Argument num_ops with value 2.5 is not of type [<class 'int'>]" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
||||||
|
@ -246,7 +305,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
||||||
Description: Test UniformAugment with greater crop size
|
Description: Test UniformAugment with greater crop size
|
||||||
Expectation: Exception is raised as expected
|
Expectation: Exception is raised as expected
|
||||||
"""
|
"""
|
||||||
logger.info("Test CPP UniformAugment with random_crop bad input")
|
logger.info("Test UniformAugment with random_crop bad input")
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||||
ds1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
ds1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||||
|
@ -262,16 +321,18 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
||||||
# apply DatasetOps
|
# apply DatasetOps
|
||||||
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
|
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
|
||||||
num_batches = 0
|
num_batches = 0
|
||||||
try:
|
with pytest.raises(RuntimeError) as error_info:
|
||||||
for _ in ds1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
for _ in ds1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
except Exception as e:
|
assert "Shape is incorrect. map operation: [UniformAugment] failed." in str(error_info)
|
||||||
assert "crop size" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_uniform_augment_callable(num_ops=2)
|
test_uniform_augment_callable()
|
||||||
test_uniform_augment(num_ops=1, plot=True)
|
test_uniform_augment_callable_pil()
|
||||||
|
test_uniform_augment_callable_pil_pyfunc()
|
||||||
|
test_uniform_augment_callable_tuple()
|
||||||
|
test_uniform_augment(num_ops=6, plot=True)
|
||||||
test_cpp_uniform_augment(num_ops=1, plot=True)
|
test_cpp_uniform_augment(num_ops=1, plot=True)
|
||||||
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
|
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
|
||||||
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
|
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
|
||||||
|
|
|
@ -15,15 +15,148 @@
|
||||||
"""
|
"""
|
||||||
Test UniformAugment op in Dataset
|
Test UniformAugment op in Dataset
|
||||||
"""
|
"""
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.py_transforms as PT
|
||||||
import mindspore.dataset.vision.c_transforms as C
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
import mindspore.dataset.vision.py_transforms as F
|
import mindspore.dataset.vision.py_transforms as F
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
from ..dataset.util import visualize_list, diff_mse
|
||||||
|
|
||||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpp_uniform_augment_callable(num_ops=2):
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test UniformAugment C++ op under under execute mode. Use list for transforms list argument.
|
||||||
|
Expectation: Output's shape is the same as expected output's shape
|
||||||
|
"""
|
||||||
|
logger.info("test_cpp_uniform_augment_callable")
|
||||||
|
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||||
|
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||||
|
|
||||||
|
decode_op = C.Decode()
|
||||||
|
img = decode_op(img)
|
||||||
|
assert img.shape == (2268, 4032, 3)
|
||||||
|
|
||||||
|
transforms_ua = [C.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||||
|
C.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||||
|
uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
img = uni_aug(img)
|
||||||
|
assert img.shape == (2268, 4032, 3) or img.shape == (200, 400, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpp_uniform_augment_callable_tuple(num_ops=2):
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test UniformAugment C++ op under under execute mode. Use tuple for transforms list argument.
|
||||||
|
Expectation: Output's shape is the same as expected output's shape
|
||||||
|
"""
|
||||||
|
logger.info("test_cpp_uniform_augment_callable")
|
||||||
|
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||||
|
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||||
|
|
||||||
|
decode_op = C.Decode()
|
||||||
|
img = decode_op(img)
|
||||||
|
assert img.shape == (2268, 4032, 3)
|
||||||
|
|
||||||
|
transforms_ua = (C.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||||
|
C.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]))
|
||||||
|
uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
img = uni_aug(img)
|
||||||
|
assert img.shape == (2268, 4032, 3) or img.shape == (200, 400, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_py_uniform_augment_callable(num_ops=2):
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test UniformAugment Python op under under execute mode. Use list for transforms list argument.
|
||||||
|
Expectation: Output's shape is the same as expected output's shape
|
||||||
|
"""
|
||||||
|
logger.info("test_cpp_uniform_augment_callable")
|
||||||
|
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||||
|
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||||
|
|
||||||
|
decode_op = F.Decode()
|
||||||
|
img = decode_op(img)
|
||||||
|
assert img.size == (4032, 2268)
|
||||||
|
|
||||||
|
transforms_ua = [F.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||||
|
F.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||||
|
uni_aug = F.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||||
|
img = uni_aug(img)
|
||||||
|
assert img.size == (4032, 2268) or img.size == (400, 200)
|
||||||
|
|
||||||
|
|
||||||
|
def test_py_uniform_augment_pyfunc(plot=False, num_ops=2):
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment Op
|
||||||
|
Description: Test Python op with valid Python function in transforms list. Include pyfunc in transforms list.
|
||||||
|
Expectation: Pipeline is successfully executed
|
||||||
|
"""
|
||||||
|
logger.info("Test UniformAugment")
|
||||||
|
|
||||||
|
# Original Images
|
||||||
|
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
|
||||||
|
transforms_original = PT.Compose([F.Decode(),
|
||||||
|
F.Resize((224, 224)),
|
||||||
|
F.ToTensor()])
|
||||||
|
|
||||||
|
ds_original = data_set.map(operations=transforms_original, input_columns="image")
|
||||||
|
|
||||||
|
ds_original = ds_original.batch(512)
|
||||||
|
|
||||||
|
for idx, (image, _) in enumerate(ds_original):
|
||||||
|
if idx == 0:
|
||||||
|
images_original = np.transpose(image.asnumpy(), (0, 2, 3, 1))
|
||||||
|
else:
|
||||||
|
images_original = np.append(images_original,
|
||||||
|
np.transpose(image.asnumpy(), (0, 2, 3, 1)),
|
||||||
|
axis=0)
|
||||||
|
|
||||||
|
# UniformAugment Images
|
||||||
|
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
|
||||||
|
transform_list = [F.RandomRotation(45),
|
||||||
|
F.RandomColor(),
|
||||||
|
F.RandomSharpness(),
|
||||||
|
F.Invert(),
|
||||||
|
lambda x: x,
|
||||||
|
F.AutoContrast(),
|
||||||
|
F.Equalize()]
|
||||||
|
|
||||||
|
transforms_ua = PT.Compose([F.Decode(),
|
||||||
|
F.Resize((224, 224)),
|
||||||
|
F.UniformAugment(transforms=transform_list,
|
||||||
|
num_ops=num_ops),
|
||||||
|
F.ToTensor()])
|
||||||
|
|
||||||
|
ds_ua = data_set.map(operations=transforms_ua, input_columns="image")
|
||||||
|
|
||||||
|
ds_ua = ds_ua.batch(512)
|
||||||
|
|
||||||
|
for idx, (image, _) in enumerate(ds_ua):
|
||||||
|
if idx == 0:
|
||||||
|
images_ua = np.transpose(image.asnumpy(), (0, 2, 3, 1))
|
||||||
|
else:
|
||||||
|
images_ua = np.append(images_ua,
|
||||||
|
np.transpose(image.asnumpy(), (0, 2, 3, 1)),
|
||||||
|
axis=0)
|
||||||
|
|
||||||
|
num_samples = images_original.shape[0]
|
||||||
|
mse = np.zeros(num_samples)
|
||||||
|
for i in range(num_samples):
|
||||||
|
mse[i] = diff_mse(images_ua[i], images_original[i])
|
||||||
|
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||||
|
|
||||||
|
if plot:
|
||||||
|
visualize_list(images_original, images_ua)
|
||||||
|
|
||||||
|
|
||||||
def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
||||||
"""
|
"""
|
||||||
Feature: UniformAugment Op
|
Feature: UniformAugment Op
|
||||||
|
@ -46,5 +179,81 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
||||||
assert "Type of Transforms[5] must be c_transform" in str(e.value)
|
assert "Type of Transforms[5] must be c_transform" in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpp_uniform_augment_exception_pyfunc():
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test C++ op with pyfunc in transforms list
|
||||||
|
Expectation: Exception is raised as expected
|
||||||
|
"""
|
||||||
|
pyfunc = lambda x: x
|
||||||
|
transforms_list = [C.RandomVerticalFlip(), pyfunc]
|
||||||
|
with pytest.raises(TypeError) as error_info:
|
||||||
|
_ = C.UniformAugment(transforms_list, 1)
|
||||||
|
error_msg = "Type of Transforms[1] must be c_transform, but got <class 'function'>"
|
||||||
|
assert error_msg in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_c_uniform_augment_exception_num_ops():
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test C++ op with more ops than number of ops in transforms list
|
||||||
|
Expectation: Exception is raised as expected
|
||||||
|
"""
|
||||||
|
transforms_list = [C.RandomVerticalFlip()]
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = C.UniformAugment(transforms_list, 3)
|
||||||
|
error_msg = "num_ops is greater than transforms list size"
|
||||||
|
assert error_msg in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_py_uniform_augment_exception_num_ops():
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test Python op with more ops than number of ops in transforms list
|
||||||
|
Expectation: Exception is raised as expected
|
||||||
|
"""
|
||||||
|
pyfunc = lambda x: x
|
||||||
|
transforms_list = [F.RandomVerticalFlip(), pyfunc]
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = F.UniformAugment(transforms_list, 9)
|
||||||
|
error_msg = "num_ops cannot be greater than the length of transforms list."
|
||||||
|
assert error_msg in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_py_uniform_augment_exception_tuple1():
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test Python op with transforms argument as tuple
|
||||||
|
Expectation: Exception is raised as expected
|
||||||
|
"""
|
||||||
|
transforms_list = (F.RandomVerticalFlip())
|
||||||
|
with pytest.raises(TypeError) as error_info:
|
||||||
|
_ = F.UniformAugment(transforms_list, 1)
|
||||||
|
error_msg = "not of type [<class 'list'>], but got"
|
||||||
|
assert error_msg in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_py_uniform_augment_exception_tuple2():
|
||||||
|
"""
|
||||||
|
Feature: UniformAugment
|
||||||
|
Description: Test Python op with transforms argument as tuple
|
||||||
|
Expectation: Exception is raised as expected
|
||||||
|
"""
|
||||||
|
transforms_list = (F.RandomHorizontalFlip(), F.RandomVerticalFlip())
|
||||||
|
with pytest.raises(TypeError) as error_info:
|
||||||
|
_ = F.UniformAugment(transforms_list, 1)
|
||||||
|
error_msg = "not of type [<class 'list'>], but got <class 'tuple'>."
|
||||||
|
assert error_msg in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
test_cpp_uniform_augment_callable()
|
||||||
|
test_cpp_uniform_augment_callable_tuple()
|
||||||
|
test_py_uniform_augment_callable()
|
||||||
|
test_py_uniform_augment_pyfunc(plot=True, num_ops=7)
|
||||||
test_cpp_uniform_augment_exception_pyops(num_ops=1)
|
test_cpp_uniform_augment_exception_pyops(num_ops=1)
|
||||||
|
test_cpp_uniform_augment_exception_pyfunc()
|
||||||
|
test_c_uniform_augment_exception_num_ops()
|
||||||
|
test_py_uniform_augment_exception_num_ops()
|
||||||
|
test_py_uniform_augment_exception_tuple1()
|
||||||
|
test_py_uniform_augment_exception_tuple2()
|
||||||
|
|
Loading…
Reference in New Issue