!40012 [feat] [assistant] [I501PV] Add new operator AdjustContrast
Merge pull request !40012 from 刘赫喃/AdjustContrast
This commit is contained in:
commit
a315a183ff
|
@ -21,6 +21,7 @@
|
|||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_brightness_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_contrast_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_hue_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_saturation_ir.h"
|
||||
|
@ -94,6 +95,16 @@ PYBIND_REGISTER(AdjustBrightnessOperation, 1, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(AdjustContrastOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::AdjustContrastOperation, TensorOperation,
|
||||
std::shared_ptr<vision::AdjustContrastOperation>>(*m, "AdjustContrastOperation")
|
||||
.def(py::init([](float contrast_factor) {
|
||||
auto adjust_contrast = std::make_shared<vision::AdjustContrastOperation>(contrast_factor);
|
||||
THROW_IF_ERROR(adjust_contrast->ValidateParams());
|
||||
return adjust_contrast;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
AdjustGammaOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::AdjustGammaOperation, TensorOperation, std::shared_ptr<vision::AdjustGammaOperation>>(
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#endif
|
||||
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_brightness_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_contrast_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_hue_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_saturation_ir.h"
|
||||
|
@ -141,6 +142,18 @@ std::shared_ptr<TensorOperation> AdjustBrightness::Parse() {
|
|||
return std::make_shared<AdjustBrightnessOperation>(data_->brightness_factor_);
|
||||
}
|
||||
|
||||
// AdjustContrast Transform Operation.
|
||||
struct AdjustContrast::Data {
|
||||
explicit Data(float contrast_factor) : contrast_factor_(contrast_factor) {}
|
||||
float contrast_factor_;
|
||||
};
|
||||
|
||||
AdjustContrast::AdjustContrast(float contrast_factor) : data_(std::make_shared<Data>(contrast_factor)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> AdjustContrast::Parse() {
|
||||
return std::make_shared<AdjustContrastOperation>(data_->contrast_factor_);
|
||||
}
|
||||
|
||||
// AdjustGamma Transform Operation.
|
||||
struct AdjustGamma::Data {
|
||||
Data(float gamma, float gain) : gamma_(gamma), gain_(gain) {}
|
||||
|
|
|
@ -66,6 +66,36 @@ class MS_API AdjustBrightness final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply contrast adjustment on input image.
|
||||
class MS_API AdjustContrast final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] contrast_factor Adjusts image contrast, non negative real number.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define operations */
|
||||
/// auto decode_op = vision::Decode();
|
||||
/// auto adjust_contrast_op = vision::AdjustContrast(10.0);
|
||||
///
|
||||
/// /* dataset is an instance of Dataset object */
|
||||
/// dataset = dataset->Map({decode_op, adjust_contrast_op}, // operations
|
||||
/// {"image"}); // input columns
|
||||
/// \endcode
|
||||
explicit AdjustContrast(float contrast_factor);
|
||||
|
||||
/// \brief Destructor.
|
||||
~AdjustContrast() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief AdjustGamma TensorTransform.
|
||||
/// \note Apply gamma correction on input image.
|
||||
class MS_API AdjustGamma final : public TensorTransform {
|
||||
|
|
|
@ -6,6 +6,7 @@ if(ENABLE_ACL)
|
|||
endif()
|
||||
add_library(kernels-image OBJECT
|
||||
adjust_brightness_op.cc
|
||||
adjust_contrast_op.cc
|
||||
adjust_gamma_op.cc
|
||||
adjust_hue_op.cc
|
||||
adjust_saturation_op.cc
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/adjust_contrast_op.h"
|
||||
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status AdjustContrastOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
return AdjustContrast(input, output, contrast_factor_);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_CONTRAST_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_CONTRAST_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class AdjustContrastOp : public TensorOp {
|
||||
public:
|
||||
explicit AdjustContrastOp(float contrast_factor) : contrast_factor_(contrast_factor) {}
|
||||
|
||||
~AdjustContrastOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kAdjustContrastOp; }
|
||||
|
||||
private:
|
||||
float contrast_factor_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CONTRAST_OP_H_
|
|
@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
|
|||
|
||||
set(DATASET_KERNELS_IR_VISION_SRC_FILES
|
||||
adjust_brightness_ir.cc
|
||||
adjust_contrast_ir.cc
|
||||
adjust_gamma_ir.cc
|
||||
adjust_hue_ir.cc
|
||||
adjust_saturation_ir.cc
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_contrast_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/adjust_contrast_op.h"
|
||||
#endif
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
#include "minddata/dataset/util/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
#ifndef ENABLE_ANDROID
|
||||
// AdjustContrastOperation
|
||||
AdjustContrastOperation::AdjustContrastOperation(float contrast_factor) : contrast_factor_(contrast_factor) {}
|
||||
|
||||
Status AdjustContrastOperation::ValidateParams() {
|
||||
// contrast_factor
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("AdjustContrast", "contrast_factor", contrast_factor_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> AdjustContrastOperation::Build() {
|
||||
std::shared_ptr<AdjustContrastOp> tensor_op = std::make_shared<AdjustContrastOp>(contrast_factor_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status AdjustContrastOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["contrast_factor"] = contrast_factor_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AdjustContrastOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "contrast_factor", kAdjustContrastOperation));
|
||||
float contrast_factor = op_params["contrast_factor"];
|
||||
*operation = std::make_shared<vision::AdjustContrastOperation>(contrast_factor);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_CONTRAST_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_CONTRAST_IR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
constexpr char kAdjustContrastOperation[] = "AdjustContrast";
|
||||
|
||||
class AdjustContrastOperation : public TensorOperation {
|
||||
public:
|
||||
explicit AdjustContrastOperation(float contrast_factor);
|
||||
|
||||
~AdjustContrastOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kAdjustContrastOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
float contrast_factor_;
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_CONTRAST_IR_H_
|
|
@ -54,6 +54,7 @@ constexpr char kTensorOp[] = "TensorOp";
|
|||
|
||||
// image
|
||||
constexpr char kAdjustBrightnessOp[] = "AdjustBrightnessOp";
|
||||
constexpr char kAdjustContrastOp[] = "AdjustContrastOp";
|
||||
constexpr char kAdjustGammaOp[] = "AdjustGammaOp";
|
||||
constexpr char kAdjustHueOp[] = "AdjustHueOp";
|
||||
constexpr char kAdjustSaturationOp[] = "AdjustSaturationOp";
|
||||
|
|
|
@ -44,15 +44,16 @@ from . import c_transforms
|
|||
from . import py_transforms
|
||||
from . import transforms
|
||||
from . import utils
|
||||
from .transforms import AdjustBrightness, AdjustGamma, AdjustHue, AdjustSaturation, AdjustSharpness, AutoAugment, \
|
||||
AutoContrast, BoundingBoxAugment, CenterCrop, ConvertColor, Crop, CutMixBatch, CutOut, Decode, Equalize, Erase, \
|
||||
FiveCrop, GaussianBlur, Grayscale, HorizontalFlip, HsvToRgb, HWC2CHW, Invert, LinearTransformation, MixUp, \
|
||||
MixUpBatch, Normalize, NormalizePad, Pad, PadToSize, Posterize, RandomAdjustSharpness, RandomAffine, \
|
||||
RandomAutoContrast, RandomColor, RandomColorAdjust, RandomCrop, RandomCropDecodeResize, RandomCropWithBBox, \
|
||||
RandomEqualize, RandomErasing, RandomGrayscale, RandomHorizontalFlip, RandomHorizontalFlipWithBBox, RandomInvert, \
|
||||
RandomLighting, RandomPerspective, RandomPosterize, RandomResizedCrop, RandomResizedCropWithBBox, RandomResize, \
|
||||
RandomResizeWithBBox, RandomRotation, RandomSelectSubpolicy, RandomSharpness, RandomSolarize, RandomVerticalFlip, \
|
||||
RandomVerticalFlipWithBBox, Rescale, Resize, ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, \
|
||||
ToNumpy, ToPIL, ToTensor, ToType, TrivialAugmentWide, UniformAugment, VerticalFlip, not_random
|
||||
from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue, AdjustSaturation, AdjustSharpness, \
|
||||
AutoAugment, AutoContrast, BoundingBoxAugment, CenterCrop, ConvertColor, Crop, CutMixBatch, CutOut, Decode, \
|
||||
Equalize, Erase, FiveCrop, GaussianBlur, Grayscale, HorizontalFlip, HsvToRgb, HWC2CHW, Invert, \
|
||||
LinearTransformation, MixUp, MixUpBatch, Normalize, NormalizePad, Pad, PadToSize, Posterize, \
|
||||
RandomAdjustSharpness, RandomAffine, RandomAutoContrast, RandomColor, RandomColorAdjust, RandomCrop, \
|
||||
RandomCropDecodeResize, RandomCropWithBBox, RandomEqualize, RandomErasing, RandomGrayscale, RandomHorizontalFlip, \
|
||||
RandomHorizontalFlipWithBBox, RandomInvert, RandomLighting, RandomPerspective, RandomPosterize, RandomResizedCrop, \
|
||||
RandomResizedCropWithBBox, RandomResize, RandomResizeWithBBox, RandomRotation, RandomSelectSubpolicy, \
|
||||
RandomSharpness, RandomSolarize, RandomVerticalFlip, RandomVerticalFlipWithBBox, Rescale, Resize, ResizeWithBBox, \
|
||||
RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, TrivialAugmentWide, \
|
||||
UniformAugment, VerticalFlip, not_random
|
||||
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, get_image_num_channels, \
|
||||
get_image_size
|
||||
|
|
|
@ -62,17 +62,18 @@ from mindspore._c_expression import typing
|
|||
from . import py_transforms_util as util
|
||||
from .py_transforms_util import is_pil
|
||||
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, parse_padding
|
||||
from .validators import check_adjust_brightness, check_adjust_gamma, check_adjust_hue, check_adjust_saturation, \
|
||||
check_adjust_sharpness, check_alpha, check_auto_augment, check_auto_contrast, check_bounding_box_augment_cpp, \
|
||||
check_center_crop, check_convert_color, check_crop, check_cut_mix_batch_c, check_cutout_new, check_decode, \
|
||||
check_erase, check_five_crop, check_gaussian_blur, check_hsv_to_rgb, check_linear_transform, check_mix_up, \
|
||||
check_mix_up_batch_c, check_normalize, check_normalizepad, check_num_channels, check_pad, check_pad_to_size, \
|
||||
check_positive_degrees, check_posterize, check_prob, check_random_adjust_sharpness, check_random_affine, \
|
||||
check_random_auto_contrast, check_random_color_adjust, check_random_crop, check_random_erasing, \
|
||||
check_random_perspective, check_random_posterize, check_random_resize_crop, check_random_rotation, \
|
||||
check_random_select_subpolicy_op, check_random_solarize, check_range, check_rescale, check_resize, \
|
||||
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_solarize, check_ten_crop, \
|
||||
check_trivial_augment_wide, check_uniform_augment, check_to_tensor, FLOAT_MAX_INTEGER
|
||||
from .validators import check_adjust_brightness, check_adjust_contrast, check_adjust_gamma, check_adjust_hue, \
|
||||
check_adjust_saturation, check_adjust_sharpness, check_alpha, check_auto_augment, check_auto_contrast, \
|
||||
check_bounding_box_augment_cpp, check_center_crop, check_convert_color, check_crop, check_cut_mix_batch_c, \
|
||||
check_cutout_new, check_decode, check_erase, check_five_crop, check_gaussian_blur, check_hsv_to_rgb, \
|
||||
check_linear_transform, check_mix_up, check_mix_up_batch_c, check_normalize, check_normalizepad, \
|
||||
check_num_channels, check_pad, check_pad_to_size, check_positive_degrees, check_posterize, check_prob, \
|
||||
check_random_adjust_sharpness, check_random_affine, check_random_auto_contrast, check_random_color_adjust, \
|
||||
check_random_crop, check_random_erasing, check_random_perspective, check_random_posterize, \
|
||||
check_random_resize_crop, check_random_rotation, check_random_select_subpolicy_op, check_random_solarize, \
|
||||
check_range, check_rescale, check_resize, check_resize_interpolation, check_rgb_to_hsv, check_rotate, \
|
||||
check_slice_patches, check_solarize, check_ten_crop, check_trivial_augment_wide, check_uniform_augment, \
|
||||
check_to_tensor, FLOAT_MAX_INTEGER
|
||||
from ..core.datatypes import mstype_to_detype, nptype_to_detype
|
||||
from ..transforms.py_transforms_util import Implementation
|
||||
from ..transforms.transforms import CompoundOperation, PyTensorOperation, TensorOperation, TypeCast
|
||||
|
@ -101,8 +102,8 @@ class AdjustBrightness(ImageTensorOperation, PyTensorOperation):
|
|||
|
||||
Args:
|
||||
brightness_factor (float): How much to adjust the brightness. Can be any non negative number.
|
||||
Non negative real number. 0 gives a black image, 1 gives the
|
||||
original image while 2 increases the brightness by a factor of 2.
|
||||
0 gives a black image, 1 gives the original image,
|
||||
while 2 increases the brightness by a factor of 2.
|
||||
|
||||
Raises:
|
||||
TypeError: If `brightness_factor` is not of type float.
|
||||
|
@ -117,6 +118,7 @@ class AdjustBrightness(ImageTensorOperation, PyTensorOperation):
|
|||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
@check_adjust_brightness
|
||||
def __init__(self, brightness_factor):
|
||||
super().__init__()
|
||||
|
@ -138,6 +140,50 @@ class AdjustBrightness(ImageTensorOperation, PyTensorOperation):
|
|||
return util.adjust_brightness(img, self.brightness_factor)
|
||||
|
||||
|
||||
class AdjustContrast(ImageTensorOperation, PyTensorOperation):
|
||||
r"""
|
||||
Adjust contrast of input image. Input image is expected to be in [H, W, C] format.
|
||||
|
||||
Args:
|
||||
contrast_factor (float): How much to adjust the contrast. Can be any non negative number.
|
||||
0 gives a solid gray image, 1 gives the original image,
|
||||
while 2 increases the contrast by a factor of 2.
|
||||
|
||||
Raises:
|
||||
TypeError: If `contrast_factor` is not of type float.
|
||||
ValueError: If `contrast_factor` is less than 0.
|
||||
RuntimeError: If given tensor shape is not <H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> transforms_list = [vision.Decode(), vision.AdjustContrast(contrast_factor=2.0)]
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
@check_adjust_contrast
|
||||
def __init__(self, contrast_factor):
|
||||
super().__init__()
|
||||
self.contrast_factor = contrast_factor
|
||||
|
||||
def parse(self):
|
||||
return cde.AdjustContrastOperation(self.contrast_factor)
|
||||
|
||||
def execute_py(self, img):
|
||||
"""
|
||||
Execute method.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be contrast adjusted.
|
||||
|
||||
Returns:
|
||||
PIL Image, contrast adjusted image.
|
||||
"""
|
||||
return util.adjust_contrast(img, self.contrast_factor)
|
||||
|
||||
|
||||
class AdjustGamma(ImageTensorOperation, PyTensorOperation):
|
||||
r"""
|
||||
Apply gamma correction on input image. Input image is expected to be in [..., H, W, C] or [H, W] format.
|
||||
|
|
|
@ -1021,6 +1021,19 @@ def check_adjust_brightness(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_adjust_contrast(method):
|
||||
"""Wrapper method to check the parameters of AdjustContrast ops (Python and C++)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[contrast_factor], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(contrast_factor, (float, int), "contrast_factor")
|
||||
check_value(contrast_factor, (0, FLOAT_MAX_INTEGER), "contrast_factor")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_adjust_gamma(method):
|
||||
"""Wrapper method to check the parameters of AdjustGamma ops (Python and C++)."""
|
||||
|
||||
|
|
|
@ -2608,3 +2608,55 @@ TEST_F(MindDataTestPipeline, TestAdjustHueParamCheck) {
|
|||
// Expect failure: invalid value of AdjustHue
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: AdjustContrast op
|
||||
/// Description: Test AdjustContrast C implementation Pipeline
|
||||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestPipeline, TestAdjustContrast) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAdjustContrast.";
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto adjustcontrast_op = vision::AdjustContrast(2.0);
|
||||
|
||||
ds = ds->Map({adjustcontrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: AdjustContrast op
|
||||
/// Description: Test improper parameters for AdjustContrast C implementation
|
||||
/// Expectation: Throw ValueError exception
|
||||
TEST_F(MindDataTestPipeline, TestAdjustContrastParamCheck) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAdjustContrastParamCheck.";
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Case 1: Negative contrast_factor
|
||||
// Create objects for the tensor ops
|
||||
auto adjustcontrast_op = vision::AdjustContrast(-1);
|
||||
auto ds1 = ds->Map({adjustcontrast_op});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
// Expect failure: invalid value of AdjustContrast
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
}
|
||||
|
|
|
@ -2835,7 +2835,7 @@ TEST_F(MindDataTestExecute, TestEraseEager) {
|
|||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
||||
/// Feature: Execute Transform op
|
||||
/// Feature: AdjustBrightness
|
||||
/// Description: Test executing AdjustBrightness op in eager mode
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestExecute, TestAdjustBrightness) {
|
||||
|
@ -2919,3 +2919,20 @@ TEST_F(MindDataTestExecute, TestAdjustHue) {
|
|||
Status rc = transform(image, &image);
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
||||
/// Feature: AdjustContrast
|
||||
/// Description: Test executing AdjustContrast op in eager mode
|
||||
/// Expectation: The data is processed successfully
|
||||
TEST_F(MindDataTestExecute, TestAdjustContrast) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestAdjustContrast.";
|
||||
// Read images
|
||||
auto image = ReadFileToTensor("data/dataset/apple.jpg");
|
||||
|
||||
// Transform params
|
||||
auto decode = vision::Decode();
|
||||
auto adjust_contrast_op = vision::AdjustContrast(1);
|
||||
|
||||
auto transform = Execute({decode, adjust_contrast_op});
|
||||
Status rc = transform(image, &image);
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
|
|
@ -17,10 +17,10 @@ Testing AdjustBrightness op in DE
|
|||
"""
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.transforms
|
||||
import mindspore.dataset.vision as vision
|
||||
from mindspore.dataset.vision import Decode
|
||||
from mindspore import log as logger
|
||||
from util import diff_mse
|
||||
|
||||
|
@ -51,7 +51,7 @@ def test_adjust_brightness_eager(plot=False):
|
|||
img = np.fromfile(image_file, dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
img = Decode()(img)
|
||||
img = vision.Decode()(img)
|
||||
img_adjustbrightness = vision.AdjustBrightness(1)(img)
|
||||
if plot:
|
||||
visualize_image(img, img_adjustbrightness)
|
||||
|
@ -64,7 +64,6 @@ def test_adjust_brightness_eager(plot=False):
|
|||
|
||||
def test_adjust_brightness_invalid_brightness_factor_param():
|
||||
"""
|
||||
Test AdjustBrightness implementation with invalid ignore parameter
|
||||
Feature: AdjustBrightness op
|
||||
Description: Test improper parameters for AdjustBrightness implementation
|
||||
Expectation: Throw ValueError exception and TypeError exception
|
||||
|
@ -141,6 +140,7 @@ def test_adjust_brightness_pipeline():
|
|||
logger.info("MSE= {}".format(str(mse)))
|
||||
assert mse == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_adjust_brightness_eager()
|
||||
test_adjust_brightness_invalid_brightness_factor_param()
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing AdjustContrast op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.transforms
|
||||
import mindspore.dataset.vision as vision
|
||||
from mindspore import log as logger
|
||||
from util import diff_mse
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||
|
||||
DATA_DIR_2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def generate_numpy_random_rgb(shape):
|
||||
"""
|
||||
Only generate floating points that are fractions like n / 256, since they
|
||||
are RGB pixels. Some low-precision floating point types in this test can't
|
||||
handle arbitrary precision floating points well.
|
||||
"""
|
||||
return np.random.randint(0, 256, shape) / 255.
|
||||
|
||||
|
||||
def test_adjust_contrast_eager(plot=False):
|
||||
"""
|
||||
Feature: AdjustContrast op
|
||||
Description: Test AdjustContrast in eager mode
|
||||
Expectation: Output is the same as expected output
|
||||
"""
|
||||
# Eager 3-channel
|
||||
image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg"
|
||||
img = np.fromfile(image_file, dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
img = vision.Decode()(img)
|
||||
img_adjustcontrast = vision.AdjustContrast(1)(img)
|
||||
if plot:
|
||||
visualize_image(img, img_adjustcontrast)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img_adjustcontrast),
|
||||
img_adjustcontrast.shape))
|
||||
mse = diff_mse(img_adjustcontrast, img)
|
||||
logger.info("MSE= {}".format(str(mse)))
|
||||
assert mse == 0
|
||||
|
||||
|
||||
def test_adjust_contrast_invalid_contrast_factor_param():
|
||||
"""
|
||||
Feature: AdjustContrast op
|
||||
Description: Test improper parameters for AdjustContrast implementation
|
||||
Expectation: Throw ValueError exception and TypeError exception
|
||||
"""
|
||||
logger.info("Test AdjustContrast Python implementation with invalid ignore parameter")
|
||||
try:
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
trans = mindspore.dataset.transforms.transforms.Compose([
|
||||
vision.Decode(True),
|
||||
vision.Resize((224, 224)),
|
||||
vision.AdjustContrast(contrast_factor=-10.0),
|
||||
vision.ToTensor()
|
||||
])
|
||||
data_set = data_set.map(operations=[trans], input_columns=["image"])
|
||||
except ValueError as error:
|
||||
logger.info("Got an exception in AdjustContrast: {}".format(str(error)))
|
||||
assert "Input contrast_factor is not within the required interval of " in str(error)
|
||||
try:
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
trans = ds.transforms.transforms.Compose([
|
||||
vision.Decode(True),
|
||||
vision.Resize((224, 224)),
|
||||
vision.AdjustContrast(contrast_factor=[1, 2]),
|
||||
vision.ToTensor()
|
||||
])
|
||||
data_set = data_set.map(operations=[trans], input_columns=["image"])
|
||||
except TypeError as error:
|
||||
logger.info("Got an exception in AdjustContrast: {}".format(str(error)))
|
||||
assert "is not of type [<class 'float'>, <class 'int'>], but got" in str(error)
|
||||
|
||||
|
||||
def test_adjust_contrast_pipeline():
|
||||
"""
|
||||
Feature: AdjustContrast op
|
||||
Description: Test AdjustContrast in pipeline mode
|
||||
Expectation: Output is the same as expected output
|
||||
"""
|
||||
# First dataset
|
||||
transforms1 = [vision.Decode(True), vision.Resize([64, 64]), vision.ToTensor()]
|
||||
transforms1 = mindspore.dataset.transforms.transforms.Compose(
|
||||
transforms1)
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR_2,
|
||||
SCHEMA_DIR,
|
||||
columns_list=["image"],
|
||||
shuffle=False)
|
||||
ds1 = ds1.map(operations=transforms1, input_columns=["image"])
|
||||
|
||||
# Second dataset
|
||||
transforms2 = [
|
||||
vision.Decode(True),
|
||||
vision.Resize([64, 64]),
|
||||
vision.AdjustContrast(1.0),
|
||||
vision.ToTensor()
|
||||
]
|
||||
transform2 = mindspore.dataset.transforms.transforms.Compose(
|
||||
transforms2)
|
||||
ds2 = ds.TFRecordDataset(DATA_DIR_2,
|
||||
SCHEMA_DIR,
|
||||
columns_list=["image"],
|
||||
shuffle=False)
|
||||
ds2 = ds2.map(operations=transform2, input_columns=["image"])
|
||||
|
||||
num_iter = 0
|
||||
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
|
||||
ds2.create_dict_iterator(num_epochs=1)):
|
||||
num_iter += 1
|
||||
ori_img = data1["image"].asnumpy()
|
||||
cvt_img = data2["image"].asnumpy()
|
||||
assert_allclose(ori_img.flatten(),
|
||||
cvt_img.flatten(),
|
||||
rtol=1e-5,
|
||||
atol=0)
|
||||
mse = diff_mse(ori_img, cvt_img)
|
||||
logger.info("MSE= {}".format(str(mse)))
|
||||
assert mse == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_adjust_contrast_eager()
|
||||
test_adjust_contrast_invalid_contrast_factor_param()
|
||||
test_adjust_contrast_pipeline()
|
Loading…
Reference in New Issue