forked from mindspore-Ecosystem/mindspore
!16702 Add GaussianBlur API for CV data processing
From: @tiancixiao Reviewed-by: @liucunwei,@jonyguo Signed-off-by: @liucunwei
This commit is contained in:
commit
730e211e52
|
@ -25,6 +25,7 @@
|
||||||
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/equalize_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/equalize_ir.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/vision/gaussian_blur_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/invert_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/invert_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/mixup_batch_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/mixup_batch_ir.h"
|
||||||
|
@ -142,6 +143,17 @@ PYBIND_REGISTER(EqualizeOperation, 1, ([](const py::module *m) {
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(
|
||||||
|
GaussianBlurOperation, 1, ([](const py::module *m) {
|
||||||
|
(void)py::class_<vision::GaussianBlurOperation, TensorOperation, std::shared_ptr<vision::GaussianBlurOperation>>(
|
||||||
|
*m, "GaussianBlurOperation")
|
||||||
|
.def(py::init([](std::vector<int32_t> kernel_size, std::vector<float> sigma) {
|
||||||
|
auto gaussian_blur = std::make_shared<vision::GaussianBlurOperation>(kernel_size, sigma);
|
||||||
|
THROW_IF_ERROR(gaussian_blur->ValidateParams());
|
||||||
|
return gaussian_blur;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(HwcToChwOperation, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(HwcToChwOperation, 1, ([](const py::module *m) {
|
||||||
(void)
|
(void)
|
||||||
py::class_<vision::HwcToChwOperation, TensorOperation, std::shared_ptr<vision::HwcToChwOperation>>(
|
py::class_<vision::HwcToChwOperation, TensorOperation, std::shared_ptr<vision::HwcToChwOperation>>(
|
||||||
|
|
|
@ -30,6 +30,7 @@
|
||||||
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/equalize_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/equalize_ir.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/vision/gaussian_blur_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/invert_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/invert_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/mixup_batch_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/mixup_batch_ir.h"
|
||||||
|
@ -296,6 +297,24 @@ std::shared_ptr<TensorOperation> DvppDecodePng::Parse(const MapTargetDevice &env
|
||||||
Equalize::Equalize() {}
|
Equalize::Equalize() {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Equalize::Parse() { return std::make_shared<EqualizeOperation>(); }
|
std::shared_ptr<TensorOperation> Equalize::Parse() { return std::make_shared<EqualizeOperation>(); }
|
||||||
|
#endif // not ENABLE_ANDROID
|
||||||
|
|
||||||
|
// GaussianBlur Transform Operation.
|
||||||
|
struct GaussianBlur::Data {
|
||||||
|
Data(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma)
|
||||||
|
: kernel_size_(kernel_size), sigma_(sigma) {}
|
||||||
|
std::vector<int32_t> kernel_size_;
|
||||||
|
std::vector<float> sigma_;
|
||||||
|
};
|
||||||
|
|
||||||
|
GaussianBlur::GaussianBlur(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma)
|
||||||
|
: data_(std::make_shared<Data>(kernel_size, sigma)) {}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
|
||||||
|
return std::make_shared<GaussianBlurOperation>(data_->kernel_size_, data_->sigma_);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
// HwcToChw Transform Operation.
|
// HwcToChw Transform Operation.
|
||||||
HWC2CHW::HWC2CHW() {}
|
HWC2CHW::HWC2CHW() {}
|
||||||
|
|
||||||
|
|
|
@ -154,6 +154,29 @@ class Decode final : public TensorTransform {
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief GaussianBlur TensorTransform.
|
||||||
|
/// \notes Blur the input image with specified Gaussian kernel.
|
||||||
|
class GaussianBlur final : public TensorTransform {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
/// \param[in] kernel_size A vector of Gaussian kernel size for width and height. The values must be positive and odd.
|
||||||
|
/// \param[in] sigma A vector of Gaussian kernel standard deviation sigma for width and height. The values must be
|
||||||
|
/// positive. Using default value 0 means to calculate the sigma according to the kernel size.
|
||||||
|
GaussianBlur(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma = {0., 0.});
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~GaussianBlur() = default;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||||
|
/// \return Shared pointer to TensorOperation object.
|
||||||
|
std::shared_ptr<TensorOperation> Parse() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Data;
|
||||||
|
std::shared_ptr<Data> data_;
|
||||||
|
};
|
||||||
|
|
||||||
/// \brief Normalize TensorTransform.
|
/// \brief Normalize TensorTransform.
|
||||||
/// \note Normalize the input image with respect to mean and standard deviation.
|
/// \note Normalize the input image with respect to mean and standard deviation.
|
||||||
class Normalize final : public TensorTransform {
|
class Normalize final : public TensorTransform {
|
||||||
|
|
|
@ -15,6 +15,7 @@ add_library(kernels-image OBJECT
|
||||||
cutmix_batch_op.cc
|
cutmix_batch_op.cc
|
||||||
decode_op.cc
|
decode_op.cc
|
||||||
equalize_op.cc
|
equalize_op.cc
|
||||||
|
gaussian_blur_op.cc
|
||||||
hwc_to_chw_op.cc
|
hwc_to_chw_op.cc
|
||||||
image_utils.cc
|
image_utils.cc
|
||||||
invert_op.cc
|
invert_op.cc
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/image/gaussian_blur_op.h"
|
||||||
|
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||||
|
#else
|
||||||
|
#include "minddata/dataset/kernels/image/lite_image_utils.h"
|
||||||
|
#endif
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
Status GaussianBlurOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
if (input->Rank() != 3 && input->Rank() != 2) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("GaussianBlur: input image is not in shape of <H,W,C> or <H,W>");
|
||||||
|
}
|
||||||
|
return GaussianBlur(input, output, kernel_x_, kernel_y_, sigma_x_, sigma_y_);
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_GAUSSIAN_BLUR_OP_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_GAUSSIAN_BLUR_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||||
|
#else
|
||||||
|
#include "minddata/dataset/kernels/image/lite_image_utils.h"
|
||||||
|
#endif
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class GaussianBlurOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor to GaussianBlur Op
|
||||||
|
/// \param[in] kernel_x - Gaussian kernel size of width
|
||||||
|
/// \param[in] kernel_y - Gaussian kernel size of height
|
||||||
|
/// \param[in] sigma_x - Gaussian kernel standard deviation of width
|
||||||
|
/// \param[in] sigma_y - Gaussian kernel standard deviation of height
|
||||||
|
GaussianBlurOp(int32_t kernel_x, int32_t kernel_y, float sigma_x, float sigma_y)
|
||||||
|
: kernel_x_(kernel_x), kernel_y_(kernel_y), sigma_x_(sigma_x), sigma_y_(sigma_y) {}
|
||||||
|
|
||||||
|
~GaussianBlurOp() override = default;
|
||||||
|
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kGaussianBlurOp; }
|
||||||
|
|
||||||
|
void Print(std::ostream &out) const override {
|
||||||
|
out << Name() << " kernel_size: (" << kernel_x_ << ", " << kernel_y_ << "), sigma: (" << sigma_x_ << ", "
|
||||||
|
<< sigma_y_ << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int32_t kernel_x_;
|
||||||
|
int32_t kernel_y_;
|
||||||
|
float sigma_x_;
|
||||||
|
float sigma_y_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_GAUSSIAN_BLUR_OP_H_
|
|
@ -1232,5 +1232,21 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_x,
|
||||||
|
int32_t kernel_y, float sigma_x, float sigma_y) {
|
||||||
|
try {
|
||||||
|
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||||
|
std::shared_ptr<CVTensor> output_cv;
|
||||||
|
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(output_cv);
|
||||||
|
|
||||||
|
cv::GaussianBlur(input_cv->mat(), output_cv->mat(), cv::Size(kernel_x, kernel_y), static_cast<double>(sigma_x),
|
||||||
|
static_cast<double>(sigma_y));
|
||||||
|
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
|
return Status::OK();
|
||||||
|
} catch (const cv::Exception &e) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("GaussianBlur: " + std::string(e.what()));
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -316,6 +316,16 @@ Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, in
|
||||||
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
|
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
|
||||||
InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
|
InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
|
||||||
|
|
||||||
|
/// \brief Filter the input image with a Gaussian kernel
|
||||||
|
/// \param[in] input Input Tensor
|
||||||
|
/// \param[out] output Transformed Tensor
|
||||||
|
/// \param[in] kernel_size_x Gaussian kernel size of width
|
||||||
|
/// \param[in] kernel_size_y Gaussian kernel size of height
|
||||||
|
/// \param[in] sigma_x Gaussian kernel standard deviation of width
|
||||||
|
/// \param[in] sigma_y Gaussian kernel standard deviation of height
|
||||||
|
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_size_x,
|
||||||
|
int32_t kernel_size_y, float sigma_x, float sigma_y);
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_
|
||||||
|
|
|
@ -710,5 +710,40 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_x,
|
||||||
|
int32_t kernel_y, float sigma_x, float sigma_y) {
|
||||||
|
try {
|
||||||
|
LiteMat lite_mat_input;
|
||||||
|
if (input->Rank() == 3) {
|
||||||
|
if (input->shape()[2] != 1 && input->shape()[2] != 3) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("GaussianBlur: input image is not in channel of 1 or 3");
|
||||||
|
}
|
||||||
|
lite_mat_input = LiteMat(input->shape()[1], input->shape()[0], input->shape()[2],
|
||||||
|
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
|
||||||
|
GetLiteCVDataType(input->type()));
|
||||||
|
} else if (input->Rank() == 2) {
|
||||||
|
lite_mat_input = LiteMat(input->shape()[1], input->shape()[0],
|
||||||
|
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
|
||||||
|
GetLiteCVDataType(input->type()));
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED("GaussianBlur: input image is not in shape of <H,W,C> or <H,W>");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), &output_tensor));
|
||||||
|
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
|
||||||
|
LiteMat lite_mat_output;
|
||||||
|
lite_mat_output.Init(lite_mat_input.width_, lite_mat_input.height_, lite_mat_input.channel_,
|
||||||
|
reinterpret_cast<void *>(buffer), GetLiteCVDataType(input->type()));
|
||||||
|
bool ret = GaussianBlur(lite_mat_input, lite_mat_output, {kernel_x, kernel_y}, static_cast<double>(sigma_x),
|
||||||
|
static_cast<double>(sigma_y));
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(ret, "GaussianBlur: GaussianBlur failed.");
|
||||||
|
*output = output_tensor;
|
||||||
|
return Status::OK();
|
||||||
|
} catch (std::runtime_error &e) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("GaussianBlur: " + std::string(e.what()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -142,6 +142,16 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
|
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
|
||||||
InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
|
InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
|
||||||
|
|
||||||
|
/// \brief Filter the input image with a Gaussian kernel
|
||||||
|
/// \param[in] input Input Tensor
|
||||||
|
/// \param[out] output Transformed Tensor
|
||||||
|
/// \param[in] kernel_size_x Gaussian kernel size of width
|
||||||
|
/// \param[in] kernel_size_y Gaussian kernel size of height
|
||||||
|
/// \param[in] sigma_x Gaussian kernel standard deviation of width
|
||||||
|
/// \param[in] sigma_y Gaussian kernel standard deviation of height
|
||||||
|
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_size_x,
|
||||||
|
int32_t kernel_size_y, float sigma_x, float sigma_y);
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_
|
||||||
|
|
|
@ -97,6 +97,18 @@ Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value) {
|
||||||
|
for (int i = 0; i < value.size(); i++) {
|
||||||
|
if (value[i] % 2 != 1) {
|
||||||
|
std::string err_msg = op_name + ":" + vec_name + " must be odd value, got: " + vec_name + "[" +
|
||||||
|
std::to_string(i) + "]=" + std::to_string(value[i]);
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
|
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
|
||||||
if (padding.empty() || padding.size() == 3 || padding.size() > 4) {
|
if (padding.empty() || padding.size() == 3 || padding.size() > 4) {
|
||||||
std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
|
std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
|
||||||
|
@ -128,6 +140,19 @@ Status ValidateVectorNonNegative(const std::string &op_name, const std::string &
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma) {
|
||||||
|
if (sigma.empty() || sigma.size() > 2) {
|
||||||
|
std::string err_msg = op_name + ": sigma expecting size 2, got sigma.size(): " + std::to_string(sigma.size());
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
for (const auto &sigma_val : sigma) {
|
||||||
|
RETURN_IF_NOT_OK(ValidateScalar(op_name, "sigma", sigma_val, {0}, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
|
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
|
||||||
if (size.empty() || size.size() > 2) {
|
if (size.empty() || size.size() > 2) {
|
||||||
std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());
|
std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());
|
||||||
|
|
|
@ -76,6 +76,9 @@ Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uin
|
||||||
// Helper function to validate mean/std value
|
// Helper function to validate mean/std value
|
||||||
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
|
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
|
||||||
|
|
||||||
|
// Helper function to validate odd value
|
||||||
|
Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value);
|
||||||
|
|
||||||
// Helper function to validate padding
|
// Helper function to validate padding
|
||||||
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
|
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
|
||||||
|
|
||||||
|
@ -86,6 +89,9 @@ Status ValidateVectorPositive(const std::string &op_name, const std::string &vec
|
||||||
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
|
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
|
||||||
const std::vector<int32_t> &vec);
|
const std::vector<int32_t> &vec);
|
||||||
|
|
||||||
|
// Helper function to validate size of sigma
|
||||||
|
Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma);
|
||||||
|
|
||||||
// Helper function to validate size of size
|
// Helper function to validate size of size
|
||||||
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
|
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
|
||||||
cutout_ir.cc
|
cutout_ir.cc
|
||||||
decode_ir.cc
|
decode_ir.cc
|
||||||
equalize_ir.cc
|
equalize_ir.cc
|
||||||
|
gaussian_blur_ir.cc
|
||||||
hwc_to_chw_ir.cc
|
hwc_to_chw_ir.cc
|
||||||
invert_ir.cc
|
invert_ir.cc
|
||||||
mixup_batch_ir.cc
|
mixup_batch_ir.cc
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/ir/vision/gaussian_blur_ir.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/kernels/image/gaussian_blur_op.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/validators.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
GaussianBlurOperation::GaussianBlurOperation(const std::vector<int32_t> kernel_size, const std::vector<float> sigma)
|
||||||
|
: kernel_size_(kernel_size), sigma_(sigma) {}
|
||||||
|
|
||||||
|
GaussianBlurOperation::~GaussianBlurOperation() = default;
|
||||||
|
|
||||||
|
std::string GaussianBlurOperation::Name() const { return kGaussianBlurOperation; }
|
||||||
|
|
||||||
|
Status GaussianBlurOperation::ValidateParams() {
|
||||||
|
RETURN_IF_NOT_OK(ValidateVectorSize("GaussianBlur", kernel_size_));
|
||||||
|
RETURN_IF_NOT_OK(ValidateVectorOdd("GaussianBlur", "kernel_size", kernel_size_));
|
||||||
|
RETURN_IF_NOT_OK(ValidateVectorSigma("GaussianBlur", sigma_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> GaussianBlurOperation::Build() {
|
||||||
|
int32_t kernel_x = kernel_size_[0];
|
||||||
|
int32_t kernel_y = kernel_size_[0];
|
||||||
|
// User has specified kernel_y.
|
||||||
|
if (kernel_size_.size() == 2) {
|
||||||
|
kernel_y = kernel_size_[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
float sigma_x = sigma_[0] <= 0 ? kernel_x * 0.15 + 0.35 : sigma_[0];
|
||||||
|
float sigma_y = sigma_x;
|
||||||
|
|
||||||
|
// User has specified sigma_y.
|
||||||
|
if (sigma_.size() == 2) {
|
||||||
|
sigma_y = sigma_[1] <= 0 ? kernel_y * 0.15 + 0.35 : sigma_[1];
|
||||||
|
}
|
||||||
|
std::shared_ptr<GaussianBlurOp> tensor_op = std::make_shared<GaussianBlurOp>(kernel_x, kernel_y, sigma_x, sigma_y);
|
||||||
|
return tensor_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GaussianBlurOperation::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["kernel_size"] = kernel_size_;
|
||||||
|
args["sigma"] = sigma_;
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_GAUSSIAN_BLUR_IR_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_GAUSSIAN_BLUR_IR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "include/api/status.h"
|
||||||
|
#include "minddata/dataset/include/dataset/constants.h"
|
||||||
|
#include "minddata/dataset/include/dataset/transforms.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
constexpr char kGaussianBlurOperation[] = "GaussianBlur";
|
||||||
|
|
||||||
|
class GaussianBlurOperation : public TensorOperation {
|
||||||
|
public:
|
||||||
|
GaussianBlurOperation(const std::vector<int32_t> kernel_size, const std::vector<float> sigma);
|
||||||
|
|
||||||
|
~GaussianBlurOperation();
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> Build() override;
|
||||||
|
|
||||||
|
Status ValidateParams() override;
|
||||||
|
|
||||||
|
std::string Name() const override;
|
||||||
|
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int32_t> kernel_size_;
|
||||||
|
std::vector<float> sigma_;
|
||||||
|
};
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_GAUSSIAN_BLUR_IR_H_
|
|
@ -69,6 +69,7 @@ constexpr char kDvppDecodePngOp[] = "DvppDecodePngOp";
|
||||||
constexpr char kDvppNormalizeOp[] = "DvppNormalizeOp";
|
constexpr char kDvppNormalizeOp[] = "DvppNormalizeOp";
|
||||||
constexpr char kDvppResizeJpegOp[] = "DvppResizeJpegOp";
|
constexpr char kDvppResizeJpegOp[] = "DvppResizeJpegOp";
|
||||||
constexpr char kEqualizeOp[] = "EqualizeOp";
|
constexpr char kEqualizeOp[] = "EqualizeOp";
|
||||||
|
constexpr char kGaussianBlurOp[] = "GaussianBlurOp";
|
||||||
constexpr char kHwcToChwOp[] = "HWC2CHWOp";
|
constexpr char kHwcToChwOp[] = "HWC2CHWOp";
|
||||||
constexpr char kInvertOp[] = "InvertOp";
|
constexpr char kInvertOp[] = "InvertOp";
|
||||||
constexpr char kMixUpBatchOp[] = "MixUpBatchOp";
|
constexpr char kMixUpBatchOp[] = "MixUpBatchOp";
|
||||||
|
|
|
@ -114,6 +114,13 @@ def check_positive(value, arg_name=""):
|
||||||
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
|
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
|
||||||
|
|
||||||
|
|
||||||
|
def check_odd(value, arg_name=""):
|
||||||
|
arg_name = pad_arg_name(arg_name)
|
||||||
|
if value % 2 != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Input {0}is not an odd value.".format(arg_name))
|
||||||
|
|
||||||
|
|
||||||
def check_2tuple(value, arg_name=""):
|
def check_2tuple(value, arg_name=""):
|
||||||
if not (isinstance(value, tuple) and len(value) == 2):
|
if not (isinstance(value, tuple) and len(value) == 2):
|
||||||
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))
|
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))
|
||||||
|
|
|
@ -54,7 +54,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec
|
||||||
check_uniform_augment_cpp, \
|
check_uniform_augment_cpp, \
|
||||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
|
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
|
||||||
check_cut_mix_batch_c, check_posterize
|
check_cut_mix_batch_c, check_posterize, check_gaussian_blur
|
||||||
from ..transforms.c_transforms import TensorOperation
|
from ..transforms.c_transforms import TensorOperation
|
||||||
|
|
||||||
|
|
||||||
|
@ -295,6 +295,41 @@ class Equalize(ImageTensorOperation):
|
||||||
return cde.EqualizeOperation()
|
return cde.EqualizeOperation()
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianBlur(ImageTensorOperation):
|
||||||
|
"""
|
||||||
|
BLur input image with the specified Gaussian kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size (Union[int, sequence]): Size of the Gaussian kernel to use. The value must be positive and odd. If
|
||||||
|
only an integer is provied, the kernel size will be (size, size). If a sequence of integer is provied, it
|
||||||
|
must be a sequence of 2 values which represents (width, height).
|
||||||
|
sigma (Union[float, sequence], optional): Standard deviation of the Gaussian kernel to use (default=None). The
|
||||||
|
value must be positive. If only an float is provied, the sigma will be (sigma, sigma). If a sequence of
|
||||||
|
float is provied, it must be a sequence of 2 values which represents the sigma of width and height. If None
|
||||||
|
is provided, the sigma will be calculated as ((kernel_size - 1) * 0.5 - 1) * 0.3 + 0.8.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> transforms_list = [c_vision.Decode(), c_vision.GaussianBlur(3, 3)]
|
||||||
|
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||||
|
... input_columns=["image"])
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_gaussian_blur
|
||||||
|
def __init__(self, kernel_size, sigma=None):
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size,)
|
||||||
|
if sigma is None:
|
||||||
|
sigma = (0,)
|
||||||
|
elif isinstance(sigma, (int, float)):
|
||||||
|
sigma = (float(sigma),)
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
return cde.GaussianBlurOperation(self.kernel_size, self.sigma)
|
||||||
|
|
||||||
|
|
||||||
class HWC2CHW(ImageTensorOperation):
|
class HWC2CHW(ImageTensorOperation):
|
||||||
"""
|
"""
|
||||||
Transpose the input image; shape (H, W, C) to shape (C, H, W).
|
Transpose the input image; shape (H, W, C) to shape (C, H, W).
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation
|
||||||
|
|
||||||
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
||||||
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
|
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
|
||||||
check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff, check_value_ratio
|
check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff, check_value_ratio, check_odd
|
||||||
from .utils import Inter, Border, ImageBatchFormat
|
from .utils import Inter, Border, ImageBatchFormat
|
||||||
|
|
||||||
|
|
||||||
|
@ -493,6 +493,7 @@ def check_rgb_to_hsv(method):
|
||||||
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
|
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
|
||||||
type_check(is_hwc, (bool,), "is_hwc")
|
type_check(is_hwc, (bool,), "is_hwc")
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
@ -504,6 +505,7 @@ def check_hsv_to_rgb(method):
|
||||||
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
|
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
|
||||||
type_check(is_hwc, (bool,), "is_hwc")
|
type_check(is_hwc, (bool,), "is_hwc")
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
@ -819,3 +821,39 @@ def check_random_solarize(method):
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_gaussian_blur(method):
|
||||||
|
"""Wrapper method to check the parameters of GaussianBlur."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[kernel_size, sigma], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
|
type_check(kernel_size, (int, list, tuple), "kernel_size")
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
check_value(kernel_size, (1, FLOAT_MAX_INTEGER), "kernel_size")
|
||||||
|
check_odd(kernel_size, "kernel_size")
|
||||||
|
elif isinstance(kernel_size, (list, tuple)) and len(kernel_size) == 2:
|
||||||
|
for index, value in enumerate(kernel_size):
|
||||||
|
type_check(value, (int,), "kernel_size[{}]".format(index))
|
||||||
|
check_value(value, (1, FLOAT_MAX_INTEGER), "kernel_size")
|
||||||
|
check_odd(value, "kernel_size[{}]".format(index))
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Kernel size should be a single integer or a list/tuple (kernel_width, kernel_height) of length 2.")
|
||||||
|
|
||||||
|
if sigma is not None:
|
||||||
|
type_check(sigma, (numbers.Number, list, tuple), "sigma")
|
||||||
|
if isinstance(sigma, numbers.Number):
|
||||||
|
check_value(sigma, (0, FLOAT_MAX_INTEGER), "sigma")
|
||||||
|
elif isinstance(sigma, (list, tuple)) and len(sigma) == 2:
|
||||||
|
for index, value in enumerate(sigma):
|
||||||
|
type_check(value, (numbers.Number,), "size[{}]".format(index))
|
||||||
|
check_value(value, (0, FLOAT_MAX_INTEGER), "sigma")
|
||||||
|
else:
|
||||||
|
raise TypeError("Sigma should be a single number or a list/tuple of length 2 for width and height.")
|
||||||
|
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -208,6 +208,7 @@ if(BUILD_MINDDATA STREQUAL "full")
|
||||||
${MINDDATA_DIR}/kernels/image/center_crop_op.cc
|
${MINDDATA_DIR}/kernels/image/center_crop_op.cc
|
||||||
${MINDDATA_DIR}/kernels/image/crop_op.cc
|
${MINDDATA_DIR}/kernels/image/crop_op.cc
|
||||||
${MINDDATA_DIR}/kernels/image/decode_op.cc
|
${MINDDATA_DIR}/kernels/image/decode_op.cc
|
||||||
|
${MINDDATA_DIR}/kernels/image/gaussian_blur_op.cc
|
||||||
${MINDDATA_DIR}/kernels/image/normalize_op.cc
|
${MINDDATA_DIR}/kernels/image/normalize_op.cc
|
||||||
${MINDDATA_DIR}/kernels/image/resize_op.cc
|
${MINDDATA_DIR}/kernels/image/resize_op.cc
|
||||||
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc
|
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc
|
||||||
|
@ -233,6 +234,7 @@ if(BUILD_MINDDATA STREQUAL "full")
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/cutout_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/cutout_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/decode_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/decode_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/equalize_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/equalize_ir.cc
|
||||||
|
${MINDDATA_DIR}/kernels/ir/vision/gaussian_blur_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/hwc_to_chw_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/hwc_to_chw_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/invert_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/invert_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/mixup_batch_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/mixup_batch_ir.cc
|
||||||
|
@ -479,6 +481,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
|
||||||
"${MINDDATA_DIR}/kernels/image/cut_out_op.cc"
|
"${MINDDATA_DIR}/kernels/image/cut_out_op.cc"
|
||||||
"${MINDDATA_DIR}/kernels/image/cutmix_batch_op.cc"
|
"${MINDDATA_DIR}/kernels/image/cutmix_batch_op.cc"
|
||||||
"${MINDDATA_DIR}/kernels/image/equalize_op.cc"
|
"${MINDDATA_DIR}/kernels/image/equalize_op.cc"
|
||||||
|
"${MINDDATA_DIR}/kernels/image/gaussian_blur.cc"
|
||||||
"${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc"
|
"${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc"
|
||||||
"${MINDDATA_DIR}/kernels/image/image_utils.cc"
|
"${MINDDATA_DIR}/kernels/image/image_utils.cc"
|
||||||
"${MINDDATA_DIR}/kernels/image/invert_op.cc"
|
"${MINDDATA_DIR}/kernels/image/invert_op.cc"
|
||||||
|
@ -539,6 +542,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/cutout_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/cutout_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/decode_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/decode_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/equalize_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/equalize_ir.cc
|
||||||
|
${MINDDATA_DIR}/kernels/ir/vision/gaussian_blur_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/hwc_to_chw_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/hwc_to_chw_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/invert_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/invert_ir.cc
|
||||||
${MINDDATA_DIR}/kernels/ir/vision/mixup_batch_ir.cc
|
${MINDDATA_DIR}/kernels/ir/vision/mixup_batch_ir.cc
|
||||||
|
|
|
@ -69,6 +69,7 @@ SET(DE_UT_SRCS
|
||||||
equalize_op_test.cc
|
equalize_op_test.cc
|
||||||
execution_tree_test.cc
|
execution_tree_test.cc
|
||||||
fill_op_test.cc
|
fill_op_test.cc
|
||||||
|
c_api_vision_gaussian_blur_test.cc
|
||||||
global_context_test.cc
|
global_context_test.cc
|
||||||
gnn_graph_test.cc
|
gnn_graph_test.cc
|
||||||
image_folder_op_test.cc
|
image_folder_op_test.cc
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/image/gaussian_blur_op.h"
|
||||||
|
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "minddata/dataset/include/dataset/datasets.h"
|
||||||
|
#include "minddata/dataset/include/dataset/execute.h"
|
||||||
|
#include "minddata/dataset/include/dataset/vision.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
|
||||||
|
class MindDataTestGaussianBlur : public UT::DatasetOpTesting {
|
||||||
|
protected:
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestGaussianBlur, TestGaussianBlurParamCheck) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestGaussianBlur-TestGaussianBlurParamCheck with invalid parameters.";
|
||||||
|
// Create an ImageFolder Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Case 1: Kernel size is not positive
|
||||||
|
// Create objects for the tensor ops
|
||||||
|
std::shared_ptr<TensorTransform> gaussian_blur1(new vision::GaussianBlur({-1}));
|
||||||
|
auto ds1 = ds->Map({gaussian_blur1});
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||||
|
// Expect failure: invalid kernel_size for GaussianBlur
|
||||||
|
EXPECT_EQ(iter1, nullptr);
|
||||||
|
|
||||||
|
// Case 2: Kernel size is not odd
|
||||||
|
// Create objects for the tensor ops
|
||||||
|
std::shared_ptr<TensorTransform> gaussian_blur2(new vision::GaussianBlur({2, 2}, {3, 3}));
|
||||||
|
auto ds2 = ds->Map({gaussian_blur2});
|
||||||
|
EXPECT_NE(ds2, nullptr);
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||||
|
// Expect failure: invalid kernel_size for GaussianBlur
|
||||||
|
EXPECT_EQ(iter2, nullptr);
|
||||||
|
|
||||||
|
// Case 3: Sigma is not positive
|
||||||
|
// Create objects for the tensor ops
|
||||||
|
std::shared_ptr<TensorTransform> gaussian_blur3(new vision::GaussianBlur({3}, {-3}));
|
||||||
|
auto ds3 = ds->Map({gaussian_blur3});
|
||||||
|
EXPECT_NE(ds3, nullptr);
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||||
|
// Expect failure: invalid sigma for GaussianBlur
|
||||||
|
EXPECT_EQ(iter3, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestGaussianBlur, TestGaussianBlurPipeline) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestGaussianBlur-TestGaussianBlurPipeline.";
|
||||||
|
|
||||||
|
// Create an ImageFolder Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create objects for the tensor ops
|
||||||
|
std::shared_ptr<TensorTransform> gaussian_blur(new vision::GaussianBlur({3, 3}, {5, 5}));
|
||||||
|
|
||||||
|
// Create a Map operation on ds
|
||||||
|
ds = ds->Map({gaussian_blur});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create a Batch operation on ds
|
||||||
|
int32_t batch_size = 1;
|
||||||
|
ds = ds->Batch(batch_size);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
i++;
|
||||||
|
auto image = row["image"];
|
||||||
|
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 10);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestGaussianBlur, TestGaussianBlurEager) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestGaussianBlur-TestGaussianBlurEager.";
|
||||||
|
|
||||||
|
// Read images
|
||||||
|
auto image = ReadFileToTensor("data/dataset/apple.jpg");
|
||||||
|
|
||||||
|
// Transform params
|
||||||
|
auto decode = vision::Decode();
|
||||||
|
auto gaussian_blur = vision::GaussianBlur({7}, {3.5});
|
||||||
|
|
||||||
|
auto transform = Execute({decode, gaussian_blur});
|
||||||
|
Status rc = transform(image, &image);
|
||||||
|
|
||||||
|
EXPECT_EQ(rc, Status::OK());
|
||||||
|
}
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing GaussianBlur op in DE
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.vision.c_transforms as c_vision
|
||||||
|
|
||||||
|
from mindspore import log as logger
|
||||||
|
from util import visualize_image, diff_mse
|
||||||
|
|
||||||
|
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
IMAGE_FILE = "../data/dataset/apple.jpg"
|
||||||
|
|
||||||
|
GENERATE_GOLDEN = False
|
||||||
|
|
||||||
|
|
||||||
|
def test_gaussian_blur_pipeline(plot=False):
|
||||||
|
"""
|
||||||
|
Test GaussianBlur of c_transforms
|
||||||
|
"""
|
||||||
|
logger.info("test_gaussian_blur_pipeline")
|
||||||
|
|
||||||
|
# First dataset
|
||||||
|
dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
gaussian_blur_op = c_vision.GaussianBlur(3, 3)
|
||||||
|
dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
|
||||||
|
dataset1 = dataset1.map(operations=gaussian_blur_op, input_columns=["image"])
|
||||||
|
|
||||||
|
# Second dataset
|
||||||
|
dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||||
|
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||||
|
if num_iter > 0:
|
||||||
|
break
|
||||||
|
gaussian_blur_ms = data1["image"]
|
||||||
|
original = data2["image"]
|
||||||
|
gaussian_blur_cv = cv2.GaussianBlur(original, (3, 3), 3)
|
||||||
|
mse = diff_mse(gaussian_blur_ms, gaussian_blur_cv)
|
||||||
|
logger.info("gaussian_blur_{}, mse: {}".format(num_iter + 1, mse))
|
||||||
|
assert mse == 0
|
||||||
|
num_iter += 1
|
||||||
|
if plot:
|
||||||
|
visualize_image(original, gaussian_blur_ms, mse, gaussian_blur_cv)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gaussian_blur_eager():
|
||||||
|
"""
|
||||||
|
Test GaussianBlur with eager mode
|
||||||
|
"""
|
||||||
|
logger.info("test_gaussian_blur_eager")
|
||||||
|
img = cv2.imread(IMAGE_FILE)
|
||||||
|
|
||||||
|
img_ms = c_vision.GaussianBlur((3, 5), (3.5, 3.5))(img)
|
||||||
|
img_cv = cv2.GaussianBlur(img, (3, 5), 3.5, 3.5)
|
||||||
|
mse = diff_mse(img_ms, img_cv)
|
||||||
|
assert mse == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_gaussian_blur_exception():
|
||||||
|
"""
|
||||||
|
Test GsianBlur with invalid parameters
|
||||||
|
"""
|
||||||
|
logger.info("test_gaussian_blur_exception")
|
||||||
|
try:
|
||||||
|
_ = c_vision.GaussianBlur([2, 2])
|
||||||
|
except ValueError as e:
|
||||||
|
logger.info("Got an exception in GaussianBlur: {}".format(str(e)))
|
||||||
|
assert "not an odd value" in str(e)
|
||||||
|
try:
|
||||||
|
_ = c_vision.GaussianBlur(3.0, [3, 3])
|
||||||
|
except TypeError as e:
|
||||||
|
logger.info("Got an exception in GaussianBlur: {}".format(str(e)))
|
||||||
|
assert "not of type [<class 'int'>, <class 'list'>, <class 'tuple'>]" in str(e)
|
||||||
|
try:
|
||||||
|
_ = c_vision.GaussianBlur(3, -3)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.info("Got an exception in GaussianBlur: {}".format(str(e)))
|
||||||
|
assert "not within the required interval" in str(e)
|
||||||
|
try:
|
||||||
|
_ = c_vision.GaussianBlur(3, [3, 3, 3])
|
||||||
|
except TypeError as e:
|
||||||
|
logger.info("Got an exception in GaussianBlur: {}".format(str(e)))
|
||||||
|
assert "should be a single number or a list/tuple of length 2" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_gaussian_blur_pipeline(plot=True)
|
||||||
|
test_gaussian_blur_eager()
|
||||||
|
test_gaussian_blur_exception()
|
Loading…
Reference in New Issue