RandomColor

This commit is contained in:
Alexey Shevlyakov 2020-08-13 14:30:27 -04:00
parent 2953720169
commit 8526d5414d
15 changed files with 532 additions and 33 deletions

View File

@ -32,6 +32,7 @@
#include "minddata/dataset/kernels/image/normalize_op.h" #include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_affine_op.h" #include "minddata/dataset/kernels/image/random_affine_op.h"
#include "minddata/dataset/kernels/image/random_color_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" #include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
@ -273,6 +274,14 @@ PYBIND_REGISTER(
py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth);
})); }));
PYBIND_REGISTER(RandomColorOp, 1, ([](const py::module *m) {
(void)py::class_<RandomColorOp, TensorOp, std::shared_ptr<RandomColorOp>>(
*m, "RandomColorOp",
"Tensor operation to blend an image with its grayscale version with random weights"
"Takes min and max for the range of random weights")
.def(py::init<float, float>(), py::arg("min"), py::arg("max"));
}));
PYBIND_REGISTER(RandomColorAdjustOp, 1, ([](const py::module *m) { PYBIND_REGISTER(RandomColorAdjustOp, 1, ([](const py::module *m) {
(void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>( (void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
*m, "RandomColorAdjustOp", *m, "RandomColorAdjustOp",

View File

@ -27,6 +27,7 @@
#include "minddata/dataset/kernels/data/one_hot_op.h" #include "minddata/dataset/kernels/data/one_hot_op.h"
#include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_affine_op.h" #include "minddata/dataset/kernels/image/random_affine_op.h"
#include "minddata/dataset/kernels/image/random_color_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
#include "minddata/dataset/kernels/image/random_crop_op.h" #include "minddata/dataset/kernels/image/random_crop_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
@ -140,6 +141,21 @@ std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint
return op; return op;
} }
// Function to create RandomColorOperation.
std::shared_ptr<RandomColorOperation> RandomColor(float t_lb, float t_ub) {
auto op = std::make_shared<RandomColorOperation>(t_lb, t_ub);
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
std::shared_ptr<TensorOp> RandomColorOperation::Build() {
std::shared_ptr<RandomColorOp> tensor_op = std::make_shared<RandomColorOp>(t_lb_, t_ub_);
return tensor_op;
}
// Function to create RandomColorAdjustOperation. // Function to create RandomColorAdjustOperation.
std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness, std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness,
std::vector<float> contrast, std::vector<float> contrast,
@ -475,6 +491,18 @@ std::shared_ptr<TensorOp> PadOperation::Build() {
return tensor_op; return tensor_op;
} }
// RandomColorOperation.
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}
bool RandomColorOperation::ValidateParams() {
// Do some input validation.
if (t_lb_ > t_ub_) {
MS_LOG(ERROR) << "RandomColor: lower bound must be less or equal to upper bound";
return false;
}
return true;
}
// RandomColorAdjustOperation. // RandomColorAdjustOperation.
RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast, RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
std::vector<float> saturation, std::vector<float> hue) std::vector<float> saturation, std::vector<float> hue)

View File

@ -70,7 +70,7 @@ class CVTensor : public Tensor {
/// Get a reference to the CV::Mat /// Get a reference to the CV::Mat
/// \return a reference to the internal CV::Mat /// \return a reference to the internal CV::Mat
cv::Mat mat() const { return mat_; } cv::Mat &mat() { return mat_; }
/// Get a copy of the CV::Mat /// Get a copy of the CV::Mat
/// \return a copy of internal CV::Mat /// \return a copy of internal CV::Mat

View File

@ -57,6 +57,7 @@ class NormalizeOperation;
class OneHotOperation; class OneHotOperation;
class PadOperation; class PadOperation;
class RandomAffineOperation; class RandomAffineOperation;
class RandomColorOperation;
class RandomColorAdjustOperation; class RandomColorAdjustOperation;
class RandomCropOperation; class RandomCropOperation;
class RandomHorizontalFlipOperation; class RandomHorizontalFlipOperation;
@ -162,6 +163,14 @@ std::shared_ptr<RandomAffineOperation> RandomAffine(
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
const std::vector<uint8_t> &fill_value = {0, 0, 0}); const std::vector<uint8_t> &fill_value = {0, 0, 0});
/// \brief Blends an image with its grayscale version with random weights
/// t and 1 - t generated from a given range. If the range is trivial
/// then the weights are determinate and t equals the bound of the interval
/// \param[in] t_lb lower bound on the range of random weights
/// \param[in] t_lb upper bound on the range of random weights
/// \return Shared pointer to the current TensorOp
std::shared_ptr<RandomColorOperation> RandomColor(float t_lb, float t_ub);
/// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image /// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image
/// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values /// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values
/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
@ -417,6 +426,21 @@ class RandomAffineOperation : public TensorOperation {
std::vector<uint8_t> fill_value_; std::vector<uint8_t> fill_value_;
}; };
class RandomColorOperation : public TensorOperation {
public:
RandomColorOperation(float t_lb, float t_ub);
~RandomColorOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
private:
float t_lb_;
float t_ub_;
};
class RandomColorAdjustOperation : public TensorOperation { class RandomColorAdjustOperation : public TensorOperation {
public: public:
RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0}, RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0},

View File

@ -44,5 +44,6 @@ add_library(kernels-image OBJECT
uniform_aug_op.cc uniform_aug_op.cc
resize_with_bbox_op.cc resize_with_bbox_op.cc
random_resize_with_bbox_op.cc random_resize_with_bbox_op.cc
random_color_op.cc
) )
add_dependencies(kernels-image kernels-soft-dvpp-image) add_dependencies(kernels-image kernels-soft-dvpp-image)

View File

@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_color_op.h"
#include "minddata/dataset/core/cv_tensor.h"
namespace mindspore {
namespace dataset {
RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {}
Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) {
IO_CHECK(in, out);
if (in->Rank() != 3) {
RETURN_STATUS_UNEXPECTED("image must have 3 channels");
}
// 0.5 pixel precision assuming an 8 bit image
const auto eps = 0.00195;
const auto t = dist_(rnd_);
if (abs(t - 1.0) < eps) {
// Just return input? Can we do it given that input would otherwise get consumed in CVTensor constructor anyway?
*out = in;
return Status::OK();
}
auto cvt_in = CVTensor::AsCVTensor(in);
auto m1 = cvt_in->mat();
cv::Mat gray;
// gray is allocated without using the allocator
cv::cvtColor(m1, gray, cv::COLOR_RGB2GRAY);
// luminosity is not preserved, consider using weights.
cv::Mat temp[3] = {gray, gray, gray};
cv::Mat cv_out;
cv::merge(temp, 3, cv_out);
std::shared_ptr<CVTensor> cvt_out;
CVTensor::CreateFromMat(cv_out, &cvt_out);
if (abs(t - 0.0) < eps) {
// return grayscale
*out = std::static_pointer_cast<Tensor>(cvt_out);
return Status::OK();
}
// return blended image. addWeighted takes care of overflow for uint8_t
cv::addWeighted(m1, t, cvt_out->mat(), 1 - t, 0, cvt_out->mat());
*out = std::static_pointer_cast<Tensor>(cvt_out);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H
#include <memory>
#include <random>
#include <vector>
#include <string>
#include <opencv2/imgproc/imgproc.hpp>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
/// \class RandomColorOp random_color_op.h
/// \brief Blends an image with its grayscale version with random weights
/// t and 1 - t generated from a given range.
/// If the range is trivial then the weights are determinate and
/// t equals the bound of the interval
class RandomColorOp : public TensorOp {
public:
RandomColorOp() = default;
/// \brief Constructor
/// \param[in] t_lb lower bound for the random weights
/// \param[in] t_ub upper bound for the random weights
RandomColorOp(float t_lb, float t_ub);
/// \brief the main function performing computations
/// \param[in] in 2- or 3- dimensional tensor representing an image
/// \param[out] out 2- or 3- dimensional tensor representing an image
/// with the same dimensions as in
Status Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) override;
/// \brief returns the name of the op
std::string Name() const override { return kRandomColorOp; }
private:
std::mt19937 rnd_;
std::uniform_real_distribution<float> dist_;
float t_lb_;
float t_ub_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H

View File

@ -129,6 +129,7 @@ constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
constexpr char kUniformAugOp[] = "UniformAugOp"; constexpr char kUniformAugOp[] = "UniformAugOp";
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp"; constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp"; constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp";
constexpr char kRandomColorOp[] = "RandomColorOp";
// text // text
constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp"; constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp";

View File

@ -46,7 +46,8 @@ import mindspore._c_dataengine as cde
from .utils import Inter, Border from .utils import Inter, Border
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \ check_range, check_resize, check_rescale, check_pad, check_cutout, \
check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER
@ -628,6 +629,21 @@ class CenterCrop(cde.CenterCropOp):
super().__init__(*size) super().__init__(*size)
class RandomColor(cde.RandomColorOp):
"""
Adjust the color of the input image by a fixed or random degree.
Args:
degrees (sequence): Range of random color adjustment degrees.
It should be in (min, max) format. If min=max, then it is a
single fixed magnitude operation (default=(0.1,1.9)).
Works with 3-channel color images.
"""
@check_positive_degrees
def __init__(self, degrees=(0.1, 1.9)):
super().__init__(*degrees)
class RandomColorAdjust(cde.RandomColorAdjustOp): class RandomColorAdjust(cde.RandomColorAdjustOp):
""" """
Randomly adjust the brightness, contrast, saturation, and hue of the input image. Randomly adjust the brightness, contrast, saturation, and hue of the input image.

View File

@ -609,21 +609,23 @@ def check_uniform_augment_py(method):
def check_positive_degrees(method): def check_positive_degrees(method):
"""A wrapper method to check degrees parameter in RandSharpness and RandColor""" """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (python and cpp)"""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[degrees], _ = parse_user_args(method, *args, **kwargs) [degrees], _ = parse_user_args(method, *args, **kwargs)
if isinstance(degrees, (list, tuple)):
if degrees is not None:
if not isinstance(degrees, (list, tuple)):
raise TypeError("degrees must be either a tuple or a list.")
type_check_list(degrees, (int, float), "degrees")
if len(degrees) != 2: if len(degrees) != 2:
raise ValueError("Degrees must be a sequence with length 2.") raise ValueError("degrees must be a sequence with length 2.")
for value in degrees: for degree in degrees:
check_value(value, (0., FLOAT_MAX_INTEGER)) check_value(degree, (0, FLOAT_MAX_INTEGER))
check_positive(degrees[0], "degrees[0]")
if degrees[0] > degrees[1]: if degrees[0] > degrees[1]:
raise ValueError("Degrees should be in (min,max) format. Got (max,min).") raise ValueError("degrees should be in (min,max) format. Got (max,min).")
else:
raise TypeError("Degrees should be a tuple or list.")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -698,4 +700,5 @@ def check_random_solarize(method):
raise ValueError("threshold must be in min max format numbers") raise ValueError("threshold must be in min max format numbers")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method

View File

@ -39,6 +39,7 @@ SET(DE_UT_SRCS
project_op_test.cc project_op_test.cc
queue_test.cc queue_test.cc
random_affine_op_test.cc random_affine_op_test.cc
random_color_op_test.cc
random_crop_op_test.cc random_crop_op_test.cc
random_crop_with_bbox_op_test.cc random_crop_with_bbox_op_test.cc
random_crop_decode_resize_op_test.cc random_crop_decode_resize_op_test.cc

View File

@ -63,10 +63,10 @@ TEST_F(MindDataTestPipeline, TestCutOut) {
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (row.size() != 0) {
i++; i++;
auto image = row["image"]; auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape(); MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row); iter->GetNextRow(&row);
} }
EXPECT_EQ(i, 20); EXPECT_EQ(i, 20);
@ -160,8 +160,9 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) {
auto image = row["image"]; auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape(); MS_LOG(INFO) << "Tensor image shape: " << image->shape();
// check if the image is in NCHW // check if the image is in NCHW
EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1] EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1] && 2268 == image->shape()[2] &&
&& 2268 == image->shape()[2] && 4032 == image->shape()[3], true); 4032 == image->shape()[3],
true);
iter->GetNextRow(&row); iter->GetNextRow(&row);
} }
EXPECT_EQ(i, 20); EXPECT_EQ(i, 20);
@ -186,7 +187,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
EXPECT_NE(one_hot_op, nullptr); EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds // Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"}); ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(-1); std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(-1);
@ -209,7 +210,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
EXPECT_NE(one_hot_op, nullptr); EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds // Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"}); ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5); std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5);
@ -258,7 +259,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) {
EXPECT_NE(one_hot_op, nullptr); EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds // Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"}); ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(); std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch();
@ -379,10 +380,10 @@ TEST_F(MindDataTestPipeline, TestPad) {
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (row.size() != 0) {
i++; i++;
auto image = row["image"]; auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape(); MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row); iter->GetNextRow(&row);
} }
EXPECT_EQ(i, 20); EXPECT_EQ(i, 20);
@ -504,6 +505,61 @@ TEST_F(MindDataTestPipeline, TestRandomAffineSuccess2) {
iter->Stop(); iter->Stop();
} }
TEST_F(MindDataTestPipeline, TestRandomColor) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomColor with non-default params.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> random_color_op_1 = vision::RandomColor(0.0, 0.0);
EXPECT_NE(random_color_op_1, nullptr);
std::shared_ptr<TensorOperation> random_color_op_2 = vision::RandomColor(1.0, 0.1);
EXPECT_EQ(random_color_op_2, nullptr);
std::shared_ptr<TensorOperation> random_color_op_3 = vision::RandomColor(0.0, 1.1);
EXPECT_NE(random_color_op_3, nullptr);
// Create a Map operation on ds
ds = ds->Map({random_color_op_1, random_color_op_3});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
// Create an ImageFolder Dataset // Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::string folder_path = datasets_root_path_ + "/testPK/data/";
@ -780,7 +836,8 @@ TEST_F(MindDataTestPipeline, TestRandomSolarize) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops // Create objects for the tensor ops
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(23, 23); //vision::RandomSolarize(); std::shared_ptr<TensorOperation> random_solarize =
mindspore::dataset::api::vision::RandomSolarize(23, 23); // vision::RandomSolarize();
EXPECT_NE(random_solarize, nullptr); EXPECT_NE(random_solarize, nullptr);
// Create a Map operation on ds // Create a Map operation on ds

