forked from mindspore-Ecosystem/mindspore
add files
This commit is contained in:
parent
e73e9a9aee
commit
7766efd58e
|
@ -44,6 +44,7 @@
|
|||
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
|
||||
#include "minddata/dataset/kernels/image/rescale_op.h"
|
||||
|
@ -383,5 +384,11 @@ PYBIND_REGISTER(
|
|||
py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomSolarizeOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomSolarizeOp, TensorOp, std::shared_ptr<RandomSolarizeOp>>(*m,
|
||||
"RandomSolarizeOp")
|
||||
.def(py::init<uint8_t, uint8_t>());
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/resize_op.h"
|
||||
#include "minddata/dataset/kernels/image/swap_red_blue_op.h"
|
||||
|
@ -198,6 +199,16 @@ std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degre
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomSolarizeOperation.
|
||||
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min, uint8_t threshold_max) {
|
||||
auto op = std::make_shared<RandomSolarizeOperation>(threshold_min, threshold_max);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomVerticalFlipOperation.
|
||||
std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
|
||||
auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
|
||||
|
@ -654,6 +665,23 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomSolarizeOperation.
|
||||
RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max)
|
||||
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}
|
||||
|
||||
bool RandomSolarizeOperation::ValidateParams() {
|
||||
if (threshold_max_ < threshold_min_) {
|
||||
MS_LOG(ERROR) << "RandomSolarize: threshold_max must be greater or equal to threshold_min";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
|
||||
std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_min_, threshold_max_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomVerticalFlipOperation
|
||||
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ class RandomColorAdjustOperation;
|
|||
class RandomCropOperation;
|
||||
class RandomHorizontalFlipOperation;
|
||||
class RandomRotationOperation;
|
||||
class RandomSolarizeOperation;
|
||||
class RandomVerticalFlipOperation;
|
||||
class ResizeOperation;
|
||||
class SwapRedBlueOperation;
|
||||
|
@ -208,6 +209,13 @@ std::shared_ptr<RandomRotationOperation> RandomRotation(
|
|||
std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
|
||||
std::vector<float> center = {-1, -1}, std::vector<uint8_t> fill_value = {0, 0, 0});
|
||||
|
||||
/// \brief Function to create a RandomSolarize TensorOperation.
|
||||
/// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold
|
||||
/// \param[in] threshold_min - lower limit
|
||||
/// \param[in] threshold_max - upper limit
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min = 0, uint8_t threshold_max = 255);
|
||||
|
||||
/// \brief Function to create a RandomVerticalFlip TensorOperation.
|
||||
/// \notes Tensor operation to perform random vertical flip.
|
||||
/// \param[in] prob - float representing the probability of flip.
|
||||
|
@ -515,6 +523,21 @@ class SwapRedBlueOperation : public TensorOperation {
|
|||
|
||||
bool ValidateParams() override;
|
||||
};
|
||||
|
||||
class RandomSolarizeOperation : public TensorOperation {
|
||||
public:
|
||||
explicit RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max);
|
||||
|
||||
~RandomSolarizeOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
uint8_t threshold_min_;
|
||||
uint8_t threshold_max_;
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
|
|
|
@ -29,11 +29,13 @@ add_library(kernels-image OBJECT
|
|||
random_resize_op.cc
|
||||
random_rotation_op.cc
|
||||
random_select_subpolicy_op.cc
|
||||
random_solarize_op.cc
|
||||
random_vertical_flip_op.cc
|
||||
random_vertical_flip_with_bbox_op.cc
|
||||
rescale_op.cc
|
||||
resize_bilinear_op.cc
|
||||
resize_op.cc
|
||||
solarize_op.cc
|
||||
swap_red_blue_op.cc
|
||||
uniform_aug_op.cc
|
||||
resize_with_bbox_op.cc
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
||||
#include "minddata/dataset/kernels/image/solarize_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status RandomSolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_,
|
||||
"threshold_min must be smaller or equal to threshold_max.");
|
||||
|
||||
uint8_t threshold_min = std::uniform_int_distribution(threshold_min_, threshold_max_)(rnd_);
|
||||
uint8_t threshold_max = std::uniform_int_distribution(threshold_min_, threshold_max_)(rnd_);
|
||||
|
||||
if (threshold_max < threshold_min) {
|
||||
uint8_t temp = threshold_min;
|
||||
threshold_min = threshold_max;
|
||||
threshold_max = temp;
|
||||
}
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold_min, threshold_max));
|
||||
return op->Compute(input, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SOLARIZE_OP_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SOLARIZE_OP_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/kernels/image/solarize_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class RandomSolarizeOp : public SolarizeOp {
|
||||
public:
|
||||
// Pick a random threshold value to solarize the image with
|
||||
explicit RandomSolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255)
|
||||
: threshold_min_(threshold_min), threshold_max_(threshold_max) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
~RandomSolarizeOp() = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomSolarizeOp; }
|
||||
|
||||
private:
|
||||
uint8_t threshold_min_;
|
||||
uint8_t threshold_max_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SOLARIZE_OP_H
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/image/solarize_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// only supports RGB images
|
||||
const uint8_t kPixelValue = 255;
|
||||
|
||||
Status SolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_,
|
||||
"threshold_min must be smaller or equal to threshold_max.");
|
||||
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
cv::Mat input_img = input_cv->mat();
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
|
||||
if (input_cv->Rank() != 2 && input_cv->Rank() != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("Shape not of either <H,W,C> or <H,W> format.");
|
||||
}
|
||||
if (input_cv->Rank() == 3) {
|
||||
int num_channels = input_cv->shape()[2];
|
||||
if (num_channels != 3 && num_channels != 1) {
|
||||
RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3.");
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<CVTensor> mask_mat_tensor;
|
||||
std::shared_ptr<CVTensor> output_cv_tensor;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_cv->mat(), &mask_mat_tensor));
|
||||
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv_tensor));
|
||||
RETURN_UNEXPECTED_IF_NULL(mask_mat_tensor);
|
||||
RETURN_UNEXPECTED_IF_NULL(output_cv_tensor);
|
||||
|
||||
if (threshold_min_ == threshold_max_) {
|
||||
mask_mat_tensor->mat().setTo(0, ~(input_cv->mat() >= threshold_min_));
|
||||
} else {
|
||||
mask_mat_tensor->mat().setTo(0, ~((input_cv->mat() >= threshold_min_) & (input_cv->mat() <= threshold_max_)));
|
||||
}
|
||||
|
||||
// solarize desired portion
|
||||
output_cv_tensor->mat() = cv::Scalar::all(255) - mask_mat_tensor->mat();
|
||||
input_cv->mat().copyTo(output_cv_tensor->mat(), mask_mat_tensor->mat() == 0);
|
||||
input_cv->mat().copyTo(output_cv_tensor->mat(), input_cv->mat() < threshold_min_);
|
||||
|
||||
*output = std::static_pointer_cast<Tensor>(output_cv_tensor);
|
||||
}
|
||||
|
||||
catch (const cv::Exception &e) {
|
||||
const char *cv_err_msg = e.what();
|
||||
std::string err_message = "Error in SolarizeOp: ";
|
||||
err_message += cv_err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_message);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SOLARIZE_OP_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SOLARIZE_OP_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class SolarizeOp : public TensorOp {
|
||||
public:
|
||||
explicit SolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255)
|
||||
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}
|
||||
|
||||
~SolarizeOp() = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kSolarizeOp; }
|
||||
|
||||
private:
|
||||
uint8_t threshold_min_;
|
||||
uint8_t threshold_max_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SOLARIZE_OP_H
|
|
@ -113,12 +113,14 @@ constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp";
|
|||
constexpr char kRandomResizeOp[] = "RandomResizeOp";
|
||||
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
|
||||
constexpr char kRandomRotationOp[] = "RandomRotationOp";
|
||||
constexpr char kRandomSolarizeOp[] = "RandomSolarizeOp";
|
||||
constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp";
|
||||
constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp";
|
||||
constexpr char kRescaleOp[] = "RescaleOp";
|
||||
constexpr char kResizeBilinearOp[] = "ResizeBilinearOp";
|
||||
constexpr char kResizeOp[] = "ResizeOp";
|
||||
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
|
||||
constexpr char kSolarizeOp[] = "SolarizeOp";
|
||||
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
|
||||
constexpr char kUniformAugOp[] = "UniformAugOp";
|
||||
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
|
||||
|
|
|
@ -48,7 +48,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec
|
|||
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||
check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \
|
||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||
check_soft_dvpp_decode_random_crop_resize_jpeg, FLOAT_MAX_INTEGER
|
||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, FLOAT_MAX_INTEGER
|
||||
|
||||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||
|
@ -932,3 +932,20 @@ class SoftDvppDecodeRandomCropResizeJpeg(cde.SoftDvppDecodeRandomCropResizeJpegO
|
|||
self.ratio = ratio
|
||||
self.max_attempts = max_attempts
|
||||
super().__init__(*size, *scale, *ratio, max_attempts)
|
||||
|
||||
|
||||
class RandomSolarize(cde.RandomSolarizeOp):
|
||||
"""
|
||||
Invert all pixel values above a threshold.
|
||||
|
||||
Args:
|
||||
threshold (sequence): Range of random solarize threshold.
|
||||
Threshold values should always be in range of [0, 255], and
|
||||
include at least one integer value in the given range and
|
||||
be in (min, max) format. If min=max, then it is a single
|
||||
fixed magnitude operation (default=(0, 255)).
|
||||
"""
|
||||
|
||||
@check_random_solarize
|
||||
def __init__(self, threshold=(0, 255)):
|
||||
super().__init__(*threshold)
|
||||
|
|
|
@ -21,7 +21,8 @@ from mindspore._c_dataengine import TensorOp
|
|||
|
||||
from .utils import Inter, Border
|
||||
from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
||||
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, check_tensor_op
|
||||
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
|
||||
check_tensor_op, UINT8_MAX
|
||||
|
||||
|
||||
def check_crop_size(size):
|
||||
|
@ -674,4 +675,25 @@ def check_soft_dvpp_decode_random_crop_resize_jpeg(method):
|
|||
check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_random_solarize(method):
|
||||
"""Wrapper method to check the parameters of RandomSolarizeOp."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[threshold], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(threshold, (tuple,), "threshold")
|
||||
type_check_list(threshold, (int,), "threshold")
|
||||
if len(threshold) != 2:
|
||||
raise ValueError("threshold must be a sequence of two numbers")
|
||||
for element in threshold:
|
||||
check_value(element, (0, UINT8_MAX))
|
||||
if threshold[1] < threshold[0]:
|
||||
raise ValueError("threshold must be in min max format numbers")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
return new_method
|
||||
|
|
|
@ -50,6 +50,7 @@ SET(DE_UT_SRCS
|
|||
random_resize_op_test.cc
|
||||
random_resize_with_bbox_op_test.cc
|
||||
random_rotation_op_test.cc
|
||||
random_solarize_op_test.cc
|
||||
random_vertical_flip_op_test.cc
|
||||
random_vertical_flip_with_bbox_op_test.cc
|
||||
rename_op_test.cc
|
||||
|
@ -104,8 +105,9 @@ SET(DE_UT_SRCS
|
|||
sliding_window_op_test.cc
|
||||
epoch_ctrl_op_test.cc
|
||||
sentence_piece_vocab_op_test.cc
|
||||
swap_red_blue_test.cc
|
||||
distributed_sampler_test.cc
|
||||
solarize_op_test.cc
|
||||
swap_red_blue_test.cc
|
||||
distributed_sampler_test.cc
|
||||
)
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
|
|
|
@ -729,3 +729,50 @@ TEST_F(MindDataTestPipeline, TestRandomRotation) {
|
|||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomSolarize) {
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(23, 23); //vision::RandomSolarize();
|
||||
EXPECT_NE(random_solarize, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({random_solarize});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 1;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 20);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
|
|
@ -142,6 +142,10 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
|
|||
expect_image_path = dir_path + "imagefolder/apple_expect_equalize.jpg";
|
||||
actual_image_path = dir_path + "imagefolder/apple_actual_equalize.jpg";
|
||||
break;
|
||||
case kRandomSolarize:
|
||||
expect_image_path = dir_path + "imagefolder/apple_expect_random_solarize.jpg";
|
||||
actual_image_path = dir_path + "imagefolder/apple_actual_random_solarize.jpg";
|
||||
break;
|
||||
default:
|
||||
MS_LOG(INFO) << "Not pass verification! Operation type does not exists.";
|
||||
EXPECT_EQ(0, 1);
|
||||
|
|
|
@ -36,6 +36,7 @@ class CVOpCommon : public Common {
|
|||
kDecode,
|
||||
kChannelSwap,
|
||||
kChangeMode,
|
||||
kRandomSolarize,
|
||||
kTemplate,
|
||||
kCrop,
|
||||
kRandomAffine,
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestRandomSolarizeOp : public UT::CVOP::CVOpCommon {
|
||||
protected:
|
||||
MindDataTestRandomSolarizeOp() : CVOpCommon() {}
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor_;
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestRandomSolarizeOp, TestOp1) {
|
||||
MS_LOG(INFO) << "Doing testRandomSolarizeOp1.";
|
||||
// setting seed here
|
||||
uint32_t curr_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(0);
|
||||
|
||||
std::unique_ptr<RandomSolarizeOp> op(new RandomSolarizeOp(100, 100));
|
||||
EXPECT_TRUE(op->OneToOne());
|
||||
Status s = op->Compute(input_tensor_, &output_tensor_);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
CheckImageShapeAndData(output_tensor_, kRandomSolarize);
|
||||
// restoring the seed
|
||||
GlobalContext::config_manager()->set_seed(curr_seed);
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "minddata/dataset/kernels/image/solarize_op.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestSolarizeOp : public UT::CVOP::CVOpCommon {
|
||||
protected:
|
||||
MindDataTestSolarizeOp() : CVOpCommon() {}
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor_;
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestSolarizeOp, TestOp1) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp1.";
|
||||
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp());
|
||||
EXPECT_TRUE(op->OneToOne());
|
||||
Status s = op->Compute(input_tensor_, &output_tensor_);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestSolarizeOp, TestOp2) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp2 - test default values";
|
||||
|
||||
// unsigned int threshold = 128;
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp());
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 0};
|
||||
std::shared_ptr<Tensor> test_input_tensor;
|
||||
std::shared_ptr<Tensor> expected_output_tensor;
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({1, (long int)test_vector.size(), 1}), &test_input_tensor);
|
||||
Tensor::CreateFromVector(expected_output_vector, TensorShape({1, (long int)test_vector.size(), 1}),
|
||||
&expected_output_tensor);
|
||||
|
||||
std::shared_ptr<Tensor> test_output_tensor;
|
||||
Status s = op->Compute(test_input_tensor, &test_output_tensor);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
ASSERT_TRUE(test_output_tensor->shape() == expected_output_tensor->shape());
|
||||
ASSERT_TRUE(test_output_tensor->type() == expected_output_tensor->type());
|
||||
MS_LOG(DEBUG) << *test_output_tensor << std::endl;
|
||||
MS_LOG(DEBUG) << *expected_output_tensor << std::endl;
|
||||
|
||||
ASSERT_TRUE(*test_output_tensor == *expected_output_tensor);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestSolarizeOp, TestOp3) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp3 - Pass in only threshold_min parameter";
|
||||
|
||||
// unsigned int threshold = 128;
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1));
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 0};
|
||||
std::shared_ptr<Tensor> test_input_tensor;
|
||||
std::shared_ptr<Tensor> expected_output_tensor;
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({1, (long int)test_vector.size(), 1}), &test_input_tensor);
|
||||
Tensor::CreateFromVector(expected_output_vector, TensorShape({1, (long int)test_vector.size(), 1}),
|
||||
&expected_output_tensor);
|
||||
|
||||
std::shared_ptr<Tensor> test_output_tensor;
|
||||
Status s = op->Compute(test_input_tensor, &test_output_tensor);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(test_output_tensor->shape() == expected_output_tensor->shape());
|
||||
ASSERT_TRUE(test_output_tensor->type() == expected_output_tensor->type());
|
||||
MS_LOG(DEBUG) << *test_output_tensor << std::endl;
|
||||
MS_LOG(DEBUG) << *expected_output_tensor << std::endl;
|
||||
ASSERT_TRUE(*test_output_tensor == *expected_output_tensor);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestSolarizeOp, TestOp4) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp4 - Pass in both threshold parameters.";
|
||||
|
||||
// unsigned int threshold = 128;
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1, 230));
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 255};
|
||||
std::shared_ptr<Tensor> test_input_tensor;
|
||||
std::shared_ptr<Tensor> expected_output_tensor;
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({1, (long int)test_vector.size(), 1}), &test_input_tensor);
|
||||
Tensor::CreateFromVector(expected_output_vector, TensorShape({1, (long int)test_vector.size(), 1}),
|
||||
&expected_output_tensor);
|
||||
|
||||
std::shared_ptr<Tensor> test_output_tensor;
|
||||
Status s = op->Compute(test_input_tensor, &test_output_tensor);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(test_output_tensor->shape() == expected_output_tensor->shape());
|
||||
ASSERT_TRUE(test_output_tensor->type() == expected_output_tensor->type());
|
||||
MS_LOG(DEBUG) << *test_output_tensor << std::endl;
|
||||
MS_LOG(DEBUG) << *expected_output_tensor << std::endl;
|
||||
ASSERT_TRUE(*test_output_tensor == *expected_output_tensor);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestSolarizeOp, TestOp5) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp5 - Rank 2 input tensor.";
|
||||
|
||||
// unsigned int threshold = 128;
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1, 230));
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 255};
|
||||
std::shared_ptr<Tensor> test_input_tensor;
|
||||
std::shared_ptr<Tensor> expected_output_tensor;
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({1, (long int)test_vector.size()}), &test_input_tensor);
|
||||
Tensor::CreateFromVector(expected_output_vector, TensorShape({1, (long int)test_vector.size()}),
|
||||
&expected_output_tensor);
|
||||
|
||||
std::shared_ptr<Tensor> test_output_tensor;
|
||||
Status s = op->Compute(test_input_tensor, &test_output_tensor);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(test_output_tensor->shape() == expected_output_tensor->shape());
|
||||
ASSERT_TRUE(test_output_tensor->type() == expected_output_tensor->type());
|
||||
MS_LOG(DEBUG) << *test_output_tensor << std::endl;
|
||||
MS_LOG(DEBUG) << *expected_output_tensor << std::endl;
|
||||
|
||||
ASSERT_TRUE(*test_output_tensor == *expected_output_tensor);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestSolarizeOp, TestOp6) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp6 - Bad Input.";
|
||||
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(10, 1));
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::shared_ptr<Tensor> test_input_tensor;
|
||||
std::shared_ptr<Tensor> test_output_tensor;
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({1, (long int)test_vector.size(), 1}), &test_input_tensor);
|
||||
|
||||
Status s = op->Compute(test_input_tensor, &test_output_tensor);
|
||||
|
||||
EXPECT_TRUE(s.IsError());
|
||||
EXPECT_NE(s.ToString().find("threshold_min must be smaller or equal to threshold_max."), std::string::npos);
|
||||
ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError);
|
||||
}
|
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 544 KiB |
|
@ -0,0 +1,112 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing RandomSolarizeOp op in DE
|
||||
"""
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def test_random_solarize_op(threshold=None, plot=False):
|
||||
"""
|
||||
Test RandomSolarize
|
||||
"""
|
||||
logger.info("Test RandomSolarize")
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
decode_op = vision.Decode()
|
||||
|
||||
if threshold is None:
|
||||
solarize_op = vision.RandomSolarize()
|
||||
else:
|
||||
solarize_op = vision.RandomSolarize(threshold)
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=solarize_op)
|
||||
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
data2 = data2.map(input_columns=["image"], operations=decode_op)
|
||||
|
||||
image_solarized = []
|
||||
image = []
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
image_solarized.append(item1["image"].copy())
|
||||
image.append(item2["image"].copy())
|
||||
if plot:
|
||||
visualize_list(image, image_solarized)
|
||||
|
||||
|
||||
def test_random_solarize_md5():
|
||||
"""
|
||||
Test RandomSolarize
|
||||
"""
|
||||
logger.info("Test RandomSolarize")
|
||||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
decode_op = vision.Decode()
|
||||
random_solarize_op = vision.RandomSolarize((10, 150))
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=random_solarize_op)
|
||||
# Compare with expected md5 from images
|
||||
filename = "random_solarize_01_result.npz"
|
||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
# Restore config setting
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_random_solarize_errors():
|
||||
"""
|
||||
Test that RandomSolarize errors with bad input
|
||||
"""
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
vision.RandomSolarize((12, 1))
|
||||
assert "threshold must be in min max format numbers" in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
vision.RandomSolarize((12, 1000))
|
||||
assert "Input is not within the required interval of (0 to 255)." in str(error_info.value)
|
||||
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
vision.RandomSolarize((122.1, 140))
|
||||
assert "Argument threshold[0] with value 122.1 is not of type (<class 'int'>,)." in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
vision.RandomSolarize((122, 100, 30))
|
||||
assert "threshold must be a sequence of two numbers" in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
vision.RandomSolarize((120,))
|
||||
assert "threshold must be a sequence of two numbers" in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_solarize_op((100, 100), plot=True)
|
||||
test_random_solarize_op((12, 120), plot=True)
|
||||
test_random_solarize_op(plot=True)
|
||||
test_random_solarize_errors()
|
||||
test_random_solarize_md5()
|
Loading…
Reference in New Issue