This commit is contained in:
Tinazhang 2020-05-19 17:57:17 -04:00
parent 6cbde2b3bb
commit e9e40b688b
2 changed files with 25 additions and 3 deletions

View File

@ -55,11 +55,11 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
// apply C++ ops (note: python OPs are not accepted)
if (count == 1) {
(**tensor_op).Compute(input, output);
RETURN_IF_NOT_OK((**tensor_op).Compute(input, output));
} else if (count % 2 == 0) {
(**tensor_op).Compute(*output, even_out_ptr);
RETURN_IF_NOT_OK((**tensor_op).Compute(*output, even_out_ptr));
} else {
(**tensor_op).Compute(even_out, output);
RETURN_IF_NOT_OK((**tensor_op).Compute(even_out, output));
}
count++;
}

View File

@ -226,6 +226,27 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)
def test_cpp_uniform_augment_random_crop_ut():
batch_size=2
cifar10_dir = "../data/dataset/testCifar10Data"
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
transforms_ua = [
C.RandomCrop(size=[224, 224]),
C.RandomHorizontalFlip()
]
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=1)
ds1 = ds1.map(input_columns="image", operations=uni_aug)
# apply DatasetOps
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
num_batches = 0
try:
for data in ds1.create_dict_iterator():
num_batches += 1
except BaseException as e:
assert "Crop size" in str(e)
if __name__ == "__main__":
test_uniform_augment(num_ops=1)
@ -233,3 +254,4 @@ if __name__ == "__main__":
test_cpp_uniform_augment_exception_pyops(num_ops=1)
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
test_cpp_uniform_augment_random_crop_ut()