!22560 [MS][crowdfunding]New operator implementation, RandomAutoContrast

Merge pull request !22560 from yangwm/autocontrast
This commit is contained in:
i-robot 2021-09-24 09:34:37 +00:00 committed by Gitee
commit ee38ffbd3d
15 changed files with 636 additions and 1 deletions

View File

@ -37,6 +37,7 @@
#include "minddata/dataset/kernels/ir/vision/normalize_pad_ir.h"
#include "minddata/dataset/kernels/ir/vision/pad_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_affine_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_color_adjust_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_color_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_crop_decode_resize_ir.h"
@ -287,6 +288,18 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(RandomAutoContrastOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RandomAutoContrastOperation, TensorOperation,
std::shared_ptr<vision::RandomAutoContrastOperation>>(*m,
"RandomAutoContrastOperation")
.def(py::init([](float cutoff, const std::vector<uint32_t> &ignore, float prob) {
auto random_auto_contrast =
std::make_shared<vision::RandomAutoContrastOperation>(cutoff, ignore, prob);
THROW_IF_ERROR(random_auto_contrast->ValidateParams());
return random_auto_contrast;
}));
}));
PYBIND_REGISTER(RandomColorAdjustOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RandomColorAdjustOperation, TensorOperation,
std::shared_ptr<vision::RandomColorAdjustOperation>>(*m,

View File

@ -41,6 +41,7 @@
#include "minddata/dataset/kernels/ir/vision/normalize_pad_ir.h"
#include "minddata/dataset/kernels/ir/vision/pad_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_affine_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_color_adjust_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_color_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_crop_decode_resize_ir.h"
@ -480,6 +481,22 @@ std::shared_ptr<TensorOperation> RandomAffine::Parse() {
}
#ifndef ENABLE_ANDROID
// RandomAutoContrast Transform Operation.
struct RandomAutoContrast::Data {
Data(float cutoff, const std::vector<uint32_t> &ignore, float prob)
: cutoff_(cutoff), ignore_(ignore), probability_(prob) {}
float cutoff_;
std::vector<uint32_t> ignore_;
float probability_;
};
RandomAutoContrast::RandomAutoContrast(float cutoff, std::vector<uint32_t> ignore, float prob)
: data_(std::make_shared<Data>(cutoff, ignore, prob)) {}
std::shared_ptr<TensorOperation> RandomAutoContrast::Parse() {
return std::make_shared<RandomAutoContrastOperation>(data_->cutoff_, data_->ignore_, data_->probability_);
}
// RandomColor Transform Operation.
struct RandomColor::Data {
Data(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}

View File

@ -323,6 +323,31 @@ class Pad final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Automatically adjust the contrast of the image with a given probability.
class RandomAutoContrast final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] cutoff Percent of the lightest and darkest pixels to be cut off from
/// the histogram of the input image. The value must be in range of [0.0, 50.0) (default=0.0).
/// \param[in] ignore The background pixel values to be ignored, each of which must be
/// in range of [0, 255] (default={}).
/// \param[in] prob A float representing the probability of AutoContrast, which must be
/// in range of [0, 1] (default=0.5).
explicit RandomAutoContrast(float cutoff = 0.0, std::vector<uint32_t> ignore = {}, float prob = 0.5);
/// \brief Destructor.
~RandomAutoContrast() = default;
protected:
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Blend an image with its grayscale version with random weights
/// t and 1 - t generated from a given range. If the range is trivial
/// then the weights are determinate and t equals to the bound of the interval.

View File

@ -29,6 +29,7 @@ add_library(kernels-image OBJECT
pad_op.cc
posterize_op.cc
random_affine_op.cc
random_auto_contrast_op.cc
random_color_adjust_op.cc
random_crop_decode_resize_op.cc
random_crop_and_resize_with_bbox_op.cc

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_auto_contrast_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
const float RandomAutoContrastOp::kCutOff = 0.0;
const std::vector<uint32_t> RandomAutoContrastOp::kIgnore = {};
const float RandomAutoContrastOp::kDefProbability = 0.5;
Status RandomAutoContrastOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
if (distribution_(rnd_)) {
return AutoContrast(input, output, cutoff_, ignore_);
}
*output = input;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AUTO_CONTRAST_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AUTO_CONTRAST_OP_H_
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RandomAutoContrastOp : public TensorOp {
public:
// Default values, also used by python_bindings.cc
static const float kCutOff;
static const std::vector<uint32_t> kIgnore;
static const float kDefProbability;
RandomAutoContrastOp(float cutoff, const std::vector<uint32_t> &ignore, float prob = kDefProbability)
: cutoff_(cutoff), ignore_(ignore), distribution_(prob) {
is_deterministic_ = false;
rnd_.seed(GetSeed());
}
~RandomAutoContrastOp() override = default;
// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const RandomAutoContrastOp &so) {
so.Print(out);
return out;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kRandomAutoContrastOp; }
private:
std::mt19937 rnd_;
float cutoff_;
std::vector<uint32_t> ignore_;
std::bernoulli_distribution distribution_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AUTO_CONTRAST_OP_H_

View File

@ -22,6 +22,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
normalize_pad_ir.cc
pad_ir.cc
random_affine_ir.cc
random_auto_contrast_ir.cc
random_color_adjust_ir.cc
random_color_ir.cc
random_crop_decode_resize_ir.cc

View File

@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/vision/random_auto_contrast_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/random_auto_contrast_op.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// RandomAutoContrastOperation
RandomAutoContrastOperation::RandomAutoContrastOperation(float cutoff, const std::vector<uint32_t> &ignore, float prob)
: cutoff_(cutoff), ignore_(ignore), probability_(prob) {}
RandomAutoContrastOperation::~RandomAutoContrastOperation() = default;
std::string RandomAutoContrastOperation::Name() const { return kRandomAutoContrastOperation; }
Status RandomAutoContrastOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateScalar("RandomAutoContrast", "cutoff", cutoff_, {0, 50}, false, true));
for (auto i = 0; i < ignore_.size(); i++) {
RETURN_IF_NOT_OK(ValidateScalar("RandomAutoContrast", "ignore[" + std::to_string(i) + "]", ignore_[i], {0, 255}));
}
RETURN_IF_NOT_OK(ValidateProbability("RandomAutoContrast", probability_));
return Status::OK();
}
std::shared_ptr<TensorOp> RandomAutoContrastOperation::Build() {
std::shared_ptr<RandomAutoContrastOp> tensor_op =
std::make_shared<RandomAutoContrastOp>(cutoff_, ignore_, probability_);
return tensor_op;
}
Status RandomAutoContrastOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["cutoff"] = cutoff_;
args["ignore"] = ignore_;
args["prob"] = probability_;
*out_json = args;
return Status::OK();
}
Status RandomAutoContrastOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("cutoff") != op_params.end(), "Failed to find cutoff");
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("ignore") != op_params.end(), "Failed to find ignore");
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("prob") != op_params.end(), "Failed to find prob");
float cutoff = op_params["cutoff"];
std::vector<uint32_t> ignore = op_params["ignore"];
float prob = op_params["prob"];
*operation = std::make_shared<vision::RandomAutoContrastOperation>(cutoff, ignore, prob);
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_AUTO_CONTRAST_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_AUTO_CONTRAST_IR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kRandomAutoContrastOperation[] = "RandomAutoContrast";
class RandomAutoContrastOperation : public TensorOperation {
public:
RandomAutoContrastOperation(float cutoff, const std::vector<uint32_t> &ignore, float prob);
~RandomAutoContrastOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
float cutoff_;
std::vector<uint32_t> ignore_;
float probability_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_AUTO_CONTRAST_IR_H_

View File

@ -80,6 +80,7 @@ constexpr char kNormalizeOp[] = "NormalizeOp";
constexpr char kNormalizePadOp[] = "NormalizePadOp";
constexpr char kPadOp[] = "PadOp";
constexpr char kRandomAffineOp[] = "RandomAffineOp";
constexpr char kRandomAutoContrastOp[] = "RandomAutoContrastOp";
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
constexpr char kRandomColorOp[] = "RandomColorOp";
constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp";

View File

@ -51,7 +51,7 @@ from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode
from .validators import check_prob, check_crop, check_center_crop, check_resize_interpolation, \
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \
check_uniform_augment_cpp, check_convert_color, check_random_resize_crop, \
check_uniform_augment_cpp, check_convert_color, check_random_resize_crop, check_random_auto_contrast, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches, check_adjust_gamma
@ -776,6 +776,38 @@ class RandomAffine(ImageTensorOperation):
self.fill_value)
class RandomAutoContrast(ImageTensorOperation):
"""
Automatically adjust the contrast of the image with a given probability.
Args:
cutoff (float, optional): Percent of the lightest and darkest pixels to be cut off from
the histogram of the input image. The value must be in range of [0.0, 50.0) (default=0.0).
ignore (Union[int, sequence], optional): The background pixel values to be ignored, each of
which must be in range of [0, 255] (default=None).
prob (float, optional): Probability of the image being automatically contrasted, which
must be in range of [0, 1] (default=0.5).
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.RandomAutoContrast(cutoff=0.0, ignore=None, prob=0.5)]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_random_auto_contrast
def __init__(self, cutoff=0.0, ignore=None, prob=0.5):
if ignore is None:
ignore = []
if isinstance(ignore, int):
ignore = [ignore]
self.cutoff = cutoff
self.ignore = ignore
self.prob = prob
def parse(self):
return cde.RandomAutoContrastOperation(self.cutoff, self.ignore, self.prob)
class RandomColor(ImageTensorOperation):
"""
Adjust the color of the input image by a fixed or random degree.

