forked from mindspore-Ecosystem/mindspore
!875 Reject python OP in operations argument for C++ uniform augmentation OP
Merge pull request !875 from AdelShafiei/ua_py
This commit is contained in:
commit
8af10eb51e
|
@ -25,18 +25,14 @@ UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops
|
||||||
std::shared_ptr<TensorOp> tensor_op;
|
std::shared_ptr<TensorOp> tensor_op;
|
||||||
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_
|
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_
|
||||||
for (auto op : op_list) {
|
for (auto op : op_list) {
|
||||||
if (py::isinstance<py::function>(op)) {
|
// only C++ op is accepted
|
||||||
// python op
|
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
||||||
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
|
|
||||||
} else if (py::isinstance<TensorOp>(op)) {
|
|
||||||
// C++ op
|
|
||||||
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
|
||||||
}
|
|
||||||
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op);
|
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
rnd_.seed(GetSeed());
|
rnd_.seed(GetSeed());
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute method to apply uniformly random selected augmentations from a list
|
// compute method to apply uniformly random selected augmentations from a list
|
||||||
Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
||||||
std::vector<std::shared_ptr<Tensor>> *output) {
|
std::vector<std::shared_ptr<Tensor>> *output) {
|
||||||
|
@ -57,7 +53,7 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply python/C++ op
|
// apply C++ ops (note: python OPs are not accepted)
|
||||||
if (count == 1) {
|
if (count == 1) {
|
||||||
(**tensor_op).Compute(input, output);
|
(**tensor_op).Compute(input, output);
|
||||||
} else if (count % 2 == 0) {
|
} else if (count % 2 == 0) {
|
||||||
|
|
|
@ -36,7 +36,7 @@ class UniformAugOp : public TensorOp {
|
||||||
static const int kDefNumOps;
|
static const int kDefNumOps;
|
||||||
|
|
||||||
// Constructor for UniformAugOp
|
// Constructor for UniformAugOp
|
||||||
// @param list op_list: list of candidate python operations
|
// @param list op_list: list of candidate C++ operations
|
||||||
// @param list num_ops: number of augemtation operations to applied
|
// @param list num_ops: number of augemtation operations to applied
|
||||||
UniformAugOp(py::list op_list, int32_t num_ops);
|
UniformAugOp(py::list op_list, int32_t num_ops);
|
||||||
|
|
||||||
|
|
|
@ -455,8 +455,19 @@ class UniformAugment(cde.UniformAugOp):
|
||||||
Tensor operation to perform randomly selected augmentation
|
Tensor operation to perform randomly selected augmentation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
operations: list of python operations.
|
operations: list of C++ operations (python OPs are not accepted).
|
||||||
NumOps (int): number of OPs to be selected and applied.
|
NumOps (int): number of OPs to be selected and applied.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> transforms_list = [c_transforms.RandomHorizontalFlip(),
|
||||||
|
>>> c_transforms.RandomVerticalFlip(),
|
||||||
|
>>> c_transforms.RandomColorAdjust(),
|
||||||
|
>>> c_transforms.RandomRotation(degrees=45)]
|
||||||
|
>>> uni_aug = c_transforms.UniformAugment(operations=transforms_list, num_ops=2)
|
||||||
|
>>> transforms_all = [c_transforms.Decode(), c_transforms.Resize(size=[224, 224]),
|
||||||
|
>>> uni_aug, F.ToTensor()]
|
||||||
|
>>> ds_ua = ds.map(input_columns="image",
|
||||||
|
>>> operations=transforms_all, num_parallel_workers=1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_uniform_augmentation
|
@check_uniform_augmentation
|
||||||
|
|
|
@ -837,8 +837,8 @@ def check_uniform_augmentation(method):
|
||||||
if not isinstance(operations, list):
|
if not isinstance(operations, list):
|
||||||
raise ValueError("operations is not a python list")
|
raise ValueError("operations is not a python list")
|
||||||
for op in operations:
|
for op in operations:
|
||||||
if not callable(op) and not isinstance(op, TensorOp):
|
if not isinstance(op, TensorOp):
|
||||||
raise ValueError("non-callable op in operations list")
|
raise ValueError("operations list only accepts C++ operations.")
|
||||||
|
|
||||||
kwargs["num_ops"] = num_ops
|
kwargs["num_ops"] = num_ops
|
||||||
kwargs["operations"] = operations
|
kwargs["operations"] = operations
|
||||||
|
|
|
@ -163,7 +163,68 @@ def test_cpp_uniform_augment(plot=False, num_ops=2):
|
||||||
mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2)
|
mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2)
|
||||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||||
|
|
||||||
|
def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
||||||
|
"""
|
||||||
|
Test UniformAugment invalid op in operations
|
||||||
|
"""
|
||||||
|
logger.info("Test CPP UniformAugment invalid OP exception")
|
||||||
|
|
||||||
|
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||||
|
C.RandomHorizontalFlip(),
|
||||||
|
C.RandomVerticalFlip(),
|
||||||
|
C.RandomColorAdjust(),
|
||||||
|
C.RandomRotation(degrees=45),
|
||||||
|
F.Invert()]
|
||||||
|
|
||||||
|
try:
|
||||||
|
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "operations" in str(e)
|
||||||
|
|
||||||
|
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
||||||
|
"""
|
||||||
|
Test UniformAugment invalid large number of ops
|
||||||
|
"""
|
||||||
|
logger.info("Test CPP UniformAugment invalid large num_ops exception")
|
||||||
|
|
||||||
|
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||||
|
C.RandomHorizontalFlip(),
|
||||||
|
C.RandomVerticalFlip(),
|
||||||
|
C.RandomColorAdjust(),
|
||||||
|
C.RandomRotation(degrees=45)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "num_ops" in str(e)
|
||||||
|
|
||||||
|
def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
||||||
|
"""
|
||||||
|
Test UniformAugment invalid non-positive number of ops
|
||||||
|
"""
|
||||||
|
logger.info("Test CPP UniformAugment invalid non-positive num_ops exception")
|
||||||
|
|
||||||
|
transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
|
||||||
|
C.RandomHorizontalFlip(),
|
||||||
|
C.RandomVerticalFlip(),
|
||||||
|
C.RandomColorAdjust(),
|
||||||
|
C.RandomRotation(degrees=45)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "num_ops" in str(e)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_uniform_augment(num_ops=1)
|
test_uniform_augment(num_ops=1)
|
||||||
test_cpp_uniform_augment(num_ops=1)
|
test_cpp_uniform_augment(num_ops=1)
|
||||||
|
test_cpp_uniform_augment_exception_pyops(num_ops=1)
|
||||||
|
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
|
||||||
|
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue