diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc index 1214345c370..cbc5aaa2e5c 100644 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc @@ -25,18 +25,14 @@ UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops std::shared_ptr tensor_op; // iterate over the op list, cast them to TensorOp and add them to tensor_op_list_ for (auto op : op_list) { - if (py::isinstance(op)) { - // python op - tensor_op = std::make_shared(op.cast()); - } else if (py::isinstance(op)) { - // C++ op - tensor_op = op.cast>(); - } + // only C++ op is accepted + tensor_op = op.cast>(); tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op); } rnd_.seed(GetSeed()); } + // compute method to apply uniformly random selected augmentations from a list Status UniformAugOp::Compute(const std::vector> &input, std::vector> *output) { @@ -57,7 +53,7 @@ Status UniformAugOp::Compute(const std::vector> &input, continue; } - // apply python/C++ op + // apply C++ ops (note: python OPs are not accepted) if (count == 1) { (**tensor_op).Compute(input, output); } else if (count % 2 == 0) { diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h index 336bc8c8598..a70edc2777a 100644 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h @@ -36,7 +36,7 @@ class UniformAugOp : public TensorOp { static const int kDefNumOps; // Constructor for UniformAugOp - // @param list op_list: list of candidate python operations + // @param list op_list: list of candidate C++ operations // @param list num_ops: number of augemtation operations to applied UniformAugOp(py::list op_list, int32_t num_ops); diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 1b495ffe923..1806d22446d 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -455,8 +455,19 @@ class UniformAugment(cde.UniformAugOp): Tensor operation to perform randomly selected augmentation Args: - operations: list of python operations. + operations: list of C++ operations (python OPs are not accepted). NumOps (int): number of OPs to be selected and applied. + + Examples: + >>> transforms_list = [c_transforms.RandomHorizontalFlip(), + >>> c_transforms.RandomVerticalFlip(), + >>> c_transforms.RandomColorAdjust(), + >>> c_transforms.RandomRotation(degrees=45)] + >>> uni_aug = c_transforms.UniformAugment(operations=transforms_list, num_ops=2) + >>> transforms_all = [c_transforms.Decode(), c_transforms.Resize(size=[224, 224]), + >>> uni_aug, F.ToTensor()] + >>> ds_ua = ds.map(input_columns="image", + >>> operations=transforms_all, num_parallel_workers=1) """ @check_uniform_augmentation diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index 2c299b077bd..96d0a3bfdcc 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -837,8 +837,8 @@ def check_uniform_augmentation(method): if not isinstance(operations, list): raise ValueError("operations is not a python list") for op in operations: - if not callable(op) and not isinstance(op, TensorOp): - raise ValueError("non-callable op in operations list") + if not isinstance(op, TensorOp): + raise ValueError("operations list only accepts C++ operations.") kwargs["num_ops"] = num_ops kwargs["operations"] = operations diff --git a/tests/ut/python/dataset/test_uniform_augment.py b/tests/ut/python/dataset/test_uniform_augment.py index ea990561165..98c22fb3cb9 100644 --- a/tests/ut/python/dataset/test_uniform_augment.py +++ b/tests/ut/python/dataset/test_uniform_augment.py @@ -163,7 +163,68 @@ def test_cpp_uniform_augment(plot=False, num_ops=2): mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2) logger.info("MSE= {}".format(str(np.mean(mse)))) +def test_cpp_uniform_augment_exception_pyops(num_ops=2): + """ + Test UniformAugment invalid op in operations + """ + logger.info("Test CPP UniformAugment invalid OP exception") + + transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), + C.RandomHorizontalFlip(), + C.RandomVerticalFlip(), + C.RandomColorAdjust(), + C.RandomRotation(degrees=45), + F.Invert()] + + try: + uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + + except BaseException as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "operations" in str(e) + +def test_cpp_uniform_augment_exception_large_numops(num_ops=6): + """ + Test UniformAugment invalid large number of ops + """ + logger.info("Test CPP UniformAugment invalid large num_ops exception") + + transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), + C.RandomHorizontalFlip(), + C.RandomVerticalFlip(), + C.RandomColorAdjust(), + C.RandomRotation(degrees=45)] + + try: + uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + + except BaseException as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "num_ops" in str(e) + +def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): + """ + Test UniformAugment invalid non-positive number of ops + """ + logger.info("Test CPP UniformAugment invalid non-positive num_ops exception") + + transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), + C.RandomHorizontalFlip(), + C.RandomVerticalFlip(), + C.RandomColorAdjust(), + C.RandomRotation(degrees=45)] + + try: + uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + + except BaseException as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "num_ops" in str(e) + if __name__ == "__main__": test_uniform_augment(num_ops=1) test_cpp_uniform_augment(num_ops=1) - + test_cpp_uniform_augment_exception_pyops(num_ops=1) + test_cpp_uniform_augment_exception_large_numops(num_ops=6) + test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0) +