!40012 [feat] [assistant] [I501PV] Add new operator AdjustContrast

Merge pull request !40012 from 刘赫喃/AdjustContrast
This commit is contained in:
i-robot 2022-08-16 01:52:08 +00:00 committed by Gitee
commit a315a183ff
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 552 additions and 27 deletions

View File

@ -21,6 +21,7 @@
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/ir/vision/adjust_brightness_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_hue_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_saturation_ir.h"
@ -94,6 +95,16 @@ PYBIND_REGISTER(AdjustBrightnessOperation, 1, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(AdjustContrastOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::AdjustContrastOperation, TensorOperation,
std::shared_ptr<vision::AdjustContrastOperation>>(*m, "AdjustContrastOperation")
.def(py::init([](float contrast_factor) {
auto adjust_contrast = std::make_shared<vision::AdjustContrastOperation>(contrast_factor);
THROW_IF_ERROR(adjust_contrast->ValidateParams());
return adjust_contrast;
}));
}));
PYBIND_REGISTER(
AdjustGammaOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::AdjustGammaOperation, TensorOperation, std::shared_ptr<vision::AdjustGammaOperation>>(

View File

@ -21,6 +21,7 @@
#endif
#include "minddata/dataset/kernels/ir/vision/adjust_brightness_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_hue_ir.h"
#include "minddata/dataset/kernels/ir/vision/adjust_saturation_ir.h"
@ -141,6 +142,18 @@ std::shared_ptr<TensorOperation> AdjustBrightness::Parse() {
return std::make_shared<AdjustBrightnessOperation>(data_->brightness_factor_);
}
// AdjustContrast Transform Operation.
struct AdjustContrast::Data {
explicit Data(float contrast_factor) : contrast_factor_(contrast_factor) {}
float contrast_factor_;
};
AdjustContrast::AdjustContrast(float contrast_factor) : data_(std::make_shared<Data>(contrast_factor)) {}
std::shared_ptr<TensorOperation> AdjustContrast::Parse() {
return std::make_shared<AdjustContrastOperation>(data_->contrast_factor_);
}
// AdjustGamma Transform Operation.
struct AdjustGamma::Data {
Data(float gamma, float gain) : gamma_(gamma), gain_(gain) {}

View File

@ -66,6 +66,36 @@ class MS_API AdjustBrightness final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Apply contrast adjustment on input image.
class MS_API AdjustContrast final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] contrast_factor Adjusts image contrast, non negative real number.
/// \par Example
/// \code
/// /* Define operations */
/// auto decode_op = vision::Decode();
/// auto adjust_contrast_op = vision::AdjustContrast(10.0);
///
/// /* dataset is an instance of Dataset object */
/// dataset = dataset->Map({decode_op, adjust_contrast_op}, // operations
/// {"image"}); // input columns
/// \endcode
explicit AdjustContrast(float contrast_factor);
/// \brief Destructor.
~AdjustContrast() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief AdjustGamma TensorTransform.
/// \note Apply gamma correction on input image.
class MS_API AdjustGamma final : public TensorTransform {

View File

@ -6,6 +6,7 @@ if(ENABLE_ACL)
endif()
add_library(kernels-image OBJECT
adjust_brightness_op.cc
adjust_contrast_op.cc
adjust_gamma_op.cc
adjust_hue_op.cc
adjust_saturation_op.cc

View File

@ -0,0 +1,30 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/adjust_contrast_op.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
Status AdjustContrastOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
return AdjustContrast(input, output, contrast_factor_);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_CONTRAST_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_CONTRAST_OP_H_
#include <memory>
#include <string>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class AdjustContrastOp : public TensorOp {
public:
explicit AdjustContrastOp(float contrast_factor) : contrast_factor_(contrast_factor) {}
~AdjustContrastOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kAdjustContrastOp; }
private:
float contrast_factor_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CONTRAST_OP_H_

View File

@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
set(DATASET_KERNELS_IR_VISION_SRC_FILES
adjust_brightness_ir.cc
adjust_contrast_ir.cc
adjust_gamma_ir.cc
adjust_hue_ir.cc
adjust_saturation_ir.cc

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/vision/adjust_contrast_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/adjust_contrast_op.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
#include "minddata/dataset/util/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// AdjustContrastOperation
AdjustContrastOperation::AdjustContrastOperation(float contrast_factor) : contrast_factor_(contrast_factor) {}
Status AdjustContrastOperation::ValidateParams() {
// contrast_factor
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("AdjustContrast", "contrast_factor", contrast_factor_));
return Status::OK();
}
std::shared_ptr<TensorOp> AdjustContrastOperation::Build() {
std::shared_ptr<AdjustContrastOp> tensor_op = std::make_shared<AdjustContrastOp>(contrast_factor_);
return tensor_op;
}
Status AdjustContrastOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["contrast_factor"] = contrast_factor_;
*out_json = args;
return Status::OK();
}
Status AdjustContrastOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "contrast_factor", kAdjustContrastOperation));
float contrast_factor = op_params["contrast_factor"];
*operation = std::make_shared<vision::AdjustContrastOperation>(contrast_factor);
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_CONTRAST_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_CONTRAST_IR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kAdjustContrastOperation[] = "AdjustContrast";
class AdjustContrastOperation : public TensorOperation {
public:
explicit AdjustContrastOperation(float contrast_factor);
~AdjustContrastOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kAdjustContrastOperation; }
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
float contrast_factor_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_CONTRAST_IR_H_

