!3255 Added cpp Equalize op to vision api

Merge pull request !3255 from alashkari/cpp_ops
This commit is contained in:
mindspore-ci-bot 2020-07-21 21:56:41 +08:00 committed by Gitee
commit c451146b14
10 changed files with 298 additions and 6 deletions

View File

@ -54,6 +54,7 @@
#include "minddata/dataset/kernels/image/center_crop_op.h"
#include "minddata/dataset/kernels/image/cut_out_op.h"
#include "minddata/dataset/kernels/image/decode_op.h"
#include "minddata/dataset/kernels/image/equalize_op.h"
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/image/invert_op.h"
@ -389,6 +390,10 @@ void bindTensorOps1(py::module *m) {
.def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));
(void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>(
*m, "EqualizeOp", "Tensor operation to apply histogram equalization on images.")
.def(py::init<>());
(void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp",
"Tensor operation to apply invert on RGB images.")
.def(py::init<>());

View File

@ -5,6 +5,7 @@ add_library(kernels-image OBJECT
center_crop_op.cc
cut_out_op.cc
decode_op.cc
equalize_op.cc
hwc_to_chw_op.cc
image_utils.cc
invert_op.cc

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/equalize_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
// only supports RGB images
Status EqualizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
return Equalize(input, output);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class EqualizeOp : public TensorOp {
public:
EqualizeOp() {}
~EqualizeOp() = default;
// Description: A function that prints info about the node
void Print(std::ostream &out) const override { out << Name(); }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kEqualizeOp; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_

View File

@ -749,6 +749,46 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
return Status::OK();
}
Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
}
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>");
}
// For greyscale images, extend dimension if rank is 2 and reshape output to be of rank 2.
if (input_cv->Rank() == 2) {
RETURN_IF_NOT_OK(input_cv->ExpandDim(2));
}
// Get number of channels and image matrix
std::size_t num_of_channels = input_cv->shape()[2];
if (num_of_channels != 1 && num_of_channels != 3) {
RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3.");
}
cv::Mat image = input_cv->mat();
// Separate the image to channels
std::vector<cv::Mat> planes(num_of_channels);
cv::split(image, planes);
// Equalize each channel separately
std::vector<cv::Mat> image_result;
for (std::size_t layer = 0; layer < planes.size(); layer++) {
cv::Mat channel_result;
cv::equalizeHist(planes[layer], channel_result);
image_result.push_back(channel_result);
}
cv::Mat result;
cv::merge(image_result, result);
std::shared_ptr<CVTensor> output_cv = std::make_shared<CVTensor>(result);
if (input_cv->Rank() == 2) output_cv->Squeeze();
(*output) = std::static_pointer_cast<Tensor>(output_cv);
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Error in equalize.");
}
return Status::OK();
}
Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height,
int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r,
uint8_t fill_g, uint8_t fill_b) {

View File

@ -200,6 +200,12 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
// @param output: Adjusted image of same shape and type.
Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &hue);
/// \brief Returns image with equalized histogram.
/// \param[in] input: Tensor of shape <H,W,3>/<H,W,1>/<H,W> in RGB/Grayscale and
/// any OpenCv compatible type, see CVTensor.
/// \param[out] output: Equalized image of same shape and type.
Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
// Masks out a random section from the image with set dimension
// @param input: input Tensor
// @param output: cutOut Tensor

View File

@ -92,6 +92,7 @@ constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
constexpr char kDecodeOp[] = "DecodeOp";
constexpr char kCenterCropOp[] = "CenterCropOp";
constexpr char kCutOutOp[] = "CutOutOp";
constexpr char kEqualizeOp[] = "EqualizeOp";
constexpr char kHwcToChwOp[] = "HwcToChwOp";
constexpr char kInvertOp[] = "InvertOp";
constexpr char kNormalizeOp[] = "NormalizeOp";

View File

