random sharpness cpp op support

This commit is contained in:
avakh 2020-08-04 11:53:29 -04:00
parent e14fff871d
commit 477528de7f
21 changed files with 758 additions and 36 deletions

View File

@ -43,6 +43,7 @@
#include "minddata/dataset/kernels/image/random_resize_op.h"
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
#include "minddata/dataset/kernels/image/random_rotation_op.h"
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
#include "minddata/dataset/kernels/image/random_solarize_op.h"
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
@ -333,6 +334,15 @@ PYBIND_REGISTER(RandomRotationOp, 1, ([](const py::module *m) {
py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
}));
PYBIND_REGISTER(RandomSharpnessOp, 1, ([](const py::module *m) {
(void)py::class_<RandomSharpnessOp, TensorOp, std::shared_ptr<RandomSharpnessOp>>(
*m, "RandomSharpnessOp",
"Tensor operation to apply RandomSharpness."
"Takes a range for degrees")
.def(py::init<float, float>(), py::arg("startDegree") = RandomSharpnessOp::kDefStartDegree,
py::arg("endDegree") = RandomSharpnessOp::kDefEndDegree);
}));
PYBIND_REGISTER(RandomSelectSubpolicyOp, 1, ([](const py::module *m) {
(void)py::class_<RandomSelectSubpolicyOp, TensorOp, std::shared_ptr<RandomSelectSubpolicyOp>>(
*m, "RandomSelectSubpolicyOp")

View File

@ -31,6 +31,7 @@
#include "minddata/dataset/kernels/image/random_crop_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
#include "minddata/dataset/kernels/image/random_rotation_op.h"
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
#include "minddata/dataset/kernels/image/random_solarize_op.h"
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
#include "minddata/dataset/kernels/image/resize_op.h"
@ -209,6 +210,16 @@ std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min, u
return op;
}
// Function to create RandomSharpnessOperation.
std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees) {
auto op = std::make_shared<RandomSharpnessOperation>(degrees);
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
// Function to create RandomVerticalFlipOperation.
std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
@ -665,6 +676,22 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
return tensor_op;
}
// Function to create RandomSharpness.
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
bool RandomSharpnessOperation::ValidateParams() {
if (degrees_.empty() || degrees_.size() != 2) {
MS_LOG(ERROR) << "RandomSharpness: degrees vector has incorrect size: degrees.size()";
return false;
}
return true;
}
std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
std::shared_ptr<RandomSharpnessOp> tensor_op = std::make_shared<RandomSharpnessOp>(degrees_[0], degrees_[1]);
return tensor_op;
}
// RandomSolarizeOperation.
RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max)
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}

View File

