add GaussianBlur API

This commit is contained in:
Xiao Tianci 2021-05-17 11:44:43 +08:00
parent 00280f662a
commit 3f495bbae1
23 changed files with 705 additions and 2 deletions

View File

@ -25,6 +25,7 @@
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
#include "minddata/dataset/kernels/ir/vision/equalize_ir.h"
#include "minddata/dataset/kernels/ir/vision/gaussian_blur_ir.h"
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
#include "minddata/dataset/kernels/ir/vision/invert_ir.h"
#include "minddata/dataset/kernels/ir/vision/mixup_batch_ir.h"
@ -142,6 +143,17 @@ PYBIND_REGISTER(EqualizeOperation, 1, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(
GaussianBlurOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::GaussianBlurOperation, TensorOperation, std::shared_ptr<vision::GaussianBlurOperation>>(
*m, "GaussianBlurOperation")
.def(py::init([](std::vector<int32_t> kernel_size, std::vector<float> sigma) {
auto gaussian_blur = std::make_shared<vision::GaussianBlurOperation>(kernel_size, sigma);
THROW_IF_ERROR(gaussian_blur->ValidateParams());
return gaussian_blur;
}));
}));
PYBIND_REGISTER(HwcToChwOperation, 1, ([](const py::module *m) {
(void)
py::class_<vision::HwcToChwOperation, TensorOperation, std::shared_ptr<vision::HwcToChwOperation>>(

View File

@ -30,6 +30,7 @@
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
#include "minddata/dataset/kernels/ir/vision/equalize_ir.h"
#include "minddata/dataset/kernels/ir/vision/gaussian_blur_ir.h"
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
#include "minddata/dataset/kernels/ir/vision/invert_ir.h"
#include "minddata/dataset/kernels/ir/vision/mixup_batch_ir.h"
@ -296,6 +297,24 @@ std::shared_ptr<TensorOperation> DvppDecodePng::Parse(const MapTargetDevice &env
Equalize::Equalize() {}
std::shared_ptr<TensorOperation> Equalize::Parse() { return std::make_shared<EqualizeOperation>(); }
#endif // not ENABLE_ANDROID
// GaussianBlur Transform Operation.
struct GaussianBlur::Data {
Data(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma)
: kernel_size_(kernel_size), sigma_(sigma) {}
std::vector<int32_t> kernel_size_;
std::vector<float> sigma_;
};
GaussianBlur::GaussianBlur(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma)
: data_(std::make_shared<Data>(kernel_size, sigma)) {}
std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
return std::make_shared<GaussianBlurOperation>(data_->kernel_size_, data_->sigma_);
}
#ifndef ENABLE_ANDROID
// HwcToChw Transform Operation.
HWC2CHW::HWC2CHW() {}

View File

@ -154,6 +154,29 @@ class Decode final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief GaussianBlur TensorTransform.
/// \notes Blur the input image with specified Gaussian kernel.
class GaussianBlur final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] kernel_size A vector of Gaussian kernel size for width and height. The values must be positive and odd.
/// \param[in] sigma A vector of Gaussian kernel standard deviation sigma for width and height. The values must be
/// positive. Using default value 0 means to calculate the sigma according to the kernel size.
GaussianBlur(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma = {0., 0.});
/// \brief Destructor.
~GaussianBlur() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Normalize TensorTransform.
/// \note Normalize the input image with respect to mean and standard deviation.
class Normalize final : public TensorTransform {

View File

@ -15,6 +15,7 @@ add_library(kernels-image OBJECT
cutmix_batch_op.cc
decode_op.cc
equalize_op.cc
gaussian_blur_op.cc
hwc_to_chw_op.cc
image_utils.cc
invert_op.cc

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/gaussian_blur_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status GaussianBlurOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
if (input->Rank() != 3 && input->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("GaussianBlur: input image is not in shape of <H,W,C> or <H,W>");
}
return GaussianBlur(input, output, kernel_x_, kernel_y_, sigma_x_, sigma_y_);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_GAUSSIAN_BLUR_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_GAUSSIAN_BLUR_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class GaussianBlurOp : public TensorOp {
public:
/// \brief Constructor to GaussianBlur Op
/// \param[in] kernel_x - Gaussian kernel size of width
/// \param[in] kernel_y - Gaussian kernel size of height
/// \param[in] sigma_x - Gaussian kernel standard deviation of width
/// \param[in] sigma_y - Gaussian kernel standard deviation of height
GaussianBlurOp(int32_t kernel_x, int32_t kernel_y, float sigma_x, float sigma_y)
: kernel_x_(kernel_x), kernel_y_(kernel_y), sigma_x_(sigma_x), sigma_y_(sigma_y) {}
~GaussianBlurOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kGaussianBlurOp; }
void Print(std::ostream &out) const override {
out << Name() << " kernel_size: (" << kernel_x_ << ", " << kernel_y_ << "), sigma: (" << sigma_x_ << ", "
<< sigma_y_ << ")";
}
protected:
int32_t kernel_x_;
int32_t kernel_y_;
float sigma_x_;
float sigma_y_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_GAUSSIAN_BLUR_OP_H_

View File

@ -1232,5 +1232,21 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
}
}
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_x,
int32_t kernel_y, float sigma_x, float sigma_y) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);
cv::GaussianBlur(input_cv->mat(), output_cv->mat(), cv::Size(kernel_x, kernel_y), static_cast<double>(sigma_x),
static_cast<double>(sigma_y));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("GaussianBlur: " + std::string(e.what()));
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -316,6 +316,16 @@ Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, in
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
/// \brief Filter the input image with a Gaussian kernel
/// \param[in] input Input Tensor
/// \param[out] output Transformed Tensor
/// \param[in] kernel_size_x Gaussian kernel size of width
/// \param[in] kernel_size_y Gaussian kernel size of height
/// \param[in] sigma_x Gaussian kernel standard deviation of width
/// \param[in] sigma_y Gaussian kernel standard deviation of height
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_size_x,
int32_t kernel_size_y, float sigma_x, float sigma_y);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_

