[feat][assistant][I3CEGR] add op random lighting

This commit is contained in:
chenx2ovo 2021-10-19 20:38:01 +08:00
parent 889f3ddc1f
commit 77caf907c4
22 changed files with 757 additions and 5 deletions

View File

@ -48,6 +48,7 @@
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_lighting_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
@ -422,6 +423,16 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(RandomLightingOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RandomLightingOperation, TensorOperation,
std::shared_ptr<vision::RandomLightingOperation>>(*m, "RandomLightingOperation")
.def(py::init([](float alpha) {
auto random_lighting = std::make_shared<vision::RandomLightingOperation>(alpha);
THROW_IF_ERROR(random_lighting->ValidateParams());
return random_lighting;
}));
}));
PYBIND_REGISTER(RandomPosterizeOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RandomPosterizeOperation, TensorOperation,
std::shared_ptr<vision::RandomPosterizeOperation>>(*m, "RandomPosterizeOperation")

View File

@ -52,11 +52,12 @@
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_lighting_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_resize_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_resize_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_rotation_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_select_subpolicy_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_sharpness_ir.h"
@ -665,6 +666,18 @@ std::shared_ptr<TensorOperation> RandomInvert::Parse() {
return std::make_shared<RandomInvertOperation>(data_->probability_);
}
// RandomLighting Transform Operation.
struct RandomLighting::Data {
explicit Data(float alpha) : alpha_(alpha) {}
float alpha_;
};
RandomLighting::RandomLighting(float alpha) : data_(std::make_shared<Data>(alpha)) {}
std::shared_ptr<TensorOperation> RandomLighting::Parse() {
return std::make_shared<RandomLightingOperation>(data_->alpha_);
}
// RandomPosterize Transform Operation.
struct RandomPosterize::Data {
explicit Data(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}

View File

@ -625,6 +625,26 @@ class RandomInvert final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Add AlexNet-style PCA-based noise to an image.
class RandomLighting final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] alpha A float representing the intensity of the image (default=0.05).
explicit RandomLighting(float alpha = 0.05);
/// \brief Destructor.
~RandomLighting() = default;
protected:
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Reduce the number of bits for each color channel randomly.
class RandomPosterize final : public TensorTransform {
public:

View File

@ -41,6 +41,7 @@ add_library(kernels-image OBJECT
random_horizontal_flip_op.cc
random_horizontal_flip_with_bbox_op.cc
random_invert_op.cc
random_lighting_op.cc
bounding_box_augment_op.cc
random_posterize_op.cc
random_resize_op.cc

View File

@ -1327,6 +1327,54 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
}
}
Status RandomLighting(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float rnd_r, float rnd_g,
float rnd_b) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
cv::Mat input_img = input_cv->mat();
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("[Internal ERROR] RandomLighting: load image failed.");
}
if (input_cv->Rank() != DEFAULT_IMAGE_RANK || input_cv->shape()[CHANNEL_INDEX] != DEFAULT_IMAGE_CHANNELS) {
RETURN_STATUS_UNEXPECTED(
"RandomLighting: input tensor is not in shape of <H,W,C> or channel is not 3, got rank: " +
std::to_string(input_cv->Rank()) + ", and channel: " + std::to_string(input_cv->shape()[CHANNEL_INDEX]));
}
auto input_type = input->type();
CHECK_FAIL_RETURN_UNEXPECTED(input_type != DataType::DE_UINT32 && input_type != DataType::DE_UINT64 &&
input_type != DataType::DE_INT64 && input_type != DataType::DE_STRING,
"RandomLighting: invalid tensor type of uint32, int64, uint64 or string.");
std::vector<std::vector<float>> eig = {{55.46 * -0.5675, 4.794 * 0.7192, 1.148 * 0.4009},
{55.46 * -0.5808, 4.794 * -0.0045, 1.148 * -0.8140},
{55.46 * -0.5836, 4.794 * -0.6948, 1.148 * 0.4203}};
float pca_r = eig[0][0] * rnd_r + eig[0][1] * rnd_g + eig[0][2] * rnd_b;
float pca_g = eig[1][0] * rnd_r + eig[1][1] * rnd_g + eig[1][2] * rnd_b;
float pca_b = eig[2][0] * rnd_r + eig[2][1] * rnd_g + eig[2][2] * rnd_b;
for (int row = 0; row < input_img.rows; row++) {
for (int col = 0; col < input_img.cols; col++) {
float r = static_cast<float>(input_img.at<cv::Vec3b>(row, col)[0]);
float g = static_cast<float>(input_img.at<cv::Vec3b>(row, col)[1]);
float b = static_cast<float>(input_img.at<cv::Vec3b>(row, col)[2]);
input_img.at<cv::Vec3b>(row, col)[0] = cv::saturate_cast<uchar>(r + pca_r);
input_img.at<cv::Vec3b>(row, col)[1] = cv::saturate_cast<uchar>(g + pca_g);
input_img.at<cv::Vec3b>(row, col)[2] = cv::saturate_cast<uchar>(b + pca_b);
}
}
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_img, input_cv->Rank(), &output_cv));
*output = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("RandomLighting: " + std::string(e.what()));
}
}
Status RgbaToRgb(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));