View File

@ -54,6 +54,7 @@ constexpr char kTensorOp[] = "TensorOp";
// image
constexpr char kAdjustBrightnessOp[] = "AdjustBrightnessOp";
constexpr char kAdjustContrastOp[] = "AdjustContrastOp";
constexpr char kAdjustGammaOp[] = "AdjustGammaOp";
constexpr char kAdjustHueOp[] = "AdjustHueOp";
constexpr char kAdjustSaturationOp[] = "AdjustSaturationOp";

View File

@ -44,15 +44,16 @@ from . import c_transforms
from . import py_transforms
from . import transforms
from . import utils
from .transforms import AdjustBrightness, AdjustGamma, AdjustHue, AdjustSaturation, AdjustSharpness, AutoAugment, \
AutoContrast, BoundingBoxAugment, CenterCrop, ConvertColor, Crop, CutMixBatch, CutOut, Decode, Equalize, Erase, \
FiveCrop, GaussianBlur, Grayscale, HorizontalFlip, HsvToRgb, HWC2CHW, Invert, LinearTransformation, MixUp, \
MixUpBatch, Normalize, NormalizePad, Pad, PadToSize, Posterize, RandomAdjustSharpness, RandomAffine, \
RandomAutoContrast, RandomColor, RandomColorAdjust, RandomCrop, RandomCropDecodeResize, RandomCropWithBBox, \
RandomEqualize, RandomErasing, RandomGrayscale, RandomHorizontalFlip, RandomHorizontalFlipWithBBox, RandomInvert, \
RandomLighting, RandomPerspective, RandomPosterize, RandomResizedCrop, RandomResizedCropWithBBox, RandomResize, \
RandomResizeWithBBox, RandomRotation, RandomSelectSubpolicy, RandomSharpness, RandomSolarize, RandomVerticalFlip, \
RandomVerticalFlipWithBBox, Rescale, Resize, ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, \
ToNumpy, ToPIL, ToTensor, ToType, TrivialAugmentWide, UniformAugment, VerticalFlip, not_random
from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue, AdjustSaturation, AdjustSharpness, \
AutoAugment, AutoContrast, BoundingBoxAugment, CenterCrop, ConvertColor, Crop, CutMixBatch, CutOut, Decode, \
Equalize, Erase, FiveCrop, GaussianBlur, Grayscale, HorizontalFlip, HsvToRgb, HWC2CHW, Invert, \
LinearTransformation, MixUp, MixUpBatch, Normalize, NormalizePad, Pad, PadToSize, Posterize, \
RandomAdjustSharpness, RandomAffine, RandomAutoContrast, RandomColor, RandomColorAdjust, RandomCrop, \
RandomCropDecodeResize, RandomCropWithBBox, RandomEqualize, RandomErasing, RandomGrayscale, RandomHorizontalFlip, \
RandomHorizontalFlipWithBBox, RandomInvert, RandomLighting, RandomPerspective, RandomPosterize, RandomResizedCrop, \
RandomResizedCropWithBBox, RandomResize, RandomResizeWithBBox, RandomRotation, RandomSelectSubpolicy, \
RandomSharpness, RandomSolarize, RandomVerticalFlip, RandomVerticalFlipWithBBox, Rescale, Resize, ResizeWithBBox, \
RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, TrivialAugmentWide, \
UniformAugment, VerticalFlip, not_random
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, get_image_num_channels, \
get_image_size