@ -61,6 +61,7 @@ class RandomColorAdjustOperation;
class RandomCropOperation;
class RandomHorizontalFlipOperation;
class RandomRotationOperation;
class RandomSharpnessOperation;
class RandomSolarizeOperation;
class RandomVerticalFlipOperation;
class ResizeOperation;
@ -209,6 +210,13 @@ std::shared_ptr<RandomRotationOperation> RandomRotation(
std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
std::vector<float> center = {-1, -1}, std::vector<uint8_t> fill_value = {0, 0, 0});
/// \brief Function to create a RandomSharpness TensorOperation.
/// \notes Tensor operation to perform random sharpness.
/// \param[in] start_degree - float representing the start of the range to uniformly sample the factor from it.
/// \param[in] end_degree - float representing the end of the range.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees = {0.1, 1.9});
/// \brief Function to create a RandomSolarize TensorOperation.
/// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold
/// \param[in] threshold_min - lower limit
@ -468,6 +476,20 @@ class RandomRotationOperation : public TensorOperation {
std::vector<uint8_t> fill_value_;
};
class RandomSharpnessOperation : public TensorOperation {
public:
explicit RandomSharpnessOperation(std::vector<float> degrees = {0.1, 1.9});
~RandomSharpnessOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
private:
std::vector<float> degrees_;
};
class RandomVerticalFlipOperation : public TensorOperation {
public:
explicit RandomVerticalFlipOperation(float probability = 0.5);

View File

@ -32,9 +32,11 @@ add_library(kernels-image OBJECT
random_solarize_op.cc
random_vertical_flip_op.cc
random_vertical_flip_with_bbox_op.cc
random_sharpness_op.cc
rescale_op.cc
resize_bilinear_op.cc
resize_op.cc
sharpness_op.cc
solarize_op.cc
swap_red_blue_op.cc
uniform_aug_op.cc

View File

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
#include <random>
#include "minddata/dataset/kernels/image/sharpness_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
const float RandomSharpnessOp::kDefStartDegree = 0.1;
const float RandomSharpnessOp::kDefEndDegree = 1.9;
/// constructor
RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree)
: start_degree_(start_degree), end_degree_(end_degree) {
rnd_.seed(GetSeed());
}
/// main function call for random sharpness : Generate the random degrees
Status RandomSharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
float random_double = distribution_(rnd_);
/// get the degree sharpness range
/// the way this op works (uniform distribution)
/// assumption here is that mDegreesEnd > mDegreeStart so we always get positive number
float degree_range = (end_degree_ - start_degree_) / 2;
float mid = (end_degree_ + start_degree_) / 2;
alpha_ = mid + random_double * degree_range;
SharpnessOp::Compute(input, output);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/kernels/image/sharpness_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RandomSharpnessOp : public SharpnessOp {
public:
static const float kDefStartDegree;
static const float kDefEndDegree;
/// Adjust the sharpness of the input image by a random degree within the given range.
/// \@param[in] start_degree A float indicating the beginning of the range.
/// \@param[in] end_degree A float indicating the end of the range.
explicit RandomSharpnessOp(float start_degree = kDefStartDegree, const float end_degree = kDefEndDegree);
~RandomSharpnessOp() override = default;
void Print(std::ostream &out) const override { out << Name(); }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kRandomSharpnessOp; }
protected:
float start_degree_;
float end_degree_;
std::uniform_real_distribution<float> distribution_{-1.0, 1.0};
std::mt19937 rnd_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_

View File

@ -0,0 +1,84 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/sharpness_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
const float SharpnessOp::kDefAlpha = 1.0;
Status SharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
cv::Mat input_img = input_cv->mat();
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
}
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>");
}
/// Get number of channels and image matrix
std::size_t num_of_channels = input_cv->shape()[2];
if (num_of_channels != 1 && num_of_channels != 3) {
RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3.");
}
/// creating a smoothing filter. 1, 1, 1,
/// 1, 5, 1,
/// 1, 1, 1
float filterSum = 13.0;
cv::Mat filter = cv::Mat(3, 3, CV_32F, cv::Scalar::all(1.0 / filterSum));
filter.at<float>(1, 1) = 5.0 / filterSum;
/// applying filter on channels
cv::Mat result = cv::Mat();
cv::filter2D(input_img, result, -1, filter);
int height = input_cv->shape()[0];
int width = input_cv->shape()[1];
/// restoring the edges
input_img.row(0).copyTo(result.row(0));
input_img.row(height - 1).copyTo(result.row(height - 1));
input_img.col(0).copyTo(result.col(0));
input_img.col(width - 1).copyTo(result.col(width - 1));
/// blend based on alpha : (alpha_ *input_img) + ((1.0-alpha_) * result);
cv::addWeighted(input_img, alpha_, result, 1.0 - alpha_, 0.0, result);
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);
*output = std::static_pointer_cast<Tensor>(output_cv);
}
catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("OpenCV error in random sharpness");
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class SharpnessOp : public TensorOp {
public:
/// Default values, also used by bindings.cc
static const float kDefAlpha;
/// This class can be used to adjust the sharpness of an image.
/// \@param[in] alpha A float indicating the enhancement factor.
/// a factor of 0.0 gives a blurred image, a factor of 1.0 gives the
/// original image, and a factor of 2.0 gives a sharpened image.
explicit SharpnessOp(const float alpha = kDefAlpha) : alpha_(alpha) {}
~SharpnessOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kSharpnessOp; }
protected:
float alpha_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_

