This commit is contained in:
Tinazhang 2020-05-19 17:57:17 -04:00
parent 6f733ec113
commit b390883c6a
2 changed files with 31 additions and 3 deletions

View File

@ -830,6 +830,9 @@ def check_uniform_augmentation(method):
else:
num_ops = 2
if not isinstance(num_ops, int):
raise ValueError("Number of operations should be an integer.")
if num_ops <= 0:
raise ValueError("num_ops should be greater than zero")
if num_ops > len(operations):

View File

@ -226,16 +226,40 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)
def test_cpp_uniform_augment_random_crop_ut():
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
"""
Test UniformAugment invalid float number of ops
"""
logger.info("Test CPP UniformAugment invalid float num_ops exception")
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
C.RandomHorizontalFlip(),
C.RandomVerticalFlip(),
C.RandomColorAdjust(),
C.RandomRotation(degrees=45)]
try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "integer" in str(e)
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
"""
Test UniformAugment with greater crop size
"""
logger.info("Test CPP UniformAugment with random_crop bad input")
batch_size=2
cifar10_dir = "../data/dataset/testCifar10Data"
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
transforms_ua = [
# Note: crop size [224, 224] > image size [32, 32]
C.RandomCrop(size=[224, 224]),
C.RandomHorizontalFlip()
]
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=1)
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
ds1 = ds1.map(input_columns="image", operations=uni_aug)
# apply DatasetOps
@ -254,4 +278,5 @@ if __name__ == "__main__":
test_cpp_uniform_augment_exception_pyops(num_ops=1)
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
test_cpp_uniform_augment_random_crop_ut()
test_cpp_uniform_augment_exception_float_numops(num_ops=2.5)
test_cpp_uniform_augment_random_crop_badinput(num_ops=1)