forked from mindspore-Ecosystem/mindspore
!3101 Implementing AutoContrast Op in CPP
Merge pull request !3101 from islam_amin/autocontrast_op
This commit is contained in:
commit
f47ad8535f
|
@ -48,6 +48,7 @@
|
|||
#include "minddata/dataset/kernels/data/slice_op.h"
|
||||
#include "minddata/dataset/kernels/data/to_float16_op.h"
|
||||
#include "minddata/dataset/kernels/data/type_cast_op.h"
|
||||
#include "minddata/dataset/kernels/image/auto_contrast_op.h"
|
||||
#include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
|
||||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/cut_out_op.h"
|
||||
|
@ -362,6 +363,11 @@ void bindTensorOps1(py::module *m) {
|
|||
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
|
||||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
||||
|
||||
(void)py::class_<AutoContrastOp, TensorOp, std::shared_ptr<AutoContrastOp>>(
|
||||
*m, "AutoContrastOp", "Tensor operation to apply autocontrast on an image.")
|
||||
.def(py::init<float, std::vector<uint32_t>>(), py::arg("cutoff") = AutoContrastOp::kCutOff,
|
||||
py::arg("ignore") = AutoContrastOp::kIgnore);
|
||||
|
||||
(void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>(
|
||||
*m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.")
|
||||
.def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(kernels-image OBJECT
|
||||
auto_contrast_op.cc
|
||||
center_crop_op.cc
|
||||
cut_out_op.cc
|
||||
decode_op.cc
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/kernels/image/auto_contrast_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const float AutoContrastOp::kCutOff = 0.0;
|
||||
const std::vector<uint32_t> AutoContrastOp::kIgnore = {};
|
||||
|
||||
Status AutoContrastOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
return AutoContrast(input, output, cutoff_, ignore_);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_
|
||||
#define DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class AutoContrastOp : public TensorOp {
|
||||
public:
|
||||
/// Default cutoff to be used
|
||||
static const float kCutOff;
|
||||
/// Default ignore to be used
|
||||
static const std::vector<uint32_t> kIgnore;
|
||||
|
||||
AutoContrastOp(const float &cutoff, const std::vector<uint32_t> &ignore) : cutoff_(cutoff), ignore_(ignore) {}
|
||||
|
||||
~AutoContrastOp() override = default;
|
||||
|
||||
/// Provide stream operator for displaying it
|
||||
friend std::ostream &operator<<(std::ostream &out, const AutoContrastOp &so) {
|
||||
so.Print(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
void Print(std::ostream &out) const override { out << Name(); }
|
||||
|
||||
std::string Name() const override { return kAutoContrastOp; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
private:
|
||||
float cutoff_;
|
||||
std::vector<uint32_t> ignore_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_
|
|
@ -585,6 +585,109 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &cutoff,
|
||||
const std::vector<uint32_t> &ignore) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>");
|
||||
}
|
||||
// Reshape to extend dimension if rank is 2 for algorithm to work. then reshape output to be of rank 2 like input
|
||||
if (input_cv->Rank() == 2) {
|
||||
RETURN_IF_NOT_OK(input_cv->ExpandDim(2));
|
||||
}
|
||||
// Get number of channels and image matrix
|
||||
std::size_t num_of_channels = input_cv->shape()[2];
|
||||
if (num_of_channels != 1 && num_of_channels != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3.");
|
||||
}
|
||||
cv::Mat image = input_cv->mat();
|
||||
// Separate the image to channels
|
||||
std::vector<cv::Mat> planes(num_of_channels);
|
||||
cv::split(image, planes);
|
||||
cv::Mat b_hist, g_hist, r_hist;
|
||||
// Establish the number of bins and set variables for histogram
|
||||
int32_t hist_size = 256;
|
||||
int32_t channels = 0;
|
||||
float range[] = {0, 256};
|
||||
const float *hist_range[] = {range};
|
||||
bool uniform = true, accumulate = false;
|
||||
// Set up lookup table for LUT(Look up table algorithm)
|
||||
std::vector<int32_t> table;
|
||||
std::vector<cv::Mat> image_result;
|
||||
for (std::size_t layer = 0; layer < planes.size(); layer++) {
|
||||
// Reset lookup table
|
||||
table = std::vector<int32_t>{};
|
||||
// Calculate Histogram for channel
|
||||
cv::Mat hist;
|
||||
cv::calcHist(&planes[layer], 1, &channels, cv::Mat(), hist, 1, &hist_size, hist_range, uniform, accumulate);
|
||||
hist.convertTo(hist, CV_32SC1);
|
||||
std::vector<int32_t> hist_vec;
|
||||
hist.col(0).copyTo(hist_vec);
|
||||
// Ignore values in ignore
|
||||
for (const auto &item : ignore) hist_vec[item] = 0;
|
||||
int32_t n = std::accumulate(hist_vec.begin(), hist_vec.end(), 0);
|
||||
// Find pixel values that are in the low cutoff and high cutoff.
|
||||
int32_t cut = static_cast<int32_t>((cutoff / 100.0) * n);
|
||||
if (cut != 0) {
|
||||
for (int32_t lo = 0; lo < 256 && cut > 0; lo++) {
|
||||
if (cut > hist_vec[lo]) {
|
||||
cut -= hist_vec[lo];
|
||||
hist_vec[lo] = 0;
|
||||
} else {
|
||||
hist_vec[lo] -= cut;
|
||||
cut = 0;
|
||||
}
|
||||
}
|
||||
cut = static_cast<int32_t>((cutoff / 100.0) * n);
|
||||
for (int32_t hi = 255; hi >= 0 && cut > 0; hi--) {
|
||||
if (cut > hist_vec[hi]) {
|
||||
cut -= hist_vec[hi];
|
||||
hist_vec[hi] = 0;
|
||||
} else {
|
||||
hist_vec[hi] -= cut;
|
||||
cut = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
int32_t lo = 0;
|
||||
int32_t hi = 255;
|
||||
for (; lo < 256 && !hist_vec[lo]; lo++) {
|
||||
}
|
||||
for (; hi >= 0 && !hist_vec[hi]; hi--) {
|
||||
}
|
||||
if (hi <= lo) {
|
||||
for (int32_t i = 0; i < 256; i++) {
|
||||
table.push_back(i);
|
||||
}
|
||||
} else {
|
||||
float scale = 255.0 / (hi - lo);
|
||||
float offset = -1 * lo * scale;
|
||||
for (int32_t i = 0; i < 256; i++) {
|
||||
int32_t ix = static_cast<int32_t>(i * scale + offset);
|
||||
ix = std::max(ix, 0);
|
||||
ix = std::min(ix, 255);
|
||||
table.push_back(ix);
|
||||
}
|
||||
}
|
||||
cv::Mat result_layer;
|
||||
cv::LUT(planes[layer], table, result_layer);
|
||||
image_result.push_back(result_layer);
|
||||
}
|
||||
cv::Mat result;
|
||||
cv::merge(image_result, result);
|
||||
std::shared_ptr<CVTensor> output_cv = std::make_shared<CVTensor>(result);
|
||||
if (input_cv->Rank() == 2) output_cv->Squeeze();
|
||||
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
||||
} catch (const cv::Exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Error in auto contrast");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &alpha) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
|
|
|
@ -175,6 +175,14 @@ Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
|||
// @param output: Adjusted image of same shape and type.
|
||||
Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &alpha);
|
||||
|
||||
// Returns image with contrast maximized.
|
||||
// @param input: Tensor of shape <H,W,3>/<H,W,1>/<H,W> in RGB/Grayscale and any OpenCv compatible type, see CVTensor.
|
||||
// @param cutoff: Cutoff percentage of how many pixels are to be removed (high pixels change to 255 and low change to 0)
|
||||
// from the high and low ends of the histogram.
|
||||
// @param ignore: Pixel values to be ignored in the algorithm.
|
||||
Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &cutoff,
|
||||
const std::vector<uint32_t> &ignore);
|
||||
|
||||
// Returns image with adjusted saturation.
|
||||
// @param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor.
|
||||
// @param alpha: Alpha value to adjust saturation by. Should be a positive number.
|
||||
|
|
|
@ -87,6 +87,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// image
|
||||
constexpr char kAutoContrastOp[] = "AutoContrastOp";
|
||||
constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
|
||||
constexpr char kDecodeOp[] = "DecodeOp";
|
||||
constexpr char kCenterCropOp[] = "CenterCropOp";
|
||||
|
|
|
@ -47,7 +47,7 @@ from .utils import Inter, Border
|
|||
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \
|
||||
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \
|
||||
check_random_select_subpolicy_op, FLOAT_MAX_INTEGER
|
||||
check_random_select_subpolicy_op, check_auto_contrast, FLOAT_MAX_INTEGER
|
||||
|
||||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||
|
@ -71,6 +71,24 @@ def parse_padding(padding):
|
|||
return padding
|
||||
|
||||
|
||||
class AutoContrast(cde.AutoContrastOp):
|
||||
"""
|
||||
Apply auto contrast on input image.
|
||||
|
||||
Args:
|
||||
cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
|
||||
ignore (int or sequence, optional): Pixel values to ignore (default=None).
|
||||
"""
|
||||
|
||||
@check_auto_contrast
|
||||
def __init__(self, cutoff=0.0, ignore=None):
|
||||
if ignore is None:
|
||||
ignore = []
|
||||
if isinstance(ignore, int):
|
||||
ignore = [ignore]
|
||||
super().__init__(cutoff, ignore)
|
||||
|
||||
|
||||
class Invert(cde.InvertOp):
|
||||
"""
|
||||
Apply invert on input image in RGB mode.
|
||||
|
|
|
@ -530,6 +530,27 @@ def check_bounding_box_augment_cpp(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_auto_contrast(method):
|
||||
"""Wrapper method to check the parameters of AutoContrast ops (python and cpp)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[cutoff, ignore], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(cutoff, (int, float), "cutoff")
|
||||
check_value(cutoff, [0, 100], "cutoff")
|
||||
if ignore is not None:
|
||||
type_check(ignore, (list, tuple, int), "ignore")
|
||||
if isinstance(ignore, int):
|
||||
check_value(ignore, [0, 255], "ignore")
|
||||
if isinstance(ignore, (list, tuple)):
|
||||
for item in ignore:
|
||||
type_check(item, (int,), "item")
|
||||
check_value(item, [0, 255], "ignore")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_uniform_augment_py(method):
|
||||
"""Wrapper method to check the parameters of python UniformAugment op."""
|
||||
|
||||
|
|
Binary file not shown.
|
@ -16,20 +16,22 @@
|
|||
Testing AutoContrast op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, diff_mse
|
||||
from util import visualize_list, diff_mse, save_and_check_md5
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
def test_auto_contrast(plot=False):
|
||||
|
||||
def test_auto_contrast_py(plot=False):
|
||||
"""
|
||||
Test AutoContrast
|
||||
"""
|
||||
logger.info("Test AutoContrast")
|
||||
logger.info("Test AutoContrast Python Op")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
@ -78,9 +80,156 @@ def test_auto_contrast(plot=False):
|
|||
mse[i] = diff_mse(images_auto_contrast[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
# Compare with expected md5 from images
|
||||
filename = "autcontrast_01_result_py.npz"
|
||||
save_and_check_md5(ds_auto_contrast, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
if plot:
|
||||
visualize_list(images_original, images_auto_contrast)
|
||||
|
||||
|
||||
def test_auto_contrast_c(plot=False):
|
||||
"""
|
||||
Test AutoContrast C Op
|
||||
"""
|
||||
logger.info("Test AutoContrast C Op")
|
||||
|
||||
# AutoContrast Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(),
|
||||
C.Resize((224, 224))])
|
||||
python_op = F.AutoContrast()
|
||||
c_op = C.AutoContrast()
|
||||
transforms_op = F.ComposeOp([lambda img: F.ToPIL()(img.astype(np.uint8)),
|
||||
python_op,
|
||||
np.array])()
|
||||
|
||||
ds_auto_contrast_py = ds.map(input_columns="image",
|
||||
operations=transforms_op)
|
||||
|
||||
ds_auto_contrast_py = ds_auto_contrast_py.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_auto_contrast_py):
|
||||
if idx == 0:
|
||||
images_auto_contrast_py = image
|
||||
else:
|
||||
images_auto_contrast_py = np.append(images_auto_contrast_py,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(),
|
||||
C.Resize((224, 224))])
|
||||
|
||||
ds_auto_contrast_c = ds.map(input_columns="image",
|
||||
operations=c_op)
|
||||
|
||||
ds_auto_contrast_c = ds_auto_contrast_c.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_auto_contrast_c):
|
||||
if idx == 0:
|
||||
images_auto_contrast_c = image
|
||||
else:
|
||||
images_auto_contrast_c = np.append(images_auto_contrast_c,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
num_samples = images_auto_contrast_c.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_auto_contrast_c[i], images_auto_contrast_py[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
np.testing.assert_equal(np.mean(mse), 0.0)
|
||||
|
||||
if plot:
|
||||
visualize_list(images_auto_contrast_c, images_auto_contrast_py, visualize_mode=2)
|
||||
|
||||
|
||||
def test_auto_contrast_one_channel_c(plot=False):
|
||||
"""
|
||||
Test AutoContrast C op with one channel
|
||||
"""
|
||||
logger.info("Test AutoContrast C Op With One Channel Images")
|
||||
|
||||
# AutoContrast Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(),
|
||||
C.Resize((224, 224))])
|
||||
python_op = F.AutoContrast()
|
||||
c_op = C.AutoContrast()
|
||||
# not using F.ToTensor() since it converts to floats
|
||||
transforms_op = F.ComposeOp([lambda img: (np.array(img)[:, :, 0]).astype(np.uint8),
|
||||
F.ToPIL(),
|
||||
python_op,
|
||||
np.array])()
|
||||
|
||||
ds_auto_contrast_py = ds.map(input_columns="image",
|
||||
operations=transforms_op)
|
||||
|
||||
ds_auto_contrast_py = ds_auto_contrast_py.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_auto_contrast_py):
|
||||
if idx == 0:
|
||||
images_auto_contrast_py = image
|
||||
else:
|
||||
images_auto_contrast_py = np.append(images_auto_contrast_py,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(),
|
||||
C.Resize((224, 224)),
|
||||
lambda img: np.array(img[:, :, 0])])
|
||||
|
||||
ds_auto_contrast_c = ds.map(input_columns="image",
|
||||
operations=c_op)
|
||||
|
||||
ds_auto_contrast_c = ds_auto_contrast_c.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_auto_contrast_c):
|
||||
if idx == 0:
|
||||
images_auto_contrast_c = image
|
||||
else:
|
||||
images_auto_contrast_c = np.append(images_auto_contrast_c,
|
||||
image,
|
||||
axis=0)
|
||||
|
||||
num_samples = images_auto_contrast_c.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_auto_contrast_c[i], images_auto_contrast_py[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
np.testing.assert_equal(np.mean(mse), 0.0)
|
||||
|
||||
if plot:
|
||||
visualize_list(images_auto_contrast_c, images_auto_contrast_py, visualize_mode=2)
|
||||
|
||||
|
||||
def test_auto_contrast_invalid_input_c():
|
||||
"""
|
||||
Test AutoContrast C Op with invalid params
|
||||
"""
|
||||
logger.info("Test AutoContrast C Op with invalid params")
|
||||
try:
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
ds = ds.map(input_columns=["image"],
|
||||
operations=[C.Decode(),
|
||||
C.Resize((224, 224)),
|
||||
lambda img: np.array(img[:, :, 0])])
|
||||
# invalid ignore
|
||||
ds = ds.map(input_columns="image",
|
||||
operations=C.AutoContrast(ignore=255.5))
|
||||
except TypeError as error:
|
||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||
assert "Argument ignore with value 255.5 is not of type" in str(error)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_auto_contrast(plot=True)
|
||||
test_auto_contrast_py(plot=True)
|
||||
test_auto_contrast_c(plot=True)
|
||||
test_auto_contrast_one_channel_c(plot=True)
|
||||
test_auto_contrast_invalid_input_c()
|
||||
|
|
Loading…
Reference in New Issue