forked from mindspore-Ecosystem/mindspore
random sharpness cpp op support
This commit is contained in:
parent
e14fff871d
commit
477528de7f
|
@ -43,6 +43,7 @@
|
||||||
#include "minddata/dataset/kernels/image/random_resize_op.h"
|
#include "minddata/dataset/kernels/image/random_resize_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
|
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||||
|
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
|
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
||||||
|
@ -333,6 +334,15 @@ PYBIND_REGISTER(RandomRotationOp, 1, ([](const py::module *m) {
|
||||||
py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
|
py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(RandomSharpnessOp, 1, ([](const py::module *m) {
|
||||||
|
(void)py::class_<RandomSharpnessOp, TensorOp, std::shared_ptr<RandomSharpnessOp>>(
|
||||||
|
*m, "RandomSharpnessOp",
|
||||||
|
"Tensor operation to apply RandomSharpness."
|
||||||
|
"Takes a range for degrees")
|
||||||
|
.def(py::init<float, float>(), py::arg("startDegree") = RandomSharpnessOp::kDefStartDegree,
|
||||||
|
py::arg("endDegree") = RandomSharpnessOp::kDefEndDegree);
|
||||||
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(RandomSelectSubpolicyOp, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(RandomSelectSubpolicyOp, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<RandomSelectSubpolicyOp, TensorOp, std::shared_ptr<RandomSelectSubpolicyOp>>(
|
(void)py::class_<RandomSelectSubpolicyOp, TensorOp, std::shared_ptr<RandomSelectSubpolicyOp>>(
|
||||||
*m, "RandomSelectSubpolicyOp")
|
*m, "RandomSelectSubpolicyOp")
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||||
|
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
||||||
#include "minddata/dataset/kernels/image/resize_op.h"
|
#include "minddata/dataset/kernels/image/resize_op.h"
|
||||||
|
@ -209,6 +210,16 @@ std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min, u
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to create RandomSharpnessOperation.
|
||||||
|
std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees) {
|
||||||
|
auto op = std::make_shared<RandomSharpnessOperation>(degrees);
|
||||||
|
// Input validation
|
||||||
|
if (!op->ValidateParams()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
// Function to create RandomVerticalFlipOperation.
|
// Function to create RandomVerticalFlipOperation.
|
||||||
std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
|
std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
|
||||||
auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
|
auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
|
||||||
|
@ -665,6 +676,22 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
|
||||||
return tensor_op;
|
return tensor_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to create RandomSharpness.
|
||||||
|
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
|
||||||
|
|
||||||
|
bool RandomSharpnessOperation::ValidateParams() {
|
||||||
|
if (degrees_.empty() || degrees_.size() != 2) {
|
||||||
|
MS_LOG(ERROR) << "RandomSharpness: degrees vector has incorrect size: degrees.size()";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
|
||||||
|
std::shared_ptr<RandomSharpnessOp> tensor_op = std::make_shared<RandomSharpnessOp>(degrees_[0], degrees_[1]);
|
||||||
|
return tensor_op;
|
||||||
|
}
|
||||||
|
|
||||||
// RandomSolarizeOperation.
|
// RandomSolarizeOperation.
|
||||||
RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max)
|
RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max)
|
||||||
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}
|
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}
|
||||||
|
|
|
@ -61,6 +61,7 @@ class RandomColorAdjustOperation;
|
||||||
class RandomCropOperation;
|
class RandomCropOperation;
|
||||||
class RandomHorizontalFlipOperation;
|
class RandomHorizontalFlipOperation;
|
||||||
class RandomRotationOperation;
|
class RandomRotationOperation;
|
||||||
|
class RandomSharpnessOperation;
|
||||||
class RandomSolarizeOperation;
|
class RandomSolarizeOperation;
|
||||||
class RandomVerticalFlipOperation;
|
class RandomVerticalFlipOperation;
|
||||||
class ResizeOperation;
|
class ResizeOperation;
|
||||||
|
@ -209,6 +210,13 @@ std::shared_ptr<RandomRotationOperation> RandomRotation(
|
||||||
std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
|
std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
|
||||||
std::vector<float> center = {-1, -1}, std::vector<uint8_t> fill_value = {0, 0, 0});
|
std::vector<float> center = {-1, -1}, std::vector<uint8_t> fill_value = {0, 0, 0});
|
||||||
|
|
||||||
|
/// \brief Function to create a RandomSharpness TensorOperation.
|
||||||
|
/// \notes Tensor operation to perform random sharpness.
|
||||||
|
/// \param[in] start_degree - float representing the start of the range to uniformly sample the factor from it.
|
||||||
|
/// \param[in] end_degree - float representing the end of the range.
|
||||||
|
/// \return Shared pointer to the current TensorOperation.
|
||||||
|
std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees = {0.1, 1.9});
|
||||||
|
|
||||||
/// \brief Function to create a RandomSolarize TensorOperation.
|
/// \brief Function to create a RandomSolarize TensorOperation.
|
||||||
/// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold
|
/// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold
|
||||||
/// \param[in] threshold_min - lower limit
|
/// \param[in] threshold_min - lower limit
|
||||||
|
@ -468,6 +476,20 @@ class RandomRotationOperation : public TensorOperation {
|
||||||
std::vector<uint8_t> fill_value_;
|
std::vector<uint8_t> fill_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class RandomSharpnessOperation : public TensorOperation {
|
||||||
|
public:
|
||||||
|
explicit RandomSharpnessOperation(std::vector<float> degrees = {0.1, 1.9});
|
||||||
|
|
||||||
|
~RandomSharpnessOperation() = default;
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> Build() override;
|
||||||
|
|
||||||
|
bool ValidateParams() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<float> degrees_;
|
||||||
|
};
|
||||||
|
|
||||||
class RandomVerticalFlipOperation : public TensorOperation {
|
class RandomVerticalFlipOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit RandomVerticalFlipOperation(float probability = 0.5);
|
explicit RandomVerticalFlipOperation(float probability = 0.5);
|
||||||
|
|
|
@ -32,9 +32,11 @@ add_library(kernels-image OBJECT
|
||||||
random_solarize_op.cc
|
random_solarize_op.cc
|
||||||
random_vertical_flip_op.cc
|
random_vertical_flip_op.cc
|
||||||
random_vertical_flip_with_bbox_op.cc
|
random_vertical_flip_with_bbox_op.cc
|
||||||
|
random_sharpness_op.cc
|
||||||
rescale_op.cc
|
rescale_op.cc
|
||||||
resize_bilinear_op.cc
|
resize_bilinear_op.cc
|
||||||
resize_op.cc
|
resize_op.cc
|
||||||
|
sharpness_op.cc
|
||||||
solarize_op.cc
|
solarize_op.cc
|
||||||
swap_red_blue_op.cc
|
swap_red_blue_op.cc
|
||||||
uniform_aug_op.cc
|
uniform_aug_op.cc
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
|
||||||
|
#include <random>
|
||||||
|
#include "minddata/dataset/kernels/image/sharpness_op.h"
|
||||||
|
#include "minddata/dataset/core/cv_tensor.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
const float RandomSharpnessOp::kDefStartDegree = 0.1;
|
||||||
|
const float RandomSharpnessOp::kDefEndDegree = 1.9;
|
||||||
|
|
||||||
|
/// constructor
|
||||||
|
RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree)
|
||||||
|
: start_degree_(start_degree), end_degree_(end_degree) {
|
||||||
|
rnd_.seed(GetSeed());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// main function call for random sharpness : Generate the random degrees
|
||||||
|
Status RandomSharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
float random_double = distribution_(rnd_);
|
||||||
|
/// get the degree sharpness range
|
||||||
|
/// the way this op works (uniform distribution)
|
||||||
|
/// assumption here is that mDegreesEnd > mDegreeStart so we always get positive number
|
||||||
|
float degree_range = (end_degree_ - start_degree_) / 2;
|
||||||
|
float mid = (end_degree_ + start_degree_) / 2;
|
||||||
|
alpha_ = mid + random_double * degree_range;
|
||||||
|
|
||||||
|
SharpnessOp::Compute(input, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "minddata/dataset/kernels/image/sharpness_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class RandomSharpnessOp : public SharpnessOp {
|
||||||
|
public:
|
||||||
|
static const float kDefStartDegree;
|
||||||
|
static const float kDefEndDegree;
|
||||||
|
|
||||||
|
/// Adjust the sharpness of the input image by a random degree within the given range.
|
||||||
|
/// \@param[in] start_degree A float indicating the beginning of the range.
|
||||||
|
/// \@param[in] end_degree A float indicating the end of the range.
|
||||||
|
|
||||||
|
explicit RandomSharpnessOp(float start_degree = kDefStartDegree, const float end_degree = kDefEndDegree);
|
||||||
|
~RandomSharpnessOp() override = default;
|
||||||
|
void Print(std::ostream &out) const override { out << Name(); }
|
||||||
|
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kRandomSharpnessOp; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
float start_degree_;
|
||||||
|
float end_degree_;
|
||||||
|
std::uniform_real_distribution<float> distribution_{-1.0, 1.0};
|
||||||
|
std::mt19937 rnd_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_
|
|
@ -0,0 +1,84 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/kernels/image/sharpness_op.h"
|
||||||
|
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||||
|
#include "minddata/dataset/core/cv_tensor.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
const float SharpnessOp::kDefAlpha = 1.0;
|
||||||
|
|
||||||
|
Status SharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
|
||||||
|
try {
|
||||||
|
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||||
|
cv::Mat input_img = input_cv->mat();
|
||||||
|
if (!input_cv->mat().data) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get number of channels and image matrix
|
||||||
|
std::size_t num_of_channels = input_cv->shape()[2];
|
||||||
|
if (num_of_channels != 1 && num_of_channels != 3) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3.");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// creating a smoothing filter. 1, 1, 1,
|
||||||
|
/// 1, 5, 1,
|
||||||
|
/// 1, 1, 1
|
||||||
|
|
||||||
|
float filterSum = 13.0;
|
||||||
|
cv::Mat filter = cv::Mat(3, 3, CV_32F, cv::Scalar::all(1.0 / filterSum));
|
||||||
|
filter.at<float>(1, 1) = 5.0 / filterSum;
|
||||||
|
|
||||||
|
/// applying filter on channels
|
||||||
|
cv::Mat result = cv::Mat();
|
||||||
|
cv::filter2D(input_img, result, -1, filter);
|
||||||
|
|
||||||
|
int height = input_cv->shape()[0];
|
||||||
|
int width = input_cv->shape()[1];
|
||||||
|
|
||||||
|
/// restoring the edges
|
||||||
|
input_img.row(0).copyTo(result.row(0));
|
||||||
|
input_img.row(height - 1).copyTo(result.row(height - 1));
|
||||||
|
input_img.col(0).copyTo(result.col(0));
|
||||||
|
input_img.col(width - 1).copyTo(result.col(width - 1));
|
||||||
|
|
||||||
|
/// blend based on alpha : (alpha_ *input_img) + ((1.0-alpha_) * result);
|
||||||
|
cv::addWeighted(input_img, alpha_, result, 1.0 - alpha_, 0.0, result);
|
||||||
|
|
||||||
|
std::shared_ptr<CVTensor> output_cv;
|
||||||
|
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv));
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(output_cv);
|
||||||
|
|
||||||
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
|
}
|
||||||
|
|
||||||
|
catch (const cv::Exception &e) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("OpenCV error in random sharpness");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class SharpnessOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
/// Default values, also used by bindings.cc
|
||||||
|
static const float kDefAlpha;
|
||||||
|
|
||||||
|
/// This class can be used to adjust the sharpness of an image.
|
||||||
|
/// \@param[in] alpha A float indicating the enhancement factor.
|
||||||
|
/// a factor of 0.0 gives a blurred image, a factor of 1.0 gives the
|
||||||
|
/// original image, and a factor of 2.0 gives a sharpened image.
|
||||||
|
|
||||||
|
explicit SharpnessOp(const float alpha = kDefAlpha) : alpha_(alpha) {}
|
||||||
|
|
||||||
|
~SharpnessOp() override = default;
|
||||||
|
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kSharpnessOp; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
float alpha_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_
|
|
@ -114,6 +114,7 @@ constexpr char kRandomResizeOp[] = "RandomResizeOp";
|
||||||
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
|
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
|
||||||
constexpr char kRandomRotationOp[] = "RandomRotationOp";
|
constexpr char kRandomRotationOp[] = "RandomRotationOp";
|
||||||
constexpr char kRandomSolarizeOp[] = "RandomSolarizeOp";
|
constexpr char kRandomSolarizeOp[] = "RandomSolarizeOp";
|
||||||
|
constexpr char kRandomSharpnessOp[] = "RandomSharpnessOp";
|
||||||
constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp";
|
constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp";
|
||||||
constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp";
|
constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp";
|
||||||
constexpr char kRescaleOp[] = "RescaleOp";
|
constexpr char kRescaleOp[] = "RescaleOp";
|
||||||
|
@ -121,6 +122,7 @@ constexpr char kResizeBilinearOp[] = "ResizeBilinearOp";
|
||||||
constexpr char kResizeOp[] = "ResizeOp";
|
constexpr char kResizeOp[] = "ResizeOp";
|
||||||
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
|
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
|
||||||
constexpr char kSolarizeOp[] = "SolarizeOp";
|
constexpr char kSolarizeOp[] = "SolarizeOp";
|
||||||
|
constexpr char kSharpnessOp[] = "SharpnessOp";
|
||||||
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
|
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
|
||||||
constexpr char kUniformAugOp[] = "UniformAugOp";
|
constexpr char kUniformAugOp[] = "UniformAugOp";
|
||||||
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
|
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
|
||||||
|
|
|
@ -48,7 +48,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec
|
||||||
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||||
check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \
|
check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \
|
||||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, FLOAT_MAX_INTEGER
|
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER
|
||||||
|
|
||||||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||||
|
@ -90,6 +90,31 @@ class AutoContrast(cde.AutoContrastOp):
|
||||||
super().__init__(cutoff, ignore)
|
super().__init__(cutoff, ignore)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSharpness(cde.RandomSharpnessOp):
|
||||||
|
"""
|
||||||
|
Adjust the sharpness of the input image by a fixed or random degree. degree of 0.0 gives a blurred image,
|
||||||
|
a degree of 1.0 gives the original image, and a degree of 2.0 gives a sharpened image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
degrees (sequence): Range of random sharpness adjustment degrees.
|
||||||
|
it should be in (min, max) format. If min=max, then it is a
|
||||||
|
single fixed magnitude operation (default = (0.1, 1.9)).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError : If degrees is not a list or tuple.
|
||||||
|
ValueError: If degrees is not positive.
|
||||||
|
ValueError: If degrees is in (max, min) format instead of (min, max).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>>c_transform.RandomSharpness(degrees=(0.2,1.9))
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_positive_degrees
|
||||||
|
def __init__(self, degrees=(0.1, 1.9)):
|
||||||
|
self.degrees = degrees
|
||||||
|
super().__init__(*degrees)
|
||||||
|
|
||||||
|
|
||||||
class Equalize(cde.EqualizeOp):
|
class Equalize(cde.EqualizeOp):
|
||||||
"""
|
"""
|
||||||
Apply histogram equalization on input image.
|
Apply histogram equalization on input image.
|
||||||
|
|
|
@ -614,14 +614,16 @@ def check_positive_degrees(method):
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[degrees], _ = parse_user_args(method, *args, **kwargs)
|
[degrees], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(degrees, (list, tuple)):
|
if isinstance(degrees, (list, tuple)):
|
||||||
if len(degrees) != 2:
|
if len(degrees) != 2:
|
||||||
raise ValueError("Degrees must be a sequence with length 2.")
|
raise ValueError("Degrees must be a sequence with length 2.")
|
||||||
|
for value in degrees:
|
||||||
|
check_value(value, (0., FLOAT_MAX_INTEGER))
|
||||||
check_positive(degrees[0], "degrees[0]")
|
check_positive(degrees[0], "degrees[0]")
|
||||||
if degrees[0] > degrees[1]:
|
if degrees[0] > degrees[1]:
|
||||||
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
|
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
|
||||||
|
else:
|
||||||
|
raise TypeError("Degrees should be a tuple or list.")
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
|
@ -34,12 +34,12 @@
|
||||||
#include "minddata/dataset/include/samplers.h"
|
#include "minddata/dataset/include/samplers.h"
|
||||||
|
|
||||||
using namespace mindspore::dataset::api;
|
using namespace mindspore::dataset::api;
|
||||||
using mindspore::MsLogLevel::ERROR;
|
|
||||||
using mindspore::ExceptionType::NoExceptionType;
|
|
||||||
using mindspore::LogStream;
|
using mindspore::LogStream;
|
||||||
using mindspore::dataset::Tensor;
|
|
||||||
using mindspore::dataset::Status;
|
|
||||||
using mindspore::dataset::BorderType;
|
using mindspore::dataset::BorderType;
|
||||||
|
using mindspore::dataset::Status;
|
||||||
|
using mindspore::dataset::Tensor;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::ERROR;
|
||||||
|
|
||||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||||
protected:
|
protected:
|
||||||
|
@ -527,12 +527,12 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
|
||||||
std::shared_ptr<TensorOperation> random_color_adjust1 = vision::RandomColorAdjust({1.0}, {0.0}, {0.5}, {0.5});
|
std::shared_ptr<TensorOperation> random_color_adjust1 = vision::RandomColorAdjust({1.0}, {0.0}, {0.5}, {0.5});
|
||||||
EXPECT_NE(random_color_adjust1, nullptr);
|
EXPECT_NE(random_color_adjust1, nullptr);
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> random_color_adjust2 = vision::RandomColorAdjust({1.0, 1.0}, {0.0, 0.0}, {0.5, 0.5},
|
std::shared_ptr<TensorOperation> random_color_adjust2 =
|
||||||
{0.5, 0.5});
|
vision::RandomColorAdjust({1.0, 1.0}, {0.0, 0.0}, {0.5, 0.5}, {0.5, 0.5});
|
||||||
EXPECT_NE(random_color_adjust2, nullptr);
|
EXPECT_NE(random_color_adjust2, nullptr);
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> random_color_adjust3 = vision::RandomColorAdjust({0.5, 1.0}, {0.0, 0.5}, {0.25, 0.5},
|
std::shared_ptr<TensorOperation> random_color_adjust3 =
|
||||||
{0.25, 0.5});
|
vision::RandomColorAdjust({0.5, 1.0}, {0.0, 0.5}, {0.25, 0.5}, {0.25, 0.5});
|
||||||
EXPECT_NE(random_color_adjust3, nullptr);
|
EXPECT_NE(random_color_adjust3, nullptr);
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> random_color_adjust4 = vision::RandomColorAdjust();
|
std::shared_ptr<TensorOperation> random_color_adjust4 = vision::RandomColorAdjust();
|
||||||
|
@ -570,6 +570,64 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
|
||||||
iter->Stop();
|
iter->Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestRandomSharpness) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSharpness.";
|
||||||
|
|
||||||
|
// Create an ImageFolder Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create a Repeat operation on ds
|
||||||
|
int32_t repeat_num = 2;
|
||||||
|
ds = ds->Repeat(repeat_num);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create objects for the tensor ops
|
||||||
|
std::shared_ptr<TensorOperation> random_sharpness_op_1 = vision::RandomSharpness({0.4, 2.3});
|
||||||
|
EXPECT_NE(random_sharpness_op_1, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> random_sharpness_op_2 = vision::RandomSharpness({});
|
||||||
|
EXPECT_EQ(random_sharpness_op_2, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> random_sharpness_op_3 = vision::RandomSharpness();
|
||||||
|
EXPECT_NE(random_sharpness_op_3, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> random_sharpness_op_4 = vision::RandomSharpness({0.1});
|
||||||
|
EXPECT_EQ(random_sharpness_op_4, nullptr);
|
||||||
|
|
||||||
|
// Create a Map operation on ds
|
||||||
|
ds = ds->Map({random_sharpness_op_1, random_sharpness_op_3});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create a Batch operation on ds
|
||||||
|
int32_t batch_size = 1;
|
||||||
|
ds = ds->Batch(batch_size);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
i++;
|
||||||
|
auto image = row["image"];
|
||||||
|
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 20);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestRandomAffineSuccess1) {
|
TEST_F(MindDataTestPipeline, TestRandomAffineSuccess1) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAffineSuccess1 with non-default params.";
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAffineSuccess1 with non-default params.";
|
||||||
|
|
||||||
|
|
|
@ -146,6 +146,14 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
|
||||||
expect_image_path = dir_path + "imagefolder/apple_expect_random_solarize.jpg";
|
expect_image_path = dir_path + "imagefolder/apple_expect_random_solarize.jpg";
|
||||||
actual_image_path = dir_path + "imagefolder/apple_actual_random_solarize.jpg";
|
actual_image_path = dir_path + "imagefolder/apple_actual_random_solarize.jpg";
|
||||||
break;
|
break;
|
||||||
|
case kInvert:
|
||||||
|
expect_image_path = dir_path + "imagefolder/apple_expect_invert.jpg";
|
||||||
|
actual_image_path = dir_path + "imagefolder/apple_actual_invert.jpg";
|
||||||
|
break;
|
||||||
|
case kRandomSharpness:
|
||||||
|
expect_image_path = dir_path + "imagefolder/apple_expect_random_sharpness.jpg";
|
||||||
|
actual_image_path = dir_path + "imagefolder/apple_actual_random_sharpness.jpg";
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(INFO) << "Not pass verification! Operation type does not exists.";
|
MS_LOG(INFO) << "Not pass verification! Operation type does not exists.";
|
||||||
EXPECT_EQ(0, 1);
|
EXPECT_EQ(0, 1);
|
||||||
|
|
|
@ -39,6 +39,8 @@ class CVOpCommon : public Common {
|
||||||
kRandomSolarize,
|
kRandomSolarize,
|
||||||
kTemplate,
|
kTemplate,
|
||||||
kCrop,
|
kCrop,
|
||||||
|
kRandomSharpness,
|
||||||
|
kInvert,
|
||||||
kRandomAffine,
|
kRandomAffine,
|
||||||
kAutoContrast,
|
kAutoContrast,
|
||||||
kEqualize
|
kEqualize
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/image/invert_op.h"
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "common/cvop_common.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
|
||||||
|
class MindDataTestInvert : public UT::CVOP::CVOpCommon {
|
||||||
|
public:
|
||||||
|
MindDataTestInvert() : CVOpCommon() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestInvert, TestOp) {
|
||||||
|
MS_LOG(INFO) << "Doing test Invert.";
|
||||||
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
std::unique_ptr<InvertOp> op(new InvertOp());
|
||||||
|
EXPECT_TRUE(op->OneToOne());
|
||||||
|
Status st = op->Compute(input_tensor_, &output_tensor);
|
||||||
|
EXPECT_TRUE(st.IsOk());
|
||||||
|
CheckImageShapeAndData(output_tensor, kInvert);
|
||||||
|
MS_LOG(INFO) << "testInvert end.";
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "common/cvop_common.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "minddata/dataset/core/config_manager.h"
|
||||||
|
#include "minddata/dataset/core/global_context.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
|
||||||
|
class MindDataTestRandomSharpness : public UT::CVOP::CVOpCommon {
|
||||||
|
public:
|
||||||
|
MindDataTestRandomSharpness() : CVOpCommon() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestRandomSharpness, TestOp) {
|
||||||
|
MS_LOG(INFO) << "Doing test RandomSharpness.";
|
||||||
|
// setting seed here
|
||||||
|
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
|
||||||
|
GlobalContext::config_manager()->set_seed(120);
|
||||||
|
// Sharpness with a factor in range [0.2,1.8]
|
||||||
|
float start_degree = 0.2;
|
||||||
|
float end_degree = 1.8;
|
||||||
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
// sharpening
|
||||||
|
std::unique_ptr<RandomSharpnessOp> op(new RandomSharpnessOp(start_degree, end_degree));
|
||||||
|
EXPECT_TRUE(op->OneToOne());
|
||||||
|
Status st = op->Compute(input_tensor_, &output_tensor);
|
||||||
|
EXPECT_TRUE(st.IsOk());
|
||||||
|
CheckImageShapeAndData(output_tensor, kRandomSharpness);
|
||||||
|
// restoring the seed
|
||||||
|
GlobalContext::config_manager()->set_seed(curr_seed);
|
||||||
|
MS_LOG(INFO) << "testRandomSharpness end.";
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 430 KiB |
Binary file not shown.
After Width: | Height: | Size: 435 KiB |
|
@ -19,20 +19,22 @@ import numpy as np
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.engine as de
|
import mindspore.dataset.engine as de
|
||||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from util import visualize_list, diff_mse, save_and_check_md5, \
|
from util import visualize_list, visualize_one_channel_dataset, diff_mse, save_and_check_md5, \
|
||||||
config_get_set_seed, config_get_set_num_parallel_workers
|
config_get_set_seed, config_get_set_num_parallel_workers
|
||||||
|
|
||||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||||
|
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||||
|
|
||||||
GENERATE_GOLDEN = False
|
GENERATE_GOLDEN = False
|
||||||
|
|
||||||
|
|
||||||
def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
|
def test_random_sharpness_py(degrees=(0.7, 0.7), plot=False):
|
||||||
"""
|
"""
|
||||||
Test RandomSharpness
|
Test RandomSharpness python op
|
||||||
"""
|
"""
|
||||||
logger.info("Test RandomSharpness")
|
logger.info("Test RandomSharpness python op")
|
||||||
|
|
||||||
# Original Images
|
# Original Images
|
||||||
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
@ -57,9 +59,13 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
|
||||||
# Random Sharpness Adjusted Images
|
# Random Sharpness Adjusted Images
|
||||||
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
|
||||||
|
py_op = F.RandomSharpness()
|
||||||
|
if degrees is not None:
|
||||||
|
py_op = F.RandomSharpness(degrees)
|
||||||
|
|
||||||
transforms_random_sharpness = F.ComposeOp([F.Decode(),
|
transforms_random_sharpness = F.ComposeOp([F.Decode(),
|
||||||
F.Resize((224, 224)),
|
F.Resize((224, 224)),
|
||||||
F.RandomSharpness(degrees=degrees),
|
py_op,
|
||||||
F.ToTensor()])
|
F.ToTensor()])
|
||||||
|
|
||||||
ds_random_sharpness = data.map(input_columns="image",
|
ds_random_sharpness = data.map(input_columns="image",
|
||||||
|
@ -86,11 +92,11 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
|
||||||
visualize_list(images_original, images_random_sharpness)
|
visualize_list(images_original, images_random_sharpness)
|
||||||
|
|
||||||
|
|
||||||
def test_random_sharpness_md5():
|
def test_random_sharpness_py_md5():
|
||||||
"""
|
"""
|
||||||
Test RandomSharpness with md5 comparison
|
Test RandomSharpness python op with md5 comparison
|
||||||
"""
|
"""
|
||||||
logger.info("Test RandomSharpness with md5 comparison")
|
logger.info("Test RandomSharpness python op with md5 comparison")
|
||||||
original_seed = config_get_set_seed(5)
|
original_seed = config_get_set_seed(5)
|
||||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
@ -107,7 +113,7 @@ def test_random_sharpness_md5():
|
||||||
data = data.map(input_columns=["image"], operations=transform())
|
data = data.map(input_columns=["image"], operations=transform())
|
||||||
|
|
||||||
# check results with md5 comparison
|
# check results with md5 comparison
|
||||||
filename = "random_sharpness_01_result.npz"
|
filename = "random_sharpness_py_01_result.npz"
|
||||||
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
# Restore configuration
|
# Restore configuration
|
||||||
|
@ -115,8 +121,230 @@ def test_random_sharpness_md5():
|
||||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_sharpness_c(degrees=(1.6, 1.6), plot=False):
|
||||||
|
"""
|
||||||
|
Test RandomSharpness cpp op
|
||||||
|
"""
|
||||||
|
print(degrees)
|
||||||
|
logger.info("Test RandomSharpness cpp op")
|
||||||
|
|
||||||
|
# Original Images
|
||||||
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
|
||||||
|
transforms_original = [C.Decode(),
|
||||||
|
C.Resize((224, 224))]
|
||||||
|
|
||||||
|
ds_original = data.map(input_columns="image",
|
||||||
|
operations=transforms_original)
|
||||||
|
|
||||||
|
ds_original = ds_original.batch(512)
|
||||||
|
|
||||||
|
for idx, (image, _) in enumerate(ds_original):
|
||||||
|
if idx == 0:
|
||||||
|
images_original = image
|
||||||
|
else:
|
||||||
|
images_original = np.append(images_original,
|
||||||
|
image,
|
||||||
|
axis=0)
|
||||||
|
|
||||||
|
# Random Sharpness Adjusted Images
|
||||||
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
|
||||||
|
c_op = C.RandomSharpness()
|
||||||
|
if degrees is not None:
|
||||||
|
c_op = C.RandomSharpness(degrees)
|
||||||
|
|
||||||
|
transforms_random_sharpness = [C.Decode(),
|
||||||
|
C.Resize((224, 224)),
|
||||||
|
c_op]
|
||||||
|
|
||||||
|
ds_random_sharpness = data.map(input_columns="image",
|
||||||
|
operations=transforms_random_sharpness)
|
||||||
|
|
||||||
|
ds_random_sharpness = ds_random_sharpness.batch(512)
|
||||||
|
|
||||||
|
for idx, (image, _) in enumerate(ds_random_sharpness):
|
||||||
|
if idx == 0:
|
||||||
|
images_random_sharpness = image
|
||||||
|
else:
|
||||||
|
images_random_sharpness = np.append(images_random_sharpness,
|
||||||
|
image,
|
||||||
|
axis=0)
|
||||||
|
|
||||||
|
num_samples = images_original.shape[0]
|
||||||
|
mse = np.zeros(num_samples)
|
||||||
|
for i in range(num_samples):
|
||||||
|
mse[i] = diff_mse(images_random_sharpness[i], images_original[i])
|
||||||
|
|
||||||
|
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||||
|
|
||||||
|
if plot:
|
||||||
|
visualize_list(images_original, images_random_sharpness)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_sharpness_c_md5():
|
||||||
|
"""
|
||||||
|
Test RandomSharpness cpp op with md5 comparison
|
||||||
|
"""
|
||||||
|
logger.info("Test RandomSharpness cpp op with md5 comparison")
|
||||||
|
original_seed = config_get_set_seed(200)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
transforms = [
|
||||||
|
C.Decode(),
|
||||||
|
C.RandomSharpness((0.1, 1.9))
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate dataset
|
||||||
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
data = data.map(input_columns=["image"], operations=transforms)
|
||||||
|
|
||||||
|
# check results with md5 comparison
|
||||||
|
filename = "random_sharpness_cpp_01_result.npz"
|
||||||
|
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_sharpness_c_py(degrees=(1.0, 1.0), plot=False):
|
||||||
|
"""
|
||||||
|
Test Random Sharpness C and python Op
|
||||||
|
"""
|
||||||
|
logger.info("Test RandomSharpness C and python Op")
|
||||||
|
|
||||||
|
# RandomSharpness Images
|
||||||
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
data = data.map(input_columns=["image"],
|
||||||
|
operations=[C.Decode(),
|
||||||
|
C.Resize((200, 300))])
|
||||||
|
|
||||||
|
python_op = F.RandomSharpness(degrees)
|
||||||
|
c_op = C.RandomSharpness(degrees)
|
||||||
|
|
||||||
|
transforms_op = F.ComposeOp([lambda img: F.ToPIL()(img.astype(np.uint8)),
|
||||||
|
python_op,
|
||||||
|
np.array])()
|
||||||
|
|
||||||
|
ds_random_sharpness_py = data.map(input_columns="image",
|
||||||
|
operations=transforms_op)
|
||||||
|
|
||||||
|
ds_random_sharpness_py = ds_random_sharpness_py.batch(512)
|
||||||
|
|
||||||
|
for idx, (image, _) in enumerate(ds_random_sharpness_py):
|
||||||
|
if idx == 0:
|
||||||
|
images_random_sharpness_py = image
|
||||||
|
|
||||||
|
else:
|
||||||
|
images_random_sharpness_py = np.append(images_random_sharpness_py,
|
||||||
|
image,
|
||||||
|
axis=0)
|
||||||
|
|
||||||
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
data = data.map(input_columns=["image"],
|
||||||
|
operations=[C.Decode(),
|
||||||
|
C.Resize((200, 300))])
|
||||||
|
|
||||||
|
ds_images_random_sharpness_c = data.map(input_columns="image",
|
||||||
|
operations=c_op)
|
||||||
|
|
||||||
|
ds_images_random_sharpness_c = ds_images_random_sharpness_c.batch(512)
|
||||||
|
|
||||||
|
for idx, (image, _) in enumerate(ds_images_random_sharpness_c):
|
||||||
|
if idx == 0:
|
||||||
|
images_random_sharpness_c = image
|
||||||
|
|
||||||
|
else:
|
||||||
|
images_random_sharpness_c = np.append(images_random_sharpness_c,
|
||||||
|
image,
|
||||||
|
axis=0)
|
||||||
|
|
||||||
|
num_samples = images_random_sharpness_c.shape[0]
|
||||||
|
mse = np.zeros(num_samples)
|
||||||
|
for i in range(num_samples):
|
||||||
|
mse[i] = diff_mse(images_random_sharpness_c[i], images_random_sharpness_py[i])
|
||||||
|
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||||
|
if plot:
|
||||||
|
visualize_list(images_random_sharpness_c, images_random_sharpness_py, visualize_mode=2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_sharpness_one_channel_c(degrees=(1.4, 1.4), plot=False):
|
||||||
|
"""
|
||||||
|
Test Random Sharpness cpp op with one channel
|
||||||
|
"""
|
||||||
|
logger.info("Test RandomSharpness C Op With MNIST Dataset (Grayscale images)")
|
||||||
|
|
||||||
|
c_op = C.RandomSharpness()
|
||||||
|
if degrees is not None:
|
||||||
|
c_op = C.RandomSharpness(degrees)
|
||||||
|
# RandomSharpness Images
|
||||||
|
data = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
|
||||||
|
ds_random_sharpness_c = data.map(input_columns="image", operations=c_op)
|
||||||
|
# Original images
|
||||||
|
data = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
images_trans = []
|
||||||
|
labels = []
|
||||||
|
for _, (data_orig, data_trans) in enumerate(zip(data, ds_random_sharpness_c)):
|
||||||
|
image_orig, label_orig = data_orig
|
||||||
|
image_trans, _ = data_trans
|
||||||
|
images.append(image_orig)
|
||||||
|
labels.append(label_orig)
|
||||||
|
images_trans.append(image_trans)
|
||||||
|
|
||||||
|
if plot:
|
||||||
|
visualize_one_channel_dataset(images, images_trans, labels)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_sharpness_invalid_params():
|
||||||
|
"""
|
||||||
|
Test RandomSharpness with invalid input parameters.
|
||||||
|
"""
|
||||||
|
logger.info("Test RandomSharpness with invalid input parameters.")
|
||||||
|
try:
|
||||||
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
data = data.map(input_columns=["image"],
|
||||||
|
operations=[C.Decode(),
|
||||||
|
C.Resize((224, 224)),
|
||||||
|
C.RandomSharpness(10)])
|
||||||
|
except TypeError as error:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
|
assert "tuple" in str(error)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
data = data.map(input_columns=["image"],
|
||||||
|
operations=[C.Decode(),
|
||||||
|
C.Resize((224, 224)),
|
||||||
|
C.RandomSharpness((-10, 10))])
|
||||||
|
except ValueError as error:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
|
assert "interval" in str(error)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
data = data.map(input_columns=["image"],
|
||||||
|
operations=[C.Decode(),
|
||||||
|
C.Resize((224, 224)),
|
||||||
|
C.RandomSharpness((10, 5))])
|
||||||
|
except ValueError as error:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
|
assert "(min,max)" in str(error)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_random_sharpness()
|
test_random_sharpness_py(plot=True)
|
||||||
test_random_sharpness(plot=True)
|
test_random_sharpness_py(None, plot=True) # test with default values
|
||||||
test_random_sharpness(degrees=(0.5, 1.5), plot=True)
|
test_random_sharpness_py_md5()
|
||||||
test_random_sharpness_md5()
|
test_random_sharpness_c(plot=True)
|
||||||
|
test_random_sharpness_c(None, plot=True) # test with default values
|
||||||
|
test_random_sharpness_c_md5()
|
||||||
|
test_random_sharpness_c_py(degrees=[1.5, 1.5], plot=True)
|
||||||
|
test_random_sharpness_c_py(degrees=[1, 1], plot=True)
|
||||||
|
test_random_sharpness_c_py(degrees=[10, 10], plot=True)
|
||||||
|
test_random_sharpness_one_channel_c(degrees=[1.7, 1.7], plot=True)
|
||||||
|
test_random_sharpness_one_channel_c(degrees=None, plot=True) # test with default values
|
||||||
|
test_random_sharpness_invalid_params()
|
||||||
|
|
Loading…
Reference in New Issue