forked from mindspore-Ecosystem/mindspore
added input validation to reject python op in C++ uniform augmentation operations list
This commit is contained in:
parent
ea7872b0a3
commit
d15bd04bfe
|
@ -25,18 +25,14 @@ UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops
|
|||
std::shared_ptr<TensorOp> tensor_op;
|
||||
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_
|
||||
for (auto op : op_list) {
|
||||
if (py::isinstance<py::function>(op)) {
|
||||
// python op
|
||||
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
|
||||
} else if (py::isinstance<TensorOp>(op)) {
|
||||
// C++ op
|
||||
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
||||
}
|
||||
// only C++ op is accepted
|
||||
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
||||
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op);
|
||||
}
|
||||
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
// compute method to apply uniformly random selected augmentations from a list
|
||||
Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
||||
std::vector<std::shared_ptr<Tensor>> *output) {
|
||||
|
@ -57,7 +53,7 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
|||
continue;
|
||||
}
|
||||
|
||||
// apply python/C++ op
|
||||
// apply C++ ops (note: python OPs are not accepted)
|
||||
if (count == 1) {
|
||||
(**tensor_op).Compute(input, output);
|
||||
} else if (count % 2 == 0) {
|
||||
|
|
|
@ -36,7 +36,7 @@ class UniformAugOp : public TensorOp {
|
|||
static const int kDefNumOps;
|
||||
|
||||
// Constructor for UniformAugOp
|
||||
// @param list op_list: list of candidate python operations
|
||||
// @param list op_list: list of candidate C++ operations
|
||||
// @param list num_ops: number of augemtation operations to applied
|
||||
UniformAugOp(py::list op_list, int32_t num_ops);
|
||||
|
||||
|
|
|
@ -455,8 +455,19 @@ class UniformAugment(cde.UniformAugOp):
|
|||
Tensor operation to perform randomly selected augmentation
|
||||
|
||||
Args:
|
||||
operations: list of python operations.
|
||||
operations: list of C++ operations (python OPs are not accepted).
|
||||
NumOps (int): number of OPs to be selected and applied.
|
||||
|
||||
Examples:
|
||||
>>> transforms_list = [c_transforms.RandomHorizontalFlip(),
|
||||
>>> c_transforms.RandomVerticalFlip(),
|
||||
>>> c_transforms.RandomColorAdjust(),
|
||||
>>> c_transforms.RandomRotation(degrees=45)]
|
||||
>>> uni_aug = c_transforms.UniformAugment(operations=transforms_list, num_ops=2)
|
||||
>>> transforms_all = [c_transforms.Decode(), c_transforms.Resize(size=[224, 224]),
|
||||
>>> uni_aug, F.ToTensor()]
|
||||
>>> ds_ua = ds.map(input_columns="image",
|
||||
>>> operations=transforms_all, num_parallel_workers=1)
|
||||
"""
|
||||
|
||||
@check_uniform_augmentation
|
||||
|
|
|
@ -837,8 +837,8 @@ def check_uniform_augmentation(method):
|
|||
if not isinstance(operations, list):
|
||||
raise ValueError("operations is not a python list")
|
||||
for op in operations:
|
||||
if not callable(op) and not isinstance(op, TensorOp):
|
||||
raise ValueError("non-callable op in operations list")
|
||||
if not isinstance(op, TensorOp):
|
||||
raise ValueError("operations list only accepts C++ operations.")
|
||||
|
||||
kwargs["num_ops"] = num_ops
|
||||
kwargs["operations"] = operations
|
||||
|
|
|
@ -163,7 +163,68 @@ def test_cpp_uniform_augment(plot=False, num_ops=2):
|
|||
mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2)
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
||||
"""
|
||||
Test UniformAugment invalid op in operations
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment invalid OP exception")
|
||||
|
||||
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||
C.RandomHorizontalFlip(),
|
||||
C.RandomVerticalFlip(),
|
||||
C.RandomColorAdjust(),
|
||||
C.RandomRotation(degrees=45),
|
||||
F.Invert()]
|
||||
|
||||
try:
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except BaseException as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "operations" in str(e)
|
||||
|
||||
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
||||
"""
|
||||
Test UniformAugment invalid large number of ops
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment invalid large num_ops exception")
|
||||
|
||||
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||
C.RandomHorizontalFlip(),
|
||||
C.RandomVerticalFlip(),
|
||||
C.RandomColorAdjust(),
|
||||
C.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except BaseException as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_ops" in str(e)
|
||||
|
||||
def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
||||
"""
|
||||
Test UniformAugment invalid non-positive number of ops
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment invalid non-positive num_ops exception")
|
||||
|
||||
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||
C.RandomHorizontalFlip(),
|
||||
C.RandomVerticalFlip(),
|
||||
C.RandomColorAdjust(),
|
||||
C.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except BaseException as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_ops" in str(e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_uniform_augment(num_ops=1)
|
||||
test_cpp_uniform_augment(num_ops=1)
|
||||
|
||||
test_cpp_uniform_augment_exception_pyops(num_ops=1)
|
||||
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
|
||||
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue