!35406 [assistant] [I501QD] Add new operator Solarize
Merge pull request !35406 from HKO/Solarize
This commit is contained in:
commit
0364b852ac
|
@ -70,6 +70,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/rgb_to_bgr_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rotate_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/solarize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/to_tensor_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
|
||||
|
@ -700,6 +701,17 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SolarizeOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
py::class_<vision::SolarizeOperation, TensorOperation, std::shared_ptr<vision::SolarizeOperation>>(
|
||||
*m, "SolarizeOperation")
|
||||
.def(py::init([](const std::vector<float> &threshold) {
|
||||
auto solarize = std::make_shared<vision::SolarizeOperation>(threshold);
|
||||
THROW_IF_ERROR(solarize->ValidateParams());
|
||||
return solarize;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ToTensorOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
py::class_<vision::ToTensorOperation, TensorOperation, std::shared_ptr<vision::ToTensorOperation>>(
|
||||
|
|
|
@ -75,6 +75,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/rgba_to_rgb_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rotate_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/solarize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/to_tensor_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
|
||||
|
@ -1130,6 +1131,16 @@ std::shared_ptr<TensorOperation> SlicePatches::Parse() {
|
|||
data_->fill_value_);
|
||||
}
|
||||
|
||||
// Solarize Transform Operation.
|
||||
struct Solarize::Data {
|
||||
explicit Data(const std::vector<float> &threshold) : threshold_(threshold) {}
|
||||
std::vector<float> threshold_;
|
||||
};
|
||||
|
||||
Solarize::Solarize(const std::vector<float> &threshold) : data_(std::make_shared<Data>(threshold)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> Solarize::Parse() { return std::make_shared<SolarizeOperation>(data_->threshold_); }
|
||||
|
||||
// SwapRedBlue Transform Operation.
|
||||
SwapRedBlue::SwapRedBlue() = default;
|
||||
std::shared_ptr<TensorOperation> SwapRedBlue::Parse() { return std::make_shared<SwapRedBlueOperation>(); }
|
||||
|
|
|
@ -1571,6 +1571,38 @@ class MS_API SlicePatches final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Invert pixels within a specified range.
|
||||
class MS_API Solarize final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] threshold A vector with two elements specifying the pixel range to invert.
|
||||
/// Threshold values should always be in (min, max) format.
|
||||
/// If min=max, it will to invert all pixels above min(max).
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define operations */
|
||||
/// auto decode_op = vision::Decode();
|
||||
/// auto solarize_op = vision::Solarize({0, 255});
|
||||
///
|
||||
/// /* dataset is an instance of Dataset object */
|
||||
/// dataset = dataset->Map({decode_op, solarize_op}, // operations
|
||||
/// {"image"}); // input columns
|
||||
/// \endcode
|
||||
explicit Solarize(const std::vector<float> &threshold);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Solarize() = default;
|
||||
|
||||
protected:
|
||||
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Swap the red and blue channels of the input image.
|
||||
class MS_API SwapRedBlue final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -1895,6 +1895,7 @@ Status SlicePatches(const std::shared_ptr<Tensor> &input, std::vector<std::share
|
|||
Status Solarize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||
const std::vector<float> &threshold) {
|
||||
try {
|
||||
RETURN_IF_NOT_OK(ValidateImage(input, "Solarize", {1, 2, 3, 4, 5, 6, 11, 12}, {2, 3}, {1, 3}));
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
cv::Mat input_img = input_cv->mat();
|
||||
if (!input_cv->mat().data) {
|
||||
|
@ -1920,8 +1921,10 @@ Status Solarize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|||
// solarize desired portion
|
||||
const float max_size = 255.f;
|
||||
output_cv_tensor->mat() = cv::Scalar::all(max_size) - mask_mat_tensor->mat();
|
||||
input_cv->mat().copyTo(output_cv_tensor->mat(), mask_mat_tensor->mat() == 0);
|
||||
input_cv->mat().copyTo(output_cv_tensor->mat(), input_cv->mat() < threshold_min);
|
||||
if (threshold_min < threshold_max) {
|
||||
input_cv->mat().copyTo(output_cv_tensor->mat(), input_cv->mat() > threshold_max);
|
||||
}
|
||||
|
||||
*output = std::static_pointer_cast<Tensor>(output_cv_tensor);
|
||||
}
|
||||
|
|
|
@ -57,6 +57,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
|
|||
rgba_to_rgb_ir.cc
|
||||
rotate_ir.cc
|
||||
slice_patches_ir.cc
|
||||
solarize_ir.cc
|
||||
swap_red_blue_ir.cc
|
||||
to_tensor_ir.cc
|
||||
uniform_aug_ir.cc
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/ir/vision/solarize_ir.h"
|
||||
|
||||
#include "minddata/dataset/kernels/image/solarize_op.h"
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
#include "minddata/dataset/util/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
#ifndef ENABLE_ANDROID
|
||||
// SolarizeOperation
|
||||
SolarizeOperation::SolarizeOperation(const std::vector<float> &threshold) : threshold_(threshold) {}
|
||||
|
||||
SolarizeOperation::~SolarizeOperation() = default;
|
||||
|
||||
Status SolarizeOperation::ValidateParams() {
|
||||
constexpr size_t kThresholdSize = 2;
|
||||
constexpr float kThresholdMax = 255;
|
||||
|
||||
if (threshold_.size() != kThresholdSize) {
|
||||
std::string err_msg =
|
||||
"Solarize: threshold must be a vector of two values, got: " + std::to_string(threshold_.size());
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (size_t i = 0; i < threshold_.size(); ++i) {
|
||||
if (threshold_[i] < 0 || threshold_[i] > kThresholdMax) {
|
||||
std::string err_msg = "Solarize: threshold has to be between 0 and 255, got:" + std::to_string(threshold_[i]);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
if (threshold_[0] > threshold_[1]) {
|
||||
std::string err_msg = "Solarize: threshold must be passed in a (min, max) format";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> SolarizeOperation::Build() {
|
||||
std::shared_ptr<SolarizeOp> tensor_op = std::make_shared<SolarizeOp>(threshold_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status SolarizeOperation::to_json(nlohmann::json *out_json) {
|
||||
(*out_json)["threshold"] = threshold_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SolarizeOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "threshold", kSolarizeOperation));
|
||||
std::vector<float> threshold = op_params["threshold"];
|
||||
*operation = std::make_shared<vision::SolarizeOperation>(threshold);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SOLARIZE_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SOLARIZE_IR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
constexpr char kSolarizeOperation[] = "Solarize";
|
||||
|
||||
class SolarizeOperation : public TensorOperation {
|
||||
public:
|
||||
explicit SolarizeOperation(const std::vector<float> &threshold);
|
||||
|
||||
~SolarizeOperation();
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kSolarizeOperation; };
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
std::vector<float> threshold_;
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SOLARIZE_IR_H_
|
|
@ -52,6 +52,7 @@ from .transforms import AdjustGamma, AutoAugment, AutoContrast, BoundingBoxAugme
|
|||
RandomHorizontalFlipWithBBox, RandomInvert, RandomLighting, RandomPerspective, RandomPosterize, RandomResizedCrop, \
|
||||
RandomResizedCropWithBBox, RandomResize, RandomResizeWithBBox, RandomRotation, RandomSelectSubpolicy, \
|
||||
RandomSharpness, RandomSolarize, RandomVerticalFlip, RandomVerticalFlipWithBBox, Rescale, Resize, ResizeWithBBox, \
|
||||
RgbToHsv, Rotate, SlicePatches, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, UniformAugment, VerticalFlip, not_random
|
||||
RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, UniformAugment, VerticalFlip, \
|
||||
not_random
|
||||
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, get_image_num_channels, \
|
||||
get_image_size
|
||||
|
|
|
@ -70,7 +70,7 @@ from .validators import check_adjust_gamma, check_alpha, check_auto_augment, che
|
|||
check_random_affine, check_random_auto_contrast, check_random_color_adjust, check_random_crop, \
|
||||
check_random_erasing, check_random_perspective, check_random_resize_crop, check_random_rotation, \
|
||||
check_random_select_subpolicy_op, check_random_solarize, check_range, check_rescale, check_resize, \
|
||||
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_ten_crop, \
|
||||
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_solarize, check_ten_crop, \
|
||||
check_uniform_augment, check_to_tensor, FLOAT_MAX_INTEGER
|
||||
from ..core.datatypes import mstype_to_detype, nptype_to_detype
|
||||
from ..transforms.py_transforms_util import Implementation
|
||||
|
@ -3292,6 +3292,40 @@ class SlicePatches(ImageTensorOperation):
|
|||
SliceMode.to_c_type(self.slice_mode), self.fill_value)
|
||||
|
||||
|
||||
class Solarize(ImageTensorOperation):
|
||||
"""
|
||||
Solarize the image by inverting all pixel values within the threshold.
|
||||
|
||||
Args:
|
||||
threshold (Union[float, tuple[float, float]]): Range of solarize threshold, should always
|
||||
be in (min, max) format, where min and max are integers in range of [0, 255], and min <= max.
|
||||
If min=max, then invert all pixel values above min(max).
|
||||
|
||||
Raises:
|
||||
TypeError: If `threshold` is not of type float or tuple[float, float].
|
||||
ValueError: If `threshold` is not in range of [0, 255].
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> transforms_list = [vision.Decode(), vision.Solarize(threshold=(10, 100))]
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
@check_solarize
|
||||
def __init__(self, threshold):
|
||||
super().__init__()
|
||||
if isinstance(threshold, (float, int)):
|
||||
threshold = (threshold, threshold)
|
||||
self.threshold = threshold
|
||||
self.implementation = Implementation.C
|
||||
|
||||
def parse(self):
|
||||
return cde.SolarizeOperation(self.threshold)
|
||||
|
||||
|
||||
class TenCrop(PyTensorOperation):
|
||||
"""
|
||||
Crop the given image into one central crop and four corners with the flipped version of these.
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation
|
|||
from mindspore._c_expression import typing
|
||||
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER, \
|
||||
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
|
||||
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
|
||||
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, UINT8_MIN, check_value_normalize_std, \
|
||||
check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32, check_non_negative_int32, \
|
||||
check_pos_int32, check_tensor_op, deprecator_factory
|
||||
from mindspore.dataset.transforms.validators import check_transform_op_type
|
||||
|
@ -1216,3 +1216,26 @@ def deprecated_py_vision(substitute_name=None, substitute_module=None):
|
|||
"""
|
||||
return deprecator_factory("1.8", "mindspore.dataset.vision.py_transforms", "mindspore.dataset.vision",
|
||||
substitute_name, substitute_module)
|
||||
|
||||
|
||||
def check_solarize(method):
|
||||
"""Wrapper method to check the parameters of SolarizeOp."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
||||
[threshold], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(threshold, (float, int, list, tuple), "threshold")
|
||||
if isinstance(threshold, (float, int)):
|
||||
threshold = (threshold, threshold)
|
||||
type_check_list(threshold, (float, int), "threshold")
|
||||
if len(threshold) != 2:
|
||||
raise TypeError("threshold must be a single number or sequence of two numbers.")
|
||||
for i, value in enumerate(threshold):
|
||||
check_value(value, (UINT8_MIN, UINT8_MAX), "threshold[{}]".format(i))
|
||||
if threshold[1] < threshold[0]:
|
||||
raise ValueError("threshold must be in order of (min, max).")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -2139,3 +2139,58 @@ TEST_F(MindDataTestPipeline, TestPadToSizeInvalid) {
|
|||
// Expect failure: Invalid offset value
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Solarize
|
||||
/// Description: Test default usage
|
||||
/// Expectation: The returned result is as expected
|
||||
TEST_F(MindDataTestPipeline, TestSolarize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSolarize.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::vector<float> threshold = {1.0, 255.0};
|
||||
auto solarize_op = vision::Solarize(threshold);
|
||||
|
||||
ds = ds->Map({solarize_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Solarize
|
||||
/// Description: Test parameter check
|
||||
/// Expectation: Error logs are as expected
|
||||
TEST_F(MindDataTestPipeline, TestSolarizeInvalidFillValue) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSolarizeInvalidFillValue.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::vector<float> threshold = {150, 100};
|
||||
auto solarize_op = vision::Solarize(threshold);
|
||||
|
||||
ds = ds->Map({solarize_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing Solarize op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, config_get_set_seed, config_get_set_num_parallel_workers, \
|
||||
visualize_one_channel_dataset, visualize_image, diff_mse
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def solarize(threshold, plot=False):
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
decode_op = vision.Decode()
|
||||
solarize_op = vision.Solarize(threshold)
|
||||
data1 = data1.map(operations=decode_op, input_columns=["image"])
|
||||
data1 = data1.map(operations=solarize_op, input_columns=["image"])
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data2 = data2.map(operations=decode_op, input_columns=["image"])
|
||||
num_iter = 0
|
||||
for dat1, dat2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
if num_iter > 0:
|
||||
break
|
||||
solarize_ms = dat1["image"]
|
||||
original = dat2["image"]
|
||||
original = Image.fromarray(original.astype('uint8')).convert('RGB')
|
||||
solarize_cv = ImageOps.solarize(original, threshold)
|
||||
solarize_ms = np.array(solarize_ms)
|
||||
solarize_cv = np.array(solarize_cv)
|
||||
mse = diff_mse(solarize_ms, solarize_cv)
|
||||
logger.info("rotate_{}, mse: {}".format(num_iter + 1, mse))
|
||||
assert mse == 0
|
||||
num_iter += 1
|
||||
if plot:
|
||||
visualize_image(original, solarize_ms, mse, solarize_cv)
|
||||
|
||||
image_solarized = []
|
||||
image = []
|
||||
|
||||
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
image_solarized.append(item1["image"].copy())
|
||||
image.append(item2["image"].copy())
|
||||
if plot:
|
||||
visualize_list(image, image_solarized)
|
||||
|
||||
|
||||
def test_solarize_basic(plot=False):
|
||||
"""
|
||||
Feature: Solarize
|
||||
Description: Test Solarize op basic usage
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
solarize(150.1, plot)
|
||||
solarize(120, plot)
|
||||
solarize(115, plot)
|
||||
|
||||
|
||||
def test_solarize_mnist(plot=False):
|
||||
"""
|
||||
Feature: Solarize op
|
||||
Description: Test Solarize op with MNIST dataset (Grayscale images)
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
mnist_1 = ds.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
|
||||
mnist_2 = ds.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
|
||||
mnist_2 = mnist_2.map(operations=vision.Solarize((1.0, 255.0)), input_columns="image")
|
||||
|
||||
images = []
|
||||
images_trans = []
|
||||
labels = []
|
||||
|
||||
for _, (data_orig, data_trans) in enumerate(zip(mnist_1, mnist_2)):
|
||||
image_orig, label_orig = data_orig
|
||||
image_trans, _ = data_trans
|
||||
images.append(image_orig.asnumpy())
|
||||
labels.append(label_orig.asnumpy())
|
||||
images_trans.append(image_trans.asnumpy())
|
||||
|
||||
if plot:
|
||||
visualize_one_channel_dataset(images, images_trans, labels)
|
||||
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_solarize_errors():
|
||||
"""
|
||||
Feature: Solarize op
|
||||
Description: Test that Solarize errors with bad input
|
||||
Expectation: Passes the error check test
|
||||
"""
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
vision.Solarize((12, 1))
|
||||
assert "threshold must be in order of (min, max)." in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
vision.Solarize((-1, 200))
|
||||
assert "Input threshold[0] is not within the required interval of [0, 255]." in str(error_info.value)
|
||||
|
||||
try:
|
||||
vision.Solarize(("122.1", "140"))
|
||||
except TypeError as e:
|
||||
assert "Argument threshold[0] with value 122.1 is not of type [<class 'float'>, <class 'int'>]" in str(e)
|
||||
|
||||
try:
|
||||
vision.Solarize((122, 100, 30))
|
||||
except TypeError as e:
|
||||
assert "threshold must be a single number or sequence of two numbers." in str(e)
|
||||
|
||||
try:
|
||||
vision.Solarize((120,))
|
||||
except TypeError as e:
|
||||
assert "threshold must be a single number or sequence of two numbers." in str(e)
|
||||
|
||||
|
||||
def test_input_shape_errors():
|
||||
"""
|
||||
Feature: Solarize op
|
||||
Description: Test that Solarize errors with bad input shape
|
||||
Expectation: Passes the error check test
|
||||
"""
|
||||
try:
|
||||
image = np.random.randint(0, 256, (300, 300, 3, 3)).astype(np.uint8)
|
||||
vision.Solarize(5)(image)
|
||||
except RuntimeError as e:
|
||||
assert "Solarize: the dimension of image tensor does not match the requirement of operator" in str(e)
|
||||
|
||||
try:
|
||||
image = np.random.randint(0, 256, (4, 300, 300)).astype(np.uint8)
|
||||
vision.Solarize(5)(image)
|
||||
except RuntimeError as e:
|
||||
assert "Solarize: the channel of image tensor does not match the requirement of operator" in str(e)
|
||||
|
||||
try:
|
||||
image = np.random.randint(0, 256, (3, 300, 300)).astype(np.uint8)
|
||||
vision.Solarize(5)(image)
|
||||
except RuntimeError as e:
|
||||
assert "Solarize: the channel of image tensor does not match the requirement of operator" in str(e)
|
||||
|
||||
|
||||
def test_input_type_errors():
|
||||
"""
|
||||
Feature: Solarize op
|
||||
Description: Test that Solarize errors with bad input type
|
||||
Expectation: Passes the error check test
|
||||
"""
|
||||
try:
|
||||
image = np.random.randint(0, 256, (300, 300, 3)).astype(np.uint32)
|
||||
vision.Solarize(5)(image)
|
||||
except RuntimeError as e:
|
||||
assert "Solarize: the data type of image tensor does not match the requirement of operator." in str(e)
|
||||
|
||||
try:
|
||||
image = np.random.randint(0, 256, (300, 300, 3)).astype(np.uint64)
|
||||
vision.Solarize(5)(image)
|
||||
except RuntimeError as e:
|
||||
assert "Solarize: the data type of image tensor does not match the requirement of operator." in str(e)
|
||||
|
||||
try:
|
||||
image = np.random.randint(0, 256, (300, 300, 3)).astype(np.float16)
|
||||
vision.Solarize(5)(image)
|
||||
except RuntimeError as e:
|
||||
assert "Solarize: the data type of image tensor does not match the requirement of operator." in str(e)
|
||||
|
||||
try:
|
||||
image = np.random.randint(0, 256, (300, 300, 3)).astype(np.float64)
|
||||
vision.Solarize(5)(image)
|
||||
except RuntimeError as e:
|
||||
assert "Solarize: the data type of image tensor does not match the requirement of operator." in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_solarize_basic()
|
||||
test_solarize_mnist(plot=False)
|
||||
test_solarize_errors()
|
||||
test_input_shape_errors()
|
||||
test_input_type_errors()
|
Loading…
Reference in New Issue