forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3CEGR] add op random lighting
This commit is contained in:
parent
889f3ddc1f
commit
77caf907c4
|
@ -48,6 +48,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_lighting_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
|
||||
|
@ -422,6 +423,16 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomLightingOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::RandomLightingOperation, TensorOperation,
|
||||
std::shared_ptr<vision::RandomLightingOperation>>(*m, "RandomLightingOperation")
|
||||
.def(py::init([](float alpha) {
|
||||
auto random_lighting = std::make_shared<vision::RandomLightingOperation>(alpha);
|
||||
THROW_IF_ERROR(random_lighting->ValidateParams());
|
||||
return random_lighting;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomPosterizeOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::RandomPosterizeOperation, TensorOperation,
|
||||
std::shared_ptr<vision::RandomPosterizeOperation>>(*m, "RandomPosterizeOperation")
|
||||
|
|
|
@ -52,11 +52,12 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_lighting_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_resize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_resize_with_bbox_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_rotation_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_select_subpolicy_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_sharpness_ir.h"
|
||||
|
@ -665,6 +666,18 @@ std::shared_ptr<TensorOperation> RandomInvert::Parse() {
|
|||
return std::make_shared<RandomInvertOperation>(data_->probability_);
|
||||
}
|
||||
|
||||
// RandomLighting Transform Operation.
|
||||
struct RandomLighting::Data {
|
||||
explicit Data(float alpha) : alpha_(alpha) {}
|
||||
float alpha_;
|
||||
};
|
||||
|
||||
RandomLighting::RandomLighting(float alpha) : data_(std::make_shared<Data>(alpha)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> RandomLighting::Parse() {
|
||||
return std::make_shared<RandomLightingOperation>(data_->alpha_);
|
||||
}
|
||||
|
||||
// RandomPosterize Transform Operation.
|
||||
struct RandomPosterize::Data {
|
||||
explicit Data(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
|
||||
|
|
|
@ -625,6 +625,26 @@ class RandomInvert final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Add AlexNet-style PCA-based noise to an image.
|
||||
class RandomLighting final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] alpha A float representing the intensity of the image (default=0.05).
|
||||
explicit RandomLighting(float alpha = 0.05);
|
||||
|
||||
/// \brief Destructor.
|
||||
~RandomLighting() = default;
|
||||
|
||||
protected:
|
||||
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Reduce the number of bits for each color channel randomly.
|
||||
class RandomPosterize final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -41,6 +41,7 @@ add_library(kernels-image OBJECT
|
|||
random_horizontal_flip_op.cc
|
||||
random_horizontal_flip_with_bbox_op.cc
|
||||
random_invert_op.cc
|
||||
random_lighting_op.cc
|
||||
bounding_box_augment_op.cc
|
||||
random_posterize_op.cc
|
||||
random_resize_op.cc
|
||||
|
|
|
@ -1327,6 +1327,54 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
|
|||
}
|
||||
}
|
||||
|
||||
Status RandomLighting(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float rnd_r, float rnd_g,
|
||||
float rnd_b) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
cv::Mat input_img = input_cv->mat();
|
||||
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] RandomLighting: load image failed.");
|
||||
}
|
||||
|
||||
if (input_cv->Rank() != DEFAULT_IMAGE_RANK || input_cv->shape()[CHANNEL_INDEX] != DEFAULT_IMAGE_CHANNELS) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"RandomLighting: input tensor is not in shape of <H,W,C> or channel is not 3, got rank: " +
|
||||
std::to_string(input_cv->Rank()) + ", and channel: " + std::to_string(input_cv->shape()[CHANNEL_INDEX]));
|
||||
}
|
||||
auto input_type = input->type();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_type != DataType::DE_UINT32 && input_type != DataType::DE_UINT64 &&
|
||||
input_type != DataType::DE_INT64 && input_type != DataType::DE_STRING,
|
||||
"RandomLighting: invalid tensor type of uint32, int64, uint64 or string.");
|
||||
|
||||
std::vector<std::vector<float>> eig = {{55.46 * -0.5675, 4.794 * 0.7192, 1.148 * 0.4009},
|
||||
{55.46 * -0.5808, 4.794 * -0.0045, 1.148 * -0.8140},
|
||||
{55.46 * -0.5836, 4.794 * -0.6948, 1.148 * 0.4203}};
|
||||
|
||||
float pca_r = eig[0][0] * rnd_r + eig[0][1] * rnd_g + eig[0][2] * rnd_b;
|
||||
float pca_g = eig[1][0] * rnd_r + eig[1][1] * rnd_g + eig[1][2] * rnd_b;
|
||||
float pca_b = eig[2][0] * rnd_r + eig[2][1] * rnd_g + eig[2][2] * rnd_b;
|
||||
for (int row = 0; row < input_img.rows; row++) {
|
||||
for (int col = 0; col < input_img.cols; col++) {
|
||||
float r = static_cast<float>(input_img.at<cv::Vec3b>(row, col)[0]);
|
||||
float g = static_cast<float>(input_img.at<cv::Vec3b>(row, col)[1]);
|
||||
float b = static_cast<float>(input_img.at<cv::Vec3b>(row, col)[2]);
|
||||
input_img.at<cv::Vec3b>(row, col)[0] = cv::saturate_cast<uchar>(r + pca_r);
|
||||
input_img.at<cv::Vec3b>(row, col)[1] = cv::saturate_cast<uchar>(g + pca_g);
|
||||
input_img.at<cv::Vec3b>(row, col)[2] = cv::saturate_cast<uchar>(b + pca_b);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<CVTensor> output_cv;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_img, input_cv->Rank(), &output_cv));
|
||||
|
||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||
return Status::OK();
|
||||
} catch (const cv::Exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("RandomLighting: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
Status RgbaToRgb(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
|
||||
|
|
|
@ -304,6 +304,16 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
|
|||
const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types,
|
||||
uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
|
||||
|
||||
/// \brief Add AlexNet-style PCA-based noise to an image.
|
||||
/// \param[in] input The input image.
|
||||
/// \param[out] output The output image.
|
||||
/// \param[in] rnd_r Random weight for red channel.
|
||||
/// \param[in] rnd_g Random weight for green channel.
|
||||
/// \param[in] rnd_b Random weight for blue channel.
|
||||
/// \return Status code.
|
||||
Status RandomLighting(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float rnd_r, float rnd_g,
|
||||
float rnd_b);
|
||||
|
||||
/// \brief Take in a 4 channel image in RBGA to RGB
|
||||
/// \param[in] input The input image
|
||||
/// \param[out] output The output image
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/random_lighting_op.h"
|
||||
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
const float RandomLightingOp::kAlpha = 0.05;
|
||||
|
||||
Status RandomLightingOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
float rnd_r = dist_(rnd_rgb_);
|
||||
float rnd_g = dist_(rnd_rgb_);
|
||||
float rnd_b = dist_(rnd_rgb_);
|
||||
return RandomLighting(input, output, rnd_r, rnd_g, rnd_b);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_LIGHTING_OP_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_LIGHTING_OP_H
|
||||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class RandomLightingOp : public TensorOp {
|
||||
public:
|
||||
// Default values
|
||||
static const float kAlpha;
|
||||
|
||||
explicit RandomLightingOp(float alpha = kAlpha) : dist_(0, alpha) {
|
||||
rnd_rgb_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
~RandomLightingOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
std::string Name() const override { return kRandomLightingOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 rnd_rgb_;
|
||||
std::normal_distribution<float> dist_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_LIGHTING_OP_H
|
|
@ -33,6 +33,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
|
|||
random_horizontal_flip_ir.cc
|
||||
random_horizontal_flip_with_bbox_ir.cc
|
||||
random_invert_ir.cc
|
||||
random_lighting_ir.cc
|
||||
random_posterize_ir.cc
|
||||
random_resized_crop_ir.cc
|
||||
random_resized_crop_with_bbox_ir.cc
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/ir/vision/random_lighting_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/random_lighting_op.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
#ifndef ENABLE_ANDROID
|
||||
// RandomLightingOperation.
|
||||
RandomLightingOperation::RandomLightingOperation(float alpha) : TensorOperation(true), alpha_(alpha) {}
|
||||
|
||||
RandomLightingOperation::~RandomLightingOperation() = default;
|
||||
|
||||
std::string RandomLightingOperation::Name() const { return kRandomLightingOperation; }
|
||||
|
||||
Status RandomLightingOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("RandomLighting", "alpha", alpha_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomLightingOperation::Build() {
|
||||
std::shared_ptr<RandomLightingOp> tensor_op = std::make_shared<RandomLightingOp>(alpha_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status RandomLightingOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["alpha"] = alpha_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_LIGHTING_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_LIGHTING_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
constexpr char kRandomLightingOperation[] = "RandomLighting";
|
||||
|
||||
class RandomLightingOperation : public TensorOperation {
|
||||
public:
|
||||
explicit RandomLightingOperation(float alpha);
|
||||
|
||||
~RandomLightingOperation();
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_LIGHTING_IR_H_
|
|
@ -93,6 +93,7 @@ constexpr char kRandomEqualizeOp[] = "RandomEqualizeOp";
|
|||
constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp";
|
||||
constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp";
|
||||
constexpr char kRandomInvertOp[] = "RandomInvertOp";
|
||||
constexpr char kRandomLightingOp[] = "RandomLightingOp";
|
||||
constexpr char kRandomResizeOp[] = "RandomResizeOp";
|
||||
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
|
||||
constexpr char kRandomRotationOp[] = "RandomRotationOp";
|
||||
|
|
|
@ -50,7 +50,7 @@ import mindspore._c_dataengine as cde
|
|||
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode
|
||||
from .validators import check_prob, check_crop, check_center_crop, check_resize_interpolation, \
|
||||
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
|
||||
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \
|
||||
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, check_alpha, \
|
||||
check_uniform_augment_cpp, check_convert_color, check_random_resize_crop, check_random_auto_contrast, \
|
||||
check_random_adjust_sharpness, \
|
||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||
|
@ -1196,6 +1196,28 @@ class RandomInvert(ImageTensorOperation):
|
|||
return cde.RandomInvertOperation(self.prob)
|
||||
|
||||
|
||||
class RandomLighting(ImageTensorOperation):
|
||||
"""
|
||||
Add AlexNet-style PCA-based noise to an image. The eigenvalue and eigenvectors for Alexnet's PCA noise is
|
||||
calculated from the imagenet dataset.
|
||||
|
||||
Args:
|
||||
alpha (float, optional): Intensity of the image (default=0.05).
|
||||
|
||||
Examples:
|
||||
>>> transforms_list = [c_vision.Decode(), c_vision.RandomLighting(0.1)]
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
@check_alpha
|
||||
def __init__(self, alpha=0.05):
|
||||
self.alpha = alpha
|
||||
|
||||
def parse(self):
|
||||
return cde.RandomLightingOperation(self.alpha)
|
||||
|
||||
|
||||
class RandomPosterize(ImageTensorOperation):
|
||||
"""
|
||||
Reduce the number of bits for each color channel to posterize the input image randomly with a given probability.
|
||||
|
|
|
@ -30,7 +30,7 @@ from . import py_transforms_util as util
|
|||
from .c_transforms import parse_padding
|
||||
from .validators import check_prob, check_center_crop, check_five_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||
check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||
check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, \
|
||||
check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, check_alpha, \
|
||||
check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \
|
||||
check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast, check_rgb_to_bgr, \
|
||||
check_adjust_gamma
|
||||
|
@ -1482,6 +1482,42 @@ class RandomColor(py_transforms.PyTensorOperation):
|
|||
return util.random_color(img, self.degrees)
|
||||
|
||||
|
||||
class RandomLighting:
|
||||
"""
|
||||
Add AlexNet-style PCA-based noise to an image.
|
||||
|
||||
Args:
|
||||
alpha (float, optional): Intensity of the image (default=0.05).
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.transforms.py_transforms import Compose
|
||||
>>>
|
||||
>>> transforms_list = Compose([py_vision.Decode(),
|
||||
... py_vision.RandomLighting(0.1),
|
||||
... py_vision.ToTensor()])
|
||||
>>> # apply the transform to dataset through map function
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns="image")
|
||||
"""
|
||||
|
||||
@check_alpha
|
||||
def __init__(self, alpha=0.05):
|
||||
self.alpha = alpha
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be added AlexNet-style PCA-based noise.
|
||||
|
||||
Returns:
|
||||
PIL Image, image with noise added.
|
||||
"""
|
||||
|
||||
return util.random_lighting(img, self.alpha)
|
||||
|
||||
|
||||
class RandomSharpness(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Adjust the sharpness of the input PIL Image by a random degree.
|
||||
|
|
|
@ -680,6 +680,42 @@ def random_color_adjust(img, brightness, contrast, saturation, hue):
|
|||
return img
|
||||
|
||||
|
||||
def random_lighting(img, alpha):
|
||||
"""
|
||||
Add AlexNet-style PCA-based noise to an image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be added AlexNet-style PCA-based noise.
|
||||
alpha (float, optional): Intensity of the image.
|
||||
|
||||
Returns:
|
||||
PIL Image, image with noise added.
|
||||
"""
|
||||
if not is_pil(img):
|
||||
raise TypeError(augment_error_message.format(type(img)))
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert("RGB")
|
||||
|
||||
alpha_r = np.random.normal(loc=0.0, scale=alpha)
|
||||
alpha_g = np.random.normal(loc=0.0, scale=alpha)
|
||||
alpha_b = np.random.normal(loc=0.0, scale=alpha)
|
||||
table = np.array([
|
||||
[55.46*-0.5675, 4.794*0.7192, 1.148 * 0.4009],
|
||||
[55.46*-0.5808, 4.794*-0.0045, 1.148 * -0.8140],
|
||||
[55.46*-0.5836, 4.794*-0.6948, 1.148 * 0.4203]
|
||||
])
|
||||
pca_r = table[0][0] * alpha_r + table[0][1] * alpha_g + table[0][2] * alpha_b
|
||||
pca_g = table[1][0] * alpha_r + table[1][1] * alpha_g + table[1][2] * alpha_b
|
||||
pca_b = table[2][0] * alpha_r + table[2][1] * alpha_g + table[2][2] * alpha_b
|
||||
img_arr = np.array(img).astype(np.float64)
|
||||
img_arr[:, :, 0] += pca_r
|
||||
img_arr[:, :, 1] += pca_g
|
||||
img_arr[:, :, 2] += pca_b
|
||||
img_arr = np.uint8(np.minimum(np.maximum(img_arr, 0), 255))
|
||||
img = Image.fromarray(img_arr)
|
||||
return img
|
||||
|
||||
|
||||
def random_rotation(img, degrees, resample, expand, center, fill_value):
|
||||
"""
|
||||
Rotate the input PIL image by a random angle.
|
||||
|
|
|
@ -365,6 +365,20 @@ def check_prob(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_alpha(method):
|
||||
"""A wrapper method to check alpha parameter in RandomLighting."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[alpha], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(alpha, (float, int,), "alpha")
|
||||
check_non_negative_float32(alpha, "alpha")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_normalize_c(method):
|
||||
"""A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""
|
||||
|
||||
|
|
|
@ -26,6 +26,74 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
|
|||
|
||||
// Tests for vision C++ API R to Z TensorTransform Operations (in alphabetical order)
|
||||
|
||||
/// Feature: RandomLighting
|
||||
/// Description: test RandomLighting Op on pipeline when alpha=0.1
|
||||
/// Expectation: the data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestRandomLightingPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomLightingPipeline.";
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<SequentialSampler>(0, 1));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
auto image = row["image"];
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorTransform> randomlighting(new mindspore::dataset::vision::RandomLighting(0.1));
|
||||
// Note: No need to check for output after calling API class constructor
|
||||
|
||||
// Convert to the same type
|
||||
std::shared_ptr<TensorTransform> type_cast(new transforms::TypeCast(mindspore::DataType::kNumberTypeUInt8));
|
||||
// Note: No need to check for output after calling API class constructor
|
||||
|
||||
ds = ds->Map({randomlighting, type_cast}, {"image"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter1 = ds->CreateIterator();
|
||||
EXPECT_NE(iter1, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row1
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row1;
|
||||
ASSERT_OK(iter1->GetNextRow(&row1));
|
||||
|
||||
auto image1 = row1["image"];
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter1->Stop();
|
||||
}
|
||||
|
||||
/// Feature: RandomLighting
|
||||
/// Description: test param check for RandomLighting Op
|
||||
/// Expectation: get nullptr when params are invalid
|
||||
TEST_F(MindDataTestPipeline, TestRandomLightingParamCheck) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomLightingParamCheck.";
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Case 1: negative alpha
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorTransform> random_lighting_op(new mindspore::dataset::vision::RandomLighting(-0.1));
|
||||
auto ds2 = ds->Map({random_lighting_op});
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
// Expect failure: invalid value of alpha
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRescaleSucess1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRescaleSucess1.";
|
||||
// Create an ImageFolder Dataset
|
||||
|
@ -341,7 +409,7 @@ TEST_F(MindDataTestPipeline, TestRGB2BGR) {
|
|||
uint64_t i = 0;
|
||||
while (row1.size() != 0) {
|
||||
i++;
|
||||
auto image =row1["image"];
|
||||
auto image = row1["image"];
|
||||
iter1->GetNextRow(&row1);
|
||||
iter2->GetNextRow(&row2);
|
||||
}
|
||||
|
|
|
@ -158,6 +158,10 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
|
|||
expect_image_path = dir_path + "imagefolder/apple_expect_random_sharpness.jpg";
|
||||
actual_image_path = dir_path + "imagefolder/apple_actual_random_sharpness.jpg";
|
||||
break;
|
||||
case kRandomLighting:
|
||||
expect_image_path = dir_path + "imagefolder/apple_expect_random_lighting.jpg";
|
||||
actual_image_path = dir_path + "imagefolder/apple_actual_random_lighting.jpg";
|
||||
break;
|
||||
case kRandomPosterize:
|
||||
expect_image_path = dir_path + "imagefolder/apple_expect_random_posterize.jpg";
|
||||
actual_image_path = dir_path + "imagefolder/apple_actual_random_posterize.jpg";
|
||||
|
|
|
@ -40,6 +40,7 @@ class CVOpCommon : public Common {
|
|||
kTemplate,
|
||||
kCrop,
|
||||
kRandomSharpness,
|
||||
kRandomLighting,
|
||||
kInvert,
|
||||
kRandomAffine,
|
||||
kRandomPosterize,
|
||||
|
|
|
@ -208,6 +208,23 @@ TEST_F(MindDataTestExecute, TestFrequencyMasking) {
|
|||
EXPECT_TRUE(status.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: RandomLighting
|
||||
/// Description: test RandomLighting Op when alpha=0.1
|
||||
/// Expectation: the data is processed successfully
|
||||
TEST_F(MindDataTestExecute, TestRandomLighting) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRandomLighting.";
|
||||
// Read images
|
||||
auto image = ReadFileToTensor("data/dataset/apple.jpg");
|
||||
|
||||
// Transform params
|
||||
auto decode = vision::Decode();
|
||||
auto random_lighting_op = vision::RandomLighting(0.1);
|
||||
|
||||
auto transform = Execute({decode, random_lighting_op});
|
||||
Status rc = transform(image, &image);
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestTimeMasking) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTimeMasking.";
|
||||
std::shared_ptr<Tensor> input_tensor_;
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,252 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing RandomLighting op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.py_transforms
|
||||
import mindspore.dataset.vision.py_transforms as F
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, diff_mse, save_and_check_md5, \
|
||||
config_get_set_seed, config_get_set_num_parallel_workers
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_random_lighting_py(alpha=1, plot=False):
|
||||
"""
|
||||
Feature: RandomLighting
|
||||
Description: test RandomLighting python op
|
||||
Expectation: equal results
|
||||
"""
|
||||
logger.info("Test RandomLighting python op")
|
||||
|
||||
# Original Images
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(),
|
||||
F.Resize((224, 224)),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_original = data.map(operations=transforms_original, input_columns="image")
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
images_original = np.append(images_original, np.transpose(image, (0, 2, 3, 1)), axis=0)
|
||||
|
||||
# Random Lighting Adjusted Images
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
alpha = alpha if alpha is not None else 0.05
|
||||
py_op = F.RandomLighting(alpha)
|
||||
|
||||
transforms_random_lighting = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(),
|
||||
F.Resize((224, 224)),
|
||||
py_op,
|
||||
F.ToTensor()])
|
||||
ds_random_lighting = data.map(operations=transforms_random_lighting, input_columns="image")
|
||||
|
||||
ds_random_lighting = ds_random_lighting.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_random_lighting.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_random_lighting = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
images_random_lighting = np.append(images_random_lighting, np.transpose(image, (0, 2, 3, 1)), axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_random_lighting[i], images_original[i])
|
||||
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize_list(images_original, images_random_lighting)
|
||||
|
||||
|
||||
def test_random_lighting_py_md5():
|
||||
"""
|
||||
Feature: RandomLighting
|
||||
Description: test RandomLighting python op with md5 comparison
|
||||
Expectation: same MD5
|
||||
"""
|
||||
logger.info("Test RandomLighting python op with md5 comparison")
|
||||
original_seed = config_get_set_seed(140)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# define map operations
|
||||
transforms = [
|
||||
F.Decode(),
|
||||
F.Resize((224, 224)),
|
||||
F.RandomLighting(1),
|
||||
F.ToTensor()
|
||||
]
|
||||
transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
|
||||
|
||||
# Generate dataset
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
data = data.map(operations=transform, input_columns=["image"])
|
||||
|
||||
# check results with md5 comparison
|
||||
filename = "random_lighting_py_01_result.npz"
|
||||
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
# Restore configuration
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_random_lighting_c(alpha=1, plot=False):
|
||||
"""
|
||||
Feature: RandomLighting
|
||||
Description: test RandomLighting cpp op
|
||||
Expectation: equal results from Mindspore and benchmark
|
||||
"""
|
||||
logger.info("Test RandomLighting cpp op")
|
||||
# Original Images
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = [C.Decode(), C.Resize((224, 224))]
|
||||
|
||||
ds_original = data.map(operations=transforms_original, input_columns="image")
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# Random Lighting Adjusted Images
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
alpha = alpha if alpha is not None else 0.05
|
||||
c_op = C.RandomLighting(alpha)
|
||||
|
||||
transforms_random_lighting = [C.Decode(), C.Resize((224, 224)), c_op]
|
||||
|
||||
ds_random_lighting = data.map(operations=transforms_random_lighting, input_columns="image")
|
||||
|
||||
ds_random_lighting = ds_random_lighting.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_random_lighting.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_random_lighting = image
|
||||
else:
|
||||
images_random_lighting = np.append(images_random_lighting, image, axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_random_lighting[i], images_original[i])
|
||||
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize_list(images_original, images_random_lighting)
|
||||
|
||||
|
||||
def test_random_lighting_c_py(alpha=1, plot=False):
|
||||
"""
|
||||
Feature: RandomLighting
|
||||
Description: test Random Lighting Cpp and Python Op
|
||||
Expectation: equal results from Cpp and Python
|
||||
"""
|
||||
logger.info("Test RandomLighting Cpp and python Op")
|
||||
|
||||
# RandomLighting Images
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
data = data.map(operations=[C.Decode(), C.Resize((200, 300))], input_columns=["image"])
|
||||
|
||||
python_op = F.RandomLighting(alpha)
|
||||
c_op = C.RandomLighting(alpha)
|
||||
|
||||
transforms_op = mindspore.dataset.transforms.py_transforms.Compose([lambda img: F.ToPIL()(img.astype(np.uint8)),
|
||||
python_op,
|
||||
np.array])
|
||||
|
||||
ds_random_lighting_py = data.map(operations=transforms_op, input_columns="image")
|
||||
|
||||
ds_random_lighting_py = ds_random_lighting_py.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_random_lighting_py.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_random_lighting_py = image
|
||||
|
||||
else:
|
||||
images_random_lighting_py = np.append(images_random_lighting_py, image, axis=0)
|
||||
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
data = data.map(operations=[C.Decode(), C.Resize((200, 300))], input_columns=["image"])
|
||||
|
||||
ds_images_random_lighting_c = data.map(operations=c_op, input_columns="image")
|
||||
|
||||
ds_random_lighting_c = ds_images_random_lighting_c.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_random_lighting_c.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_random_lighting_c = image
|
||||
else:
|
||||
images_random_lighting_c = np.append(images_random_lighting_c, image, axis=0)
|
||||
|
||||
num_samples = images_random_lighting_c.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_random_lighting_c[i], images_random_lighting_py[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
if plot:
|
||||
visualize_list(images_random_lighting_c, images_random_lighting_py, visualize_mode=2)
|
||||
|
||||
|
||||
def test_random_lighting_invalid_params():
|
||||
"""
|
||||
Feature: RandomLighting
|
||||
Description: test RandomLighting with invalid input parameters
|
||||
Expectation: throw ValueError or TypeError
|
||||
"""
|
||||
logger.info("Test RandomLighting with invalid input parameters.")
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
data = data.map(operations=[C.Decode(), C.Resize((224, 224)),
|
||||
C.RandomLighting(-2)], input_columns=["image"])
|
||||
assert "Input alpha is not within the required interval of [0, 16777216]." in str(error_info.value)
|
||||
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||
data = data.map(operations=[C.Decode(), C.Resize((224, 224)),
|
||||
C.RandomLighting('1')], input_columns=["image"])
|
||||
err_msg = "Argument alpha with value 1 is not of type [<class 'float'>, <class 'int'>], but got <class 'str'>."
|
||||
assert err_msg in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_lighting_py()
|
||||
test_random_lighting_py_md5()
|
||||
test_random_lighting_c()
|
||||
test_random_lighting_c_py()
|
||||
test_random_lighting_invalid_params()
|
Loading…
Reference in New Issue