View File

@ -710,5 +710,40 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
}
}
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_x,
int32_t kernel_y, float sigma_x, float sigma_y) {
try {
LiteMat lite_mat_input;
if (input->Rank() == 3) {
if (input->shape()[2] != 1 && input->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("GaussianBlur: input image is not in channel of 1 or 3");
}
lite_mat_input = LiteMat(input->shape()[1], input->shape()[0], input->shape()[2],
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
GetLiteCVDataType(input->type()));
} else if (input->Rank() == 2) {
lite_mat_input = LiteMat(input->shape()[1], input->shape()[0],
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
GetLiteCVDataType(input->type()));
} else {
RETURN_STATUS_UNEXPECTED("GaussianBlur: input image is not in shape of <H,W,C> or <H,W>");
}
std::shared_ptr<Tensor> output_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), &output_tensor));
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
LiteMat lite_mat_output;
lite_mat_output.Init(lite_mat_input.width_, lite_mat_input.height_, lite_mat_input.channel_,
reinterpret_cast<void *>(buffer), GetLiteCVDataType(input->type()));
bool ret = GaussianBlur(lite_mat_input, lite_mat_output, {kernel_x, kernel_y}, static_cast<double>(sigma_x),
static_cast<double>(sigma_y));
CHECK_FAIL_RETURN_UNEXPECTED(ret, "GaussianBlur: GaussianBlur failed.");
*output = output_tensor;
return Status::OK();
} catch (std::runtime_error &e) {
RETURN_STATUS_UNEXPECTED("GaussianBlur: " + std::string(e.what()));
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -142,6 +142,16 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
/// \brief Filter the input image with a Gaussian kernel
/// \param[in] input Input Tensor
/// \param[out] output Transformed Tensor
/// \param[in] kernel_size_x Gaussian kernel size of width
/// \param[in] kernel_size_y Gaussian kernel size of height
/// \param[in] sigma_x Gaussian kernel standard deviation of width
/// \param[in] sigma_y Gaussian kernel standard deviation of height
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_size_x,
int32_t kernel_size_y, float sigma_x, float sigma_y);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_