View File

@ -62,17 +62,18 @@ from mindspore._c_expression import typing
from . import py_transforms_util as util
from .py_transforms_util import is_pil
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, parse_padding
from .validators import check_adjust_brightness, check_adjust_gamma, check_adjust_hue, check_adjust_saturation, \
check_adjust_sharpness, check_alpha, check_auto_augment, check_auto_contrast, check_bounding_box_augment_cpp, \
check_center_crop, check_convert_color, check_crop, check_cut_mix_batch_c, check_cutout_new, check_decode, \
check_erase, check_five_crop, check_gaussian_blur, check_hsv_to_rgb, check_linear_transform, check_mix_up, \
check_mix_up_batch_c, check_normalize, check_normalizepad, check_num_channels, check_pad, check_pad_to_size, \
check_positive_degrees, check_posterize, check_prob, check_random_adjust_sharpness, check_random_affine, \
check_random_auto_contrast, check_random_color_adjust, check_random_crop, check_random_erasing, \
check_random_perspective, check_random_posterize, check_random_resize_crop, check_random_rotation, \
check_random_select_subpolicy_op, check_random_solarize, check_range, check_rescale, check_resize, \
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_solarize, check_ten_crop, \
check_trivial_augment_wide, check_uniform_augment, check_to_tensor, FLOAT_MAX_INTEGER
from .validators import check_adjust_brightness, check_adjust_contrast, check_adjust_gamma, check_adjust_hue, \
check_adjust_saturation, check_adjust_sharpness, check_alpha, check_auto_augment, check_auto_contrast, \
check_bounding_box_augment_cpp, check_center_crop, check_convert_color, check_crop, check_cut_mix_batch_c, \
check_cutout_new, check_decode, check_erase, check_five_crop, check_gaussian_blur, check_hsv_to_rgb, \
check_linear_transform, check_mix_up, check_mix_up_batch_c, check_normalize, check_normalizepad, \
check_num_channels, check_pad, check_pad_to_size, check_positive_degrees, check_posterize, check_prob, \
check_random_adjust_sharpness, check_random_affine, check_random_auto_contrast, check_random_color_adjust, \
check_random_crop, check_random_erasing, check_random_perspective, check_random_posterize, \
check_random_resize_crop, check_random_rotation, check_random_select_subpolicy_op, check_random_solarize, \
check_range, check_rescale, check_resize, check_resize_interpolation, check_rgb_to_hsv, check_rotate, \
check_slice_patches, check_solarize, check_ten_crop, check_trivial_augment_wide, check_uniform_augment, \
check_to_tensor, FLOAT_MAX_INTEGER
from ..core.datatypes import mstype_to_detype, nptype_to_detype
from ..transforms.py_transforms_util import Implementation
from ..transforms.transforms import CompoundOperation, PyTensorOperation, TensorOperation, TypeCast
@ -101,8 +102,8 @@ class AdjustBrightness(ImageTensorOperation, PyTensorOperation):
Args:
brightness_factor (float): How much to adjust the brightness. Can be any non negative number.
Non negative real number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
0 gives a black image, 1 gives the original image,
while 2 increases the brightness by a factor of 2.
Raises:
TypeError: If `brightness_factor` is not of type float.
@ -117,6 +118,7 @@ class AdjustBrightness(ImageTensorOperation, PyTensorOperation):
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_adjust_brightness
def __init__(self, brightness_factor):
super().__init__()
@ -138,6 +140,50 @@ class AdjustBrightness(ImageTensorOperation, PyTensorOperation):
return util.adjust_brightness(img, self.brightness_factor)
class AdjustContrast(ImageTensorOperation, PyTensorOperation):
r"""
Adjust contrast of input image. Input image is expected to be in [H, W, C] format.
Args:
contrast_factor (float): How much to adjust the contrast. Can be any non negative number.
0 gives a solid gray image, 1 gives the original image,
while 2 increases the contrast by a factor of 2.
Raises:
TypeError: If `contrast_factor` is not of type float.
ValueError: If `contrast_factor` is less than 0.
RuntimeError: If given tensor shape is not <H, W, C>.
Supported Platforms:
``CPU``
Examples:
>>> transforms_list = [vision.Decode(), vision.AdjustContrast(contrast_factor=2.0)]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_adjust_contrast
def __init__(self, contrast_factor):
super().__init__()
self.contrast_factor = contrast_factor
def parse(self):
return cde.AdjustContrastOperation(self.contrast_factor)
def execute_py(self, img):
"""
Execute method.
Args:
img (PIL Image): Image to be contrast adjusted.
Returns:
PIL Image, contrast adjusted image.
"""
return util.adjust_contrast(img, self.contrast_factor)
class AdjustGamma(ImageTensorOperation, PyTensorOperation):
r"""
Apply gamma correction on input image. Input image is expected to be in [..., H, W, C] or [H, W] format.

