forked from mindspore-Ecosystem/mindspore
!1377 Code Fix for Uniform Augmentation
Merge pull request !1377 from Tinazhang/cc
This commit is contained in:
commit
a528797253
|
@ -830,6 +830,9 @@ def check_uniform_augmentation(method):
|
|||
else:
|
||||
num_ops = 2
|
||||
|
||||
if not isinstance(num_ops, int):
|
||||
raise ValueError("Number of operations should be an integer.")
|
||||
|
||||
if num_ops <= 0:
|
||||
raise ValueError("num_ops should be greater than zero")
|
||||
if num_ops > len(operations):
|
||||
|
|
|
@ -226,16 +226,40 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_ops" in str(e)
|
||||
|
||||
def test_cpp_uniform_augment_random_crop_ut():
|
||||
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
||||
"""
|
||||
Test UniformAugment invalid float number of ops
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment invalid float num_ops exception")
|
||||
|
||||
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||
C.RandomHorizontalFlip(),
|
||||
C.RandomVerticalFlip(),
|
||||
C.RandomColorAdjust(),
|
||||
C.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except BaseException as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "integer" in str(e)
|
||||
|
||||
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
||||
"""
|
||||
Test UniformAugment with greater crop size
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment with random_crop bad input")
|
||||
batch_size=2
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||
|
||||
transforms_ua = [
|
||||
# Note: crop size [224, 224] > image size [32, 32]
|
||||
C.RandomCrop(size=[224, 224]),
|
||||
C.RandomHorizontalFlip()
|
||||
]
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=1)
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
ds1 = ds1.map(input_columns="image", operations=uni_aug)
|
||||
|
||||
# apply DatasetOps
|
||||
|
@ -254,4 +278,5 @@ if __name__ == "__main__":
|
|||
test_cpp_uniform_augment_exception_pyops(num_ops=1)
|
||||
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
|
||||
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
|
||||
test_cpp_uniform_augment_random_crop_ut()
|
||||
test_cpp_uniform_augment_exception_float_numops(num_ops=2.5)
|
||||
test_cpp_uniform_augment_random_crop_badinput(num_ops=1)
|
||||
|
|
Loading…
Reference in New Issue