diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc index a9c07d3aba5..068cb072d42 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc @@ -44,6 +44,7 @@ #include "minddata/dataset/kernels/ir/vision/random_crop_with_bbox_ir.h" #include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h" #include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h" +#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h" #include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h" #include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h" #include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h" @@ -373,6 +374,17 @@ PYBIND_REGISTER(RandomHorizontalFlipWithBBoxOperation, 1, ([](const py::module * })); })); +PYBIND_REGISTER( + RandomInvertOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomInvertOperation") + .def(py::init([](float prob) { + auto random_invert = std::make_shared(prob); + THROW_IF_ERROR(random_invert->ValidateParams()); + return random_invert; + })); + })); + PYBIND_REGISTER(RandomPosterizeOperation, 1, ([](const py::module *m) { (void)py::class_>(*m, "RandomPosterizeOperation") diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index bca04791f51..77a31ffe79b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -48,6 +48,7 @@ #include "minddata/dataset/kernels/ir/vision/random_crop_with_bbox_ir.h" #include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h" #include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h" +#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h" #include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h" #include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h" #include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h" @@ -608,6 +609,18 @@ std::shared_ptr RandomHorizontalFlipWithBBox::Parse() { return std::make_shared(data_->probability_); } +// RandomInvert Operation. +struct RandomInvert::Data { + explicit Data(float prob) : probability_(prob) {} + float probability_; +}; + +RandomInvert::RandomInvert(float prob) : data_(std::make_shared(prob)) {} + +std::shared_ptr RandomInvert::Parse() { + return std::make_shared(data_->probability_); +} + // RandomPosterize Transform Operation. struct RandomPosterize::Data { explicit Data(const std::vector &bit_range) : bit_range_(bit_range) {} diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h index b5eee10143a..827ba58a0af 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h @@ -535,6 +535,27 @@ class RandomHorizontalFlipWithBBox final : public TensorTransform { std::shared_ptr data_; }; +/// \brief Randomly invert the input image with a given probability. +class RandomInvert final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] prob A float representing the probability of the image being inverted, which + /// must be in range of [0, 1] (default=0.5). + explicit RandomInvert(float prob = 0.5); + + /// \brief Destructor. + ~RandomInvert() = default; + + protected: + /// \brief The function to convert a TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + /// \brief Reduce the number of bits for each color channel randomly. class RandomPosterize final : public TensorTransform { public: diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index dc903e3003b..437bd329128 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -37,6 +37,7 @@ add_library(kernels-image OBJECT random_crop_with_bbox_op.cc random_horizontal_flip_op.cc random_horizontal_flip_with_bbox_op.cc + random_invert_op.cc bounding_box_augment_op.cc random_posterize_op.cc random_resize_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.cc new file mode 100644 index 00000000000..4ba135c7188 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_invert_op.h" + +namespace mindspore { +namespace dataset { +const float RandomInvertOp::kDefProbability = 0.5; + +Status RandomInvertOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (distribution_(rnd_)) { + return InvertOp::Compute(input, output); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.h new file mode 100644 index 00000000000..9e8e4c0c03e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.h @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_INVERT_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_INVERT_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/image/invert_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomInvertOp : public InvertOp { + public: + static const float kDefProbability; + + explicit RandomInvertOp(float prob = kDefProbability) : distribution_(prob) { + is_deterministic_ = false; + rnd_.seed(GetSeed()); + } + + ~RandomInvertOp() override = default; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const RandomInvertOp &so) { + so.Print(out); + return out; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomInvertOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_INVERT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/CMakeLists.txt index 7e68e1ad195..eea71dcd9cb 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/CMakeLists.txt @@ -29,6 +29,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES random_crop_with_bbox_ir.cc random_horizontal_flip_ir.cc random_horizontal_flip_with_bbox_ir.cc + random_invert_ir.cc random_posterize_ir.cc random_resized_crop_ir.cc random_resized_crop_with_bbox_ir.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/random_invert_ir.cc b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/random_invert_ir.cc new file mode 100644 index 00000000000..d07c8cbb501 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/random_invert_ir.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h" + +#ifndef ENABLE_ANDROID +#include "minddata/dataset/kernels/image/random_invert_op.h" +#endif + +#include "minddata/dataset/kernels/ir/validators.h" + +namespace mindspore { +namespace dataset { +namespace vision { +#ifndef ENABLE_ANDROID +// RandomInvertOperation +RandomInvertOperation::RandomInvertOperation(float prob) : TensorOperation(true), probability_(prob) {} + +RandomInvertOperation::~RandomInvertOperation() = default; + +std::string RandomInvertOperation::Name() const { return kRandomInvertOperation; } + +Status RandomInvertOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateProbability("RandomInvert", probability_)); + return Status::OK(); +} + +std::shared_ptr RandomInvertOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(probability_); + return tensor_op; +} + +Status RandomInvertOperation::to_json(nlohmann::json *out_json) { + (*out_json)["prob"] = probability_; + return Status::OK(); +} + +Status RandomInvertOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { + CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("prob") != op_params.end(), "Failed to find prob"); + float prob = op_params["prob"]; + *operation = std::make_shared(prob); + return Status::OK(); +} + +#endif +} // namespace vision +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/random_invert_ir.h b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/random_invert_ir.h new file mode 100644 index 00000000000..a6c20bf05fd --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/random_invert_ir.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_INVERT_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_INVERT_IR_H_ + +#include +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { + +namespace vision { + +constexpr char kRandomInvertOperation[] = "RandomInvert"; + +class RandomInvertOperation : public TensorOperation { + public: + explicit RandomInvertOperation(float prob); + + ~RandomInvertOperation(); + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override; + + Status to_json(nlohmann::json *out_json) override; + + static Status from_json(nlohmann::json op_params, std::shared_ptr *operation); + + private: + float probability_; +}; + +} // namespace vision +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_INVERT_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index a7711604ee5..59e92658432 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -89,6 +89,7 @@ constexpr char kRandomCropOp[] = "RandomCropOp"; constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp"; constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp"; constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp"; +constexpr char kRandomInvertOp[] = "RandomInvertOp"; constexpr char kRandomResizeOp[] = "RandomResizeOp"; constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp"; constexpr char kRandomRotationOp[] = "RandomRotationOp"; diff --git a/mindspore/dataset/vision/c_transforms.py b/mindspore/dataset/vision/c_transforms.py index ba2f7ef2a5e..add7140ca1e 100644 --- a/mindspore/dataset/vision/c_transforms.py +++ b/mindspore/dataset/vision/c_transforms.py @@ -1092,6 +1092,27 @@ class RandomHorizontalFlipWithBBox(ImageTensorOperation): return cde.RandomHorizontalFlipWithBBoxOperation(self.prob) +class RandomInvert(ImageTensorOperation): + """ + Randomly invert the colors of image with a given probability. + + Args: + prob (float, optional): Probability of the image being inverted, which must be in range of [0, 1] (default=0.5). + + Examples: + >>> transforms_list = [c_vision.Decode(), c_vision.RandomInvert(0.5)] + >>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list, + ... input_columns=["image"]) + """ + + @check_prob + def __init__(self, prob=0.5): + self.prob = prob + + def parse(self): + return cde.RandomInvertOperation(self.prob) + + class RandomPosterize(ImageTensorOperation): """ Reduce the number of bits for each color channel to posterize the input image randomly with a given probability. diff --git a/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc b/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc index 3fbb4ea0d59..506e2f0324b 100644 --- a/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc @@ -1270,3 +1270,49 @@ TEST_F(MindDataTestPipeline, TestConvertColorFail) { std::shared_ptr iter = ds->CreateIterator(); EXPECT_EQ(iter, nullptr); } + +TEST_F(MindDataTestPipeline, TestRandomInvert) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomInvert."; + + std::string MindDataPath = "data/dataset"; + std::string folder_path = MindDataPath + "/testImageNetData/train/"; + std::shared_ptr ds = ImageFolder(folder_path, true, std::make_shared(false, 2)); + EXPECT_NE(ds, nullptr); + + auto random_invert_op = vision::RandomInvert(0.5); + + ds = ds->Map({random_invert_op}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + iter->GetNextRow(&row); + } + EXPECT_EQ(i, 2); + + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestRandomInvertInvalidProb) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomInvertInvalidProb."; + + std::string MindDataPath = "data/dataset"; + std::string folder_path = MindDataPath + "/testImageNetData/train/"; + std::shared_ptr ds = ImageFolder(folder_path, true, std::make_shared(false, 2)); + EXPECT_NE(ds, nullptr); + + auto random_invert_op = vision::RandomInvert(1.5); + + ds = ds->Map({random_invert_op}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} diff --git a/tests/ut/cpp/dataset/execute_test.cc b/tests/ut/cpp/dataset/execute_test.cc index b4bf81e0d1b..f49ac2196e0 100644 --- a/tests/ut/cpp/dataset/execute_test.cc +++ b/tests/ut/cpp/dataset/execute_test.cc @@ -1057,6 +1057,7 @@ TEST_F(MindDataTestExecute, TestFadeWithInvalidArg) { Status s04 = Transform04(input_04, &input_04); EXPECT_FALSE(s04.IsOk()); } + TEST_F(MindDataTestExecute, TestVolDefalutValue) { MS_LOG(INFO) << "Doing MindDataTestExecute-TestVolDefalutValue."; std::shared_ptr input_tensor_; @@ -1097,3 +1098,17 @@ TEST_F(MindDataTestExecute, TestMagphaseEager) { Status rc = transform({input_tensor}, &output_tensor); ASSERT_TRUE(rc.IsOk()); } + +TEST_F(MindDataTestExecute, TestRandomInvertEager) { + MS_LOG(INFO) << "Doing MindDataTestExecute-TestRandomInvertEager."; + // Read images + auto image = ReadFileToTensor("data/dataset/apple.jpg"); + + // Transform params + auto decode = vision::Decode(); + auto random_invert_op = vision::RandomInvert(0.6); + + auto transform = Execute({decode, random_invert_op}); + Status rc = transform(image, &image); + EXPECT_EQ(rc, Status::OK()); +} diff --git a/tests/ut/python/dataset/test_random_invert.py b/tests/ut/python/dataset/test_random_invert.py new file mode 100644 index 00000000000..67088ec0f8c --- /dev/null +++ b/tests/ut/python/dataset/test_random_invert.py @@ -0,0 +1,129 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing RandomInvert in DE +""" +import numpy as np + +import mindspore.dataset as ds +from mindspore.dataset.vision.c_transforms import Decode, Resize, RandomInvert, Invert +from mindspore import log as logger +from util import visualize_list, visualize_image, diff_mse + +image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg" +data_dir = "../data/dataset/testImageNetData/train/" + + +def test_random_invert_pipeline(plot=False): + """ + Test RandomInvert pipeline + """ + logger.info("Test RandomInvert pipeline") + + # Original Images + data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + transforms_original = [Decode(), Resize(size=[224, 224])] + ds_original = data_set.map(operations=transforms_original, input_columns="image") + ds_original = ds_original.batch(512) + + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image.asnumpy() + else: + images_original = np.append(images_original, + image.asnumpy(), + axis=0) + + # Randomly Inverted Images + data_set1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + transform_random_invert = [Decode(), Resize(size=[224, 224]), RandomInvert(0.6)] + ds_random_invert = data_set1.map(operations=transform_random_invert, input_columns="image") + ds_random_invert = ds_random_invert.batch(512) + for idx, (image, _) in enumerate(ds_random_invert): + if idx == 0: + images_random_invert = image.asnumpy() + else: + images_random_invert = np.append(images_random_invert, + image.asnumpy(), + axis=0) + if plot: + visualize_list(images_original, images_random_invert) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_random_invert[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_random_invert_eager(): + """ + Test RandomInvert eager. + """ + img = np.fromfile(image_file, dtype=np.uint8) + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + + img = Decode()(img) + img_inverted = Invert()(img) + img_random_inverted = RandomInvert(1.0)(img) + logger.info("Image.type: {}, Image.shape: {}".format(type(img_random_inverted), img_random_inverted.shape)) + + assert img_random_inverted.all() == img_inverted.all() + + +def test_random_invert_comp(plot=False): + """ + Test RandomInvert op compared with Invert op. + """ + random_invert_op = RandomInvert(prob=1.0) + invert_op = Invert() + + dataset1 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True) + for item in dataset1.create_dict_iterator(output_numpy=True): + image = item['image'] + dataset1.map(operations=random_invert_op, input_columns=['image']) + dataset2 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True) + dataset2.map(operations=invert_op, input_columns=['image']) + for item1, item2 in zip(dataset1.create_dict_iterator(output_numpy=True), + dataset2.create_dict_iterator(output_numpy=True)): + image_random_inverted = item1['image'] + image_inverted = item2['image'] + + mse = diff_mse(image_inverted, image_random_inverted) + assert mse == 0 + logger.info("mse: {}".format(mse)) + if plot: + visualize_image(image, image_random_inverted, mse, image_inverted) + + +def test_random_invert_invalid_prob(): + """ + Test invalid prob. prob out of range. + """ + logger.info("test_random_invert_invalid_prob") + dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True) + try: + random_invert_op = RandomInvert(1.5) + dataset = dataset.map(operations=random_invert_op, input_columns=['image']) + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e) + + +if __name__ == "__main__": + test_random_invert_pipeline(plot=True) + test_random_invert_eager() + test_random_invert_comp(plot=True) + test_random_invert_invalid_prob()