View File

@ -0,0 +1,99 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/random_color_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestRandomColorOp : public UT::CVOP::CVOpCommon {
public:
MindDataTestRandomColorOp() : CVOpCommon(), shape({3, 3, 3}) {
std::shared_ptr<Tensor> in;
std::shared_ptr<Tensor> gray;
(void)Tensor::CreateEmpty(shape, DataType(DataType::DE_UINT8), &in);
(void)Tensor::CreateEmpty(shape, DataType(DataType::DE_UINT8), &input_tensor);
Status s = in->Fill<uint8_t>(42);
s = input_tensor->Fill<uint8_t>(42);
cvt_in = CVTensor::AsCVTensor(in);
cv::Mat m2;
auto m1 = cvt_in->mat();
cv::cvtColor(m1, m2, cv::COLOR_RGB2GRAY);
cv::Mat temp[3] = {m2 , m2 , m2 };
cv::Mat cv_out;
cv::merge(temp, 3, cv_out);
std::shared_ptr<CVTensor> cvt_out;
CVTensor::CreateFromMat(cv_out, &cvt_out);
gray_tensor = std::static_pointer_cast<Tensor>(cvt_out);
}
TensorShape shape;
std::shared_ptr<Tensor> input_tensor;
std::shared_ptr<CVTensor> cvt_in;
std::shared_ptr<Tensor> gray_tensor;
};
int64_t Compare(std::shared_ptr<Tensor> t1, std::shared_ptr<Tensor> t2) {
auto shape = t1->shape();
int64_t sum = 0;
for (auto i = 0; i < shape[0]; i++) {
for (auto j = 0; j < shape[1]; j++) {
for (auto k = 0; k < shape[2]; k++) {
uint8_t value1;
uint8_t value2;
(void)t1->GetItemAt<uint8_t>(&value1, {i, j, k});
(void)t2->GetItemAt<uint8_t>(&value2, {i, j, k});
sum += abs(static_cast<int>(value1) - static_cast<int>(value2));
}
}
}
return sum;
}
// these tests are tautological, write better tests when the requirements for the output are determined
// e. g. how do we want to convert to gray and what does it mean to blend with a gray image (pre- post- gamma corrected,
// what weights).
TEST_F(MindDataTestRandomColorOp, TestOp1) {
std::shared_ptr<Tensor> output_tensor;
auto op = RandomColorOp(1, 1);
auto s = op.Compute(input_tensor, &output_tensor);
auto res = Compare(input_tensor, output_tensor);
EXPECT_EQ(0, res);
}
TEST_F(MindDataTestRandomColorOp, TestOp2) {
std::shared_ptr<Tensor> output_tensor;
auto op = RandomColorOp(0, 0);
auto s = op.Compute(input_tensor, &output_tensor);
EXPECT_TRUE(s.IsOk());
auto res = Compare(output_tensor, gray_tensor);
EXPECT_EQ(res, 0);
}
TEST_F(MindDataTestRandomColorOp, TestOp3) {
std::shared_ptr<Tensor> output_tensor;
auto op = RandomColorOp(0.0, 1.0);
for (auto i = 0; i < 1; i++) {
auto s = op.Compute(input_tensor, &output_tensor);
EXPECT_TRUE(s.IsOk());
}
}