View File

@ -1021,6 +1021,19 @@ def check_adjust_brightness(method):
return new_method
def check_adjust_contrast(method):
"""Wrapper method to check the parameters of AdjustContrast ops (Python and C++)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[contrast_factor], _ = parse_user_args(method, *args, **kwargs)
type_check(contrast_factor, (float, int), "contrast_factor")
check_value(contrast_factor, (0, FLOAT_MAX_INTEGER), "contrast_factor")
return method(self, *args, **kwargs)
return new_method
def check_adjust_gamma(method):
"""Wrapper method to check the parameters of AdjustGamma ops (Python and C++)."""

View File

@ -2608,3 +2608,55 @@ TEST_F(MindDataTestPipeline, TestAdjustHueParamCheck) {
// Expect failure: invalid value of AdjustHue
EXPECT_EQ(iter1, nullptr);
}
/// Feature: AdjustContrast op
/// Description: Test AdjustContrast C implementation Pipeline
/// Expectation: Output is equal to the expected output
TEST_F(MindDataTestPipeline, TestAdjustContrast) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAdjustContrast.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
auto adjustcontrast_op = vision::AdjustContrast(2.0);
ds = ds->Map({adjustcontrast_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 2);
iter->Stop();
}
/// Feature: AdjustContrast op
/// Description: Test improper parameters for AdjustContrast C implementation
/// Expectation: Throw ValueError exception
TEST_F(MindDataTestPipeline, TestAdjustContrastParamCheck) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAdjustContrastParamCheck.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
// Case 1: Negative contrast_factor
// Create objects for the tensor ops
auto adjustcontrast_op = vision::AdjustContrast(-1);
auto ds1 = ds->Map({adjustcontrast_op});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid value of AdjustContrast
EXPECT_EQ(iter1, nullptr);
}

View File

@ -2835,7 +2835,7 @@ TEST_F(MindDataTestExecute, TestEraseEager) {
EXPECT_EQ(rc, Status::OK());
}
/// Feature: Execute Transform op
/// Feature: AdjustBrightness
/// Description: Test executing AdjustBrightness op in eager mode
/// Expectation: The data is processed successfully
TEST_F(MindDataTestExecute, TestAdjustBrightness) {
@ -2919,3 +2919,20 @@ TEST_F(MindDataTestExecute, TestAdjustHue) {
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}
/// Feature: AdjustContrast
/// Description: Test executing AdjustContrast op in eager mode
/// Expectation: The data is processed successfully
TEST_F(MindDataTestExecute, TestAdjustContrast) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestAdjustContrast.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto adjust_contrast_op = vision::AdjustContrast(1);
auto transform = Execute({decode, adjust_contrast_op});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}

View File

@ -17,10 +17,10 @@ Testing AdjustBrightness op in DE
"""
import numpy as np
from numpy.testing import assert_allclose
import mindspore.dataset as ds
import mindspore.dataset.transforms.transforms
import mindspore.dataset.vision as vision
from mindspore.dataset.vision import Decode
from mindspore import log as logger
from util import diff_mse
@ -51,7 +51,7 @@ def test_adjust_brightness_eager(plot=False):
img = np.fromfile(image_file, dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = Decode()(img)
img = vision.Decode()(img)
img_adjustbrightness = vision.AdjustBrightness(1)(img)
if plot:
visualize_image(img, img_adjustbrightness)
@ -64,7 +64,6 @@ def test_adjust_brightness_eager(plot=False):
def test_adjust_brightness_invalid_brightness_factor_param():
"""
Test AdjustBrightness implementation with invalid ignore parameter
Feature: AdjustBrightness op
Description: Test improper parameters for AdjustBrightness implementation
Expectation: Throw ValueError exception and TypeError exception
@ -141,6 +140,7 @@ def test_adjust_brightness_pipeline():
logger.info("MSE= {}".format(str(mse)))
assert mse == 0
if __name__ == "__main__":
test_adjust_brightness_eager()
test_adjust_brightness_invalid_brightness_factor_param()

View File

@ -0,0 +1,147 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing AdjustContrast op in DE
"""
import numpy as np
from numpy.testing import assert_allclose
import mindspore.dataset as ds
import mindspore.dataset.transforms.transforms
import mindspore.dataset.vision as vision
from mindspore import log as logger
from util import diff_mse
DATA_DIR = "../data/dataset/testImageNetData/train/"
MNIST_DATA_DIR = "../data/dataset/testMnistData"
DATA_DIR_2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def generate_numpy_random_rgb(shape):
"""
Only generate floating points that are fractions like n / 256, since they
are RGB pixels. Some low-precision floating point types in this test can't
handle arbitrary precision floating points well.
"""
return np.random.randint(0, 256, shape) / 255.
def test_adjust_contrast_eager(plot=False):
"""
Feature: AdjustContrast op
Description: Test AdjustContrast in eager mode
Expectation: Output is the same as expected output
"""
# Eager 3-channel
image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg"
img = np.fromfile(image_file, dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = vision.Decode()(img)
img_adjustcontrast = vision.AdjustContrast(1)(img)
if plot:
visualize_image(img, img_adjustcontrast)
logger.info("Image.type: {}, Image.shape: {}".format(type(img_adjustcontrast),
img_adjustcontrast.shape))
mse = diff_mse(img_adjustcontrast, img)
logger.info("MSE= {}".format(str(mse)))
assert mse == 0
def test_adjust_contrast_invalid_contrast_factor_param():
"""
Feature: AdjustContrast op
Description: Test improper parameters for AdjustContrast implementation
Expectation: Throw ValueError exception and TypeError exception
"""
logger.info("Test AdjustContrast Python implementation with invalid ignore parameter")
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
trans = mindspore.dataset.transforms.transforms.Compose([
vision.Decode(True),
vision.Resize((224, 224)),
vision.AdjustContrast(contrast_factor=-10.0),
vision.ToTensor()
])
data_set = data_set.map(operations=[trans], input_columns=["image"])
except ValueError as error:
logger.info("Got an exception in AdjustContrast: {}".format(str(error)))
assert "Input contrast_factor is not within the required interval of " in str(error)
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
trans = ds.transforms.transforms.Compose([
vision.Decode(True),
vision.Resize((224, 224)),
vision.AdjustContrast(contrast_factor=[1, 2]),
vision.ToTensor()
])
data_set = data_set.map(operations=[trans], input_columns=["image"])
except TypeError as error:
logger.info("Got an exception in AdjustContrast: {}".format(str(error)))
assert "is not of type [<class 'float'>, <class 'int'>], but got" in str(error)
def test_adjust_contrast_pipeline():
"""
Feature: AdjustContrast op
Description: Test AdjustContrast in pipeline mode
Expectation: Output is the same as expected output
"""
# First dataset
transforms1 = [vision.Decode(True), vision.Resize([64, 64]), vision.ToTensor()]
transforms1 = mindspore.dataset.transforms.transforms.Compose(
transforms1)
ds1 = ds.TFRecordDataset(DATA_DIR_2,
SCHEMA_DIR,
columns_list=["image"],
shuffle=False)
ds1 = ds1.map(operations=transforms1, input_columns=["image"])
# Second dataset
transforms2 = [
vision.Decode(True),
vision.Resize([64, 64]),
vision.AdjustContrast(1.0),
vision.ToTensor()
]
transform2 = mindspore.dataset.transforms.transforms.Compose(
transforms2)
ds2 = ds.TFRecordDataset(DATA_DIR_2,
SCHEMA_DIR,
columns_list=["image"],
shuffle=False)
ds2 = ds2.map(operations=transform2, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
ds2.create_dict_iterator(num_epochs=1)):
num_iter += 1
ori_img = data1["image"].asnumpy()
cvt_img = data2["image"].asnumpy()
assert_allclose(ori_img.flatten(),
cvt_img.flatten(),
rtol=1e-5,
atol=0)
mse = diff_mse(ori_img, cvt_img)
logger.info("MSE= {}".format(str(mse)))
assert mse == 0
if __name__ == "__main__":
test_adjust_contrast_eager()
test_adjust_contrast_invalid_contrast_factor_param()
test_adjust_contrast_pipeline()