View File

@ -97,6 +97,18 @@ Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float
return Status::OK();
}
Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value) {
for (int i = 0; i < value.size(); i++) {
if (value[i] % 2 != 1) {
std::string err_msg = op_name + ":" + vec_name + " must be odd value, got: " + vec_name + "[" +
std::to_string(i) + "]=" + std::to_string(value[i]);
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
}
return Status::OK();
}
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
if (padding.empty() || padding.size() == 3 || padding.size() > 4) {
std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
@ -128,6 +140,19 @@ Status ValidateVectorNonNegative(const std::string &op_name, const std::string &
return Status::OK();
}
Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma) {
if (sigma.empty() || sigma.size() > 2) {
std::string err_msg = op_name + ": sigma expecting size 2, got sigma.size(): " + std::to_string(sigma.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (const auto &sigma_val : sigma) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, "sigma", sigma_val, {0}, false));
}
return Status::OK();
}
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
if (size.empty() || size.size() > 2) {
std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());

View File

@ -76,6 +76,9 @@ Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uin
// Helper function to validate mean/std value
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
// Helper function to validate odd value
Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value);
// Helper function to validate padding
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
@ -86,6 +89,9 @@ Status ValidateVectorPositive(const std::string &op_name, const std::string &vec
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
const std::vector<int32_t> &vec);
// Helper function to validate size of sigma
Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma);
// Helper function to validate size of size
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);

View File

@ -11,6 +11,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
cutout_ir.cc
decode_ir.cc
equalize_ir.cc
gaussian_blur_ir.cc
hwc_to_chw_ir.cc
invert_ir.cc
mixup_batch_ir.cc

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/vision/gaussian_blur_ir.h"
#include "minddata/dataset/kernels/image/gaussian_blur_op.h"
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
GaussianBlurOperation::GaussianBlurOperation(const std::vector<int32_t> kernel_size, const std::vector<float> sigma)
: kernel_size_(kernel_size), sigma_(sigma) {}
GaussianBlurOperation::~GaussianBlurOperation() = default;
std::string GaussianBlurOperation::Name() const { return kGaussianBlurOperation; }
Status GaussianBlurOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateVectorSize("GaussianBlur", kernel_size_));
RETURN_IF_NOT_OK(ValidateVectorOdd("GaussianBlur", "kernel_size", kernel_size_));
RETURN_IF_NOT_OK(ValidateVectorSigma("GaussianBlur", sigma_));
return Status::OK();
}
std::shared_ptr<TensorOp> GaussianBlurOperation::Build() {
int32_t kernel_x = kernel_size_[0];
int32_t kernel_y = kernel_size_[0];
// User has specified kernel_y.
if (kernel_size_.size() == 2) {
kernel_y = kernel_size_[1];
}
float sigma_x = sigma_[0] <= 0 ? kernel_x * 0.15 + 0.35 : sigma_[0];
float sigma_y = sigma_x;
// User has specified sigma_y.
if (sigma_.size() == 2) {
sigma_y = sigma_[1] <= 0 ? kernel_y * 0.15 + 0.35 : sigma_[1];
}
std::shared_ptr<GaussianBlurOp> tensor_op = std::make_shared<GaussianBlurOp>(kernel_x, kernel_y, sigma_x, sigma_y);
return tensor_op;
}
Status GaussianBlurOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["kernel_size"] = kernel_size_;
args["sigma"] = sigma_;
*out_json = args;
return Status::OK();
}
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_GAUSSIAN_BLUR_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_GAUSSIAN_BLUR_IR_H_
#include <memory>
#include <string>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kGaussianBlurOperation[] = "GaussianBlur";
class GaussianBlurOperation : public TensorOperation {
public:
GaussianBlurOperation(const std::vector<int32_t> kernel_size, const std::vector<float> sigma);
~GaussianBlurOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<int32_t> kernel_size_;
std::vector<float> sigma_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_GAUSSIAN_BLUR_IR_H_

View File

@ -69,6 +69,7 @@ constexpr char kDvppDecodePngOp[] = "DvppDecodePngOp";
constexpr char kDvppNormalizeOp[] = "DvppNormalizeOp";
constexpr char kDvppResizeJpegOp[] = "DvppResizeJpegOp";
constexpr char kEqualizeOp[] = "EqualizeOp";
constexpr char kGaussianBlurOp[] = "GaussianBlurOp";
constexpr char kHwcToChwOp[] = "HWC2CHWOp";
constexpr char kInvertOp[] = "InvertOp";
constexpr char kMixUpBatchOp[] = "MixUpBatchOp";

View File

@ -114,6 +114,13 @@ def check_positive(value, arg_name=""):
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
def check_odd(value, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value % 2 != 1:
raise ValueError(
"Input {0}is not an odd value.".format(arg_name))
def check_2tuple(value, arg_name=""):
if not (isinstance(value, tuple) and len(value) == 2):
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))