@ -89,6 +89,13 @@ class AutoContrast(cde.AutoContrastOp):
super().__init__(cutoff, ignore)
class Equalize(cde.EqualizeOp):
"""
Apply histogram equalization on input image.
does not have input arguments.
"""
class Invert(cde.InvertOp):
"""
Apply invert on input image in RGB mode.

Binary file not shown.

View File

@ -18,6 +18,7 @@ Testing Equalize op in DE
import numpy as np
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.vision.py_transforms as F
from mindspore import log as logger
from util import visualize_list, diff_mse, save_and_check_md5
@ -26,9 +27,9 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
GENERATE_GOLDEN = False
def test_equalize(plot=False):
def test_equalize_py(plot=False):
"""
Test Equalize
Test Equalize py op
"""
logger.info("Test Equalize")
@ -83,9 +84,141 @@ def test_equalize(plot=False):
visualize_list(images_original, images_equalize)
def test_equalize_md5():
def test_equalize_c(plot=False):
"""
Test Equalize with md5 check
Test Equalize Cpp op
"""
logger.info("Test Equalize cpp op")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = [C.Decode(), C.Resize(size=[224, 224])]
ds_original = ds.map(input_columns="image",
operations=transforms_original)
ds_original = ds_original.batch(512)
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original,
image,
axis=0)
# Equalize Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transform_equalize = [C.Decode(), C.Resize(size=[224, 224]),
C.Equalize()]
ds_equalize = ds.map(input_columns="image",
operations=transform_equalize)
ds_equalize = ds_equalize.batch(512)
for idx, (image, _) in enumerate(ds_equalize):
if idx == 0:
images_equalize = image
else:
images_equalize = np.append(images_equalize,
image,
axis=0)
if plot:
visualize_list(images_original, images_equalize)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_equalize[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_equalize_py_c(plot=False):
"""
Test Equalize Cpp op and python op
"""
logger.info("Test Equalize cpp and python op")
# equalize Images in cpp
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
ds = ds.map(input_columns=["image"],
operations=[C.Decode(), C.Resize((224, 224))])
ds_c_equalize = ds.map(input_columns="image",
operations=C.Equalize())
ds_c_equalize = ds_c_equalize.batch(512)
for idx, (image, _) in enumerate(ds_c_equalize):
if idx == 0:
images_c_equalize = image
else:
images_c_equalize = np.append(images_c_equalize,
image,
axis=0)
# Equalize images in python
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
ds = ds.map(input_columns=["image"],
operations=[C.Decode(), C.Resize((224, 224))])
transforms_p_equalize = F.ComposeOp([lambda img: img.astype(np.uint8),
F.ToPIL(),
F.Equalize(),
np.array])
ds_p_equalize = ds.map(input_columns="image",
operations=transforms_p_equalize())
ds_p_equalize = ds_p_equalize.batch(512)
for idx, (image, _) in enumerate(ds_p_equalize):
if idx == 0:
images_p_equalize = image
else:
images_p_equalize = np.append(images_p_equalize,
image,
axis=0)
num_samples = images_c_equalize.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_p_equalize[i], images_c_equalize[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize_list(images_c_equalize, images_p_equalize, visualize_mode=2)
def test_equalize_one_channel():
"""
Test Equalize cpp op with one channel image
"""
logger.info("Test Equalize C Op With One Channel Images")
c_op = C.Equalize()
try:
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
ds = ds.map(input_columns=["image"],
operations=[C.Decode(),
C.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])])
ds.map(input_columns="image",
operations=c_op)
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "The shape" in str(e)
def test_equalize_md5_py():
"""
Test Equalize py op with md5 check
"""
logger.info("Test Equalize")
@ -101,6 +234,31 @@ def test_equalize_md5():
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_equalize_md5_c():
"""
Test Equalize cpp op with md5 check
"""
logger.info("Test Equalize cpp op with md5 check")
# Generate dataset
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_equalize = [C.Decode(),
C.Resize(size=[224, 224]),
C.Equalize(),
F.ToTensor()]
data = ds.map(input_columns="image", operations=transforms_equalize)
# Compare with expected md5 from images
filename = "equalize_01_result_c.npz"
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
if __name__ == "__main__":
test_equalize(plot=True)
test_equalize_md5()
test_equalize_py(plot=False)
test_equalize_c(plot=False)
test_equalize_py_c(plot=False)
test_equalize_one_channel()
test_equalize_md5_py()
test_equalize_md5_c()