diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc index 196eb1ed1fa..9d289f306e2 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc @@ -422,7 +422,7 @@ PYBIND_REGISTER( PYBIND_REGISTER(RandomSolarizeOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "RandomSolarizeOp") - .def(py::init()); + .def(py::init>()); })); } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index c9cf66cdec1..f339855dd81 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -241,8 +241,8 @@ std::shared_ptr RandomRotation(std::vector degre } // Function to create RandomSolarizeOperation. -std::shared_ptr RandomSolarize(uint8_t threshold_min, uint8_t threshold_max) { - auto op = std::make_shared(threshold_min, threshold_max); +std::shared_ptr RandomSolarize(std::vector threshold) { + auto op = std::make_shared(threshold); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -811,19 +811,22 @@ std::shared_ptr RandomSharpnessOperation::Build() { } // RandomSolarizeOperation. -RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max) - : threshold_min_(threshold_min), threshold_max_(threshold_max) {} +RandomSolarizeOperation::RandomSolarizeOperation(std::vector threshold) : threshold_(threshold) {} bool RandomSolarizeOperation::ValidateParams() { - if (threshold_max_ < threshold_min_) { - MS_LOG(ERROR) << "RandomSolarize: threshold_max must be greater or equal to threshold_min"; + if (threshold_.size() != 2) { + MS_LOG(ERROR) << "RandomSolarize: threshold vector has incorrect size: " << threshold_.size(); + return false; + } + if (threshold_.at(0) > threshold_.at(1)) { + MS_LOG(ERROR) << "RandomSolarize: threshold must be passed in a min, max format"; return false; } return true; } std::shared_ptr RandomSolarizeOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(threshold_min_, threshold_max_); + std::shared_ptr tensor_op = std::make_shared(threshold_); return tensor_op; } diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index 3bfa27148c1..bcc6452a358 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -249,10 +249,9 @@ std::shared_ptr RandomSharpness(std::vector deg /// \brief Function to create a RandomSolarize TensorOperation. /// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold -/// \param[in] threshold_min - lower limit -/// \param[in] threshold_max - upper limit +/// \param[in] threshold - a vector with two elements specifying the pixel range to invert. /// \return Shared pointer to the current TensorOperation. -std::shared_ptr RandomSolarize(uint8_t threshold_min = 0, uint8_t threshold_max = 255); +std::shared_ptr RandomSolarize(std::vector threshold = {0, 255}); /// \brief Function to create a RandomVerticalFlip TensorOperation. /// \notes Tensor operation to perform random vertical flip. @@ -657,7 +656,7 @@ class SwapRedBlueOperation : public TensorOperation { class RandomSolarizeOperation : public TensorOperation { public: - explicit RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max); + explicit RandomSolarizeOperation(std::vector threshold); ~RandomSolarizeOperation() = default; @@ -666,8 +665,7 @@ class RandomSolarizeOperation : public TensorOperation { bool ValidateParams() override; private: - uint8_t threshold_min_; - uint8_t threshold_max_; + std::vector threshold_; }; } // namespace vision } // namespace api diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.cc index 83860febc86..b4a22b971e9 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.cc @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include + #include "minddata/dataset/kernels/image/random_solarize_op.h" #include "minddata/dataset/kernels/image/solarize_op.h" #include "minddata/dataset/kernels/image/image_utils.h" @@ -24,6 +26,9 @@ namespace dataset { Status RandomSolarizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + + uint8_t threshold_min_ = threshold_[0], threshold_max_ = threshold_[1]; + CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_, "threshold_min must be smaller or equal to threshold_max."); @@ -35,7 +40,8 @@ Status RandomSolarizeOp::Compute(const std::shared_ptr &input, std::shar threshold_min = threshold_max; threshold_max = temp; } - std::unique_ptr op(new SolarizeOp(threshold_min, threshold_max)); + std::vector inputs = {threshold_min, threshold_max}; + std::unique_ptr op(new SolarizeOp(inputs)); return op->Compute(input, output); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h index 1be2c527561..d85fd1889ac 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h @@ -19,6 +19,7 @@ #include #include +#include #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/kernels/tensor_op.h" @@ -31,10 +32,7 @@ namespace dataset { class RandomSolarizeOp : public SolarizeOp { public: // Pick a random threshold value to solarize the image with - explicit RandomSolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255) - : threshold_min_(threshold_min), threshold_max_(threshold_max) { - rnd_.seed(GetSeed()); - } + explicit RandomSolarizeOp(std::vector threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); } ~RandomSolarizeOp() = default; @@ -43,8 +41,7 @@ class RandomSolarizeOp : public SolarizeOp { std::string Name() const override { return kRandomSolarizeOp; } private: - uint8_t threshold_min_; - uint8_t threshold_max_; + std::vector threshold_; std::mt19937 rnd_; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.cc index 41970b54709..874817b3777 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.cc @@ -27,6 +27,8 @@ const uint8_t kPixelValue = 255; Status SolarizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + uint8_t threshold_min_ = threshold_[0], threshold_max_ = threshold_[1]; + CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_, "threshold_min must be smaller or equal to threshold_max."); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.h index 18399f7d304..b69d91106de 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.h @@ -19,6 +19,7 @@ #include #include +#include #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/kernels/tensor_op.h" @@ -28,8 +29,7 @@ namespace mindspore { namespace dataset { class SolarizeOp : public TensorOp { public: - explicit SolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255) - : threshold_min_(threshold_min), threshold_max_(threshold_max) {} + explicit SolarizeOp(std::vector threshold = {0, 255}) : threshold_(threshold) {} ~SolarizeOp() = default; @@ -38,8 +38,7 @@ class SolarizeOp : public TensorOp { std::string Name() const override { return kSolarizeOp; } private: - uint8_t threshold_min_; - uint8_t threshold_max_; + std::vector threshold_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 2bd027812d3..5252d758d09 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -1045,4 +1045,4 @@ class RandomSolarize(cde.RandomSolarizeOp): @check_random_solarize def __init__(self, threshold=(0, 255)): - super().__init__(*threshold) + super().__init__(threshold) diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index 316129b000c..e751c8a55ff 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -1109,31 +1109,23 @@ TEST_F(MindDataTestPipeline, TestUniformAugWithOps) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestRandomSolarize) { +TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); - // Create a Repeat operation on ds - int32_t repeat_num = 2; - ds = ds->Repeat(repeat_num); - EXPECT_NE(ds, nullptr); - // Create objects for the tensor ops - std::shared_ptr random_solarize = - mindspore::dataset::api::vision::RandomSolarize(23, 23); // vision::RandomSolarize(); + std::vector threshold = {10, 100}; + std::shared_ptr random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); EXPECT_NE(random_solarize, nullptr); // Create a Map operation on ds ds = ds->Map({random_solarize}); EXPECT_NE(ds, nullptr); - // Create a Batch operation on ds - int32_t batch_size = 1; - ds = ds->Batch(batch_size); - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); @@ -1151,8 +1143,61 @@ TEST_F(MindDataTestPipeline, TestRandomSolarize) { iter->GetNextRow(&row); } - EXPECT_EQ(i, 20); + EXPECT_EQ(i, 10); // Manually terminate the pipeline iter->Stop(); } + +TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize with default params."; + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_solarize = mindspore::dataset::api::vision::RandomSolarize(); + EXPECT_NE(random_solarize, nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_solarize}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestRandomSolarizeFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize with invalid params."; + std::vector threshold = {13, 1}; + std::shared_ptr random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); + EXPECT_EQ(random_solarize, nullptr); + + threshold = {1, 2, 3}; + random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); + EXPECT_EQ(random_solarize, nullptr); + + threshold = {1}; + random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); + EXPECT_EQ(random_solarize, nullptr); +} diff --git a/tests/ut/cpp/dataset/random_solarize_op_test.cc b/tests/ut/cpp/dataset/random_solarize_op_test.cc index 3277da79127..391eb94b0d4 100644 --- a/tests/ut/cpp/dataset/random_solarize_op_test.cc +++ b/tests/ut/cpp/dataset/random_solarize_op_test.cc @@ -38,7 +38,9 @@ TEST_F(MindDataTestRandomSolarizeOp, TestOp1) { uint32_t curr_seed = GlobalContext::config_manager()->seed(); GlobalContext::config_manager()->set_seed(0); - std::unique_ptr op(new RandomSolarizeOp(100, 100)); + std::vector threshold = {100, 100}; + std::unique_ptr op(new RandomSolarizeOp(threshold)); + EXPECT_TRUE(op->OneToOne()); Status s = op->Compute(input_tensor_, &output_tensor_); diff --git a/tests/ut/cpp/dataset/solarize_op_test.cc b/tests/ut/cpp/dataset/solarize_op_test.cc index 6f1bdad6dac..bbcd3d2a537 100644 --- a/tests/ut/cpp/dataset/solarize_op_test.cc +++ b/tests/ut/cpp/dataset/solarize_op_test.cc @@ -74,7 +74,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp3) { MS_LOG(INFO) << "Doing testSolarizeOp3 - Pass in only threshold_min parameter"; // unsigned int threshold = 128; - std::unique_ptr op(new SolarizeOp(1)); + std::vector threshold ={1, 255}; + std::unique_ptr op(new SolarizeOp(threshold)); std::vector test_vector = {3, 4, 59, 210, 255}; std::vector expected_output_vector = {252, 251, 196, 45, 0}; @@ -98,8 +99,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp3) { TEST_F(MindDataTestSolarizeOp, TestOp4) { MS_LOG(INFO) << "Doing testSolarizeOp4 - Pass in both threshold parameters."; - // unsigned int threshold = 128; - std::unique_ptr op(new SolarizeOp(1, 230)); + std::vector threshold ={1, 230}; + std::unique_ptr op(new SolarizeOp(threshold)); std::vector test_vector = {3, 4, 59, 210, 255}; std::vector expected_output_vector = {252, 251, 196, 45, 255}; @@ -123,8 +124,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp4) { TEST_F(MindDataTestSolarizeOp, TestOp5) { MS_LOG(INFO) << "Doing testSolarizeOp5 - Rank 2 input tensor."; - // unsigned int threshold = 128; - std::unique_ptr op(new SolarizeOp(1, 230)); + std::vector threshold ={1, 230}; + std::unique_ptr op(new SolarizeOp(threshold)); std::vector test_vector = {3, 4, 59, 210, 255}; std::vector expected_output_vector = {252, 251, 196, 45, 255}; @@ -149,7 +150,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp5) { TEST_F(MindDataTestSolarizeOp, TestOp6) { MS_LOG(INFO) << "Doing testSolarizeOp6 - Bad Input."; - std::unique_ptr op(new SolarizeOp(10, 1)); + std::vector threshold ={10, 1}; + std::unique_ptr op(new SolarizeOp(threshold)); std::vector test_vector = {3, 4, 59, 210, 255}; std::shared_ptr test_input_tensor; diff --git a/tests/ut/data/dataset/golden/random_solarize_02_result.npz b/tests/ut/data/dataset/golden/random_solarize_02_result.npz new file mode 100644 index 00000000000..683ed673a01 Binary files /dev/null and b/tests/ut/data/dataset/golden/random_solarize_02_result.npz differ diff --git a/tests/ut/python/dataset/test_random_solarize_op.py b/tests/ut/python/dataset/test_random_solarize_op.py index 1c9a9cd678b..f39abcc1728 100644 --- a/tests/ut/python/dataset/test_random_solarize_op.py +++ b/tests/ut/python/dataset/test_random_solarize_op.py @@ -17,68 +17,89 @@ Testing RandomSolarizeOp op in DE """ import pytest import mindspore.dataset as ds +import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as vision from mindspore import log as logger -from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers +from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers, \ + visualize_one_channel_dataset GENERATE_GOLDEN = False +MNIST_DATA_DIR = "../data/dataset/testMnistData" DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" -def test_random_solarize_op(threshold=None, plot=False): +def test_random_solarize_op(threshold=(10, 150), plot=False, run_golden=True): """ Test RandomSolarize """ logger.info("Test RandomSolarize") # First dataset - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) decode_op = vision.Decode() + original_seed = config_get_set_seed(0) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + if threshold is None: solarize_op = vision.RandomSolarize() else: solarize_op = vision.RandomSolarize(threshold) + data1 = data1.map(input_columns=["image"], operations=decode_op) data1 = data1.map(input_columns=["image"], operations=solarize_op) # Second dataset - data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data2 = data2.map(input_columns=["image"], operations=decode_op) + if run_golden: + filename = "random_solarize_01_result.npz" + save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) + image_solarized = [] image = [] + for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): image_solarized.append(item1["image"].copy()) image.append(item2["image"].copy()) if plot: visualize_list(image, image_solarized) - -def test_random_solarize_md5(): - """ - Test RandomSolarize - """ - logger.info("Test RandomSolarize") - original_seed = config_get_set_seed(0) - original_num_parallel_workers = config_get_set_num_parallel_workers(1) - - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) - decode_op = vision.Decode() - random_solarize_op = vision.RandomSolarize((10, 150)) - data1 = data1.map(input_columns=["image"], operations=decode_op) - data1 = data1.map(input_columns=["image"], operations=random_solarize_op) - # Compare with expected md5 from images - filename = "random_solarize_01_result.npz" - save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) - - # Restore config setting ds.config.set_seed(original_seed) ds.config.set_num_parallel_workers(original_num_parallel_workers) +def test_random_solarize_mnist(plot=False, run_golden=True): + """ + Test RandomSolarize op with MNIST dataset (Grayscale images) + """ + + mnist_1 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) + mnist_2 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) + mnist_2 = mnist_2.map(input_columns="image", operations=vision.RandomSolarize((0, 255))) + + images = [] + images_trans = [] + labels = [] + + for _, (data_orig, data_trans) in enumerate(zip(mnist_1, mnist_2)): + image_orig, label_orig = data_orig + image_trans, _ = data_trans + images.append(image_orig) + labels.append(label_orig) + images_trans.append(image_trans) + + if plot: + visualize_one_channel_dataset(images, images_trans, labels) + + if run_golden: + filename = "random_solarize_02_result.npz" + save_and_check_md5(mnist_2, filename, generate_golden=GENERATE_GOLDEN) + + def test_random_solarize_errors(): """ Test that RandomSolarize errors with bad input @@ -105,8 +126,8 @@ def test_random_solarize_errors(): if __name__ == "__main__": - test_random_solarize_op((100, 100), plot=True) - test_random_solarize_op((12, 120), plot=True) - test_random_solarize_op(plot=True) + test_random_solarize_op((10, 150), plot=True, run_golden=True) + test_random_solarize_op((12, 120), plot=True, run_golden=False) + test_random_solarize_op(plot=True, run_golden=False) + test_random_solarize_mnist(plot=True, run_golden=True) test_random_solarize_errors() - test_random_solarize_md5()