RandomColor
This commit is contained in:
parent
2953720169
commit
8526d5414d
|
@ -32,6 +32,7 @@
|
||||||
#include "minddata/dataset/kernels/image/normalize_op.h"
|
#include "minddata/dataset/kernels/image/normalize_op.h"
|
||||||
#include "minddata/dataset/kernels/image/pad_op.h"
|
#include "minddata/dataset/kernels/image/pad_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_affine_op.h"
|
#include "minddata/dataset/kernels/image/random_affine_op.h"
|
||||||
|
#include "minddata/dataset/kernels/image/random_color_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
|
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
|
#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
|
||||||
|
@ -273,6 +274,14 @@ PYBIND_REGISTER(
|
||||||
py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth);
|
py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(RandomColorOp, 1, ([](const py::module *m) {
|
||||||
|
(void)py::class_<RandomColorOp, TensorOp, std::shared_ptr<RandomColorOp>>(
|
||||||
|
*m, "RandomColorOp",
|
||||||
|
"Tensor operation to blend an image with its grayscale version with random weights"
|
||||||
|
"Takes min and max for the range of random weights")
|
||||||
|
.def(py::init<float, float>(), py::arg("min"), py::arg("max"));
|
||||||
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(RandomColorAdjustOp, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(RandomColorAdjustOp, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
|
(void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
|
||||||
*m, "RandomColorAdjustOp",
|
*m, "RandomColorAdjustOp",
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "minddata/dataset/kernels/data/one_hot_op.h"
|
#include "minddata/dataset/kernels/data/one_hot_op.h"
|
||||||
#include "minddata/dataset/kernels/image/pad_op.h"
|
#include "minddata/dataset/kernels/image/pad_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_affine_op.h"
|
#include "minddata/dataset/kernels/image/random_affine_op.h"
|
||||||
|
#include "minddata/dataset/kernels/image/random_color_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||||
|
@ -140,6 +141,21 @@ std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to create RandomColorOperation.
|
||||||
|
std::shared_ptr<RandomColorOperation> RandomColor(float t_lb, float t_ub) {
|
||||||
|
auto op = std::make_shared<RandomColorOperation>(t_lb, t_ub);
|
||||||
|
// Input validation
|
||||||
|
if (!op->ValidateParams()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> RandomColorOperation::Build() {
|
||||||
|
std::shared_ptr<RandomColorOp> tensor_op = std::make_shared<RandomColorOp>(t_lb_, t_ub_);
|
||||||
|
return tensor_op;
|
||||||
|
}
|
||||||
|
|
||||||
// Function to create RandomColorAdjustOperation.
|
// Function to create RandomColorAdjustOperation.
|
||||||
std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness,
|
std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness,
|
||||||
std::vector<float> contrast,
|
std::vector<float> contrast,
|
||||||
|
@ -475,6 +491,18 @@ std::shared_ptr<TensorOp> PadOperation::Build() {
|
||||||
return tensor_op;
|
return tensor_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RandomColorOperation.
|
||||||
|
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}
|
||||||
|
|
||||||
|
bool RandomColorOperation::ValidateParams() {
|
||||||
|
// Do some input validation.
|
||||||
|
if (t_lb_ > t_ub_) {
|
||||||
|
MS_LOG(ERROR) << "RandomColor: lower bound must be less or equal to upper bound";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// RandomColorAdjustOperation.
|
// RandomColorAdjustOperation.
|
||||||
RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
|
RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
|
||||||
std::vector<float> saturation, std::vector<float> hue)
|
std::vector<float> saturation, std::vector<float> hue)
|
||||||
|
|
|
@ -70,7 +70,7 @@ class CVTensor : public Tensor {
|
||||||
|
|
||||||
/// Get a reference to the CV::Mat
|
/// Get a reference to the CV::Mat
|
||||||
/// \return a reference to the internal CV::Mat
|
/// \return a reference to the internal CV::Mat
|
||||||
cv::Mat mat() const { return mat_; }
|
cv::Mat &mat() { return mat_; }
|
||||||
|
|
||||||
/// Get a copy of the CV::Mat
|
/// Get a copy of the CV::Mat
|
||||||
/// \return a copy of internal CV::Mat
|
/// \return a copy of internal CV::Mat
|
||||||
|
|
|
@ -57,6 +57,7 @@ class NormalizeOperation;
|
||||||
class OneHotOperation;
|
class OneHotOperation;
|
||||||
class PadOperation;
|
class PadOperation;
|
||||||
class RandomAffineOperation;
|
class RandomAffineOperation;
|
||||||
|
class RandomColorOperation;
|
||||||
class RandomColorAdjustOperation;
|
class RandomColorAdjustOperation;
|
||||||
class RandomCropOperation;
|
class RandomCropOperation;
|
||||||
class RandomHorizontalFlipOperation;
|
class RandomHorizontalFlipOperation;
|
||||||
|
@ -162,6 +163,14 @@ std::shared_ptr<RandomAffineOperation> RandomAffine(
|
||||||
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
||||||
const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
||||||
|
|
||||||
|
/// \brief Blends an image with its grayscale version with random weights
|
||||||
|
/// t and 1 - t generated from a given range. If the range is trivial
|
||||||
|
/// then the weights are determinate and t equals the bound of the interval
|
||||||
|
/// \param[in] t_lb lower bound on the range of random weights
|
||||||
|
/// \param[in] t_lb upper bound on the range of random weights
|
||||||
|
/// \return Shared pointer to the current TensorOp
|
||||||
|
std::shared_ptr<RandomColorOperation> RandomColor(float t_lb, float t_ub);
|
||||||
|
|
||||||
/// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image
|
/// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image
|
||||||
/// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values
|
/// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values
|
||||||
/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
|
/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
|
||||||
|
@ -417,6 +426,21 @@ class RandomAffineOperation : public TensorOperation {
|
||||||
std::vector<uint8_t> fill_value_;
|
std::vector<uint8_t> fill_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class RandomColorOperation : public TensorOperation {
|
||||||
|
public:
|
||||||
|
RandomColorOperation(float t_lb, float t_ub);
|
||||||
|
|
||||||
|
~RandomColorOperation() = default;
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> Build() override;
|
||||||
|
|
||||||
|
bool ValidateParams() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
float t_lb_;
|
||||||
|
float t_ub_;
|
||||||
|
};
|
||||||
|
|
||||||
class RandomColorAdjustOperation : public TensorOperation {
|
class RandomColorAdjustOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0},
|
RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0},
|
||||||
|
|
|
@ -44,5 +44,6 @@ add_library(kernels-image OBJECT
|
||||||
uniform_aug_op.cc
|
uniform_aug_op.cc
|
||||||
resize_with_bbox_op.cc
|
resize_with_bbox_op.cc
|
||||||
random_resize_with_bbox_op.cc
|
random_resize_with_bbox_op.cc
|
||||||
|
random_color_op.cc
|
||||||
)
|
)
|
||||||
add_dependencies(kernels-image kernels-soft-dvpp-image)
|
add_dependencies(kernels-image kernels-soft-dvpp-image)
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/kernels/image/random_color_op.h"
|
||||||
|
#include "minddata/dataset/core/cv_tensor.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {}
|
||||||
|
|
||||||
|
Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) {
|
||||||
|
IO_CHECK(in, out);
|
||||||
|
if (in->Rank() != 3) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("image must have 3 channels");
|
||||||
|
}
|
||||||
|
// 0.5 pixel precision assuming an 8 bit image
|
||||||
|
const auto eps = 0.00195;
|
||||||
|
const auto t = dist_(rnd_);
|
||||||
|
if (abs(t - 1.0) < eps) {
|
||||||
|
// Just return input? Can we do it given that input would otherwise get consumed in CVTensor constructor anyway?
|
||||||
|
*out = in;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
auto cvt_in = CVTensor::AsCVTensor(in);
|
||||||
|
auto m1 = cvt_in->mat();
|
||||||
|
cv::Mat gray;
|
||||||
|
// gray is allocated without using the allocator
|
||||||
|
cv::cvtColor(m1, gray, cv::COLOR_RGB2GRAY);
|
||||||
|
// luminosity is not preserved, consider using weights.
|
||||||
|
cv::Mat temp[3] = {gray, gray, gray};
|
||||||
|
cv::Mat cv_out;
|
||||||
|
cv::merge(temp, 3, cv_out);
|
||||||
|
std::shared_ptr<CVTensor> cvt_out;
|
||||||
|
CVTensor::CreateFromMat(cv_out, &cvt_out);
|
||||||
|
if (abs(t - 0.0) < eps) {
|
||||||
|
// return grayscale
|
||||||
|
*out = std::static_pointer_cast<Tensor>(cvt_out);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// return blended image. addWeighted takes care of overflow for uint8_t
|
||||||
|
cv::addWeighted(m1, t, cvt_out->mat(), 1 - t, 0, cvt_out->mat());
|
||||||
|
*out = std::static_pointer_cast<Tensor>(cvt_out);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <opencv2/imgproc/imgproc.hpp>
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/core/cv_tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
/// \class RandomColorOp random_color_op.h
|
||||||
|
/// \brief Blends an image with its grayscale version with random weights
|
||||||
|
/// t and 1 - t generated from a given range.
|
||||||
|
/// If the range is trivial then the weights are determinate and
|
||||||
|
/// t equals the bound of the interval
|
||||||
|
class RandomColorOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
RandomColorOp() = default;
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param[in] t_lb lower bound for the random weights
|
||||||
|
/// \param[in] t_ub upper bound for the random weights
|
||||||
|
RandomColorOp(float t_lb, float t_ub);
|
||||||
|
/// \brief the main function performing computations
|
||||||
|
/// \param[in] in 2- or 3- dimensional tensor representing an image
|
||||||
|
/// \param[out] out 2- or 3- dimensional tensor representing an image
|
||||||
|
/// with the same dimensions as in
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) override;
|
||||||
|
/// \brief returns the name of the op
|
||||||
|
std::string Name() const override { return kRandomColorOp; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mt19937 rnd_;
|
||||||
|
std::uniform_real_distribution<float> dist_;
|
||||||
|
float t_lb_;
|
||||||
|
float t_ub_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H
|
|
@ -129,6 +129,7 @@ constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
|
||||||
constexpr char kUniformAugOp[] = "UniformAugOp";
|
constexpr char kUniformAugOp[] = "UniformAugOp";
|
||||||
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
|
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
|
||||||
constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp";
|
constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp";
|
||||||
|
constexpr char kRandomColorOp[] = "RandomColorOp";
|
||||||
|
|
||||||
// text
|
// text
|
||||||
constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp";
|
constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp";
|
||||||
|
|
|
@ -46,7 +46,8 @@ import mindspore._c_dataengine as cde
|
||||||
from .utils import Inter, Border
|
from .utils import Inter, Border
|
||||||
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||||
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||||
check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \
|
check_range, check_resize, check_rescale, check_pad, check_cutout, \
|
||||||
|
check_uniform_augment_cpp, \
|
||||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER
|
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER
|
||||||
|
|
||||||
|
@ -628,6 +629,21 @@ class CenterCrop(cde.CenterCropOp):
|
||||||
super().__init__(*size)
|
super().__init__(*size)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomColor(cde.RandomColorOp):
|
||||||
|
"""
|
||||||
|
Adjust the color of the input image by a fixed or random degree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
degrees (sequence): Range of random color adjustment degrees.
|
||||||
|
It should be in (min, max) format. If min=max, then it is a
|
||||||
|
single fixed magnitude operation (default=(0.1,1.9)).
|
||||||
|
Works with 3-channel color images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_positive_degrees
|
||||||
|
def __init__(self, degrees=(0.1, 1.9)):
|
||||||
|
super().__init__(*degrees)
|
||||||
|
|
||||||
class RandomColorAdjust(cde.RandomColorAdjustOp):
|
class RandomColorAdjust(cde.RandomColorAdjustOp):
|
||||||
"""
|
"""
|
||||||
Randomly adjust the brightness, contrast, saturation, and hue of the input image.
|
Randomly adjust the brightness, contrast, saturation, and hue of the input image.
|
||||||
|
|
|
@ -609,21 +609,23 @@ def check_uniform_augment_py(method):
|
||||||
|
|
||||||
|
|
||||||
def check_positive_degrees(method):
|
def check_positive_degrees(method):
|
||||||
"""A wrapper method to check degrees parameter in RandSharpness and RandColor"""
|
"""A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (python and cpp)"""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[degrees], _ = parse_user_args(method, *args, **kwargs)
|
[degrees], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if isinstance(degrees, (list, tuple)):
|
|
||||||
|
if degrees is not None:
|
||||||
|
if not isinstance(degrees, (list, tuple)):
|
||||||
|
raise TypeError("degrees must be either a tuple or a list.")
|
||||||
|
type_check_list(degrees, (int, float), "degrees")
|
||||||
if len(degrees) != 2:
|
if len(degrees) != 2:
|
||||||
raise ValueError("Degrees must be a sequence with length 2.")
|
raise ValueError("degrees must be a sequence with length 2.")
|
||||||
for value in degrees:
|
for degree in degrees:
|
||||||
check_value(value, (0., FLOAT_MAX_INTEGER))
|
check_value(degree, (0, FLOAT_MAX_INTEGER))
|
||||||
check_positive(degrees[0], "degrees[0]")
|
|
||||||
if degrees[0] > degrees[1]:
|
if degrees[0] > degrees[1]:
|
||||||
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
|
raise ValueError("degrees should be in (min,max) format. Got (max,min).")
|
||||||
else:
|
|
||||||
raise TypeError("Degrees should be a tuple or list.")
|
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
@ -698,4 +700,5 @@ def check_random_solarize(method):
|
||||||
raise ValueError("threshold must be in min max format numbers")
|
raise ValueError("threshold must be in min max format numbers")
|
||||||
|
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
|
@ -39,6 +39,7 @@ SET(DE_UT_SRCS
|
||||||
project_op_test.cc
|
project_op_test.cc
|
||||||
queue_test.cc
|
queue_test.cc
|
||||||
random_affine_op_test.cc
|
random_affine_op_test.cc
|
||||||
|
random_color_op_test.cc
|
||||||
random_crop_op_test.cc
|
random_crop_op_test.cc
|
||||||
random_crop_with_bbox_op_test.cc
|
random_crop_with_bbox_op_test.cc
|
||||||
random_crop_decode_resize_op_test.cc
|
random_crop_decode_resize_op_test.cc
|
||||||
|
|
|
@ -63,10 +63,10 @@ TEST_F(MindDataTestPipeline, TestCutOut) {
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
i++;
|
i++;
|
||||||
auto image = row["image"];
|
auto image = row["image"];
|
||||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_EQ(i, 20);
|
EXPECT_EQ(i, 20);
|
||||||
|
@ -160,8 +160,9 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) {
|
||||||
auto image = row["image"];
|
auto image = row["image"];
|
||||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||||
// check if the image is in NCHW
|
// check if the image is in NCHW
|
||||||
EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1]
|
EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1] && 2268 == image->shape()[2] &&
|
||||||
&& 2268 == image->shape()[2] && 4032 == image->shape()[3], true);
|
4032 == image->shape()[3],
|
||||||
|
true);
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
EXPECT_EQ(i, 20);
|
EXPECT_EQ(i, 20);
|
||||||
|
@ -186,7 +187,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
|
||||||
EXPECT_NE(one_hot_op, nullptr);
|
EXPECT_NE(one_hot_op, nullptr);
|
||||||
|
|
||||||
// Create a Map operation on ds
|
// Create a Map operation on ds
|
||||||
ds = ds->Map({one_hot_op},{"label"});
|
ds = ds->Map({one_hot_op}, {"label"});
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(-1);
|
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(-1);
|
||||||
|
@ -209,7 +210,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
|
||||||
EXPECT_NE(one_hot_op, nullptr);
|
EXPECT_NE(one_hot_op, nullptr);
|
||||||
|
|
||||||
// Create a Map operation on ds
|
// Create a Map operation on ds
|
||||||
ds = ds->Map({one_hot_op},{"label"});
|
ds = ds->Map({one_hot_op}, {"label"});
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5);
|
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5);
|
||||||
|
@ -258,7 +259,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) {
|
||||||
EXPECT_NE(one_hot_op, nullptr);
|
EXPECT_NE(one_hot_op, nullptr);
|
||||||
|
|
||||||
// Create a Map operation on ds
|
// Create a Map operation on ds
|
||||||
ds = ds->Map({one_hot_op},{"label"});
|
ds = ds->Map({one_hot_op}, {"label"});
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch();
|
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch();
|
||||||
|
@ -379,10 +380,10 @@ TEST_F(MindDataTestPipeline, TestPad) {
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
i++;
|
i++;
|
||||||
auto image = row["image"];
|
auto image = row["image"];
|
||||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_EQ(i, 20);
|
EXPECT_EQ(i, 20);
|
||||||
|
@ -504,6 +505,61 @@ TEST_F(MindDataTestPipeline, TestRandomAffineSuccess2) {
|
||||||
iter->Stop();
|
iter->Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestRandomColor) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomColor with non-default params.";
|
||||||
|
|
||||||
|
// Create an ImageFolder Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create a Repeat operation on ds
|
||||||
|
int32_t repeat_num = 2;
|
||||||
|
ds = ds->Repeat(repeat_num);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create objects for the tensor ops
|
||||||
|
std::shared_ptr<TensorOperation> random_color_op_1 = vision::RandomColor(0.0, 0.0);
|
||||||
|
EXPECT_NE(random_color_op_1, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> random_color_op_2 = vision::RandomColor(1.0, 0.1);
|
||||||
|
EXPECT_EQ(random_color_op_2, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> random_color_op_3 = vision::RandomColor(0.0, 1.1);
|
||||||
|
EXPECT_NE(random_color_op_3, nullptr);
|
||||||
|
|
||||||
|
// Create a Map operation on ds
|
||||||
|
ds = ds->Map({random_color_op_1, random_color_op_3});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create a Batch operation on ds
|
||||||
|
int32_t batch_size = 1;
|
||||||
|
ds = ds->Batch(batch_size);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
i++;
|
||||||
|
auto image = row["image"];
|
||||||
|
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 20);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
|
TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
|
||||||
// Create an ImageFolder Dataset
|
// Create an ImageFolder Dataset
|
||||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
@ -780,7 +836,8 @@ TEST_F(MindDataTestPipeline, TestRandomSolarize) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
// Create objects for the tensor ops
|
// Create objects for the tensor ops
|
||||||
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(23, 23); //vision::RandomSolarize();
|
std::shared_ptr<TensorOperation> random_solarize =
|
||||||
|
mindspore::dataset::api::vision::RandomSolarize(23, 23); // vision::RandomSolarize();
|
||||||
EXPECT_NE(random_solarize, nullptr);
|
EXPECT_NE(random_solarize, nullptr);
|
||||||
|
|
||||||
// Create a Map operation on ds
|
// Create a Map operation on ds
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "common/cvop_common.h"
|
||||||
|
#include "minddata/dataset/kernels/image/random_color_op.h"
|
||||||
|
#include "minddata/dataset/core/cv_tensor.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
|
||||||
|
class MindDataTestRandomColorOp : public UT::CVOP::CVOpCommon {
|
||||||
|
public:
|
||||||
|
MindDataTestRandomColorOp() : CVOpCommon(), shape({3, 3, 3}) {
|
||||||
|
std::shared_ptr<Tensor> in;
|
||||||
|
std::shared_ptr<Tensor> gray;
|
||||||
|
|
||||||
|
(void)Tensor::CreateEmpty(shape, DataType(DataType::DE_UINT8), &in);
|
||||||
|
(void)Tensor::CreateEmpty(shape, DataType(DataType::DE_UINT8), &input_tensor);
|
||||||
|
Status s = in->Fill<uint8_t>(42);
|
||||||
|
s = input_tensor->Fill<uint8_t>(42);
|
||||||
|
cvt_in = CVTensor::AsCVTensor(in);
|
||||||
|
cv::Mat m2;
|
||||||
|
auto m1 = cvt_in->mat();
|
||||||
|
cv::cvtColor(m1, m2, cv::COLOR_RGB2GRAY);
|
||||||
|
cv::Mat temp[3] = {m2 , m2 , m2 };
|
||||||
|
cv::Mat cv_out;
|
||||||
|
cv::merge(temp, 3, cv_out);
|
||||||
|
std::shared_ptr<CVTensor> cvt_out;
|
||||||
|
CVTensor::CreateFromMat(cv_out, &cvt_out);
|
||||||
|
gray_tensor = std::static_pointer_cast<Tensor>(cvt_out);
|
||||||
|
}
|
||||||
|
TensorShape shape;
|
||||||
|
std::shared_ptr<Tensor> input_tensor;
|
||||||
|
std::shared_ptr<CVTensor> cvt_in;
|
||||||
|
std::shared_ptr<Tensor> gray_tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
int64_t Compare(std::shared_ptr<Tensor> t1, std::shared_ptr<Tensor> t2) {
|
||||||
|
auto shape = t1->shape();
|
||||||
|
int64_t sum = 0;
|
||||||
|
for (auto i = 0; i < shape[0]; i++) {
|
||||||
|
for (auto j = 0; j < shape[1]; j++) {
|
||||||
|
for (auto k = 0; k < shape[2]; k++) {
|
||||||
|
uint8_t value1;
|
||||||
|
uint8_t value2;
|
||||||
|
(void)t1->GetItemAt<uint8_t>(&value1, {i, j, k});
|
||||||
|
(void)t2->GetItemAt<uint8_t>(&value2, {i, j, k});
|
||||||
|
sum += abs(static_cast<int>(value1) - static_cast<int>(value2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// these tests are tautological, write better tests when the requirements for the output are determined
|
||||||
|
// e. g. how do we want to convert to gray and what does it mean to blend with a gray image (pre- post- gamma corrected,
|
||||||
|
// what weights).
|
||||||
|
TEST_F(MindDataTestRandomColorOp, TestOp1) {
|
||||||
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
auto op = RandomColorOp(1, 1);
|
||||||
|
auto s = op.Compute(input_tensor, &output_tensor);
|
||||||
|
auto res = Compare(input_tensor, output_tensor);
|
||||||
|
EXPECT_EQ(0, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestRandomColorOp, TestOp2) {
|
||||||
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
auto op = RandomColorOp(0, 0);
|
||||||
|
auto s = op.Compute(input_tensor, &output_tensor);
|
||||||
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
auto res = Compare(output_tensor, gray_tensor);
|
||||||
|
EXPECT_EQ(res, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestRandomColorOp, TestOp3) {
|
||||||
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
auto op = RandomColorOp(0.0, 1.0);
|
||||||
|
for (auto i = 0; i < 1; i++) {
|
||||||
|
auto s = op.Compute(input_tensor, &output_tensor);
|
||||||
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
}
|
||||||
|
}
|
Binary file not shown.
|
@ -16,9 +16,11 @@
|
||||||
Testing RandomColor op in DE
|
Testing RandomColor op in DE
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.engine as de
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from util import visualize_list, diff_mse, save_and_check_md5, \
|
from util import visualize_list, diff_mse, save_and_check_md5, \
|
||||||
|
@ -26,11 +28,17 @@ from util import visualize_list, diff_mse, save_and_check_md5, \
|
||||||
|
|
||||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||||
|
|
||||||
|
C_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
C_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||||
|
|
||||||
GENERATE_GOLDEN = False
|
GENERATE_GOLDEN = False
|
||||||
|
|
||||||
def test_random_color(degrees=(0.1, 1.9), plot=False):
|
|
||||||
|
def test_random_color_py(degrees=(0.1, 1.9), plot=False):
|
||||||
"""
|
"""
|
||||||
Test RandomColor
|
Test Python RandomColor
|
||||||
"""
|
"""
|
||||||
logger.info("Test RandomColor")
|
logger.info("Test RandomColor")
|
||||||
|
|
||||||
|
@ -85,9 +93,53 @@ def test_random_color(degrees=(0.1, 1.9), plot=False):
|
||||||
visualize_list(images_original, images_random_color)
|
visualize_list(images_original, images_random_color)
|
||||||
|
|
||||||
|
|
||||||
def test_random_color_md5():
|
def test_random_color_c(degrees=(0.1, 1.9), plot=False, run_golden=True):
|
||||||
"""
|
"""
|
||||||
Test RandomColor with md5 check
|
Test Cpp RandomColor
|
||||||
|
"""
|
||||||
|
logger.info("test_random_color_op")
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(10)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
# Decode with rgb format set to True
|
||||||
|
data1 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
data2 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
|
||||||
|
# Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
|
||||||
|
if degrees is None:
|
||||||
|
c_op = vision.RandomColor()
|
||||||
|
else:
|
||||||
|
c_op = vision.RandomColor(degrees)
|
||||||
|
|
||||||
|
data1 = data1.map(input_columns=["image"], operations=[vision.Decode()])
|
||||||
|
data2 = data2.map(input_columns=["image"], operations=[vision.Decode(), c_op])
|
||||||
|
|
||||||
|
image_random_color_op = []
|
||||||
|
image = []
|
||||||
|
|
||||||
|
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||||
|
actual = item1["image"]
|
||||||
|
expected = item2["image"]
|
||||||
|
image.append(actual)
|
||||||
|
image_random_color_op.append(expected)
|
||||||
|
|
||||||
|
if run_golden:
|
||||||
|
# Compare with expected md5 from images
|
||||||
|
filename = "random_color_op_02_result.npz"
|
||||||
|
save_and_check_md5(data2, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
if plot:
|
||||||
|
visualize_list(image, image_random_color_op)
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers((original_num_parallel_workers))
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_color_py_md5():
|
||||||
|
"""
|
||||||
|
Test Python RandomColor with md5 check
|
||||||
"""
|
"""
|
||||||
logger.info("Test RandomColor with md5 check")
|
logger.info("Test RandomColor with md5 check")
|
||||||
original_seed = config_get_set_seed(10)
|
original_seed = config_get_set_seed(10)
|
||||||
|
@ -110,8 +162,94 @@ def test_random_color_md5():
|
||||||
ds.config.set_num_parallel_workers((original_num_parallel_workers))
|
ds.config.set_num_parallel_workers((original_num_parallel_workers))
|
||||||
|
|
||||||
|
|
||||||
|
def test_compare_random_color_op(degrees=None, plot=False):
|
||||||
|
"""
|
||||||
|
Compare Random Color op in Python and Cpp
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("test_random_color_op")
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(5)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
# Decode with rgb format set to True
|
||||||
|
data1 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
data2 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
|
||||||
|
if degrees is None:
|
||||||
|
c_op = vision.RandomColor()
|
||||||
|
p_op = F.RandomColor()
|
||||||
|
else:
|
||||||
|
c_op = vision.RandomColor(degrees)
|
||||||
|
p_op = F.RandomColor(degrees)
|
||||||
|
|
||||||
|
transforms_random_color_py = F.ComposeOp([lambda img: img.astype(np.uint8), F.ToPIL(),
|
||||||
|
p_op, np.array])
|
||||||
|
|
||||||
|
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(), c_op])
|
||||||
|
data2 = data2.map(input_columns=["image"], operations=[vision.Decode()])
|
||||||
|
data2 = data2.map(input_columns=["image"], operations=transforms_random_color_py())
|
||||||
|
|
||||||
|
image_random_color_op = []
|
||||||
|
image = []
|
||||||
|
|
||||||
|
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||||
|
actual = item1["image"]
|
||||||
|
expected = item2["image"]
|
||||||
|
image_random_color_op.append(actual)
|
||||||
|
image.append(expected)
|
||||||
|
assert actual.shape == expected.shape
|
||||||
|
mse = diff_mse(actual, expected)
|
||||||
|
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if plot:
|
||||||
|
visualize_list(image, image_random_color_op)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_color_c_errors():
|
||||||
|
"""
|
||||||
|
Test that Cpp RandomColor errors with bad input
|
||||||
|
"""
|
||||||
|
with pytest.raises(TypeError) as error_info:
|
||||||
|
vision.RandomColor((12))
|
||||||
|
assert "degrees must be either a tuple or a list." in str(error_info.value)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError) as error_info:
|
||||||
|
vision.RandomColor(("col", 3))
|
||||||
|
assert "Argument degrees[0] with value col is not of type (<class 'int'>, <class 'float'>)." in str(
|
||||||
|
error_info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
vision.RandomColor((0.9, 0.1))
|
||||||
|
assert "degrees should be in (min,max) format. Got (max,min)." in str(error_info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
vision.RandomColor((0.9,))
|
||||||
|
assert "degrees must be a sequence with length 2." in str(error_info.value)
|
||||||
|
|
||||||
|
# RandomColor Cpp Op will fail with one channel input
|
||||||
|
mnist_ds = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
|
||||||
|
mnist_ds = mnist_ds.map(input_columns="image", operations=vision.RandomColor())
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as error_info:
|
||||||
|
for _ in enumerate(mnist_ds):
|
||||||
|
pass
|
||||||
|
assert "Invalid number of channels in input image" in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_random_color()
|
test_random_color_py()
|
||||||
test_random_color(plot=True)
|
test_random_color_py(plot=True)
|
||||||
test_random_color(degrees=(0.5, 1.5), plot=True)
|
test_random_color_py(degrees=(0.5, 1.5), plot=True)
|
||||||
test_random_color_md5()
|
test_random_color_py_md5()
|
||||||
|
|
||||||
|
test_random_color_c()
|
||||||
|
test_random_color_c(plot=True)
|
||||||
|
test_random_color_c(degrees=(0.5, 1.5), plot=True, run_golden=False)
|
||||||
|
test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False)
|
||||||
|
test_compare_random_color_op(plot=True)
|
||||||
|
test_random_color_c_errors()
|
||||||
|
|
Loading…
Reference in New Issue