View File

@ -16,9 +16,11 @@
Testing RandomColor op in DE Testing RandomColor op in DE
""" """
import numpy as np import numpy as np
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.dataset.transforms.vision.py_transforms as F import mindspore.dataset.transforms.vision.py_transforms as F
from mindspore import log as logger from mindspore import log as logger
from util import visualize_list, diff_mse, save_and_check_md5, \ from util import visualize_list, diff_mse, save_and_check_md5, \
@ -26,11 +28,17 @@ from util import visualize_list, diff_mse, save_and_check_md5, \
DATA_DIR = "../data/dataset/testImageNetData/train/" DATA_DIR = "../data/dataset/testImageNetData/train/"
C_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
C_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
MNIST_DATA_DIR = "../data/dataset/testMnistData"
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
def test_random_color(degrees=(0.1, 1.9), plot=False):
def test_random_color_py(degrees=(0.1, 1.9), plot=False):
""" """
Test RandomColor Test Python RandomColor
""" """
logger.info("Test RandomColor") logger.info("Test RandomColor")
@ -85,9 +93,53 @@ def test_random_color(degrees=(0.1, 1.9), plot=False):
visualize_list(images_original, images_random_color) visualize_list(images_original, images_random_color)
def test_random_color_md5(): def test_random_color_c(degrees=(0.1, 1.9), plot=False, run_golden=True):
""" """
Test RandomColor with md5 check Test Cpp RandomColor
"""
logger.info("test_random_color_op")
original_seed = config_get_set_seed(10)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# Decode with rgb format set to True
data1 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False)
# Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
if degrees is None:
c_op = vision.RandomColor()
else:
c_op = vision.RandomColor(degrees)
data1 = data1.map(input_columns=["image"], operations=[vision.Decode()])
data2 = data2.map(input_columns=["image"], operations=[vision.Decode(), c_op])
image_random_color_op = []
image = []
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
actual = item1["image"]
expected = item2["image"]
image.append(actual)
image_random_color_op.append(expected)
if run_golden:
# Compare with expected md5 from images
filename = "random_color_op_02_result.npz"
save_and_check_md5(data2, filename, generate_golden=GENERATE_GOLDEN)
if plot:
visualize_list(image, image_random_color_op)
# Restore configuration
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers((original_num_parallel_workers))
def test_random_color_py_md5():
"""
Test Python RandomColor with md5 check
""" """
logger.info("Test RandomColor with md5 check") logger.info("Test RandomColor with md5 check")
original_seed = config_get_set_seed(10) original_seed = config_get_set_seed(10)
@ -110,8 +162,94 @@ def test_random_color_md5():
ds.config.set_num_parallel_workers((original_num_parallel_workers)) ds.config.set_num_parallel_workers((original_num_parallel_workers))
def test_compare_random_color_op(degrees=None, plot=False):
"""
Compare Random Color op in Python and Cpp
"""
logger.info("test_random_color_op")
original_seed = config_get_set_seed(5)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# Decode with rgb format set to True
data1 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False)
if degrees is None:
c_op = vision.RandomColor()
p_op = F.RandomColor()
else:
c_op = vision.RandomColor(degrees)
p_op = F.RandomColor(degrees)
transforms_random_color_py = F.ComposeOp([lambda img: img.astype(np.uint8), F.ToPIL(),
p_op, np.array])
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(), c_op])
data2 = data2.map(input_columns=["image"], operations=[vision.Decode()])
data2 = data2.map(input_columns=["image"], operations=transforms_random_color_py())
image_random_color_op = []
image = []
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
actual = item1["image"]
expected = item2["image"]
image_random_color_op.append(actual)
image.append(expected)
assert actual.shape == expected.shape
mse = diff_mse(actual, expected)
logger.info("MSE= {}".format(str(np.mean(mse))))
# Restore configuration
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_num_parallel_workers)
if plot:
visualize_list(image, image_random_color_op)
def test_random_color_c_errors():
"""
Test that Cpp RandomColor errors with bad input
"""
with pytest.raises(TypeError) as error_info:
vision.RandomColor((12))
assert "degrees must be either a tuple or a list." in str(error_info.value)
with pytest.raises(TypeError) as error_info:
vision.RandomColor(("col", 3))
assert "Argument degrees[0] with value col is not of type (<class 'int'>, <class 'float'>)." in str(
error_info.value)
with pytest.raises(ValueError) as error_info:
vision.RandomColor((0.9, 0.1))
assert "degrees should be in (min,max) format. Got (max,min)." in str(error_info.value)
with pytest.raises(ValueError) as error_info:
vision.RandomColor((0.9,))
assert "degrees must be a sequence with length 2." in str(error_info.value)
# RandomColor Cpp Op will fail with one channel input
mnist_ds = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
mnist_ds = mnist_ds.map(input_columns="image", operations=vision.RandomColor())
with pytest.raises(RuntimeError) as error_info:
for _ in enumerate(mnist_ds):
pass
assert "Invalid number of channels in input image" in str(error_info.value)
if __name__ == "__main__": if __name__ == "__main__":
test_random_color() test_random_color_py()
test_random_color(plot=True) test_random_color_py(plot=True)
test_random_color(degrees=(0.5, 1.5), plot=True) test_random_color_py(degrees=(0.5, 1.5), plot=True)
test_random_color_md5() test_random_color_py_md5()
test_random_color_c()
test_random_color_c(plot=True)
test_random_color_c(degrees=(0.5, 1.5), plot=True, run_golden=False)
test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False)
test_compare_random_color_op(plot=True)
test_random_color_c_errors()