added input validation to reject python op in C++ uniform augmentation operations list

This commit is contained in:
Adel Shafiei 2020-04-29 15:42:36 -04:00
parent ea7872b0a3
commit d15bd04bfe
5 changed files with 81 additions and 13 deletions

View File

@ -25,18 +25,14 @@ UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops
std::shared_ptr<TensorOp> tensor_op;
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_
for (auto op : op_list) {
if (py::isinstance<py::function>(op)) {
// python op
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
} else if (py::isinstance<TensorOp>(op)) {
// C++ op
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
}
// only C++ op is accepted
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op);
}
rnd_.seed(GetSeed());
}
// compute method to apply uniformly random selected augmentations from a list
Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) {
@ -57,7 +53,7 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
continue;
}
// apply python/C++ op
// apply C++ ops (note: python OPs are not accepted)
if (count == 1) {
(**tensor_op).Compute(input, output);
} else if (count % 2 == 0) {

View File

@ -36,7 +36,7 @@ class UniformAugOp : public TensorOp {
static const int kDefNumOps;
// Constructor for UniformAugOp
// @param list op_list: list of candidate python operations
// @param list op_list: list of candidate C++ operations
// @param list num_ops: number of augemtation operations to applied
UniformAugOp(py::list op_list, int32_t num_ops);

View File

@ -455,8 +455,19 @@ class UniformAugment(cde.UniformAugOp):
Tensor operation to perform randomly selected augmentation
Args:
operations: list of python operations.
operations: list of C++ operations (python OPs are not accepted).
NumOps (int): number of OPs to be selected and applied.
Examples:
>>> transforms_list = [c_transforms.RandomHorizontalFlip(),
>>> c_transforms.RandomVerticalFlip(),
>>> c_transforms.RandomColorAdjust(),
>>> c_transforms.RandomRotation(degrees=45)]
>>> uni_aug = c_transforms.UniformAugment(operations=transforms_list, num_ops=2)
>>> transforms_all = [c_transforms.Decode(), c_transforms.Resize(size=[224, 224]),
>>> uni_aug, F.ToTensor()]
>>> ds_ua = ds.map(input_columns="image",
>>> operations=transforms_all, num_parallel_workers=1)
"""
@check_uniform_augmentation

View File

@ -837,8 +837,8 @@ def check_uniform_augmentation(method):
if not isinstance(operations, list):
raise ValueError("operations is not a python list")
for op in operations:
if not callable(op) and not isinstance(op, TensorOp):
raise ValueError("non-callable op in operations list")
if not isinstance(op, TensorOp):
raise ValueError("operations list only accepts C++ operations.")
kwargs["num_ops"] = num_ops
kwargs["operations"] = operations

View File

@ -163,7 +163,68 @@ def test_cpp_uniform_augment(plot=False, num_ops=2):
mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2)
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_cpp_uniform_augment_exception_pyops(num_ops=2):
"""
Test UniformAugment invalid op in operations
"""
logger.info("Test CPP UniformAugment invalid OP exception")
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
C.RandomHorizontalFlip(),
C.RandomVerticalFlip(),
C.RandomColorAdjust(),
C.RandomRotation(degrees=45),
F.Invert()]
try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "operations" in str(e)
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
"""
Test UniformAugment invalid large number of ops
"""
logger.info("Test CPP UniformAugment invalid large num_ops exception")
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
C.RandomHorizontalFlip(),
C.RandomVerticalFlip(),
C.RandomColorAdjust(),
C.RandomRotation(degrees=45)]
try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)
def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
"""
Test UniformAugment invalid non-positive number of ops
"""
logger.info("Test CPP UniformAugment invalid non-positive num_ops exception")
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
C.RandomHorizontalFlip(),
C.RandomVerticalFlip(),
C.RandomColorAdjust(),
C.RandomRotation(degrees=45)]
try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)
if __name__ == "__main__":
test_uniform_augment(num_ops=1)
test_cpp_uniform_augment(num_ops=1)
test_cpp_uniform_augment_exception_pyops(num_ops=1)
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)