[feat][assistant][I3CEGM] add op adjust gamma

This commit is contained in:
chenx2ovo 2021-08-02 18:09:40 +08:00
parent 84c9082468
commit 4191dde45f
23 changed files with 924 additions and 9 deletions

View File

@ -18,6 +18,7 @@
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"
@ -67,6 +68,17 @@
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
AdjustGammaOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::AdjustGammaOperation, TensorOperation, std::shared_ptr<vision::AdjustGammaOperation>>(
*m, "AdjustGammaOperation")
.def(py::init([](float gamma, float gain) {
auto ajust_gamma = std::make_shared<vision::AdjustGammaOperation>(gamma, gain);
THROW_IF_ERROR(ajust_gamma->ValidateParams());
return ajust_gamma;
}));
}));
PYBIND_REGISTER(
AutoContrastOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::AutoContrastOperation, TensorOperation, std::shared_ptr<vision::AutoContrastOperation>>(

View File

@ -21,6 +21,7 @@
#endif
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/affine_ir.h"
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
@ -118,6 +119,19 @@ std::shared_ptr<TensorOperation> Affine::Parse() {
}
#ifndef ENABLE_ANDROID
// AdjustGamma Transform Operation.
struct AdjustGamma::Data {
Data(float gamma, float gain) : gamma_(gamma), gain_(gain) {}
float gamma_;
float gain_;
};
AdjustGamma::AdjustGamma(float gamma, float gain) : data_(std::make_shared<Data>(gamma, gain)) {}
std::shared_ptr<TensorOperation> AdjustGamma::Parse() {
return std::make_shared<AdjustGammaOperation>(data_->gamma_, data_->gain_);
}
// AutoContrast Transform Operation.
struct AutoContrast::Data {
Data(float cutoff, const std::vector<uint32_t> &ignore) : cutoff_(cutoff), ignore_(ignore) {}

View File

@ -57,7 +57,31 @@ class AutoContrast final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Apply a given image transform on a random selection of bounding box regions of a given image.
/// \brief AdjustGamma TensorTransform.
/// \notes Apply gamma correction on input image.
class AdjustGamma final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] gamma Non negative real number, which makes the output image pixel value
/// exponential in relation to the input image pixel value.
/// \param[in] gain The constant multiplier.
explicit AdjustGamma(float gamma, float gain = 1);
/// \brief Destructor.
~AdjustGamma() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief BoundingBoxAugment TensorTransform.
/// \note Apply a given image transform on a random selection of bounding box regions of a given image.
class BoundingBoxAugment final : public TensorTransform {
public:
/// \brief Constructor.

View File

@ -6,6 +6,7 @@ if(ENABLE_ACL)
add_subdirectory(dvpp)
endif()
add_library(kernels-image OBJECT
adjust_gamma_op.cc
affine_op.cc
auto_contrast_op.cc
bounding_box.cc

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/adjust_gamma_op.h"
#include <memory>
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
const float AdjustGammaOp::kGain = 1.0;
Status AdjustGammaOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
// typecast
CHECK_FAIL_RETURN_UNEXPECTED(input->type() != DataType::DE_STRING,
"AdjustGamma: input tensor type should be [int, float, double], but got string.");
if (input->type().IsFloat()) {
std::shared_ptr<Tensor> input_tensor;
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
return AdjustGamma(input_tensor, output, gamma_, gain_);
} else {
return AdjustGamma(input, output, gamma_, gain_);
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_GAMMA_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_GAMMA_OP_H_
#include <memory>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class AdjustGammaOp : public TensorOp {
public:
/// Default gain to be used
static const float kGain;
AdjustGammaOp(const float &gamma, const float &gain) : gamma_(gamma), gain_(gain) {}
~AdjustGammaOp() override = default;
/// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const AdjustGammaOp &so) {
so.Print(out);
return out;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kAdjustGammaOp; }
private:
float gamma_;
float gain_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_GAMMA_OP_H_

View File

@ -872,6 +872,64 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens
return Status::OK();
}
Status AdjustGamma(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &gamma,
const float &gain) {
try {
int num_channels = 1;
if (input->Rank() < 2) {
RETURN_STATUS_UNEXPECTED("AdjustGamma: image shape is not <...,H,W,C> or <H,W>.");
}
if (input->Rank() > 2) {
num_channels = input->shape()[-1];
}
if (num_channels != 1 && num_channels != 3) {
RETURN_STATUS_UNEXPECTED("AdjustGamma: channel of input image should be 1 or 3.");
}
if (input->type().IsFloat()) {
for (auto itr = input->begin<float>(); itr != input->end<float>(); itr++) {
*itr = pow((*itr) * gain, gamma);
*itr = std::min(std::max((*itr), 0.0f), 1.0f);
}
*output = input;
} else {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("AdjustGamma: load image failed.");
}
cv::Mat input_img = input_cv->mat();
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
uchar LUT[256] = {};
for (int i = 0; i < 256; i++) {
float f = i / 255.0;
f = pow(f, gamma);
LUT[i] = static_cast<uchar>(floor(std::min(f * (255.0 + 1 - 1e-3) * gain, 255.0)));
}
if (input_img.channels() == 1) {
cv::MatIterator_<uchar> it = input_img.begin<uchar>();
cv::MatIterator_<uchar> it_end = input_img.end<uchar>();
for (; it != it_end; ++it) {
*it = LUT[(*it)];
}
} else {
cv::MatIterator_<cv::Vec3b> it = input_img.begin<cv::Vec3b>();
cv::MatIterator_<cv::Vec3b> it_end = input_img.end<cv::Vec3b>();
for (; it != it_end; ++it) {
(*it)[0] = LUT[(*it)[0]];
(*it)[1] = LUT[(*it)[1]];
(*it)[2] = LUT[(*it)[2]];
}
}
output_cv->mat() = input_img * 1;
*output = std::static_pointer_cast<Tensor>(output_cv);
}
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("AdjustGamma: " + std::string(e.what()));
}
return Status::OK();
}
Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &cutoff,
const std::vector<uint32_t> &ignore) {
try {

View File

@ -234,6 +234,16 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens
Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &cutoff,
const std::vector<uint32_t> &ignore);
/// \brief Returns image with gamma correction.
/// \param[in] input: Tensor of shape <H,W,3>/<H,W,1>/<H,W> in RGB/Grayscale and any OpenCV compatible type,
/// see CVTensor.
/// \param[in] gamma: Non negative real number, same as gamma in the equation. gamma larger than 1 make the shadows
/// darker, while gamma smaller than 1 make dark regions lighter.
/// \param[in] gain: The constant multiplier.
/// \param[out] output: Adjusted image of same shape and type.
Status AdjustGamma(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &gamma,
const float &gain);
/// \brief Returns image with adjusted saturation.
/// \param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor.
/// \param alpha: Alpha value to adjust saturation by. Should be a positive number.

View File

@ -38,6 +38,11 @@ Status ValidateFloatScalarPositive(const std::string &op_name, const std::string
return Status::OK();
}
Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, false));
return Status::OK();
}
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) {
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
std::string err_msg =

View File

@ -36,6 +36,9 @@ Status ValidateIntScalarPositive(const std::string &op_name, const std::string &
// Helper function to positive float scalar
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to non-negative float scalar
Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to validate scalar
template <typename T>
Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,

View File

@ -2,6 +2,7 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
set(DATASET_KERNELS_IR_VISION_SRC_FILES
adjust_gamma_ir.cc
affine_ir.cc
auto_contrast_ir.cc
bounding_box_augment_ir.cc

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/adjust_gamma_op.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// AdjustGammaOperation
AdjustGammaOperation::AdjustGammaOperation(float gamma, float gain) : gamma_(gamma), gain_(gain) {}
Status AdjustGammaOperation::ValidateParams() {
// gamma
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("AdjustGamma", "gamma", gamma_));
return Status::OK();
}
std::shared_ptr<TensorOp> AdjustGammaOperation::Build() {
std::shared_ptr<AdjustGammaOp> tensor_op = std::make_shared<AdjustGammaOp>(gamma_, gain_);
return tensor_op;
}
Status AdjustGammaOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["gamma"] = gamma_;
args["gain"] = gain_;
*out_json = args;
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_GAMMA_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_GAMMA_IR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kAdjustGammaOperation[] = "AdjustGamma";
class AdjustGammaOperation : public TensorOperation {
public:
explicit AdjustGammaOperation(float gamma, float gain);
~AdjustGammaOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kAdjustGammaOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float gamma_;
float gain_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_GAMMA_IR_H_

View File

@ -53,6 +53,7 @@ namespace dataset {
constexpr char kTensorOp[] = "TensorOp";
// image
constexpr char kAdjustGammaOp[] = "AdjustGammaOp";
constexpr char kAffineOp[] = "AffineOp";
constexpr char kAutoContrastOp[] = "AutoContrastOp";
constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";

View File

@ -54,7 +54,7 @@ from .validators import check_prob, check_crop, check_center_crop, check_resize_
check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches, check_adjust_gamma
from ..transforms.c_transforms import TensorOperation
@ -107,6 +107,37 @@ def parse_padding(padding):
return padding
class AdjustGamma(ImageTensorOperation):
r"""
Apply gamma correction on input image. Input image is expected to be in [..., H, W, C] or [H, W, C] format.
.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
gamma (float): Non negative real number.
The output image pixel value is exponentially related to the input image pixel value.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float, optional): The constant multiplier (default=1).
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.AdjustGamma(gamma=10.0, gain=1.0)]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_adjust_gamma
def __init__(self, gamma, gain=1):
self.gamma = gamma
self.gain = gain
def parse(self):
return cde.AdjustGammaOperation(self.gamma, self.gain)
class AutoContrast(ImageTensorOperation):
"""
Apply automatic contrast on input image. This operator calculates histogram of image, reassign cutoff percent

View File

@ -31,7 +31,8 @@ from .validators import check_prob, check_center_crop, check_five_crop, check_re
check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, \
check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \
check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast, check_rgb_to_bgr
check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast, check_rgb_to_bgr, \
check_adjust_gamma
from .utils import Inter, Border
from .py_transforms_util import is_pil
@ -1375,7 +1376,6 @@ class RgbToBgr:
return util.rgb_to_bgrs(rgb_imgs, self.is_hwc)
class RgbToHsv:
"""
Convert a NumPy RGB image or a batch of NumPy RGB images to HSV images.
@ -1525,6 +1525,44 @@ class RandomSharpness:
return util.random_sharpness(img, self.degrees)
class AdjustGamma:
"""
Adjust gamma of the input PIL image.
Args:
gamma (float): Non negative real number, same as gamma in the equation.
gain (float, optional): The constant multiplier.
Examples:
>>> from mindspore.dataset.transforms.py_transforms import Compose
>>> transforms_list = Compose([py_vision.Decode(),
... py_vision.AdjustGamma(),
... py_vision.ToTensor()])
>>> # apply the transform to dataset through map function
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns="image")
"""
@check_adjust_gamma
def __init__(self, gamma, gain=1.0):
self.gamma = gamma
self.gain = gain
self.random = False
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be augmented with AutoContrast.
Returns:
img (PIL image), Augmented image.
"""
return util.adjust_gamma(img, self.gamma, self.gain)
class AutoContrast:
"""
Automatically maximize the contrast of the input PIL image.

View File

@ -19,7 +19,6 @@ import math
import numbers
import random
import colorsys
import numpy as np
from PIL import Image, ImageOps, ImageEnhance, __version__
@ -1243,6 +1242,7 @@ def rgb_to_bgr(np_rgb_img, is_hwc):
np_bgr_img = np_rgb_img[::-1, :, :]
return np_bgr_img
def rgb_to_bgrs(np_rgb_imgs, is_hwc):
"""
Convert RGB imgs to BGR imgs.
@ -1473,6 +1473,32 @@ def random_sharpness(img, degrees):
return ImageEnhance.Sharpness(img).enhance(v)
def adjust_gamma(img, gamma, gain):
"""
Adjust gamma of the input PIL image.
Args:
img (PIL image): Image to be augmented with AdjustGamma.
gamma (float): Non negative real number, same as gamma in the equation.
gain (float, optional): The constant multiplier.
Returns:
img (PIL image), Augmented image.
"""
if not is_pil(img):
raise TypeError("img should be PIL image. Got {}.".format(type(img)))
gamma_table = [(255 + 1 - 1e-3) * gain * pow(x / 255., gamma) for x in range(256)]
if len(img.split()) == 3:
gamma_table = gamma_table * 3
img = img.point(gamma_table)
elif len(img.split()) == 1:
img = img.point(gamma_table)
return img
def auto_contrast(img, cutoff, ignore):
"""
Automatically maximize the contrast of the input PIL image.

View File

@ -19,10 +19,10 @@ from functools import wraps
import numpy as np
from mindspore._c_dataengine import TensorOp, TensorOperation
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, parse_user_args, type_check, \
type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff, check_value_ratio, \
check_odd
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER, \
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
check_value_cutoff, check_value_ratio, check_odd
from .utils import Inter, Border, ImageBatchFormat, SliceMode
@ -788,6 +788,22 @@ def check_bounding_box_augment_cpp(method):
return new_method
def check_adjust_gamma(method):
"""Wrapper method to check the parameters of AdjustGamma ops (Python and C++)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[gamma, gain], _ = parse_user_args(method, *args, **kwargs)
type_check(gamma, (float, int), "gamma")
check_value(gamma, (0, FLOAT_MAX_INTEGER))
if gain is not None:
type_check(gain, (float, int), "gain")
check_value(gain, (FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER))
return method(self, *args, **kwargs)
return new_method
def check_auto_contrast(method):
"""Wrapper method to check the parameters of AutoContrast ops (Python and C++)."""

View File

@ -27,6 +27,102 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
// Tests for vision C++ API A to Q TensorTransform Operations (in alphabetical order)
TEST_F(MindDataTestPipeline, TestAdjustGammaSuccess1) {
// pipeline 3-channel
MS_LOG(INFO) << "Pipeline Test.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds2, nullptr);
auto adjustgamma_op = vision::AdjustGamma(10.0);
ds1 = ds1->Map({adjustgamma_op});
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_NE(iter1, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row1;
iter1->GetNextRow(&row1);
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
EXPECT_NE(iter2, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row2;
iter2->GetNextRow(&row2);
uint64_t i = 0;
while (row1.size() != 0) {
i++;
auto image = row1["image"];
iter1->GetNextRow(&row1);
iter2->GetNextRow(&row2);
}
EXPECT_EQ(i, 2);
iter1->Stop();
iter2->Stop();
}
TEST_F(MindDataTestPipeline, TestAdjustGammaSuccess2) {
// pipeline 1-channel
MS_LOG(INFO) << "Pipeline Test.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds2, nullptr);
auto adjustgamma_op = vision::AdjustGamma(10.0);
auto rgb2gray_op = vision::RGB2GRAY();
ds1 = ds1->Map({rgb2gray_op, adjustgamma_op});
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_NE(iter1, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row1;
iter1->GetNextRow(&row1);
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
EXPECT_NE(iter2, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row2;
iter2->GetNextRow(&row2);
uint64_t i = 0;
while (row1.size() != 0) {
i++;
auto image = row1["image"];
iter1->GetNextRow(&row1);
iter2->GetNextRow(&row2);
}
EXPECT_EQ(i, 2);
iter1->Stop();
iter2->Stop();
}
TEST_F(MindDataTestPipeline, TestAdjustGammaParamCheck) {
// pipeline 3-channel
MS_LOG(INFO) << "Pipeline Test.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
// Case 1: Negative gamma
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> adjust_gamma(new vision::AdjustGamma(-1, 1.0));
auto ds1 = ds->Map({adjust_gamma});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid value of AdjustGamma
EXPECT_EQ(iter1, nullptr);
}
TEST_F(MindDataTestPipeline, TestAutoContrastSuccess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAutoContrastSuccess1.";

View File

@ -134,6 +134,10 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
expect_image_path = dir_path + "imagefolder/apple_expect_randomaffine.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_randomaffine.jpg";
break;
case kAdjustGamma:
expect_image_path = dir_path + "imagefolder/apple_expect_adjustgamma.png";
actual_image_path = dir_path + "imagefolder/apple_actual_adjustgamma.png";
break;
case kAutoContrast:
expect_image_path = dir_path + "imagefolder/apple_expect_autocontrast.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_autocontrast.jpg";

View File

@ -44,6 +44,7 @@ class CVOpCommon : public Common {
kRandomAffine,
kRandomPosterize,
kAutoContrast,
kAdjustGamma,
kEqualize
};

View File

@ -70,6 +70,35 @@ TEST_F(MindDataTestExecute, TestAllpassBiquadWithWrongArg) {
EXPECT_FALSE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestAdjustGammaEager1) {
// 3-channel eager
MS_LOG(INFO) << "3-channel image test";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto adjust_gamma_op = vision::AdjustGamma(0.1, 1.0);
auto transform = Execute({decode, adjust_gamma_op});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}
TEST_F(MindDataTestExecute, TestAdjustGammaEager2) {
// 1-channel eager
MS_LOG(INFO) << "1-channel image test";
auto m1 = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto rgb2gray = vision::RGB2GRAY();
auto adjust_gamma_op = vision::AdjustGamma(0.1, 1.0);
auto transform = Execute({decode, rgb2gray, adjust_gamma_op});
Status rc = transform(m1, &m1);
EXPECT_EQ(rc, Status::OK());
}
TEST_F(MindDataTestExecute, TestComposeTransforms) {
MS_LOG(INFO) << "Doing TestComposeTransforms.";

View File

@ -0,0 +1,329 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing AdjustGamma op in DE
"""
import numpy as np
from numpy.testing import assert_allclose
import PIL
import mindspore.dataset as ds
import mindspore.dataset.transforms.py_transforms
import mindspore.dataset.vision.py_transforms as F
import mindspore.dataset.vision.c_transforms as C
from mindspore import log as logger
DATA_DIR = "../data/dataset/testImageNetData/train/"
MNIST_DATA_DIR = "../data/dataset/testMnistData"
DATA_DIR_2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN = False
def generate_numpy_random_rgb(shape):
"""
Only generate floating points that are fractions like n / 256, since they
are RGB pixels. Some low-precision floating point types in this test can't
handle arbitrary precision floating points well.
"""
return np.random.randint(0, 256, shape) / 255.
def test_adjust_gamma_c_eager():
# Eager 3-channel
rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
img_in = rgb_flat.reshape((8, 8, 3))
adjustgamma_op = C.AdjustGamma(10, 1)
img_out = adjustgamma_op(img_in)
assert img_out is not None
def test_adjust_gamma_py_eager():
# Eager 3-channel
rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.uint8)
img_in = PIL.Image.fromarray(rgb_flat.reshape((8, 8, 3)))
adjustgamma_op = F.AdjustGamma(10, 1)
img_out = adjustgamma_op(img_in)
assert img_out is not None
def test_adjust_gamma_c_eager_gray():
# Eager 3-channel
rgb_flat = generate_numpy_random_rgb((64, 1)).astype(np.float32)
img_in = rgb_flat.reshape((8, 8))
adjustgamma_op = C.AdjustGamma(10, 1)
img_out = adjustgamma_op(img_in)
assert img_out is not None
def test_adjust_gamma_py_eager_gray():
# Eager 3-channel
rgb_flat = generate_numpy_random_rgb((64, 1)).astype(np.uint8)
img_in = PIL.Image.fromarray(rgb_flat.reshape((8, 8)))
adjustgamma_op = F.AdjustGamma(10, 1)
img_out = adjustgamma_op(img_in)
assert img_out is not None
def test_adjust_gamma_invalid_gamma_param_c():
"""
Test AdjustGamma C Op with invalid ignore parameter
"""
logger.info("Test AdjustGamma C Op with invalid ignore parameter")
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data_set = data_set.map(operations=[C.Decode(),
C.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])],
input_columns=["image"])
# invalid gamma
data_set = data_set.map(operations=C.AdjustGamma(gamma=-10.0,
gain=1.0),
input_columns="image")
except ValueError as error:
logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
assert "Input is not within the required interval of " in str(error)
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data_set = data_set.map(operations=[C.Decode(),
C.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])],
input_columns=["image"])
# invalid gamma
data_set = data_set.map(operations=C.AdjustGamma(gamma=[1, 2],
gain=1.0),
input_columns="image")
except TypeError as error:
logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
assert "is not of type [<class 'float'>, <class 'int'>], but got" in str(error)
def test_adjust_gamma_invalid_gamma_param_py():
"""
Test AdjustGamma python Op with invalid ignore parameter
"""
logger.info("Test AdjustGamma python Op with invalid ignore parameter")
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
trans = mindspore.dataset.transforms.py_transforms.Compose([
F.Decode(),
F.Resize((224, 224)),
F.AdjustGamma(gamma=-10.0),
F.ToTensor()
])
data_set = data_set.map(operations=[trans],
input_columns=["image"])
except ValueError as error:
logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
assert "Input is not within the required interval of " in str(error)
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
trans = mindspore.dataset.transforms.py_transforms.Compose([
F.Decode(),
F.Resize((224, 224)),
F.AdjustGamma(gamma=[1, 2]),
F.ToTensor()
])
data_set = data_set.map(operations=[trans],
input_columns=["image"])
except TypeError as error:
logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
assert "is not of type [<class 'float'>, <class 'int'>], but got" in str(error)
def test_adjust_gamma_invalid_gain_param_c():
"""
Test AdjustGamma C Op with invalid gain parameter
"""
logger.info("Test AdjustGamma C Op with invalid gain parameter")
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data_set = data_set.map(operations=[C.Decode(),
C.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])],
input_columns=["image"])
# invalid gain
data_set = data_set.map(operations=C.AdjustGamma(gamma=10.0,
gain=[1, 10]),
input_columns="image")
except TypeError as error:
logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
assert "is not of type [<class 'float'>, <class 'int'>], but got " in str(error)
def test_adjust_gamma_invalid_gain_param_py():
"""
Test AdjustGamma python Op with invalid gain parameter
"""
logger.info("Test AdjustGamma python Op with invalid gain parameter")
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
trans = mindspore.dataset.transforms.py_transforms.Compose([
F.Decode(),
F.Resize((224, 224)),
F.AdjustGamma(gamma=10.0, gain=[1, 10]),
F.ToTensor()
])
data_set = data_set.map(operations=[trans],
input_columns=["image"])
except TypeError as error:
logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
assert "is not of type [<class 'float'>, <class 'int'>], but got " in str(error)
def test_adjust_gamma_pipeline_c():
"""
Test AdjustGamma C Op Pipeline
"""
# First dataset
transforms1 = [C.Decode(), C.Resize([64, 64])]
transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
transforms1)
ds1 = ds.TFRecordDataset(DATA_DIR_2,
SCHEMA_DIR,
columns_list=["image"],
shuffle=False)
ds1 = ds1.map(operations=transforms1, input_columns=["image"])
# Second dataset
transforms2 = [
C.Decode(),
C.Resize([64, 64]),
C.AdjustGamma(1.0, 1.0)
]
transform2 = mindspore.dataset.transforms.py_transforms.Compose(
transforms2)
ds2 = ds.TFRecordDataset(DATA_DIR_2,
SCHEMA_DIR,
columns_list=["image"],
shuffle=False)
ds2 = ds2.map(operations=transform2, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
ds2.create_dict_iterator(num_epochs=1)):
num_iter += 1
ori_img = data1["image"].asnumpy()
cvt_img = data2["image"].asnumpy()
assert_allclose(ori_img.flatten(),
cvt_img.flatten(),
rtol=1e-5,
atol=0)
assert ori_img.shape == cvt_img.shape
def test_adjust_gamma_pipeline_py():
"""
Test AdjustGamma python Op Pipeline
"""
# First dataset
transforms1 = [F.Decode(), F.Resize([64, 64]), F.ToTensor()]
transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
transforms1)
ds1 = ds.TFRecordDataset(DATA_DIR_2,
SCHEMA_DIR,
columns_list=["image"],
shuffle=False)
ds1 = ds1.map(operations=transforms1, input_columns=["image"])
# Second dataset
transforms2 = [
F.Decode(),
F.Resize([64, 64]),
F.AdjustGamma(1.0, 1.0),
F.ToTensor()
]
transform2 = mindspore.dataset.transforms.py_transforms.Compose(
transforms2)
ds2 = ds.TFRecordDataset(DATA_DIR_2,
SCHEMA_DIR,
columns_list=["image"],
shuffle=False)
ds2 = ds2.map(operations=transform2, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
ds2.create_dict_iterator(num_epochs=1)):
num_iter += 1
ori_img = data1["image"].asnumpy()
cvt_img = data2["image"].asnumpy()
assert_allclose(ori_img.flatten(),
cvt_img.flatten(),
rtol=1e-5,
atol=0)
assert ori_img.shape == cvt_img.shape
def test_adjust_gamma_pipeline_py_gray():
"""
Test AdjustGamma python Op Pipeline 1-channel
"""
# First dataset
transforms1 = [F.Decode(), F.Resize([64, 64]), F.Grayscale(), F.ToTensor()]
transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
transforms1)
ds1 = ds.TFRecordDataset(DATA_DIR_2,
SCHEMA_DIR,
columns_list=["image"],
shuffle=False)
ds1 = ds1.map(operations=transforms1, input_columns=["image"])
# Second dataset
transforms2 = [
F.Decode(),
F.Resize([64, 64]),
F.Grayscale(),
F.AdjustGamma(1.0, 1.0),
F.ToTensor()
]
transform2 = mindspore.dataset.transforms.py_transforms.Compose(
transforms2)
ds2 = ds.TFRecordDataset(DATA_DIR_2,
SCHEMA_DIR,
columns_list=["image"],
shuffle=False)
ds2 = ds2.map(operations=transform2, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
ds2.create_dict_iterator(num_epochs=1)):
num_iter += 1
ori_img = data1["image"].asnumpy()
cvt_img = data2["image"].asnumpy()
assert_allclose(ori_img.flatten(),
cvt_img.flatten(),
rtol=1e-5,
atol=0)
if __name__ == "__main__":
test_adjust_gamma_c_eager()
test_adjust_gamma_py_eager()
test_adjust_gamma_c_eager_gray()
test_adjust_gamma_py_eager_gray()
test_adjust_gamma_invalid_gamma_param_c()
test_adjust_gamma_invalid_gamma_param_py()
test_adjust_gamma_invalid_gain_param_c()
test_adjust_gamma_invalid_gain_param_py()
test_adjust_gamma_pipeline_c()
test_adjust_gamma_pipeline_py()
test_adjust_gamma_pipeline_py_gray()