This commit is contained in:
jinzi 2022-06-25 11:04:30 +08:00
parent 600a7b14dc
commit f4835f6df1
12 changed files with 515 additions and 4 deletions

View File

@ -70,6 +70,7 @@
#include "minddata/dataset/kernels/ir/vision/rgb_to_bgr_ir.h"
#include "minddata/dataset/kernels/ir/vision/rotate_ir.h"
#include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
#include "minddata/dataset/kernels/ir/vision/solarize_ir.h"
#include "minddata/dataset/kernels/ir/vision/to_tensor_ir.h"
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
@ -700,6 +701,17 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(SolarizeOperation, 1, ([](const py::module *m) {
(void)
py::class_<vision::SolarizeOperation, TensorOperation, std::shared_ptr<vision::SolarizeOperation>>(
*m, "SolarizeOperation")
.def(py::init([](const std::vector<float> &threshold) {
auto solarize = std::make_shared<vision::SolarizeOperation>(threshold);
THROW_IF_ERROR(solarize->ValidateParams());
return solarize;
}));
}));
PYBIND_REGISTER(ToTensorOperation, 1, ([](const py::module *m) {
(void)
py::class_<vision::ToTensorOperation, TensorOperation, std::shared_ptr<vision::ToTensorOperation>>(

View File

@ -75,6 +75,7 @@
#include "minddata/dataset/kernels/ir/vision/rgba_to_rgb_ir.h"
#include "minddata/dataset/kernels/ir/vision/rotate_ir.h"
#include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
#include "minddata/dataset/kernels/ir/vision/solarize_ir.h"
#include "minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h"
#include "minddata/dataset/kernels/ir/vision/to_tensor_ir.h"
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
@ -1130,6 +1131,16 @@ std::shared_ptr<TensorOperation> SlicePatches::Parse() {
data_->fill_value_);
}
// Solarize Transform Operation.
struct Solarize::Data {
explicit Data(const std::vector<float> &threshold) : threshold_(threshold) {}
std::vector<float> threshold_;
};
Solarize::Solarize(const std::vector<float> &threshold) : data_(std::make_shared<Data>(threshold)) {}
std::shared_ptr<TensorOperation> Solarize::Parse() { return std::make_shared<SolarizeOperation>(data_->threshold_); }
// SwapRedBlue Transform Operation.
SwapRedBlue::SwapRedBlue() = default;
std::shared_ptr<TensorOperation> SwapRedBlue::Parse() { return std::make_shared<SwapRedBlueOperation>(); }

View File

@ -1570,6 +1570,38 @@ class MS_API SlicePatches final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Invert pixels within a specified range.
class MS_API Solarize final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] threshold A vector with two elements specifying the pixel range to invert.
/// Threshold values should always be in (min, max) format.
/// If min=max, it will to invert all pixels above min(max).
/// \par Example
/// \code
/// /* Define operations */
/// auto decode_op = vision::Decode();
/// auto solarize_op = vision::Solarize({0, 255});
///
/// /* dataset is an instance of Dataset object */
/// dataset = dataset->Map({decode_op, solarize_op}, // operations
/// {"image"}); // input columns
/// \endcode
explicit Solarize(const std::vector<float> &threshold);
/// \brief Destructor.
~Solarize() = default;
protected:
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Swap the red and blue channels of the input image.
class MS_API SwapRedBlue final : public TensorTransform {
public:

View File

@ -1892,6 +1892,7 @@ Status SlicePatches(const std::shared_ptr<Tensor> &input, std::vector<std::share
Status Solarize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
const std::vector<float> &threshold) {
try {
RETURN_IF_NOT_OK(ValidateImage(input, "Solarize", {1, 2, 3, 4, 5, 6, 11, 12}, {2, 3}, {1, 3}));
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
cv::Mat input_img = input_cv->mat();
if (!input_cv->mat().data) {
@ -1917,8 +1918,10 @@ Status Solarize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
// solarize desired portion
const float max_size = 255.f;
output_cv_tensor->mat() = cv::Scalar::all(max_size) - mask_mat_tensor->mat();
input_cv->mat().copyTo(output_cv_tensor->mat(), mask_mat_tensor->mat() == 0);
input_cv->mat().copyTo(output_cv_tensor->mat(), input_cv->mat() < threshold_min);
if (threshold_min < threshold_max) {
input_cv->mat().copyTo(output_cv_tensor->mat(), input_cv->mat() > threshold_max);
}
*output = std::static_pointer_cast<Tensor>(output_cv_tensor);
}

View File

@ -57,6 +57,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
rgba_to_rgb_ir.cc
rotate_ir.cc
slice_patches_ir.cc
solarize_ir.cc
swap_red_blue_ir.cc
to_tensor_ir.cc
uniform_aug_ir.cc

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/vision/solarize_ir.h"
#include "minddata/dataset/kernels/image/solarize_op.h"
#include "minddata/dataset/kernels/ir/validators.h"
#include "minddata/dataset/util/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// SolarizeOperation
SolarizeOperation::SolarizeOperation(const std::vector<float> &threshold) : threshold_(threshold) {}
SolarizeOperation::~SolarizeOperation() = default;
Status SolarizeOperation::ValidateParams() {
constexpr size_t kThresholdSize = 2;
constexpr float kThresholdMax = 255;
if (threshold_.size() != kThresholdSize) {
std::string err_msg =
"Solarize: threshold must be a vector of two values, got: " + std::to_string(threshold_.size());
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (size_t i = 0; i < threshold_.size(); ++i) {
if (threshold_[i] < 0 || threshold_[i] > kThresholdMax) {
std::string err_msg = "Solarize: threshold has to be between 0 and 255, got:" + std::to_string(threshold_[i]);
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
if (threshold_[0] > threshold_[1]) {
std::string err_msg = "Solarize: threshold must be passed in a (min, max) format";
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
std::shared_ptr<TensorOp> SolarizeOperation::Build() {
std::shared_ptr<SolarizeOp> tensor_op = std::make_shared<SolarizeOp>(threshold_);
return tensor_op;
}
Status SolarizeOperation::to_json(nlohmann::json *out_json) {
(*out_json)["threshold"] = threshold_;
return Status::OK();
}
Status SolarizeOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "threshold", kSolarizeOperation));
std::vector<float> threshold = op_params["threshold"];
*operation = std::make_shared<vision::SolarizeOperation>(threshold);
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SOLARIZE_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SOLARIZE_IR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kSolarizeOperation[] = "Solarize";
class SolarizeOperation : public TensorOperation {
public:
explicit SolarizeOperation(const std::vector<float> &threshold);
~SolarizeOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kSolarizeOperation; };
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
std::vector<float> threshold_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SOLARIZE_IR_H_

View File

@ -52,6 +52,7 @@ from .transforms import AdjustGamma, AutoAugment, AutoContrast, BoundingBoxAugme
RandomHorizontalFlipWithBBox, RandomInvert, RandomLighting, RandomPerspective, RandomPosterize, RandomResizedCrop, \
RandomResizedCropWithBBox, RandomResize, RandomResizeWithBBox, RandomRotation, RandomSelectSubpolicy, \
RandomSharpness, RandomSolarize, RandomVerticalFlip, RandomVerticalFlipWithBBox, Rescale, Resize, ResizeWithBBox, \
RgbToHsv, Rotate, SlicePatches, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, UniformAugment, VerticalFlip, not_random
RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, UniformAugment, VerticalFlip, \
not_random
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, get_image_num_channels, \
get_image_size

View File

@ -70,7 +70,7 @@ from .validators import check_adjust_gamma, check_alpha, check_auto_augment, che
check_random_affine, check_random_auto_contrast, check_random_color_adjust, check_random_crop, \
check_random_erasing, check_random_perspective, check_random_resize_crop, check_random_rotation, \
check_random_select_subpolicy_op, check_random_solarize, check_range, check_rescale, check_resize, \
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_ten_crop, \
check_resize_interpolation, check_rgb_to_hsv, check_rotate, check_slice_patches, check_solarize, check_ten_crop, \
check_uniform_augment, check_to_tensor, FLOAT_MAX_INTEGER
from ..core.datatypes import mstype_to_detype, nptype_to_detype
from ..transforms.py_transforms_util import Implementation
@ -3292,6 +3292,40 @@ class SlicePatches(ImageTensorOperation):
SliceMode.to_c_type(self.slice_mode), self.fill_value)
class Solarize(ImageTensorOperation):
"""
Solarize the image by inverting all pixel values within the threshold.
Args:
threshold (Union[float, tuple[float, float]]): Range of solarize threshold, should always
be in (min, max) format, where min and max are integers in range of [0, 255], and min <= max.
If min=max, then invert all pixel values above min(max).
Raises:
TypeError: If `threshold` is not of type float or tuple[float, float].
ValueError: If `threshold` is not in range of [0, 255].
Supported Platforms:
``CPU``
Examples:
>>> transforms_list = [vision.Decode(), vision.Solarize(threshold=(10, 100))]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_solarize
def __init__(self, threshold):
super().__init__()
if isinstance(threshold, (float, int)):
threshold = (threshold, threshold)
self.threshold = threshold
self.implementation = Implementation.C
def parse(self):
return cde.SolarizeOperation(self.threshold)
class TenCrop(PyTensorOperation):
"""
Crop the given image into one central crop and four corners with the flipped version of these.

View File

@ -22,7 +22,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation
from mindspore._c_expression import typing
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER, \
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, UINT8_MIN, check_value_normalize_std, \
check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32, check_non_negative_int32, \
check_pos_int32, check_tensor_op, deprecator_factory
from mindspore.dataset.transforms.validators import check_transform_op_type
@ -1216,3 +1216,26 @@ def deprecated_py_vision(substitute_name=None, substitute_module=None):
"""
return deprecator_factory("1.8", "mindspore.dataset.vision.py_transforms", "mindspore.dataset.vision",
substitute_name, substitute_module)
def check_solarize(method):
"""Wrapper method to check the parameters of SolarizeOp."""
@wraps(method)
def new_method(self, *args, **kwargs):
[threshold], _ = parse_user_args(method, *args, **kwargs)
type_check(threshold, (float, int, list, tuple), "threshold")
if isinstance(threshold, (float, int)):
threshold = (threshold, threshold)
type_check_list(threshold, (float, int), "threshold")
if len(threshold) != 2:
raise TypeError("threshold must be a single number or sequence of two numbers.")
for i, value in enumerate(threshold):
check_value(value, (UINT8_MIN, UINT8_MAX), "threshold[{}]".format(i))
if threshold[1] < threshold[0]:
raise ValueError("threshold must be in order of (min, max).")
return method(self, *args, **kwargs)
return new_method

View File

@ -2089,3 +2089,58 @@ TEST_F(MindDataTestPipeline, TestPadToSizeInvalid) {
// Expect failure: Invalid offset value
EXPECT_EQ(iter, nullptr);
}
/// Feature: Solarize
/// Description: Test default usage
/// Expectation: The returned result is as expected
TEST_F(MindDataTestPipeline, TestSolarize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSolarize.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
std::vector<float> threshold = {1.0, 255.0};
auto solarize_op = vision::Solarize(threshold);
ds = ds->Map({solarize_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 2);
iter->Stop();
}
/// Feature: Solarize
/// Description: Test parameter check
/// Expectation: Error logs are as expected
TEST_F(MindDataTestPipeline, TestSolarizeInvalidFillValue) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSolarizeInvalidFillValue.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
std::vector<float> threshold = {150, 100};
auto solarize_op = vision::Solarize(threshold);
ds = ds->Map({solarize_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1,207 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Solarize op in DE
"""
import numpy as np
from PIL import Image, ImageOps
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import log as logger
from util import visualize_list, config_get_set_seed, config_get_set_num_parallel_workers, \
visualize_one_channel_dataset, visualize_image, diff_mse
GENERATE_GOLDEN = False
MNIST_DATA_DIR = "../data/dataset/testMnistData"
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def solarize(threshold, plot=False):
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = vision.Decode()
solarize_op = vision.Solarize(threshold)
data1 = data1.map(operations=decode_op, input_columns=["image"])
data1 = data1.map(operations=solarize_op, input_columns=["image"])
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(operations=decode_op, input_columns=["image"])
num_iter = 0
for dat1, dat2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
if num_iter > 0:
break
solarize_ms = dat1["image"]
original = dat2["image"]
original = Image.fromarray(original.astype('uint8')).convert('RGB')
solarize_cv = ImageOps.solarize(original, threshold)
solarize_ms = np.array(solarize_ms)
solarize_cv = np.array(solarize_cv)
mse = diff_mse(solarize_ms, solarize_cv)
logger.info("rotate_{}, mse: {}".format(num_iter + 1, mse))
assert mse == 0
num_iter += 1
if plot:
visualize_image(original, solarize_ms, mse, solarize_cv)
image_solarized = []
image = []
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_solarized.append(item1["image"].copy())
image.append(item2["image"].copy())
if plot:
visualize_list(image, image_solarized)
def test_solarize_basic(plot=False):
"""
Feature: Solarize
Description: Test Solarize op basic usage
Expectation: The dataset is processed as expected
"""
solarize(150.1, plot)
solarize(120, plot)
solarize(115, plot)
def test_solarize_mnist(plot=False):
"""
Feature: Solarize op
Description: Test Solarize op with MNIST dataset (Grayscale images)
Expectation: The dataset is processed as expected
"""
original_seed = config_get_set_seed(0)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
mnist_1 = ds.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
mnist_2 = ds.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
mnist_2 = mnist_2.map(operations=vision.Solarize((1.0, 255.0)), input_columns="image")
images = []
images_trans = []
labels = []
for _, (data_orig, data_trans) in enumerate(zip(mnist_1, mnist_2)):
image_orig, label_orig = data_orig
image_trans, _ = data_trans
images.append(image_orig.asnumpy())
labels.append(label_orig.asnumpy())
images_trans.append(image_trans.asnumpy())
if plot:
visualize_one_channel_dataset(images, images_trans, labels)
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_solarize_errors():
"""
Feature: Solarize op
Description: Test that Solarize errors with bad input
Expectation: Passes the error check test
"""
with pytest.raises(ValueError) as error_info:
vision.Solarize((12, 1))
assert "threshold must be in order of (min, max)." in str(error_info.value)
with pytest.raises(ValueError) as error_info:
vision.Solarize((-1, 200))
assert "Input threshold[0] is not within the required interval of [0, 255]." in str(error_info.value)
try:
vision.Solarize(("122.1", "140"))
except TypeError as e:
assert "Argument threshold[0] with value 122.1 is not of type [<class 'float'>, <class 'int'>]" in str(e)
try:
vision.Solarize((122, 100, 30))
except TypeError as e:
assert "threshold must be a single number or sequence of two numbers." in str(e)
try:
vision.Solarize((120,))
except TypeError as e:
assert "threshold must be a single number or sequence of two numbers." in str(e)
def test_input_shape_errors():
"""
Feature: Solarize op
Description: Test that Solarize errors with bad input shape
Expectation: Passes the error check test
"""
try:
image = np.random.randint(0, 256, (300, 300, 3, 3)).astype(np.uint8)
vision.Solarize(5)(image)
except RuntimeError as e:
assert "Solarize: the dimension of image tensor does not match the requirement of operator" in str(e)
try:
image = np.random.randint(0, 256, (4, 300, 300)).astype(np.uint8)
vision.Solarize(5)(image)
except RuntimeError as e:
assert "Solarize: the channel of image tensor does not match the requirement of operator" in str(e)
try:
image = np.random.randint(0, 256, (3, 300, 300)).astype(np.uint8)
vision.Solarize(5)(image)
except RuntimeError as e:
assert "Solarize: the channel of image tensor does not match the requirement of operator" in str(e)
def test_input_type_errors():
"""
Feature: Solarize op
Description: Test that Solarize errors with bad input type
Expectation: Passes the error check test
"""
try:
image = np.random.randint(0, 256, (300, 300, 3)).astype(np.uint32)
vision.Solarize(5)(image)
except RuntimeError as e:
assert "Solarize: the data type of image tensor does not match the requirement of operator." in str(e)
try:
image = np.random.randint(0, 256, (300, 300, 3)).astype(np.uint64)
vision.Solarize(5)(image)
except RuntimeError as e:
assert "Solarize: the data type of image tensor does not match the requirement of operator." in str(e)
try:
image = np.random.randint(0, 256, (300, 300, 3)).astype(np.float16)
vision.Solarize(5)(image)
except RuntimeError as e:
assert "Solarize: the data type of image tensor does not match the requirement of operator." in str(e)
try:
image = np.random.randint(0, 256, (300, 300, 3)).astype(np.float64)
vision.Solarize(5)(image)
except RuntimeError as e:
assert "Solarize: the data type of image tensor does not match the requirement of operator." in str(e)
if __name__ == "__main__":
test_solarize_basic()
test_solarize_mnist(plot=False)
test_solarize_errors()
test_input_shape_errors()
test_input_type_errors()