forked from mindspore-Ecosystem/mindspore
support cpp invert operation
This commit is contained in:
parent
60927ef130
commit
35c3a63701
|
@ -54,6 +54,7 @@
|
|||
#include "minddata/dataset/kernels/image/decode_op.h"
|
||||
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/kernels/image/invert_op.h"
|
||||
#include "minddata/dataset/kernels/image/normalize_op.h"
|
||||
#include "minddata/dataset/kernels/image/pad_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
||||
|
@ -362,6 +363,10 @@ void bindTensorOps1(py::module *m) {
|
|||
.def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
|
||||
py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));
|
||||
|
||||
(void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp",
|
||||
"Tensor operation to apply invert on RGB images.")
|
||||
.def(py::init<>());
|
||||
|
||||
(void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>(
|
||||
*m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.")
|
||||
.def(py::init<float, float>(), py::arg("rescale"), py::arg("shift"));
|
||||
|
|
|
@ -6,6 +6,7 @@ add_library(kernels-image OBJECT
|
|||
decode_op.cc
|
||||
hwc_to_chw_op.cc
|
||||
image_utils.cc
|
||||
invert_op.cc
|
||||
normalize_op.cc
|
||||
pad_op.cc
|
||||
random_color_adjust_op.cc
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/image/invert_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// only supports RGB images
|
||||
|
||||
Status InvertOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
cv::Mat input_img = input_cv->mat();
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
|
||||
if (input_cv->Rank() != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("Shape not <H,W,C>");
|
||||
}
|
||||
int num_channels = input_cv->shape()[2];
|
||||
if (num_channels != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("The shape is incorrect: num of channels != 3");
|
||||
}
|
||||
|
||||
auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type());
|
||||
RETURN_UNEXPECTED_IF_NULL(output_cv);
|
||||
|
||||
output_cv->mat() = cv::Scalar::all(255) - input_img;
|
||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||
}
|
||||
|
||||
catch (const cv::Exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Error in invert");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_KERNELS_IMAGE_INVERT_OP_H
|
||||
#define DATASET_KERNELS_IMAGE_INVERT_OP_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class InvertOp : public TensorOp {
|
||||
public:
|
||||
InvertOp() {}
|
||||
~InvertOp() = default;
|
||||
|
||||
// Description: A function that prints info about the node
|
||||
void Print(std::ostream &out) const override { out << Name(); }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kInvertOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_KERNELS_IMAGE_INVERT_OP_H
|
|
@ -92,6 +92,7 @@ constexpr char kDecodeOp[] = "DecodeOp";
|
|||
constexpr char kCenterCropOp[] = "CenterCropOp";
|
||||
constexpr char kCutOutOp[] = "CutOutOp";
|
||||
constexpr char kHwcToChwOp[] = "HwcToChwOp";
|
||||
constexpr char kInvertOp[] = "InvertOp";
|
||||
constexpr char kNormalizeOp[] = "NormalizeOp";
|
||||
constexpr char kPadOp[] = "PadOp";
|
||||
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
|
||||
|
|
|
@ -71,6 +71,13 @@ def parse_padding(padding):
|
|||
return padding
|
||||
|
||||
|
||||
class Invert(cde.InvertOp):
|
||||
"""
|
||||
Apply invert on input image in RGB mode.
|
||||
does not have input arguments.
|
||||
"""
|
||||
|
||||
|
||||
class Decode(cde.DecodeOp):
|
||||
"""
|
||||
Decode the input image in RGB mode.
|
||||
|
|
Binary file not shown.
|
@ -19,18 +19,20 @@ import numpy as np
|
|||
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, save_and_check_md5
|
||||
from util import visualize_list, save_and_check_md5, diff_mse
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
def test_invert(plot=False):
|
||||
|
||||
def test_invert_py(plot=False):
|
||||
"""
|
||||
Test Invert
|
||||
Test Invert python op
|
||||
"""
|
||||
logger.info("Test Invert")
|
||||
logger.info("Test Invert Python op")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
@ -52,7 +54,7 @@ def test_invert(plot=False):
|
|||
np.transpose(image, (0, 2, 3, 1)),
|
||||
axis=0)
|
||||
|
||||
# Color Inverted Images
|
||||
# Color Inverted Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_invert = F.ComposeOp([F.Decode(),
|
||||
|
@ -83,11 +85,143 @@ def test_invert(plot=False):
|
|||
visualize_list(images_original, images_invert)
|
||||
|
||||
|
||||
def test_invert_md5():
|
||||
def test_invert_c(plot=False):
|
||||
"""
|
||||
Test Invert with md5 check
|
||||
Test Invert Cpp op
|
||||
"""
|
||||
logger.info("Test Invert with md5 check")
|
||||
logger.info("Test Invert cpp op")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = [C.Decode(), C.Resize(size=[224, 224])]
|
||||
|
||||
ds_original = ds.map(input_columns="image",
|
||||
operations=transforms_original)
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
# Invert Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transform_invert = [C.Decode(), C.Resize(size=[224, 224]),
|
||||
C.Invert()]
|
||||
|
||||
ds_invert = ds.map(input_columns="image",
|
||||
operations=transform_invert)
|
||||
|
||||
ds_invert = ds_invert.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_invert):
|
||||
if idx == 0:
|
||||
images_invert = image
|
||||
else:
|
||||
images_invert = np.append(images_invert,
|
||||
image,
|
||||
axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_invert)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_invert[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_invert_py_c(plot=False):
|
||||
"""
|
||||
Test Invert Cpp op and python op
|
||||
"""
|
||||
logger.info("Test Invert cpp and python op")
|
||||
|
||||
# Invert Images in cpp
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(), C.Resize((224, 224))])
|
||||
|
||||
ds_c_invert = ds.map(input_columns="image",
|
||||
operations=C.Invert())
|
||||
|
||||
ds_c_invert = ds_c_invert.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_c_invert):
|
||||
if idx == 0:
|
||||
images_c_invert = image
|
||||
else:
|
||||
images_c_invert = np.append(images_c_invert,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
# invert images in python
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(), C.Resize((224, 224))])
|
||||
|
||||
transforms_p_invert = F.ComposeOp([lambda img: img.astype(np.uint8),
|
||||
F.ToPIL(),
|
||||
F.Invert(),
|
||||
np.array])
|
||||
|
||||
ds_p_invert = ds.map(input_columns="image",
|
||||
operations=transforms_p_invert())
|
||||
|
||||
ds_p_invert = ds_p_invert.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_p_invert):
|
||||
if idx == 0:
|
||||
images_p_invert = image
|
||||
else:
|
||||
images_p_invert = np.append(images_p_invert,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
num_samples = images_c_invert.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_p_invert[i], images_c_invert[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize_list(images_c_invert, images_p_invert, visualize_mode=2)
|
||||
|
||||
|
||||
def test_invert_one_channel():
|
||||
"""
|
||||
Test Invert cpp op with one channel image
|
||||
"""
|
||||
logger.info("Test Invert C Op With One Channel Images")
|
||||
|
||||
c_op = C.Invert()
|
||||
|
||||
try:
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(),
|
||||
C.Resize((224, 224)),
|
||||
lambda img: np.array(img[:, :, 0])])
|
||||
|
||||
ds.map(input_columns="image",
|
||||
operations=c_op)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "The shape" in str(e)
|
||||
|
||||
|
||||
def test_invert_md5_py():
|
||||
"""
|
||||
Test Invert python op with md5 check
|
||||
"""
|
||||
logger.info("Test Invert python op with md5 check")
|
||||
|
||||
# Generate dataset
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
@ -98,10 +232,34 @@ def test_invert_md5():
|
|||
|
||||
data = ds.map(input_columns="image", operations=transforms_invert())
|
||||
# Compare with expected md5 from images
|
||||
filename = "invert_01_result.npz"
|
||||
filename = "invert_01_result_py.npz"
|
||||
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_invert_md5_c():
|
||||
"""
|
||||
Test Invert cpp op with md5 check
|
||||
"""
|
||||
logger.info("Test Invert cpp op with md5 check")
|
||||
|
||||
# Generate dataset
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_invert = [C.Decode(),
|
||||
C.Resize(size=[224, 224]),
|
||||
C.Invert(),
|
||||
F.ToTensor()]
|
||||
|
||||
data = ds.map(input_columns="image", operations=transforms_invert)
|
||||
# Compare with expected md5 from images
|
||||
filename = "invert_01_result_c.npz"
|
||||
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_invert(plot=True)
|
||||
test_invert_md5()
|
||||
test_invert_py(plot=False)
|
||||
test_invert_c(plot=False)
|
||||
test_invert_py_c(plot=False)
|
||||
test_invert_one_channel()
|
||||
test_invert_md5_py()
|
||||
test_invert_md5_c()
|
||||
|
|
Loading…
Reference in New Issue