View File

@ -114,6 +114,7 @@ constexpr char kRandomResizeOp[] = "RandomResizeOp";
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
constexpr char kRandomRotationOp[] = "RandomRotationOp";
constexpr char kRandomSolarizeOp[] = "RandomSolarizeOp";
constexpr char kRandomSharpnessOp[] = "RandomSharpnessOp";
constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp";
constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp";
constexpr char kRescaleOp[] = "RescaleOp";
@ -121,6 +122,7 @@ constexpr char kResizeBilinearOp[] = "ResizeBilinearOp";
constexpr char kResizeOp[] = "ResizeOp";
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
constexpr char kSolarizeOp[] = "SolarizeOp";
constexpr char kSharpnessOp[] = "SharpnessOp";
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
constexpr char kUniformAugOp[] = "UniformAugOp";
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";

View File

@ -48,7 +48,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, FLOAT_MAX_INTEGER
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
@ -90,6 +90,31 @@ class AutoContrast(cde.AutoContrastOp):
super().__init__(cutoff, ignore)
class RandomSharpness(cde.RandomSharpnessOp):
"""
Adjust the sharpness of the input image by a fixed or random degree. degree of 0.0 gives a blurred image,
a degree of 1.0 gives the original image, and a degree of 2.0 gives a sharpened image.
Args:
degrees (sequence): Range of random sharpness adjustment degrees.
it should be in (min, max) format. If min=max, then it is a
single fixed magnitude operation (default = (0.1, 1.9)).
Raises:
TypeError : If degrees is not a list or tuple.
ValueError: If degrees is not positive.
ValueError: If degrees is in (max, min) format instead of (min, max).
Examples:
>>>c_transform.RandomSharpness(degrees=(0.2,1.9))
"""
@check_positive_degrees
def __init__(self, degrees=(0.1, 1.9)):
self.degrees = degrees
super().__init__(*degrees)
class Equalize(cde.EqualizeOp):
"""
Apply histogram equalization on input image.

View File

@ -614,14 +614,16 @@ def check_positive_degrees(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[degrees], _ = parse_user_args(method, *args, **kwargs)
if isinstance(degrees, (list, tuple)):
if len(degrees) != 2:
raise ValueError("Degrees must be a sequence with length 2.")
for value in degrees:
check_value(value, (0., FLOAT_MAX_INTEGER))
check_positive(degrees[0], "degrees[0]")
if degrees[0] > degrees[1]:
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
else:
raise TypeError("Degrees should be a tuple or list.")
return method(self, *args, **kwargs)
return new_method

View File

@ -34,12 +34,12 @@
#include "minddata/dataset/include/samplers.h"
using namespace mindspore::dataset::api;
using mindspore::MsLogLevel::ERROR;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
using mindspore::dataset::Tensor;
using mindspore::dataset::Status;
using mindspore::dataset::BorderType;
using mindspore::dataset::Status;
using mindspore::dataset::Tensor;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::ERROR;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
@ -308,10 +308,10 @@ TEST_F(MindDataTestPipeline, TestPad) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
@ -358,10 +358,10 @@ TEST_F(MindDataTestPipeline, TestCutOut) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
@ -527,12 +527,12 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
std::shared_ptr<TensorOperation> random_color_adjust1 = vision::RandomColorAdjust({1.0}, {0.0}, {0.5}, {0.5});
EXPECT_NE(random_color_adjust1, nullptr);
std::shared_ptr<TensorOperation> random_color_adjust2 = vision::RandomColorAdjust({1.0, 1.0}, {0.0, 0.0}, {0.5, 0.5},
{0.5, 0.5});
std::shared_ptr<TensorOperation> random_color_adjust2 =
vision::RandomColorAdjust({1.0, 1.0}, {0.0, 0.0}, {0.5, 0.5}, {0.5, 0.5});
EXPECT_NE(random_color_adjust2, nullptr);
std::shared_ptr<TensorOperation> random_color_adjust3 = vision::RandomColorAdjust({0.5, 1.0}, {0.0, 0.5}, {0.25, 0.5},
{0.25, 0.5});
std::shared_ptr<TensorOperation> random_color_adjust3 =
vision::RandomColorAdjust({0.5, 1.0}, {0.0, 0.5}, {0.25, 0.5}, {0.25, 0.5});
EXPECT_NE(random_color_adjust3, nullptr);
std::shared_ptr<TensorOperation> random_color_adjust4 = vision::RandomColorAdjust();
@ -558,10 +558,68 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomSharpness) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSharpness.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> random_sharpness_op_1 = vision::RandomSharpness({0.4, 2.3});
EXPECT_NE(random_sharpness_op_1, nullptr);
std::shared_ptr<TensorOperation> random_sharpness_op_2 = vision::RandomSharpness({});
EXPECT_EQ(random_sharpness_op_2, nullptr);
std::shared_ptr<TensorOperation> random_sharpness_op_3 = vision::RandomSharpness();
EXPECT_NE(random_sharpness_op_3, nullptr);
std::shared_ptr<TensorOperation> random_sharpness_op_4 = vision::RandomSharpness({0.1});
EXPECT_EQ(random_sharpness_op_4, nullptr);
// Create a Map operation on ds
ds = ds->Map({random_sharpness_op_1, random_sharpness_op_3});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);

View File

@ -146,6 +146,14 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
expect_image_path = dir_path + "imagefolder/apple_expect_random_solarize.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_random_solarize.jpg";
break;
case kInvert:
expect_image_path = dir_path + "imagefolder/apple_expect_invert.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_invert.jpg";
break;
case kRandomSharpness:
expect_image_path = dir_path + "imagefolder/apple_expect_random_sharpness.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_random_sharpness.jpg";
break;
default:
MS_LOG(INFO) << "Not pass verification! Operation type does not exists.";
EXPECT_EQ(0, 1);

View File

@ -39,6 +39,8 @@ class CVOpCommon : public Common {
kRandomSolarize,
kTemplate,
kCrop,
kRandomSharpness,
kInvert,
kRandomAffine,
kAutoContrast,
kEqualize

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/invert_op.h"
#include "common/common.h"
#include "common/cvop_common.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestInvert : public UT::CVOP::CVOpCommon {
public:
MindDataTestInvert() : CVOpCommon() {}
};
TEST_F(MindDataTestInvert, TestOp) {
MS_LOG(INFO) << "Doing test Invert.";
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<InvertOp> op(new InvertOp());
EXPECT_TRUE(op->OneToOne());
Status st = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(st.IsOk());
CheckImageShapeAndData(output_tensor, kInvert);
MS_LOG(INFO) << "testInvert end.";
}

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
#include "common/common.h"
#include "common/cvop_common.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestRandomSharpness : public UT::CVOP::CVOpCommon {
public:
MindDataTestRandomSharpness() : CVOpCommon() {}
};
TEST_F(MindDataTestRandomSharpness, TestOp) {
MS_LOG(INFO) << "Doing test RandomSharpness.";
// setting seed here
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(120);
// Sharpness with a factor in range [0.2,1.8]
float start_degree = 0.2;
float end_degree = 1.8;
std::shared_ptr<Tensor> output_tensor;
// sharpening
std::unique_ptr<RandomSharpnessOp> op(new RandomSharpnessOp(start_degree, end_degree));
EXPECT_TRUE(op->OneToOne());
Status st = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(st.IsOk());
CheckImageShapeAndData(output_tensor, kRandomSharpness);
// restoring the seed
GlobalContext::config_manager()->set_seed(curr_seed);
MS_LOG(INFO) << "testRandomSharpness end.";
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 430 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 435 KiB

View File

@ -19,20 +19,22 @@ import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.py_transforms as F
import mindspore.dataset.transforms.vision.c_transforms as C
from mindspore import log as logger
from util import visualize_list, diff_mse, save_and_check_md5, \
from util import visualize_list, visualize_one_channel_dataset, diff_mse, save_and_check_md5, \
config_get_set_seed, config_get_set_num_parallel_workers
DATA_DIR = "../data/dataset/testImageNetData/train/"
MNIST_DATA_DIR = "../data/dataset/testMnistData"
GENERATE_GOLDEN = False
def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
def test_random_sharpness_py(degrees=(0.7, 0.7), plot=False):
"""
Test RandomSharpness
Test RandomSharpness python op
"""
logger.info("Test RandomSharpness")
logger.info("Test RandomSharpness python op")
# Original Images
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
@ -54,12 +56,16 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
np.transpose(image, (0, 2, 3, 1)),
axis=0)
# Random Sharpness Adjusted Images
# Random Sharpness Adjusted Images
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
py_op = F.RandomSharpness()
if degrees is not None:
py_op = F.RandomSharpness(degrees)
transforms_random_sharpness = F.ComposeOp([F.Decode(),
F.Resize((224, 224)),
F.RandomSharpness(degrees=degrees),
py_op,
F.ToTensor()])
ds_random_sharpness = data.map(input_columns="image",
@ -86,11 +92,11 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
visualize_list(images_original, images_random_sharpness)
def test_random_sharpness_md5():
def test_random_sharpness_py_md5():
"""
Test RandomSharpness with md5 comparison
Test RandomSharpness python op with md5 comparison
"""
logger.info("Test RandomSharpness with md5 comparison")
logger.info("Test RandomSharpness python op with md5 comparison")
original_seed = config_get_set_seed(5)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
@ -107,7 +113,7 @@ def test_random_sharpness_md5():
data = data.map(input_columns=["image"], operations=transform())
# check results with md5 comparison
filename = "random_sharpness_01_result.npz"
filename = "random_sharpness_py_01_result.npz"
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
# Restore configuration
@ -115,8 +121,230 @@ def test_random_sharpness_md5():
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_random_sharpness_c(degrees=(1.6, 1.6), plot=False):
"""
Test RandomSharpness cpp op
"""
print(degrees)
logger.info("Test RandomSharpness cpp op")
# Original Images
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = [C.Decode(),
C.Resize((224, 224))]
ds_original = data.map(input_columns="image",
operations=transforms_original)
ds_original = ds_original.batch(512)
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original,
image,
axis=0)
# Random Sharpness Adjusted Images
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
c_op = C.RandomSharpness()
if degrees is not None:
c_op = C.RandomSharpness(degrees)
transforms_random_sharpness = [C.Decode(),
C.Resize((224, 224)),
c_op]
ds_random_sharpness = data.map(input_columns="image",
operations=transforms_random_sharpness)
ds_random_sharpness = ds_random_sharpness.batch(512)
for idx, (image, _) in enumerate(ds_random_sharpness):
if idx == 0:
images_random_sharpness = image
else:
images_random_sharpness = np.append(images_random_sharpness,
image,
axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_random_sharpness[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize_list(images_original, images_random_sharpness)
def test_random_sharpness_c_md5():
"""
Test RandomSharpness cpp op with md5 comparison
"""
logger.info("Test RandomSharpness cpp op with md5 comparison")
original_seed = config_get_set_seed(200)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# define map operations
transforms = [
C.Decode(),
C.RandomSharpness((0.1, 1.9))
]
# Generate dataset
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(input_columns=["image"], operations=transforms)
# check results with md5 comparison
filename = "random_sharpness_cpp_01_result.npz"
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
# Restore configuration
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_random_sharpness_c_py(degrees=(1.0, 1.0), plot=False):
"""
Test Random Sharpness C and python Op
"""
logger.info("Test RandomSharpness C and python Op")
# RandomSharpness Images
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(input_columns=["image"],
operations=[C.Decode(),
C.Resize((200, 300))])
python_op = F.RandomSharpness(degrees)
c_op = C.RandomSharpness(degrees)
transforms_op = F.ComposeOp([lambda img: F.ToPIL()(img.astype(np.uint8)),
python_op,
np.array])()
ds_random_sharpness_py = data.map(input_columns="image",
operations=transforms_op)
ds_random_sharpness_py = ds_random_sharpness_py.batch(512)
for idx, (image, _) in enumerate(ds_random_sharpness_py):
if idx == 0:
images_random_sharpness_py = image
else:
images_random_sharpness_py = np.append(images_random_sharpness_py,
image,
axis=0)
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(input_columns=["image"],
operations=[C.Decode(),
C.Resize((200, 300))])
ds_images_random_sharpness_c = data.map(input_columns="image",
operations=c_op)
ds_images_random_sharpness_c = ds_images_random_sharpness_c.batch(512)
for idx, (image, _) in enumerate(ds_images_random_sharpness_c):
if idx == 0:
images_random_sharpness_c = image
else:
images_random_sharpness_c = np.append(images_random_sharpness_c,
image,
axis=0)
num_samples = images_random_sharpness_c.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_random_sharpness_c[i], images_random_sharpness_py[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize_list(images_random_sharpness_c, images_random_sharpness_py, visualize_mode=2)
def test_random_sharpness_one_channel_c(degrees=(1.4, 1.4), plot=False):
"""
Test Random Sharpness cpp op with one channel
"""
logger.info("Test RandomSharpness C Op With MNIST Dataset (Grayscale images)")
c_op = C.RandomSharpness()
if degrees is not None:
c_op = C.RandomSharpness(degrees)
# RandomSharpness Images
data = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
ds_random_sharpness_c = data.map(input_columns="image", operations=c_op)
# Original images
data = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
images = []
images_trans = []
labels = []
for _, (data_orig, data_trans) in enumerate(zip(data, ds_random_sharpness_c)):
image_orig, label_orig = data_orig
image_trans, _ = data_trans
images.append(image_orig)
labels.append(label_orig)
images_trans.append(image_trans)
if plot:
visualize_one_channel_dataset(images, images_trans, labels)
def test_random_sharpness_invalid_params():
"""
Test RandomSharpness with invalid input parameters.
"""
logger.info("Test RandomSharpness with invalid input parameters.")
try:
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(input_columns=["image"],
operations=[C.Decode(),
C.Resize((224, 224)),
C.RandomSharpness(10)])
except TypeError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "tuple" in str(error)
try:
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(input_columns=["image"],
operations=[C.Decode(),
C.Resize((224, 224)),
C.RandomSharpness((-10, 10))])
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "interval" in str(error)
try:
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(input_columns=["image"],
operations=[C.Decode(),
C.Resize((224, 224)),
C.RandomSharpness((10, 5))])
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "(min,max)" in str(error)
if __name__ == "__main__":
test_random_sharpness()
test_random_sharpness(plot=True)
test_random_sharpness(degrees=(0.5, 1.5), plot=True)
test_random_sharpness_md5()
test_random_sharpness_py(plot=True)
test_random_sharpness_py(None, plot=True) # test with default values
test_random_sharpness_py_md5()
test_random_sharpness_c(plot=True)
test_random_sharpness_c(None, plot=True) # test with default values
test_random_sharpness_c_md5()
test_random_sharpness_c_py(degrees=[1.5, 1.5], plot=True)
test_random_sharpness_c_py(degrees=[1, 1], plot=True)
test_random_sharpness_c_py(degrees=[10, 10], plot=True)
test_random_sharpness_one_channel_c(degrees=[1.7, 1.7], plot=True)
test_random_sharpness_one_channel_c(degrees=None, plot=True) # test with default values
test_random_sharpness_invalid_params()