View File

@ -54,7 +54,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec
check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
check_cut_mix_batch_c, check_posterize
check_cut_mix_batch_c, check_posterize, check_gaussian_blur
from ..transforms.c_transforms import TensorOperation
@ -295,6 +295,41 @@ class Equalize(ImageTensorOperation):
return cde.EqualizeOperation()
class GaussianBlur(ImageTensorOperation):
"""
BLur input image with the specified Gaussian kernel.
Args:
kernel_size (Union[int, sequence]): Size of the Gaussian kernel to use. The value must be positive and odd. If
only an integer is provied, the kernel size will be (size, size). If a sequence of integer is provied, it
must be a sequence of 2 values which represents (width, height).
sigma (Union[float, sequence], optional): Standard deviation of the Gaussian kernel to use (default=None). The
value must be positive. If only an float is provied, the sigma will be (sigma, sigma). If a sequence of
float is provied, it must be a sequence of 2 values which represents the sigma of width and height. If None
is provided, the sigma will be calculated as ((kernel_size - 1) * 0.5 - 1) * 0.3 + 0.8.
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.GaussianBlur(3, 3)]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_gaussian_blur
def __init__(self, kernel_size, sigma=None):
if isinstance(kernel_size, int):
kernel_size = (kernel_size,)
if sigma is None:
sigma = (0,)
elif isinstance(sigma, (int, float)):
sigma = (float(sigma),)
self.kernel_size = kernel_size
self.sigma = sigma
def parse(self):
return cde.GaussianBlurOperation(self.kernel_size, self.sigma)
class HWC2CHW(ImageTensorOperation):
"""
Transpose the input image; shape (H, W, C) to shape (C, H, W).

View File

