From 572750cd403ef0ea08db69814e5a918c6c3b02a4 Mon Sep 17 00:00:00 2001 From: Amir Lashkari Date: Mon, 20 Jul 2020 15:59:05 -0400 Subject: [PATCH] Added cpp Equalize op adding missed files Fixed Errors fixed bugs fixed bugs fixed bugs fixed bugs fixed bugs Updated Python UT for Equalize op Added cpp Equalize op adding missed files Fixed Errors fixed bugs fixed bugs fixed bugs fixed bugs fixed bugs Updated Python UT for Equalize op Fixed comment style update files deleted files added files added image_utils.h Fixed PyLint and CPPLint Errors updated tensor_op.h --- .../minddata/dataset/api/python_bindings.cc | 5 + .../dataset/kernels/image/CMakeLists.txt | 1 + .../dataset/kernels/image/equalize_op.cc | 29 +++ .../dataset/kernels/image/equalize_op.h | 45 +++++ .../dataset/kernels/image/image_utils.cc | 40 +++++ .../dataset/kernels/image/image_utils.h | 6 + .../minddata/dataset/kernels/tensor_op.h | 1 + .../dataset/transforms/vision/c_transforms.py | 7 + .../dataset/golden/equalize_01_result_c.npz | Bin 0 -> 713 bytes tests/ut/python/dataset/test_equalize.py | 170 +++++++++++++++++- 10 files changed, 298 insertions(+), 6 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h create mode 100644 tests/ut/data/dataset/golden/equalize_01_result_c.npz diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc index 08016ee0613..457ff8b1b5e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -53,6 +53,7 @@ #include "minddata/dataset/kernels/image/center_crop_op.h" #include "minddata/dataset/kernels/image/cut_out_op.h" #include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/equalize_op.h" #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/invert_op.h" @@ -374,6 +375,10 @@ void bindTensorOps1(py::module *m) { .def(py::init(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); + (void)py::class_>( + *m, "EqualizeOp", "Tensor operation to apply histogram equalization on images.") + .def(py::init<>()); + (void)py::class_>(*m, "InvertOp", "Tensor operation to apply invert on RGB images.") .def(py::init<>()); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index 743fc83c149..a7777302847 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(kernels-image OBJECT center_crop_op.cc cut_out_op.cc decode_op.cc + equalize_op.cc hwc_to_chw_op.cc image_utils.cc invert_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc new file mode 100644 index 00000000000..e5bf0fd6282 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/equalize_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { + +// only supports RGB images + +Status EqualizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return Equalize(input, output); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h new file mode 100644 index 00000000000..9fd030f5852 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class EqualizeOp : public TensorOp { + public: + EqualizeOp() {} + ~EqualizeOp() = default; + + // Description: A function that prints info about the node + void Print(std::ostream &out) const override { out << Name(); } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kEqualizeOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index dac076a5f43..97e73525640 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -749,6 +749,46 @@ Status AdjustHue(const std::shared_ptr &input, std::shared_ptr * return Status::OK(); } +Status Equalize(const std::shared_ptr &input, std::shared_ptr *output) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + // For greyscale images, extend dimension if rank is 2 and reshape output to be of rank 2. + if (input_cv->Rank() == 2) { + RETURN_IF_NOT_OK(input_cv->ExpandDim(2)); + } + // Get number of channels and image matrix + std::size_t num_of_channels = input_cv->shape()[2]; + if (num_of_channels != 1 && num_of_channels != 3) { + RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3."); + } + cv::Mat image = input_cv->mat(); + // Separate the image to channels + std::vector planes(num_of_channels); + cv::split(image, planes); + // Equalize each channel separately + std::vector image_result; + for (std::size_t layer = 0; layer < planes.size(); layer++) { + cv::Mat channel_result; + cv::equalizeHist(planes[layer], channel_result); + image_result.push_back(channel_result); + } + cv::Mat result; + cv::merge(image_result, result); + std::shared_ptr output_cv = std::make_shared(result); + if (input_cv->Rank() == 2) output_cv->Squeeze(); + (*output) = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in equalize."); + } + return Status::OK(); +} + Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index c1426338954..9a90bec61eb 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -200,6 +200,12 @@ Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &hue); +/// \brief Returns image with equalized histogram. +/// \param[in] input: Tensor of shape // in RGB/Grayscale and +/// any OpenCv compatible type, see CVTensor. +/// \param[out] output: Equalized image of same shape and type. +Status Equalize(const std::shared_ptr &input, std::shared_ptr *output); + // Masks out a random section from the image with set dimension // @param input: input Tensor // @param output: cutOut Tensor diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 638ce49dbf5..9c12759422e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -92,6 +92,7 @@ constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; constexpr char kDecodeOp[] = "DecodeOp"; constexpr char kCenterCropOp[] = "CenterCropOp"; constexpr char kCutOutOp[] = "CutOutOp"; +constexpr char kEqualizeOp[] = "EqualizeOp"; constexpr char kHwcToChwOp[] = "HwcToChwOp"; constexpr char kInvertOp[] = "InvertOp"; constexpr char kNormalizeOp[] = "NormalizeOp"; diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 9a73a900739..3c2e7aeecb3 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -89,6 +89,13 @@ class AutoContrast(cde.AutoContrastOp): super().__init__(cutoff, ignore) +class Equalize(cde.EqualizeOp): + """ + Apply histogram equalization on input image. + does not have input arguments. + """ + + class Invert(cde.InvertOp): """ Apply invert on input image in RGB mode. diff --git a/tests/ut/data/dataset/golden/equalize_01_result_c.npz b/tests/ut/data/dataset/golden/equalize_01_result_c.npz new file mode 100644 index 0000000000000000000000000000000000000000..2c3a37eb4dcf3b010e926530e3ca11a3e74c5e53 GIT binary patch literal 713 zcmWIWW@Zs#fB;2?fCuxq9hn#yK$w$3gdwr0DBeIXub`5VK>#cWQV5a+fysWMz5$Vp z3}p<}>M5zk$wlf`3hFif>N*PQY57GZMTvRw`9&$IAYr$}oZ?iVcyUHzK`M~1VWgvA zq^YA&t3W>BYG6*zE6pva)Jx7UO4Z9P%_+$Qx;L?sE50Z-IX|zsq^LBxgsYGNqKYdo z1tMF>=*`et$mGnJRLI<3$P!e@s^QJ(&E(D0R>%fbno?3(kjhoa9s%;HzeOR3H-k50 zdm(2~A(w_Xa|9z$w5E{T&(F{6KM;TkZ~Kx$o}|v$LSBssR-k-NVp3{OAzy4EzeWZ_ z2G~l044{32L4`sf`&e2Fg)<-)q?r_oKr9dqDiniRU{ffLY5_w@p+r)rv%ts8|680j z%Z=J}Kgk|>T9i~MSyCvK1hYgMWQj~rp)AxAxePIEdgMzA6_QfHF3rqMOiwLTj4f1x zyHpuuh)PhQD%21)upuHh-)3F8|M*N|=Cg|*zlEGGR4*yifLp8y(xVkrs14PlgI$kq wNugd+YHCTLev)2*HzSh>Gp^JJ3`8hs00lTAEe3eAvVjB`fzTXC*MQR_0M3obPyhe` literal 0 HcmV?d00001 diff --git a/tests/ut/python/dataset/test_equalize.py b/tests/ut/python/dataset/test_equalize.py index 0a5f2f93d50..26102ae809f 100644 --- a/tests/ut/python/dataset/test_equalize.py +++ b/tests/ut/python/dataset/test_equalize.py @@ -18,6 +18,7 @@ Testing Equalize op in DE import numpy as np import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.vision.py_transforms as F from mindspore import log as logger from util import visualize_list, diff_mse, save_and_check_md5 @@ -26,9 +27,9 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" GENERATE_GOLDEN = False -def test_equalize(plot=False): +def test_equalize_py(plot=False): """ - Test Equalize + Test Equalize py op """ logger.info("Test Equalize") @@ -83,9 +84,141 @@ def test_equalize(plot=False): visualize_list(images_original, images_equalize) -def test_equalize_md5(): +def test_equalize_c(plot=False): """ - Test Equalize with md5 check + Test Equalize Cpp op + """ + logger.info("Test Equalize cpp op") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = [C.Decode(), C.Resize(size=[224, 224])] + + ds_original = ds.map(input_columns="image", + operations=transforms_original) + + ds_original = ds_original.batch(512) + + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, + image, + axis=0) + + # Equalize Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transform_equalize = [C.Decode(), C.Resize(size=[224, 224]), + C.Equalize()] + + ds_equalize = ds.map(input_columns="image", + operations=transform_equalize) + + ds_equalize = ds_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_equalize): + if idx == 0: + images_equalize = image + else: + images_equalize = np.append(images_equalize, + image, + axis=0) + if plot: + visualize_list(images_original, images_equalize) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_equalize[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_equalize_py_c(plot=False): + """ + Test Equalize Cpp op and python op + """ + logger.info("Test Equalize cpp and python op") + + # equalize Images in cpp + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), C.Resize((224, 224))]) + + ds_c_equalize = ds.map(input_columns="image", + operations=C.Equalize()) + + ds_c_equalize = ds_c_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_c_equalize): + if idx == 0: + images_c_equalize = image + else: + images_c_equalize = np.append(images_c_equalize, + image, + axis=0) + + # Equalize images in python + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), C.Resize((224, 224))]) + + transforms_p_equalize = F.ComposeOp([lambda img: img.astype(np.uint8), + F.ToPIL(), + F.Equalize(), + np.array]) + + ds_p_equalize = ds.map(input_columns="image", + operations=transforms_p_equalize()) + + ds_p_equalize = ds_p_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_p_equalize): + if idx == 0: + images_p_equalize = image + else: + images_p_equalize = np.append(images_p_equalize, + image, + axis=0) + + num_samples = images_c_equalize.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_p_equalize[i], images_c_equalize[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize_list(images_c_equalize, images_p_equalize, visualize_mode=2) + + +def test_equalize_one_channel(): + """ + Test Equalize cpp op with one channel image + """ + logger.info("Test Equalize C Op With One Channel Images") + + c_op = C.Equalize() + + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + + ds.map(input_columns="image", + operations=c_op) + + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "The shape" in str(e) + + +def test_equalize_md5_py(): + """ + Test Equalize py op with md5 check """ logger.info("Test Equalize") @@ -101,6 +234,31 @@ def test_equalize_md5(): save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) +def test_equalize_md5_c(): + """ + Test Equalize cpp op with md5 check + """ + logger.info("Test Equalize cpp op with md5 check") + + # Generate dataset + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_equalize = [C.Decode(), + C.Resize(size=[224, 224]), + C.Equalize(), + F.ToTensor()] + + data = ds.map(input_columns="image", operations=transforms_equalize) + # Compare with expected md5 from images + filename = "equalize_01_result_c.npz" + save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) + + if __name__ == "__main__": - test_equalize(plot=True) - test_equalize_md5() + test_equalize_py(plot=False) + test_equalize_c(plot=False) + test_equalize_py_c(plot=False) + test_equalize_one_channel() + test_equalize_md5_py() + test_equalize_md5_c() + \ No newline at end of file