View File

@ -304,6 +304,16 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types,
uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
/// \brief Add AlexNet-style PCA-based noise to an image.
/// \param[in] input The input image.
/// \param[out] output The output image.
/// \param[in] rnd_r Random weight for red channel.
/// \param[in] rnd_g Random weight for green channel.
/// \param[in] rnd_b Random weight for blue channel.
/// \return Status code.
Status RandomLighting(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float rnd_r, float rnd_g,
float rnd_b);
/// \brief Take in a 4 channel image in RBGA to RGB
/// \param[in] input The input image
/// \param[out] output The output image

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_lighting_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
const float RandomLightingOp::kAlpha = 0.05;
Status RandomLightingOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
float rnd_r = dist_(rnd_rgb_);
float rnd_g = dist_(rnd_rgb_);
float rnd_b = dist_(rnd_rgb_);
return RandomLighting(input, output, rnd_r, rnd_g, rnd_b);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_LIGHTING_OP_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_LIGHTING_OP_H
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RandomLightingOp : public TensorOp {
public:
// Default values
static const float kAlpha;
explicit RandomLightingOp(float alpha = kAlpha) : dist_(0, alpha) {
rnd_rgb_.seed(GetSeed());
is_deterministic_ = false;
}
~RandomLightingOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) override;
std::string Name() const override { return kRandomLightingOp; }
private:
std::mt19937 rnd_rgb_;
std::normal_distribution<float> dist_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_LIGHTING_OP_H

View File

