[MD] Unified UniformAugment op: allow Pyfunc in transforms list
This commit is contained in:
parent
4c9252c21b
commit
a61c709d00
|
@ -40,12 +40,12 @@ std::string UniformAugOperation::Name() const { return kUniformAugOperation; }
|
|||
|
||||
Status UniformAugOperation::ValidateParams() {
|
||||
// transforms
|
||||
RETURN_IF_NOT_OK(ValidateVectorTransforms("UniformAug", transforms_));
|
||||
RETURN_IF_NOT_OK(ValidateVectorTransforms("UniformAugment", transforms_));
|
||||
// num_ops
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("UniformAug", "num_ops", num_ops_));
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("UniformAugment", "num_ops", num_ops_));
|
||||
if (num_ops_ > transforms_.size()) {
|
||||
std::string err_msg =
|
||||
"UniformAug: num_ops must be less than or equal to transforms size, but got: " + std::to_string(num_ops_);
|
||||
"UniformAugment: num_ops must be less than or equal to transforms size, but got: " + std::to_string(num_ops_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace dataset {
|
|||
|
||||
namespace vision {
|
||||
|
||||
constexpr char kUniformAugOperation[] = "UniformAug";
|
||||
constexpr char kUniformAugOperation[] = "UniformAugment";
|
||||
|
||||
class UniformAugOperation : public TensorOperation {
|
||||
public:
|
||||
|
|
|
@ -118,7 +118,7 @@ constexpr char kSlicePatchesOp[] = "SlicePatchesOp";
|
|||
constexpr char kSolarizeOp[] = "SolarizeOp";
|
||||
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
|
||||
constexpr char kToTensorOp[] = "ToTensorOp";
|
||||
constexpr char kUniformAugOp[] = "UniformAugOp";
|
||||
constexpr char kUniformAugOp[] = "UniformAugmentOp";
|
||||
constexpr char kVerticalFlipOp[] = "VerticalFlipOp";
|
||||
|
||||
// video
|
||||
|
|
|
@ -69,7 +69,7 @@ from .validators import check_adjust_gamma, check_alpha, check_auto_augment, che
|
|||
check_random_erasing, check_random_perspective, check_random_resize_crop, check_random_rotation, \
|
||||
check_random_select_subpolicy_op, check_random_solarize, check_range, check_rescale, check_resize, \
|
||||
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_ten_crop, \
|
||||
check_uniform_augment_cpp, check_to_tensor, FLOAT_MAX_INTEGER
|
||||
check_uniform_augment, check_to_tensor, FLOAT_MAX_INTEGER
|
||||
from ..core.datatypes import mstype_to_detype, nptype_to_detype
|
||||
from ..transforms.py_transforms_util import Implementation
|
||||
from ..transforms.transforms import CompoundOperation, PyTensorOperation, TensorOperation, TypeCast
|
||||
|
@ -3541,7 +3541,7 @@ class UniformAugment(CompoundOperation):
|
|||
... input_columns="image")
|
||||
"""
|
||||
|
||||
@check_uniform_augment_cpp
|
||||
@check_uniform_augment
|
||||
def __init__(self, transforms, num_ops=2):
|
||||
super().__init__(transforms)
|
||||
self.num_ops = num_ops
|
||||
|
|
|
@ -24,7 +24,8 @@ from mindspore.dataset.core.validator_helpers import check_value, check_uint8, F
|
|||
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
|
||||
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
|
||||
check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32, check_non_negative_int32, \
|
||||
check_pos_int32, deprecator_factory
|
||||
check_pos_int32, check_tensor_op, deprecator_factory
|
||||
from mindspore.dataset.transforms.validators import check_transform_op_type
|
||||
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy
|
||||
|
||||
|
||||
|
@ -932,6 +933,30 @@ def check_uniform_augment_cpp(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_uniform_augment(method):
|
||||
"""Wrapper method to check the parameters of UniformAugment Unified op."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(num_ops, (int,), "num_ops")
|
||||
check_positive(num_ops, "num_ops")
|
||||
|
||||
if num_ops > len(transforms):
|
||||
raise ValueError("num_ops is greater than transforms list size.")
|
||||
|
||||
type_check(transforms, (list, tuple,), "transforms list")
|
||||
if not transforms:
|
||||
raise ValueError("transforms list can not be empty.")
|
||||
for ind, op in enumerate(transforms):
|
||||
check_tensor_op(op, "transforms[{0}]".format(ind))
|
||||
check_transform_op_type(ind, op)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_bounding_box_augment_cpp(method):
|
||||
"""Wrapper method to check the parameters of BoundingBoxAugment C++ op."""
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
Testing UniformAugment in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms
|
||||
|
@ -29,7 +30,7 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
|
|||
def test_uniform_augment_callable(num_ops=2):
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test UniformAugment under normal test case
|
||||
Description: Test UniformAugment under execute mode
|
||||
Expectation: Output's shape is the same as expected output's shape
|
||||
"""
|
||||
logger.info("test_uniform_augment_callable")
|
||||
|
@ -40,11 +41,75 @@ def test_uniform_augment_callable(num_ops=2):
|
|||
img = decode_op(img)
|
||||
assert img.shape == (2268, 4032, 3)
|
||||
|
||||
transforms_ua = [vision.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32]),
|
||||
vision.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32])]
|
||||
transforms_ua = [vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||
vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
img = uni_aug(img)
|
||||
assert img.shape == (2268, 4032, 3) or img.shape == (400, 400, 3)
|
||||
assert img.shape == (2268, 4032, 3) or img.shape == (200, 400, 3)
|
||||
|
||||
|
||||
def test_uniform_augment_callable_pil(num_ops=2):
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test UniformAugment under execute mode, with PIL input.
|
||||
Expectation: Output's shape is the same as expected output's shape
|
||||
"""
|
||||
logger.info("test_uniform_augment_callable")
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
decode_op = vision.Decode(to_pil=True)
|
||||
img = decode_op(img)
|
||||
assert img.size == (4032, 2268)
|
||||
|
||||
transforms_ua = [vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||
vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
img = uni_aug(img)
|
||||
assert img.size == (4032, 2268) or img.size == (400, 200)
|
||||
|
||||
|
||||
def test_uniform_augment_callable_pil_pyfunc(num_ops=3):
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test UniformAugment under execute mode, with PIL input. Include pyfunc in transforms list.
|
||||
Expectation: Output's shape is the same as expected output's shape
|
||||
"""
|
||||
logger.info("test_uniform_augment_callable")
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
decode_op = vision.Decode(to_pil=True)
|
||||
img = decode_op(img)
|
||||
assert img.size == (4032, 2268)
|
||||
|
||||
transforms_ua = [vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||
lambda x: x,
|
||||
vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
img = uni_aug(img)
|
||||
assert img.size == (4032, 2268) or img.size == (400, 200)
|
||||
|
||||
|
||||
def test_uniform_augment_callable_tuple(num_ops=2):
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test UniformAugment under execute mode. Use tuple for transforms list argument.
|
||||
Expectation: Output's shape is the same as expected output's shape
|
||||
"""
|
||||
logger.info("test_uniform_augment_callable")
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
decode_op = vision.Decode()
|
||||
img = decode_op(img)
|
||||
assert img.shape == (2268, 4032, 3)
|
||||
|
||||
transforms_ua = (vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||
vision.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]))
|
||||
uni_aug = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
img = uni_aug(img)
|
||||
assert img.shape == (2268, 4032, 3) or img.shape == (200, 400, 3)
|
||||
|
||||
|
||||
def test_uniform_augment(plot=False, num_ops=2):
|
||||
|
@ -74,7 +139,7 @@ def test_uniform_augment(plot=False, num_ops=2):
|
|||
np.transpose(image.asnumpy(), (0, 2, 3, 1)),
|
||||
axis=0)
|
||||
|
||||
# UniformAugment Images
|
||||
# UniformAugment Images
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transform_list = [vision.RandomRotation(45),
|
||||
|
@ -188,12 +253,10 @@ def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
|||
vision.RandomColorAdjust(),
|
||||
vision.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_ops" in str(e)
|
||||
logger.info("Got an exception in DE: {}".format(str(error_info)))
|
||||
assert "num_ops is greater than transforms list size" in str(error_info)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
||||
|
@ -202,7 +265,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
|||
Description: Test UniformAugment using invalid non-positive num_ops
|
||||
Expectation: Exception is raised as expected
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment invalid non-positive num_ops exception")
|
||||
logger.info("Test UniformAugment invalid non-positive num_ops exception")
|
||||
|
||||
transforms_ua = [vision.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||
vision.RandomHorizontalFlip(),
|
||||
|
@ -210,12 +273,10 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
|||
vision.RandomColorAdjust(),
|
||||
vision.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input num_ops must be greater than 0" in str(e)
|
||||
logger.info("Got an exception in DE: {}".format(str(error_info)))
|
||||
assert "Input num_ops must be greater than 0" in str(error_info)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
||||
|
@ -224,7 +285,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
|||
Description: Test UniformAugment using invalid float num_ops
|
||||
Expectation: Exception is raised as expected
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment invalid float num_ops exception")
|
||||
logger.info("Test UniformAugment invalid float num_ops exception")
|
||||
|
||||
transforms_ua = [vision.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||
vision.RandomHorizontalFlip(),
|
||||
|
@ -232,12 +293,10 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
|||
vision.RandomColorAdjust(),
|
||||
vision.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
_ = vision.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Argument num_ops with value 2.5 is not of type [<class 'int'>]" in str(e)
|
||||
logger.info("Got an exception in DE: {}".format(str(error_info)))
|
||||
assert "Argument num_ops with value 2.5 is not of type [<class 'int'>]" in str(error_info)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
||||
|
@ -246,7 +305,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
|||
Description: Test UniformAugment with greater crop size
|
||||
Expectation: Exception is raised as expected
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment with random_crop bad input")
|
||||
logger.info("Test UniformAugment with random_crop bad input")
|
||||
batch_size = 2
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
ds1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||
|
@ -262,16 +321,18 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
|||
# apply DatasetOps
|
||||
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
|
||||
num_batches = 0
|
||||
try:
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_batches += 1
|
||||
except Exception as e:
|
||||
assert "crop size" in str(e)
|
||||
assert "Shape is incorrect. map operation: [UniformAugment] failed." in str(error_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_uniform_augment_callable(num_ops=2)
|
||||
test_uniform_augment(num_ops=1, plot=True)
|
||||
test_uniform_augment_callable()
|
||||
test_uniform_augment_callable_pil()
|
||||
test_uniform_augment_callable_pil_pyfunc()
|
||||
test_uniform_augment_callable_tuple()
|
||||
test_uniform_augment(num_ops=6, plot=True)
|
||||
test_cpp_uniform_augment(num_ops=1, plot=True)
|
||||
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
|
||||
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
|
||||
|
|
|
@ -15,15 +15,148 @@
|
|||
"""
|
||||
Test UniformAugment op in Dataset
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.py_transforms as PT
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.vision.py_transforms as F
|
||||
from mindspore import log as logger
|
||||
from ..dataset.util import visualize_list, diff_mse
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_callable(num_ops=2):
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test UniformAugment C++ op under under execute mode. Use list for transforms list argument.
|
||||
Expectation: Output's shape is the same as expected output's shape
|
||||
"""
|
||||
logger.info("test_cpp_uniform_augment_callable")
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
decode_op = C.Decode()
|
||||
img = decode_op(img)
|
||||
assert img.shape == (2268, 4032, 3)
|
||||
|
||||
transforms_ua = [C.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||
C.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||
uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
img = uni_aug(img)
|
||||
assert img.shape == (2268, 4032, 3) or img.shape == (200, 400, 3)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_callable_tuple(num_ops=2):
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test UniformAugment C++ op under under execute mode. Use tuple for transforms list argument.
|
||||
Expectation: Output's shape is the same as expected output's shape
|
||||
"""
|
||||
logger.info("test_cpp_uniform_augment_callable")
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
decode_op = C.Decode()
|
||||
img = decode_op(img)
|
||||
assert img.shape == (2268, 4032, 3)
|
||||
|
||||
transforms_ua = (C.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||
C.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]))
|
||||
uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
img = uni_aug(img)
|
||||
assert img.shape == (2268, 4032, 3) or img.shape == (200, 400, 3)
|
||||
|
||||
|
||||
def test_py_uniform_augment_callable(num_ops=2):
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test UniformAugment Python op under under execute mode. Use list for transforms list argument.
|
||||
Expectation: Output's shape is the same as expected output's shape
|
||||
"""
|
||||
logger.info("test_cpp_uniform_augment_callable")
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
decode_op = F.Decode()
|
||||
img = decode_op(img)
|
||||
assert img.size == (4032, 2268)
|
||||
|
||||
transforms_ua = [F.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32]),
|
||||
F.RandomCrop(size=[200, 400], padding=[32, 32, 32, 32])]
|
||||
uni_aug = F.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
|
||||
img = uni_aug(img)
|
||||
assert img.size == (4032, 2268) or img.size == (400, 200)
|
||||
|
||||
|
||||
def test_py_uniform_augment_pyfunc(plot=False, num_ops=2):
|
||||
"""
|
||||
Feature: UniformAugment Op
|
||||
Description: Test Python op with valid Python function in transforms list. Include pyfunc in transforms list.
|
||||
Expectation: Pipeline is successfully executed
|
||||
"""
|
||||
logger.info("Test UniformAugment")
|
||||
|
||||
# Original Images
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = PT.Compose([F.Decode(),
|
||||
F.Resize((224, 224)),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_original = data_set.map(operations=transforms_original, input_columns="image")
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image.asnumpy(), (0, 2, 3, 1))
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
np.transpose(image.asnumpy(), (0, 2, 3, 1)),
|
||||
axis=0)
|
||||
|
||||
# UniformAugment Images
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transform_list = [F.RandomRotation(45),
|
||||
F.RandomColor(),
|
||||
F.RandomSharpness(),
|
||||
F.Invert(),
|
||||
lambda x: x,
|
||||
F.AutoContrast(),
|
||||
F.Equalize()]
|
||||
|
||||
transforms_ua = PT.Compose([F.Decode(),
|
||||
F.Resize((224, 224)),
|
||||
F.UniformAugment(transforms=transform_list,
|
||||
num_ops=num_ops),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_ua = data_set.map(operations=transforms_ua, input_columns="image")
|
||||
|
||||
ds_ua = ds_ua.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_ua):
|
||||
if idx == 0:
|
||||
images_ua = np.transpose(image.asnumpy(), (0, 2, 3, 1))
|
||||
else:
|
||||
images_ua = np.append(images_ua,
|
||||
np.transpose(image.asnumpy(), (0, 2, 3, 1)),
|
||||
axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_ua[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize_list(images_original, images_ua)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
||||
"""
|
||||
Feature: UniformAugment Op
|
||||
|
@ -46,5 +179,81 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
|||
assert "Type of Transforms[5] must be c_transform" in str(e.value)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_exception_pyfunc():
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test C++ op with pyfunc in transforms list
|
||||
Expectation: Exception is raised as expected
|
||||
"""
|
||||
pyfunc = lambda x: x
|
||||
transforms_list = [C.RandomVerticalFlip(), pyfunc]
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
_ = C.UniformAugment(transforms_list, 1)
|
||||
error_msg = "Type of Transforms[1] must be c_transform, but got <class 'function'>"
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
|
||||
def test_c_uniform_augment_exception_num_ops():
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test C++ op with more ops than number of ops in transforms list
|
||||
Expectation: Exception is raised as expected
|
||||
"""
|
||||
transforms_list = [C.RandomVerticalFlip()]
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
_ = C.UniformAugment(transforms_list, 3)
|
||||
error_msg = "num_ops is greater than transforms list size"
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
|
||||
def test_py_uniform_augment_exception_num_ops():
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test Python op with more ops than number of ops in transforms list
|
||||
Expectation: Exception is raised as expected
|
||||
"""
|
||||
pyfunc = lambda x: x
|
||||
transforms_list = [F.RandomVerticalFlip(), pyfunc]
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
_ = F.UniformAugment(transforms_list, 9)
|
||||
error_msg = "num_ops cannot be greater than the length of transforms list."
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
|
||||
def test_py_uniform_augment_exception_tuple1():
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test Python op with transforms argument as tuple
|
||||
Expectation: Exception is raised as expected
|
||||
"""
|
||||
transforms_list = (F.RandomVerticalFlip())
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
_ = F.UniformAugment(transforms_list, 1)
|
||||
error_msg = "not of type [<class 'list'>], but got"
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
|
||||
def test_py_uniform_augment_exception_tuple2():
|
||||
"""
|
||||
Feature: UniformAugment
|
||||
Description: Test Python op with transforms argument as tuple
|
||||
Expectation: Exception is raised as expected
|
||||
"""
|
||||
transforms_list = (F.RandomHorizontalFlip(), F.RandomVerticalFlip())
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
_ = F.UniformAugment(transforms_list, 1)
|
||||
error_msg = "not of type [<class 'list'>], but got <class 'tuple'>."
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cpp_uniform_augment_callable()
|
||||
test_cpp_uniform_augment_callable_tuple()
|
||||
test_py_uniform_augment_callable()
|
||||
test_py_uniform_augment_pyfunc(plot=True, num_ops=7)
|
||||
test_cpp_uniform_augment_exception_pyops(num_ops=1)
|
||||
test_cpp_uniform_augment_exception_pyfunc()
|
||||
test_c_uniform_augment_exception_num_ops()
|
||||
test_py_uniform_augment_exception_num_ops()
|
||||
test_py_uniform_augment_exception_tuple1()
|
||||
test_py_uniform_augment_exception_tuple2()
|
||||
|
|
Loading…
Reference in New Issue