@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff, check_value_ratio
check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff, check_value_ratio, check_odd
from .utils import Inter, Border, ImageBatchFormat
@ -493,6 +493,7 @@ def check_rgb_to_hsv(method):
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
type_check(is_hwc, (bool,), "is_hwc")
return method(self, *args, **kwargs)
return new_method
@ -504,6 +505,7 @@ def check_hsv_to_rgb(method):
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
type_check(is_hwc, (bool,), "is_hwc")
return method(self, *args, **kwargs)
return new_method
@ -819,3 +821,39 @@ def check_random_solarize(method):
return method(self, *args, **kwargs)
return new_method
def check_gaussian_blur(method):
"""Wrapper method to check the parameters of GaussianBlur."""
@wraps(method)
def new_method(self, *args, **kwargs):
[kernel_size, sigma], _ = parse_user_args(method, *args, **kwargs)
type_check(kernel_size, (int, list, tuple), "kernel_size")
if isinstance(kernel_size, int):
check_value(kernel_size, (1, FLOAT_MAX_INTEGER), "kernel_size")
check_odd(kernel_size, "kernel_size")
elif isinstance(kernel_size, (list, tuple)) and len(kernel_size) == 2:
for index, value in enumerate(kernel_size):
type_check(value, (int,), "kernel_size[{}]".format(index))
check_value(value, (1, FLOAT_MAX_INTEGER), "kernel_size")
check_odd(value, "kernel_size[{}]".format(index))
else:
raise TypeError(
"Kernel size should be a single integer or a list/tuple (kernel_width, kernel_height) of length 2.")
if sigma is not None:
type_check(sigma, (numbers.Number, list, tuple), "sigma")
if isinstance(sigma, numbers.Number):
check_value(sigma, (0, FLOAT_MAX_INTEGER), "sigma")
elif isinstance(sigma, (list, tuple)) and len(sigma) == 2:
for index, value in enumerate(sigma):
type_check(value, (numbers.Number,), "size[{}]".format(index))
check_value(value, (0, FLOAT_MAX_INTEGER), "sigma")
else:
raise TypeError("Sigma should be a single number or a list/tuple of length 2 for width and height.")
return method(self, *args, **kwargs)
return new_method

View File

@ -208,6 +208,7 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/kernels/image/center_crop_op.cc
${MINDDATA_DIR}/kernels/image/crop_op.cc
${MINDDATA_DIR}/kernels/image/decode_op.cc
${MINDDATA_DIR}/kernels/image/gaussian_blur_op.cc
${MINDDATA_DIR}/kernels/image/normalize_op.cc
${MINDDATA_DIR}/kernels/image/resize_op.cc
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc
@ -233,6 +234,7 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/kernels/ir/vision/cutout_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/decode_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/equalize_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/gaussian_blur_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/hwc_to_chw_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/invert_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/mixup_batch_ir.cc
@ -479,6 +481,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
"${MINDDATA_DIR}/kernels/image/cut_out_op.cc"
"${MINDDATA_DIR}/kernels/image/cutmix_batch_op.cc"
"${MINDDATA_DIR}/kernels/image/equalize_op.cc"
"${MINDDATA_DIR}/kernels/image/gaussian_blur.cc"
"${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc"
"${MINDDATA_DIR}/kernels/image/image_utils.cc"
"${MINDDATA_DIR}/kernels/image/invert_op.cc"
@ -539,6 +542,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
${MINDDATA_DIR}/kernels/ir/vision/cutout_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/decode_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/equalize_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/gaussian_blur_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/hwc_to_chw_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/invert_ir.cc
${MINDDATA_DIR}/kernels/ir/vision/mixup_batch_ir.cc

View File