View File

@ -311,6 +311,30 @@ def check_random_resize_crop(method):
return new_method
def check_random_auto_contrast(method):
"""Wrapper method to check the parameters of Python RandomAutoContrast op."""
@wraps(method)
def new_method(self, *args, **kwargs):
[cutoff, ignore, prob], _ = parse_user_args(method, *args, **kwargs)
type_check(cutoff, (int, float), "cutoff")
check_value_cutoff(cutoff, [0, 50], "cutoff")
if ignore is not None:
type_check(ignore, (list, tuple, int), "ignore")
if isinstance(ignore, int):
check_value(ignore, [0, 255], "ignore")
if isinstance(ignore, (list, tuple)):
for item in ignore:
type_check(item, (int,), "item")
check_value(item, [0, 255], "ignore")
type_check(prob, (float, int,), "prob")
check_value(prob, [0., 1.], "prob")
return method(self, *args, **kwargs)
return new_method
def check_prob(method):
"""A wrapper that wraps a parameter checker (to confirm probability) around the original function."""

View File

@ -1316,3 +1316,83 @@ TEST_F(MindDataTestPipeline, TestRandomInvertInvalidProb) {
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestRandomAutoContrast) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrast.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
auto random_auto_contrast_op = vision::RandomAutoContrast(1.0, {0, 255}, 0.5);
ds = ds->Map({random_auto_contrast_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 2);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidProb) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidProb.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
auto random_auto_contrast_op = vision::RandomAutoContrast(0.0, {}, 1.5);
ds = ds->Map({random_auto_contrast_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidCutoff) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidCutoff.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
auto random_auto_contrast_op = vision::RandomAutoContrast(-2.0, {}, 0.5);
ds = ds->Map({random_auto_contrast_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidIgnore) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidCutoff.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
auto random_auto_contrast_op = vision::RandomAutoContrast(1.0, {10, 256}, 0.5);
ds = ds->Map({random_auto_contrast_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}

View File

@ -1112,3 +1112,17 @@ TEST_F(MindDataTestExecute, TestRandomInvertEager) {
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}
TEST_F(MindDataTestExecute, TestRandomAutoContrastEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRandomAutoContrastEager.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto random_auto_contrast_op = vision::RandomAutoContrast(0.6);
auto transform = Execute({decode, random_auto_contrast_op});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}

View File

@ -0,0 +1,186 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RandomAutoContrast op in DE
"""
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import log as logger
from util import visualize_list, visualize_image, diff_mse
image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg"
data_dir = "../data/dataset/testImageNetData/train/"
def test_random_auto_contrast_pipeline(plot=False):
"""
Test RandomAutoContrast pipeline
"""
logger.info("Test RandomAutoContrast pipeline")
# Original Images
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
transforms_original = [c_vision.Decode(), c_vision.Resize(size=[224, 224])]
ds_original = data_set.map(operations=transforms_original, input_columns="image")
ds_original = ds_original.batch(512)
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image.asnumpy()
else:
images_original = np.append(images_original,
image.asnumpy(),
axis=0)
# Randomly Automatically Contrasted Images
data_set1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
transform_random_auto_contrast = [c_vision.Decode(),
c_vision.Resize(size=[224, 224]),
c_vision.RandomAutoContrast(prob=0.6)]
ds_random_auto_contrast = data_set1.map(operations=transform_random_auto_contrast, input_columns="image")
ds_random_auto_contrast = ds_random_auto_contrast.batch(512)
for idx, (image, _) in enumerate(ds_random_auto_contrast):
if idx == 0:
images_random_auto_contrast = image.asnumpy()
else:
images_random_auto_contrast = np.append(images_random_auto_contrast,
image.asnumpy(),
axis=0)
if plot:
visualize_list(images_original, images_random_auto_contrast)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_random_auto_contrast[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_random_auto_contrast_eager():
"""
Test RandomAutoContrast eager.
"""
img = np.fromfile(image_file, dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = c_vision.Decode()(img)
img_auto_contrast = c_vision.AutoContrast(1.0, None)(img)
img_random_auto_contrast = c_vision.RandomAutoContrast(1.0, None, 1.0)(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img_auto_contrast), img_random_auto_contrast.shape))
assert img_auto_contrast.all() == img_random_auto_contrast.all()
def test_random_auto_contrast_comp(plot=False):
"""
Test RandomAutoContrast op compared with AutoContrast op.
"""
random_auto_contrast_op = c_vision.RandomAutoContrast(prob=1.0)
auto_contrast_op = c_vision.AutoContrast()
dataset1 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
for item in dataset1.create_dict_iterator(output_numpy=True):
image = item['image']
dataset1.map(operations=random_auto_contrast_op, input_columns=['image'])
dataset2 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
dataset2.map(operations=auto_contrast_op, input_columns=['image'])
for item1, item2 in zip(dataset1.create_dict_iterator(output_numpy=True),
dataset2.create_dict_iterator(output_numpy=True)):
image_random_auto_contrast = item1['image']
image_auto_contrast = item2['image']
mse = diff_mse(image_auto_contrast, image_random_auto_contrast)
assert mse == 0
logger.info("mse: {}".format(mse))
if plot:
visualize_image(image, image_random_auto_contrast, mse, image_auto_contrast)
def test_random_auto_contrast_invalid_prob():
"""
Test RandomAutoContrast Op with invalid prob parameter.
"""
logger.info("test_random_auto_contrast_invalid_prob")
dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
try:
random_auto_contrast_op = c_vision.RandomAutoContrast(prob=1.5)
dataset = dataset.map(operations=random_auto_contrast_op, input_columns=['image'])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
def test_random_auto_contrast_invalid_ignore():
"""
Test RandomAutoContrast Op with invalid ignore parameter.
"""
logger.info("test_random_auto_contrast_invalid_ignore")
try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[c_vision.Decode(),
c_vision.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])], input_columns=["image"])
# invalid ignore
data_set = data_set.map(operations=c_vision.RandomAutoContrast(ignore=255.5), input_columns="image")
except TypeError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Argument ignore with value 255.5 is not of type" in str(error)
try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])], input_columns=["image"])
# invalid ignore
data_set = data_set.map(operations=c_vision.RandomAutoContrast(ignore=(10, 100)), input_columns="image")
except TypeError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Argument ignore with value (10,100) is not of type" in str(error)
def test_random_auto_contrast_invalid_cutoff():
"""
Test RandomAutoContrast Op with invalid cutoff parameter.
"""
logger.info("test_random_auto_contrast_invalid_cutoff")
try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[c_vision.Decode(),
c_vision.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])], input_columns=["image"])
# invalid cutoff
data_set = data_set.map(operations=c_vision.RandomAutoContrast(cutoff=-10.0), input_columns="image")
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[c_vision.Decode(),
c_vision.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])], input_columns=["image"])
# invalid cutoff
data_set = data_set.map(operations=c_vision.RandomAutoContrast(cutoff=120.0), input_columns="image")
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
if __name__ == "__main__":
test_random_auto_contrast_pipeline(plot=True)
test_random_auto_contrast_eager()
test_random_auto_contrast_comp(plot=True)
test_random_auto_contrast_invalid_prob()
test_random_auto_contrast_invalid_ignore()
test_random_auto_contrast_invalid_cutoff()