forked from mindspore-Ecosystem/mindspore
!3255 Added cpp Equalize op to vision api
Merge pull request !3255 from alashkari/cpp_ops
This commit is contained in:
commit
c451146b14
|
@ -54,6 +54,7 @@
|
|||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/cut_out_op.h"
|
||||
#include "minddata/dataset/kernels/image/decode_op.h"
|
||||
#include "minddata/dataset/kernels/image/equalize_op.h"
|
||||
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/kernels/image/invert_op.h"
|
||||
|
@ -389,6 +390,10 @@ void bindTensorOps1(py::module *m) {
|
|||
.def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
|
||||
py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));
|
||||
|
||||
(void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>(
|
||||
*m, "EqualizeOp", "Tensor operation to apply histogram equalization on images.")
|
||||
.def(py::init<>());
|
||||
|
||||
(void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp",
|
||||
"Tensor operation to apply invert on RGB images.")
|
||||
.def(py::init<>());
|
||||
|
|
|
@ -5,6 +5,7 @@ add_library(kernels-image OBJECT
|
|||
center_crop_op.cc
|
||||
cut_out_op.cc
|
||||
decode_op.cc
|
||||
equalize_op.cc
|
||||
hwc_to_chw_op.cc
|
||||
image_utils.cc
|
||||
invert_op.cc
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/image/equalize_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// only supports RGB images
|
||||
|
||||
Status EqualizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
return Equalize(input, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class EqualizeOp : public TensorOp {
|
||||
public:
|
||||
EqualizeOp() {}
|
||||
~EqualizeOp() = default;
|
||||
|
||||
// Description: A function that prints info about the node
|
||||
void Print(std::ostream &out) const override { out << Name(); }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kEqualizeOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_
|
|
@ -749,6 +749,46 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>");
|
||||
}
|
||||
// For greyscale images, extend dimension if rank is 2 and reshape output to be of rank 2.
|
||||
if (input_cv->Rank() == 2) {
|
||||
RETURN_IF_NOT_OK(input_cv->ExpandDim(2));
|
||||
}
|
||||
// Get number of channels and image matrix
|
||||
std::size_t num_of_channels = input_cv->shape()[2];
|
||||
if (num_of_channels != 1 && num_of_channels != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3.");
|
||||
}
|
||||
cv::Mat image = input_cv->mat();
|
||||
// Separate the image to channels
|
||||
std::vector<cv::Mat> planes(num_of_channels);
|
||||
cv::split(image, planes);
|
||||
// Equalize each channel separately
|
||||
std::vector<cv::Mat> image_result;
|
||||
for (std::size_t layer = 0; layer < planes.size(); layer++) {
|
||||
cv::Mat channel_result;
|
||||
cv::equalizeHist(planes[layer], channel_result);
|
||||
image_result.push_back(channel_result);
|
||||
}
|
||||
cv::Mat result;
|
||||
cv::merge(image_result, result);
|
||||
std::shared_ptr<CVTensor> output_cv = std::make_shared<CVTensor>(result);
|
||||
if (input_cv->Rank() == 2) output_cv->Squeeze();
|
||||
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
||||
} catch (const cv::Exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Error in equalize.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height,
|
||||
int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r,
|
||||
uint8_t fill_g, uint8_t fill_b) {
|
||||
|
|
|
@ -200,6 +200,12 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
|||
// @param output: Adjusted image of same shape and type.
|
||||
Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &hue);
|
||||
|
||||
/// \brief Returns image with equalized histogram.
|
||||
/// \param[in] input: Tensor of shape <H,W,3>/<H,W,1>/<H,W> in RGB/Grayscale and
|
||||
/// any OpenCv compatible type, see CVTensor.
|
||||
/// \param[out] output: Equalized image of same shape and type.
|
||||
Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
// Masks out a random section from the image with set dimension
|
||||
// @param input: input Tensor
|
||||
// @param output: cutOut Tensor
|
||||
|
|
|
@ -92,6 +92,7 @@ constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
|
|||
constexpr char kDecodeOp[] = "DecodeOp";
|
||||
constexpr char kCenterCropOp[] = "CenterCropOp";
|
||||
constexpr char kCutOutOp[] = "CutOutOp";
|
||||
constexpr char kEqualizeOp[] = "EqualizeOp";
|
||||
constexpr char kHwcToChwOp[] = "HwcToChwOp";
|
||||
constexpr char kInvertOp[] = "InvertOp";
|
||||
constexpr char kNormalizeOp[] = "NormalizeOp";
|
||||
|
|
|
@ -89,6 +89,13 @@ class AutoContrast(cde.AutoContrastOp):
|
|||
super().__init__(cutoff, ignore)
|
||||
|
||||
|
||||
class Equalize(cde.EqualizeOp):
|
||||
"""
|
||||
Apply histogram equalization on input image.
|
||||
does not have input arguments.
|
||||
"""
|
||||
|
||||
|
||||
class Invert(cde.InvertOp):
|
||||
"""
|
||||
Apply invert on input image in RGB mode.
|
||||
|
|
Binary file not shown.
|
@ -18,6 +18,7 @@ Testing Equalize op in DE
|
|||
import numpy as np
|
||||
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, diff_mse, save_and_check_md5
|
||||
|
@ -26,9 +27,9 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
|
|||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
def test_equalize(plot=False):
|
||||
def test_equalize_py(plot=False):
|
||||
"""
|
||||
Test Equalize
|
||||
Test Equalize py op
|
||||
"""
|
||||
logger.info("Test Equalize")
|
||||
|
||||
|
@ -83,9 +84,141 @@ def test_equalize(plot=False):
|
|||
visualize_list(images_original, images_equalize)
|
||||
|
||||
|
||||
def test_equalize_md5():
|
||||
def test_equalize_c(plot=False):
|
||||
"""
|
||||
Test Equalize with md5 check
|
||||
Test Equalize Cpp op
|
||||
"""
|
||||
logger.info("Test Equalize cpp op")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = [C.Decode(), C.Resize(size=[224, 224])]
|
||||
|
||||
ds_original = ds.map(input_columns="image",
|
||||
operations=transforms_original)
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
# Equalize Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transform_equalize = [C.Decode(), C.Resize(size=[224, 224]),
|
||||
C.Equalize()]
|
||||
|
||||
ds_equalize = ds.map(input_columns="image",
|
||||
operations=transform_equalize)
|
||||
|
||||
ds_equalize = ds_equalize.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_equalize):
|
||||
if idx == 0:
|
||||
images_equalize = image
|
||||
else:
|
||||
images_equalize = np.append(images_equalize,
|
||||
image,
|
||||
axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_equalize)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_equalize[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_equalize_py_c(plot=False):
|
||||
"""
|
||||
Test Equalize Cpp op and python op
|
||||
"""
|
||||
logger.info("Test Equalize cpp and python op")
|
||||
|
||||
# equalize Images in cpp
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(), C.Resize((224, 224))])
|
||||
|
||||
ds_c_equalize = ds.map(input_columns="image",
|
||||
operations=C.Equalize())
|
||||
|
||||
ds_c_equalize = ds_c_equalize.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_c_equalize):
|
||||
if idx == 0:
|
||||
images_c_equalize = image
|
||||
else:
|
||||
images_c_equalize = np.append(images_c_equalize,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
# Equalize images in python
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(), C.Resize((224, 224))])
|
||||
|
||||
transforms_p_equalize = F.ComposeOp([lambda img: img.astype(np.uint8),
|
||||
F.ToPIL(),
|
||||
F.Equalize(),
|
||||
np.array])
|
||||
|
||||
ds_p_equalize = ds.map(input_columns="image",
|
||||
operations=transforms_p_equalize())
|
||||
|
||||
ds_p_equalize = ds_p_equalize.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_p_equalize):
|
||||
if idx == 0:
|
||||
images_p_equalize = image
|
||||
else:
|
||||
images_p_equalize = np.append(images_p_equalize,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
num_samples = images_c_equalize.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_p_equalize[i], images_c_equalize[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize_list(images_c_equalize, images_p_equalize, visualize_mode=2)
|
||||
|
||||
|
||||
def test_equalize_one_channel():
|
||||
"""
|
||||
Test Equalize cpp op with one channel image
|
||||
"""
|
||||
logger.info("Test Equalize C Op With One Channel Images")
|
||||
|
||||
c_op = C.Equalize()
|
||||
|
||||
try:
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(),
|
||||
C.Resize((224, 224)),
|
||||
lambda img: np.array(img[:, :, 0])])
|
||||
|
||||
ds.map(input_columns="image",
|
||||
operations=c_op)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "The shape" in str(e)
|
||||
|
||||
|
||||
def test_equalize_md5_py():
|
||||
"""
|
||||
Test Equalize py op with md5 check
|
||||
"""
|
||||
logger.info("Test Equalize")
|
||||
|
||||
|
@ -101,6 +234,31 @@ def test_equalize_md5():
|
|||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_equalize_md5_c():
|
||||
"""
|
||||
Test Equalize cpp op with md5 check
|
||||
"""
|
||||
logger.info("Test Equalize cpp op with md5 check")
|
||||
|
||||
# Generate dataset
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_equalize = [C.Decode(),
|
||||
C.Resize(size=[224, 224]),
|
||||
C.Equalize(),
|
||||
F.ToTensor()]
|
||||
|
||||
data = ds.map(input_columns="image", operations=transforms_equalize)
|
||||
# Compare with expected md5 from images
|
||||
filename = "equalize_01_result_c.npz"
|
||||
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_equalize(plot=True)
|
||||
test_equalize_md5()
|
||||
test_equalize_py(plot=False)
|
||||
test_equalize_c(plot=False)
|
||||
test_equalize_py_c(plot=False)
|
||||
test_equalize_one_channel()
|
||||
test_equalize_md5_py()
|
||||
test_equalize_md5_c()
|
||||
|
Loading…
Reference in New Issue