@ -69,6 +69,7 @@ SET(DE_UT_SRCS
equalize_op_test.cc
execution_tree_test.cc
fill_op_test.cc
c_api_vision_gaussian_blur_test.cc
global_context_test.cc
gnn_graph_test.cc
image_folder_op_test.cc

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/gaussian_blur_op.h"
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/execute.h"
#include "minddata/dataset/include/dataset/vision.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
class MindDataTestGaussianBlur : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestGaussianBlur, TestGaussianBlurParamCheck) {
MS_LOG(INFO) << "Doing MindDataTestGaussianBlur-TestGaussianBlurParamCheck with invalid parameters.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Case 1: Kernel size is not positive
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> gaussian_blur1(new vision::GaussianBlur({-1}));
auto ds1 = ds->Map({gaussian_blur1});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid kernel_size for GaussianBlur
EXPECT_EQ(iter1, nullptr);
// Case 2: Kernel size is not odd
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> gaussian_blur2(new vision::GaussianBlur({2, 2}, {3, 3}));
auto ds2 = ds->Map({gaussian_blur2});
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
// Expect failure: invalid kernel_size for GaussianBlur
EXPECT_EQ(iter2, nullptr);
// Case 3: Sigma is not positive
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> gaussian_blur3(new vision::GaussianBlur({3}, {-3}));
auto ds3 = ds->Map({gaussian_blur3});
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid sigma for GaussianBlur
EXPECT_EQ(iter3, nullptr);
}
TEST_F(MindDataTestGaussianBlur, TestGaussianBlurPipeline) {
MS_LOG(INFO) << "Doing MindDataTestGaussianBlur-TestGaussianBlurPipeline.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> gaussian_blur(new vision::GaussianBlur({3, 3}, {5, 5}));
// Create a Map operation on ds
ds = ds->Map({gaussian_blur});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestGaussianBlur, TestGaussianBlurEager) {
MS_LOG(INFO) << "Doing MindDataTestGaussianBlur-TestGaussianBlurEager.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto gaussian_blur = vision::GaussianBlur({7}, {3.5});
auto transform = Execute({decode, gaussian_blur});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}

View File

@ -0,0 +1,109 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing GaussianBlur op in DE
"""
import cv2
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import log as logger
from util import visualize_image, diff_mse
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
IMAGE_FILE = "../data/dataset/apple.jpg"
GENERATE_GOLDEN = False
def test_gaussian_blur_pipeline(plot=False):
"""
Test GaussianBlur of c_transforms
"""
logger.info("test_gaussian_blur_pipeline")
# First dataset
dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
decode_op = c_vision.Decode()
gaussian_blur_op = c_vision.GaussianBlur(3, 3)
dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
dataset1 = dataset1.map(operations=gaussian_blur_op, input_columns=["image"])
# Second dataset
dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
if num_iter > 0:
break
gaussian_blur_ms = data1["image"]
original = data2["image"]
gaussian_blur_cv = cv2.GaussianBlur(original, (3, 3), 3)
mse = diff_mse(gaussian_blur_ms, gaussian_blur_cv)
logger.info("gaussian_blur_{}, mse: {}".format(num_iter + 1, mse))
assert mse == 0
num_iter += 1
if plot:
visualize_image(original, gaussian_blur_ms, mse, gaussian_blur_cv)
def test_gaussian_blur_eager():
"""
Test GaussianBlur with eager mode
"""
logger.info("test_gaussian_blur_eager")
img = cv2.imread(IMAGE_FILE)
img_ms = c_vision.GaussianBlur((3, 5), (3.5, 3.5))(img)
img_cv = cv2.GaussianBlur(img, (3, 5), 3.5, 3.5)
mse = diff_mse(img_ms, img_cv)
assert mse == 0
def test_gaussian_blur_exception():
"""
Test GsianBlur with invalid parameters
"""
logger.info("test_gaussian_blur_exception")
try:
_ = c_vision.GaussianBlur([2, 2])
except ValueError as e:
logger.info("Got an exception in GaussianBlur: {}".format(str(e)))
assert "not an odd value" in str(e)
try:
_ = c_vision.GaussianBlur(3.0, [3, 3])
except TypeError as e:
logger.info("Got an exception in GaussianBlur: {}".format(str(e)))
assert "not of type [<class 'int'>, <class 'list'>, <class 'tuple'>]" in str(e)
try:
_ = c_vision.GaussianBlur(3, -3)
except ValueError as e:
logger.info("Got an exception in GaussianBlur: {}".format(str(e)))
assert "not within the required interval" in str(e)
try:
_ = c_vision.GaussianBlur(3, [3, 3, 3])
except TypeError as e:
logger.info("Got an exception in GaussianBlur: {}".format(str(e)))
assert "should be a single number or a list/tuple of length 2" in str(e)
if __name__ == "__main__":
test_gaussian_blur_pipeline(plot=True)
test_gaussian_blur_eager()
test_gaussian_blur_exception()