@ -33,6 +33,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
random_horizontal_flip_ir.cc
random_horizontal_flip_with_bbox_ir.cc
random_invert_ir.cc
random_lighting_ir.cc
random_posterize_ir.cc
random_resized_crop_ir.cc
random_resized_crop_with_bbox_ir.cc

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/vision/random_lighting_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/random_lighting_op.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// RandomLightingOperation.
RandomLightingOperation::RandomLightingOperation(float alpha) : TensorOperation(true), alpha_(alpha) {}
RandomLightingOperation::~RandomLightingOperation() = default;
std::string RandomLightingOperation::Name() const { return kRandomLightingOperation; }
Status RandomLightingOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("RandomLighting", "alpha", alpha_));
return Status::OK();
}
std::shared_ptr<TensorOp> RandomLightingOperation::Build() {
std::shared_ptr<RandomLightingOp> tensor_op = std::make_shared<RandomLightingOp>(alpha_);
return tensor_op;
}
Status RandomLightingOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["alpha"] = alpha_;
*out_json = args;
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_LIGHTING_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_LIGHTING_IR_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kRandomLightingOperation[] = "RandomLighting";
class RandomLightingOperation : public TensorOperation {
public:
explicit RandomLightingOperation(float alpha);
~RandomLightingOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
private:
float alpha_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_LIGHTING_IR_H_

View File

@ -93,6 +93,7 @@ constexpr char kRandomEqualizeOp[] = "RandomEqualizeOp";
constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp";
constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp";
constexpr char kRandomInvertOp[] = "RandomInvertOp";
constexpr char kRandomLightingOp[] = "RandomLightingOp";
constexpr char kRandomResizeOp[] = "RandomResizeOp";
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
constexpr char kRandomRotationOp[] = "RandomRotationOp";

View File

@ -50,7 +50,7 @@ import mindspore._c_dataengine as cde
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode
from .validators import check_prob, check_crop, check_center_crop, check_resize_interpolation, \
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, check_alpha, \
check_uniform_augment_cpp, check_convert_color, check_random_resize_crop, check_random_auto_contrast, \
check_random_adjust_sharpness, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
@ -1196,6 +1196,28 @@ class RandomInvert(ImageTensorOperation):
return cde.RandomInvertOperation(self.prob)
class RandomLighting(ImageTensorOperation):
"""
Add AlexNet-style PCA-based noise to an image. The eigenvalue and eigenvectors for Alexnet's PCA noise is
calculated from the imagenet dataset.
Args:
alpha (float, optional): Intensity of the image (default=0.05).
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.RandomLighting(0.1)]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_alpha
def __init__(self, alpha=0.05):
self.alpha = alpha
def parse(self):
return cde.RandomLightingOperation(self.alpha)
class RandomPosterize(ImageTensorOperation):
"""
Reduce the number of bits for each color channel to posterize the input image randomly with a given probability.

View File

@ -30,7 +30,7 @@ from . import py_transforms_util as util
from .c_transforms import parse_padding
from .validators import check_prob, check_center_crop, check_five_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, \
check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, check_alpha, \
check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \
check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast, check_rgb_to_bgr, \
check_adjust_gamma
@ -1482,6 +1482,42 @@ class RandomColor(py_transforms.PyTensorOperation):
return util.random_color(img, self.degrees)
class RandomLighting:
"""
Add AlexNet-style PCA-based noise to an image.
Args:
alpha (float, optional): Intensity of the image (default=0.05).
Examples:
>>> from mindspore.dataset.transforms.py_transforms import Compose
>>>
>>> transforms_list = Compose([py_vision.Decode(),
... py_vision.RandomLighting(0.1),
... py_vision.ToTensor()])
>>> # apply the transform to dataset through map function
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns="image")
"""
@check_alpha
def __init__(self, alpha=0.05):
self.alpha = alpha
def __call__(self, img):
"""
Call method.
Args:
img (PIL Image): Image to be added AlexNet-style PCA-based noise.
Returns:
PIL Image, image with noise added.
"""
return util.random_lighting(img, self.alpha)
class RandomSharpness(py_transforms.PyTensorOperation):
"""
Adjust the sharpness of the input PIL Image by a random degree.

View File

@ -680,6 +680,42 @@ def random_color_adjust(img, brightness, contrast, saturation, hue):
return img
def random_lighting(img, alpha):
"""
Add AlexNet-style PCA-based noise to an image.
Args:
img (PIL Image): Image to be added AlexNet-style PCA-based noise.
alpha (float, optional): Intensity of the image.
Returns:
PIL Image, image with noise added.
"""
if not is_pil(img):
raise TypeError(augment_error_message.format(type(img)))
if img.mode != 'RGB':
img = img.convert("RGB")
alpha_r = np.random.normal(loc=0.0, scale=alpha)
alpha_g = np.random.normal(loc=0.0, scale=alpha)
alpha_b = np.random.normal(loc=0.0, scale=alpha)
table = np.array([
[55.46*-0.5675, 4.794*0.7192, 1.148 * 0.4009],
[55.46*-0.5808, 4.794*-0.0045, 1.148 * -0.8140],
[55.46*-0.5836, 4.794*-0.6948, 1.148 * 0.4203]
])
pca_r = table[0][0] * alpha_r + table[0][1] * alpha_g + table[0][2] * alpha_b
pca_g = table[1][0] * alpha_r + table[1][1] * alpha_g + table[1][2] * alpha_b
pca_b = table[2][0] * alpha_r + table[2][1] * alpha_g + table[2][2] * alpha_b
img_arr = np.array(img).astype(np.float64)
img_arr[:, :, 0] += pca_r
img_arr[:, :, 1] += pca_g
img_arr[:, :, 2] += pca_b
img_arr = np.uint8(np.minimum(np.maximum(img_arr, 0), 255))
img = Image.fromarray(img_arr)
return img
def random_rotation(img, degrees, resample, expand, center, fill_value):
"""
Rotate the input PIL image by a random angle.

View File

@ -365,6 +365,20 @@ def check_prob(method):
return new_method
def check_alpha(method):
"""A wrapper method to check alpha parameter in RandomLighting."""
@wraps(method)
def new_method(self, *args, **kwargs):
[alpha], _ = parse_user_args(method, *args, **kwargs)
type_check(alpha, (float, int,), "alpha")
check_non_negative_float32(alpha, "alpha")
return method(self, *args, **kwargs)
return new_method
def check_normalize_c(method):
"""A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""

View File

@ -26,6 +26,74 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
// Tests for vision C++ API R to Z TensorTransform Operations (in alphabetical order)
/// Feature: RandomLighting
/// Description: test RandomLighting Op on pipeline when alpha=0.1
/// Expectation: the data is processed successfully
TEST_F(MindDataTestPipeline, TestRandomLightingPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomLightingPipeline.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<SequentialSampler>(0, 1));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
auto image = row["image"];
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> randomlighting(new mindspore::dataset::vision::RandomLighting(0.1));
// Note: No need to check for output after calling API class constructor
// Convert to the same type
std::shared_ptr<TensorTransform> type_cast(new transforms::TypeCast(mindspore::DataType::kNumberTypeUInt8));
// Note: No need to check for output after calling API class constructor
ds = ds->Map({randomlighting, type_cast}, {"image"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter1 = ds->CreateIterator();
EXPECT_NE(iter1, nullptr);
// Iterate the dataset and get each row1
std::unordered_map<std::string, mindspore::MSTensor> row1;
ASSERT_OK(iter1->GetNextRow(&row1));
auto image1 = row1["image"];
// Manually terminate the pipeline
iter1->Stop();
}
/// Feature: RandomLighting
/// Description: test param check for RandomLighting Op
/// Expectation: get nullptr when params are invalid
TEST_F(MindDataTestPipeline, TestRandomLightingParamCheck) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomLightingParamCheck.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Case 1: negative alpha
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> random_lighting_op(new mindspore::dataset::vision::RandomLighting(-0.1));
auto ds2 = ds->Map({random_lighting_op});
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
// Expect failure: invalid value of alpha
EXPECT_EQ(iter2, nullptr);
}
TEST_F(MindDataTestPipeline, TestRescaleSucess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRescaleSucess1.";
// Create an ImageFolder Dataset
@ -341,7 +409,7 @@ TEST_F(MindDataTestPipeline, TestRGB2BGR) {
uint64_t i = 0;
while (row1.size() != 0) {
i++;
auto image =row1["image"];
auto image = row1["image"];
iter1->GetNextRow(&row1);
iter2->GetNextRow(&row2);
}

View File

@ -158,6 +158,10 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
expect_image_path = dir_path + "imagefolder/apple_expect_random_sharpness.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_random_sharpness.jpg";
break;
case kRandomLighting:
expect_image_path = dir_path + "imagefolder/apple_expect_random_lighting.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_random_lighting.jpg";
break;
case kRandomPosterize:
expect_image_path = dir_path + "imagefolder/apple_expect_random_posterize.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_random_posterize.jpg";

View File

@ -40,6 +40,7 @@ class CVOpCommon : public Common {
kTemplate,
kCrop,
kRandomSharpness,
kRandomLighting,
kInvert,
kRandomAffine,
kRandomPosterize,

View File

@ -208,6 +208,23 @@ TEST_F(MindDataTestExecute, TestFrequencyMasking) {
EXPECT_TRUE(status.IsOk());
}
/// Feature: RandomLighting
/// Description: test RandomLighting Op when alpha=0.1
/// Expectation: the data is processed successfully
TEST_F(MindDataTestExecute, TestRandomLighting) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRandomLighting.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto random_lighting_op = vision::RandomLighting(0.1);
auto transform = Execute({decode, random_lighting_op});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}
TEST_F(MindDataTestExecute, TestTimeMasking) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTimeMasking.";
std::shared_ptr<Tensor> input_tensor_;

View File

@ -0,0 +1,252 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RandomLighting op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.py_transforms
import mindspore.dataset.vision.py_transforms as F
import mindspore.dataset.vision.c_transforms as C
from mindspore import log as logger
from util import visualize_list, diff_mse, save_and_check_md5, \
config_get_set_seed, config_get_set_num_parallel_workers
DATA_DIR = "../data/dataset/testImageNetData/train/"
MNIST_DATA_DIR = "../data/dataset/testMnistData"
GENERATE_GOLDEN = False
def test_random_lighting_py(alpha=1, plot=False):
"""
Feature: RandomLighting
Description: test RandomLighting python op
Expectation: equal results
"""
logger.info("Test RandomLighting python op")
# Original Images
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(),
F.Resize((224, 224)),
F.ToTensor()])
ds_original = data.map(operations=transforms_original, input_columns="image")
ds_original = ds_original.batch(512)
for idx, (image, _) in enumerate(ds_original.create_tuple_iterator(num_epochs=1, output_numpy=True)):
if idx == 0:
images_original = np.transpose(image, (0, 2, 3, 1))
else:
images_original = np.append(images_original, np.transpose(image, (0, 2, 3, 1)), axis=0)
# Random Lighting Adjusted Images
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
alpha = alpha if alpha is not None else 0.05
py_op = F.RandomLighting(alpha)
transforms_random_lighting = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(),
F.Resize((224, 224)),
py_op,
F.ToTensor()])
ds_random_lighting = data.map(operations=transforms_random_lighting, input_columns="image")
ds_random_lighting = ds_random_lighting.batch(512)
for idx, (image, _) in enumerate(ds_random_lighting.create_tuple_iterator(num_epochs=1, output_numpy=True)):
if idx == 0:
images_random_lighting = np.transpose(image, (0, 2, 3, 1))
else:
images_random_lighting = np.append(images_random_lighting, np.transpose(image, (0, 2, 3, 1)), axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_random_lighting[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize_list(images_original, images_random_lighting)
def test_random_lighting_py_md5():
"""
Feature: RandomLighting
Description: test RandomLighting python op with md5 comparison
Expectation: same MD5
"""
logger.info("Test RandomLighting python op with md5 comparison")
original_seed = config_get_set_seed(140)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# define map operations
transforms = [
F.Decode(),
F.Resize((224, 224)),
F.RandomLighting(1),
F.ToTensor()
]
transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
# Generate dataset
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(operations=transform, input_columns=["image"])
# check results with md5 comparison
filename = "random_lighting_py_01_result.npz"
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
# Restore configuration
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_random_lighting_c(alpha=1, plot=False):
"""
Feature: RandomLighting
Description: test RandomLighting cpp op
Expectation: equal results from Mindspore and benchmark
"""
logger.info("Test RandomLighting cpp op")
# Original Images
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = [C.Decode(), C.Resize((224, 224))]
ds_original = data.map(operations=transforms_original, input_columns="image")
ds_original = ds_original.batch(512)
for idx, (image, _) in enumerate(ds_original.create_tuple_iterator(num_epochs=1, output_numpy=True)):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# Random Lighting Adjusted Images
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
alpha = alpha if alpha is not None else 0.05
c_op = C.RandomLighting(alpha)
transforms_random_lighting = [C.Decode(), C.Resize((224, 224)), c_op]
ds_random_lighting = data.map(operations=transforms_random_lighting, input_columns="image")
ds_random_lighting = ds_random_lighting.batch(512)
for idx, (image, _) in enumerate(ds_random_lighting.create_tuple_iterator(num_epochs=1, output_numpy=True)):
if idx == 0:
images_random_lighting = image
else:
images_random_lighting = np.append(images_random_lighting, image, axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_random_lighting[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize_list(images_original, images_random_lighting)
def test_random_lighting_c_py(alpha=1, plot=False):
"""
Feature: RandomLighting
Description: test Random Lighting Cpp and Python Op
Expectation: equal results from Cpp and Python
"""
logger.info("Test RandomLighting Cpp and python Op")
# RandomLighting Images
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(operations=[C.Decode(), C.Resize((200, 300))], input_columns=["image"])
python_op = F.RandomLighting(alpha)
c_op = C.RandomLighting(alpha)
transforms_op = mindspore.dataset.transforms.py_transforms.Compose([lambda img: F.ToPIL()(img.astype(np.uint8)),
python_op,
np.array])
ds_random_lighting_py = data.map(operations=transforms_op, input_columns="image")
ds_random_lighting_py = ds_random_lighting_py.batch(512)
for idx, (image, _) in enumerate(ds_random_lighting_py.create_tuple_iterator(num_epochs=1, output_numpy=True)):
if idx == 0:
images_random_lighting_py = image
else:
images_random_lighting_py = np.append(images_random_lighting_py, image, axis=0)
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(operations=[C.Decode(), C.Resize((200, 300))], input_columns=["image"])
ds_images_random_lighting_c = data.map(operations=c_op, input_columns="image")
ds_random_lighting_c = ds_images_random_lighting_c.batch(512)
for idx, (image, _) in enumerate(ds_random_lighting_c.create_tuple_iterator(num_epochs=1, output_numpy=True)):
if idx == 0:
images_random_lighting_c = image
else:
images_random_lighting_c = np.append(images_random_lighting_c, image, axis=0)
num_samples = images_random_lighting_c.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_random_lighting_c[i], images_random_lighting_py[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize_list(images_random_lighting_c, images_random_lighting_py, visualize_mode=2)
def test_random_lighting_invalid_params():
"""
Feature: RandomLighting
Description: test RandomLighting with invalid input parameters
Expectation: throw ValueError or TypeError
"""
logger.info("Test RandomLighting with invalid input parameters.")
with pytest.raises(ValueError) as error_info:
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(operations=[C.Decode(), C.Resize((224, 224)),
C.RandomLighting(-2)], input_columns=["image"])
assert "Input alpha is not within the required interval of [0, 16777216]." in str(error_info.value)
with pytest.raises(TypeError) as error_info:
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data = data.map(operations=[C.Decode(), C.Resize((224, 224)),
C.RandomLighting('1')], input_columns=["image"])
err_msg = "Argument alpha with value 1 is not of type [<class 'float'>, <class 'int'>], but got <class 'str'>."
assert err_msg in str(error_info.value)
if __name__ == "__main__":
test_random_lighting_py()
test_random_lighting_py_md5()
test_random_lighting_c()
test_random_lighting_c_py()
test_random_lighting_invalid_params()