forked from OSSInnovation/mindspore
Added Posterize Op
This commit is contained in:
parent
0e27a04da1
commit
979111d0ef
|
@ -42,6 +42,7 @@
|
|||
#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_posterize_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_resize_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||
|
@ -142,6 +143,13 @@ PYBIND_REGISTER(RandomAffineOp, 1, ([](const py::module *m) {
|
|||
py::arg("fill_value") = RandomAffineOp::kFillValue);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomPosterizeOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomPosterizeOp, TensorOp, std::shared_ptr<RandomPosterizeOp>>(
|
||||
*m, "RandomPosterizeOp", "Tensor operation to apply random posterize operation on an image.")
|
||||
.def(py::init<uint8_t, uint8_t>(), py::arg("min_bit") = RandomPosterizeOp::kMinBit,
|
||||
py::arg("max_bit") = RandomPosterizeOp::kMaxBit);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RandomResizeWithBBoxOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomResizeWithBBoxOp>>(
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_posterize_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
||||
|
@ -217,6 +218,16 @@ std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob)
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomPosterizeOperation.
|
||||
std::shared_ptr<RandomPosterizeOperation> RandomPosterize(uint8_t min_bit, uint8_t max_bit) {
|
||||
auto op = std::make_shared<RandomPosterizeOperation>(min_bit, max_bit);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomRotationOperation.
|
||||
std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degrees, InterpolationMode resample,
|
||||
bool expand, std::vector<float> center,
|
||||
|
@ -725,6 +736,31 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomPosterizeOperation
|
||||
RandomPosterizeOperation::RandomPosterizeOperation(uint8_t min_bit, uint8_t max_bit)
|
||||
: min_bit_(min_bit), max_bit_(max_bit) {}
|
||||
|
||||
bool RandomPosterizeOperation::ValidateParams() {
|
||||
if (min_bit_ < 1 || min_bit_ > 8) {
|
||||
MS_LOG(ERROR) << "RandomPosterize: min_bit value is out of range [1-8]: " << min_bit_;
|
||||
return false;
|
||||
}
|
||||
if (max_bit_ < 1 || max_bit_ > 8) {
|
||||
MS_LOG(ERROR) << "RandomPosterize: max_bit value is out of range [1-8]: " << max_bit_;
|
||||
return false;
|
||||
}
|
||||
if (max_bit_ < min_bit_) {
|
||||
MS_LOG(ERROR) << "RandomPosterize: max_bit value is less than min_bit: max =" << max_bit_ << ", min = " << min_bit_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
|
||||
std::shared_ptr<RandomPosterizeOp> tensor_op = std::make_shared<RandomPosterizeOp>(min_bit_, max_bit_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// Function to create RandomRotationOperation.
|
||||
RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
|
||||
bool expand, std::vector<float> center,
|
||||
|
|
|
@ -62,6 +62,7 @@ class RandomColorOperation;
|
|||
class RandomColorAdjustOperation;
|
||||
class RandomCropOperation;
|
||||
class RandomHorizontalFlipOperation;
|
||||
class RandomPosterizeOperation;
|
||||
class RandomRotationOperation;
|
||||
class RandomSharpnessOperation;
|
||||
class RandomSolarizeOperation;
|
||||
|
@ -220,6 +221,13 @@ std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::
|
|||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob = 0.5);
|
||||
|
||||
/// \brief Function to create a RandomPosterize TensorOperation.
|
||||
/// \notes Tensor operation to perform random posterize.
|
||||
/// \param[in] min_bit - uint8_t representing the minimum bit in range. (Default=8)
|
||||
/// \param[in] max_bit - uint8_t representing the maximum bit in range. (Default=8)
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<RandomPosterizeOperation> RandomPosterize(uint8_t min_bit = 8, uint8_t max_bit = 8);
|
||||
|
||||
/// \brief Function to create a RandomRotation TensorOp
|
||||
/// \notes Rotates the image according to parameters
|
||||
/// \param[in] degrees A float vector size 2, representing the starting and ending degree
|
||||
|
@ -521,6 +529,21 @@ class RandomHorizontalFlipOperation : public TensorOperation {
|
|||
float probability_;
|
||||
};
|
||||
|
||||
class RandomPosterizeOperation : public TensorOperation {
|
||||
public:
|
||||
explicit RandomPosterizeOperation(uint8_t min_bit = 8, uint8_t max_bit = 8);
|
||||
|
||||
~RandomPosterizeOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
uint8_t min_bit_;
|
||||
uint8_t max_bit_;
|
||||
};
|
||||
|
||||
class RandomRotationOperation : public TensorOperation {
|
||||
public:
|
||||
RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode, bool expand,
|
||||
|
|
|
@ -17,6 +17,7 @@ add_library(kernels-image OBJECT
|
|||
mixup_batch_op.cc
|
||||
normalize_op.cc
|
||||
pad_op.cc
|
||||
posterize_op.cc
|
||||
random_affine_op.cc
|
||||
random_color_adjust_op.cc
|
||||
random_crop_decode_resize_op.cc
|
||||
|
@ -27,6 +28,7 @@ add_library(kernels-image OBJECT
|
|||
random_horizontal_flip_op.cc
|
||||
random_horizontal_flip_with_bbox_op.cc
|
||||
bounding_box_augment_op.cc
|
||||
random_posterize_op.cc
|
||||
random_resize_op.cc
|
||||
random_rotation_op.cc
|
||||
random_select_subpolicy_op.cc
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/posterize_op.h"
|
||||
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const uint8_t PosterizeOp::kBit = 8;
|
||||
|
||||
PosterizeOp::PosterizeOp(uint8_t bit) : bit_(bit) {}
|
||||
|
||||
Status PosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
uint8_t mask_value = ~((uint8_t)(1 << (8 - bit_)) - 1);
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of <H,W,C> or <H,W>");
|
||||
}
|
||||
std::vector<uint8_t> lut_vector;
|
||||
for (std::size_t i = 0; i < 256; i++) {
|
||||
lut_vector.push_back(i & mask_value);
|
||||
}
|
||||
cv::Mat in_image = input_cv->mat();
|
||||
cv::Mat output_img;
|
||||
cv::LUT(in_image, lut_vector, output_img);
|
||||
std::shared_ptr<CVTensor> result_tensor;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor));
|
||||
*output = std::static_pointer_cast<Tensor>(result_tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class PosterizeOp : public TensorOp {
|
||||
public:
|
||||
/// Default values
|
||||
static const uint8_t kBit;
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param[in] bit: bits to use
|
||||
explicit PosterizeOp(uint8_t bit = kBit);
|
||||
|
||||
~PosterizeOp() override = default;
|
||||
|
||||
std::string Name() const override { return kPosterizeOp; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
/// Member variables
|
||||
private:
|
||||
std::string kPosterizeOp = "PosterizeOp";
|
||||
|
||||
protected:
|
||||
uint8_t bit_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/random_posterize_op.h"
|
||||
|
||||
#include <random>
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const uint8_t RandomPosterizeOp::kMinBit = 8;
|
||||
const uint8_t RandomPosterizeOp::kMaxBit = 8;
|
||||
|
||||
RandomPosterizeOp::RandomPosterizeOp(uint8_t min_bit, uint8_t max_bit)
|
||||
: PosterizeOp(min_bit), min_bit_(min_bit), max_bit_(max_bit) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
Status RandomPosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
bit_ = (min_bit_ == max_bit_) ? min_bit_ : std::uniform_int_distribution<uint8_t>(min_bit_, max_bit_)(rnd_);
|
||||
return PosterizeOp::Compute(input, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/kernels/image/posterize_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class RandomPosterizeOp : public PosterizeOp {
|
||||
public:
|
||||
/// Default values
|
||||
static const uint8_t kMinBit;
|
||||
static const uint8_t kMaxBit;
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param[in] min_bit: Minimum bit in range
|
||||
/// \param[in] max_bit: Maximum bit in range
|
||||
explicit RandomPosterizeOp(uint8_t min_bit = kMinBit, uint8_t max_bit = kMaxBit);
|
||||
|
||||
~RandomPosterizeOp() override = default;
|
||||
|
||||
std::string Name() const override { return kRandomPosterizeOp; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
/// Member variables
|
||||
private:
|
||||
std::string kRandomPosterizeOp = "RandomPosterizeOp";
|
||||
uint8_t min_bit_;
|
||||
uint8_t max_bit_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
|
|
@ -50,7 +50,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec
|
|||
check_uniform_augment_cpp, \
|
||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
|
||||
check_cut_mix_batch_c
|
||||
check_cut_mix_batch_c, check_posterize
|
||||
|
||||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||
|
@ -459,6 +459,26 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp):
|
|||
super().__init__(prob)
|
||||
|
||||
|
||||
class RandomPosterize(cde.RandomPosterizeOp):
|
||||
"""
|
||||
Reduce the number of bits for each color channel.
|
||||
|
||||
Args:
|
||||
bits (sequence or int): Range of random posterize to compress image.
|
||||
bits values should always be in range of [1,8], and include at
|
||||
least one integer values in the given range. It should be in
|
||||
(min, max) or integer format. If min=max, then it is a single fixed
|
||||
magnitude operation (default=8).
|
||||
"""
|
||||
|
||||
@check_posterize
|
||||
def __init__(self, bits=(8, 8)):
|
||||
self.bits = bits
|
||||
if isinstance(bits, int):
|
||||
bits = (bits, bits)
|
||||
super().__init__(bits[0], bits[1])
|
||||
|
||||
|
||||
class RandomVerticalFlip(cde.RandomVerticalFlipOp):
|
||||
"""
|
||||
Flip the input image vertically, randomly with a given probability.
|
||||
|
@ -676,6 +696,7 @@ class RandomColor(cde.RandomColorOp):
|
|||
def __init__(self, degrees=(0.1, 1.9)):
|
||||
super().__init__(*degrees)
|
||||
|
||||
|
||||
class RandomColorAdjust(cde.RandomColorAdjustOp):
|
||||
"""
|
||||
Randomly adjust the brightness, contrast, saturation, and hue of the input image.
|
||||
|
|
|
@ -162,6 +162,28 @@ def check_crop(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_posterize(method):
|
||||
""""A wrapper that wraps a parameter checker to the original function(posterize operation)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[bits], _ = parse_user_args(method, *args, **kwargs)
|
||||
if bits is not None:
|
||||
type_check(bits, (list, tuple, int), "bits")
|
||||
if isinstance(bits, int):
|
||||
check_value(bits, [1, 8])
|
||||
if isinstance(bits, (list, tuple)):
|
||||
if len(bits) != 2:
|
||||
raise TypeError("Size of bits should be a single integer or a list/tuple (min, max) of length 2.")
|
||||
for item in bits:
|
||||
check_uint8(item, "bits")
|
||||
# also checks if min <= max
|
||||
check_range(bits, [1, 8])
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_resize_interpolation(method):
|
||||
"""A wrapper that wraps a parameter checker to the original function(resize interpolation operation)."""
|
||||
|
||||
|
|
|
@ -789,6 +789,120 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomPosterizeFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomPosterize with invalid params.";
|
||||
|
||||
// Create objects for the tensor ops
|
||||
// Invalid max > 8
|
||||
std::shared_ptr<TensorOperation> posterize = vision::RandomPosterize(1, 9);
|
||||
EXPECT_EQ(posterize, nullptr);
|
||||
// Invalid min < 1
|
||||
posterize = vision::RandomPosterize(0, 8);
|
||||
EXPECT_EQ(posterize, nullptr);
|
||||
// min > max
|
||||
posterize = vision::RandomPosterize(8, 1);
|
||||
EXPECT_EQ(posterize, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomPosterizeSuccess1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomPosterizeSuccess1 with non-default params.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> posterize =
|
||||
vision::RandomPosterize(1, 4);
|
||||
EXPECT_NE(posterize, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({posterize});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 1;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 20);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomPosterizeSuccess2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomPosterizeSuccess2 with default params.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> posterize = vision::RandomPosterize();
|
||||
EXPECT_NE(posterize, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({posterize});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 1;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 20);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomSharpness) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSharpness.";
|
||||
|
||||
|
|
|
@ -154,6 +154,10 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
|
|||
expect_image_path = dir_path + "imagefolder/apple_expect_random_sharpness.jpg";
|
||||
actual_image_path = dir_path + "imagefolder/apple_actual_random_sharpness.jpg";
|
||||
break;
|
||||
case kRandomPosterize:
|
||||
expect_image_path = dir_path + "imagefolder/apple_expect_random_posterize.jpg";
|
||||
actual_image_path = dir_path + "imagefolder/apple_actual_random_posterize.jpg";
|
||||
break;
|
||||
default:
|
||||
MS_LOG(INFO) << "Not pass verification! Operation type does not exists.";
|
||||
EXPECT_EQ(0, 1);
|
||||
|
|
|
@ -42,6 +42,7 @@ class CVOpCommon : public Common {
|
|||
kRandomSharpness,
|
||||
kInvert,
|
||||
kRandomAffine,
|
||||
kRandomPosterize,
|
||||
kAutoContrast,
|
||||
kEqualize
|
||||
};
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "minddata/dataset/kernels/image/random_posterize_op.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestRandomPosterizeOp : public UT::CVOP::CVOpCommon {
|
||||
public:
|
||||
MindDataTestRandomPosterizeOp() : CVOpCommon() {}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestRandomPosterizeOp, TestOp1) {
|
||||
MS_LOG(INFO) << "Doing testRandomPosterize.";
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
std::unique_ptr<RandomPosterizeOp> op(new RandomPosterizeOp(1, 1));
|
||||
EXPECT_TRUE(op->OneToOne());
|
||||
Status s = op->Compute(input_tensor_, &output_tensor);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
CheckImageShapeAndData(output_tensor, kRandomPosterize);
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 380 KiB |
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing RandomPosterize op in DE
|
||||
"""
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, save_and_check_md5, \
|
||||
config_get_set_seed, config_get_set_num_parallel_workers
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def test_random_posterize_op_c(plot=False, run_golden=True):
|
||||
"""
|
||||
Test RandomPosterize in C transformations
|
||||
"""
|
||||
logger.info("test_random_posterize_op_c")
|
||||
|
||||
original_seed = config_get_set_seed(55)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# define map operations
|
||||
transforms1 = [
|
||||
c_vision.Decode(),
|
||||
c_vision.RandomPosterize((1, 8))
|
||||
]
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data1 = data1.map(input_columns=["image"], operations=transforms1)
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode()])
|
||||
|
||||
image_posterize = []
|
||||
image_original = []
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
image1 = item1["image"]
|
||||
image2 = item2["image"]
|
||||
image_posterize.append(image1)
|
||||
image_original.append(image2)
|
||||
|
||||
if run_golden:
|
||||
# check results with md5 comparison
|
||||
filename = "random_posterize_01_result_c.npz"
|
||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
if plot:
|
||||
visualize_list(image_original, image_posterize)
|
||||
|
||||
# Restore configuration
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_random_posterize_op_fixed_point_c(plot=False, run_golden=True):
|
||||
"""
|
||||
Test RandomPosterize in C transformations with fixed point
|
||||
"""
|
||||
logger.info("test_random_posterize_op_c")
|
||||
|
||||
# define map operations
|
||||
transforms1 = [
|
||||
c_vision.Decode(),
|
||||
c_vision.RandomPosterize(1)
|
||||
]
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data1 = data1.map(input_columns=["image"], operations=transforms1)
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode()])
|
||||
|
||||
image_posterize = []
|
||||
image_original = []
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
image1 = item1["image"]
|
||||
image2 = item2["image"]
|
||||
image_posterize.append(image1)
|
||||
image_original.append(image2)
|
||||
|
||||
if run_golden:
|
||||
# check results with md5 comparison
|
||||
filename = "random_posterize_fixed_point_01_result_c.npz"
|
||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
if plot:
|
||||
visualize_list(image_original, image_posterize)
|
||||
|
||||
|
||||
def test_random_posterize_exception_bit():
|
||||
"""
|
||||
Test RandomPosterize: out of range input bits and invalid type
|
||||
"""
|
||||
logger.info("test_random_posterize_exception_bit")
|
||||
# Test max > 8
|
||||
try:
|
||||
_ = c_vision.RandomPosterize((1, 9))
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Input is not within the required interval of (1 to 8)."
|
||||
# Test min < 1
|
||||
try:
|
||||
_ = c_vision.RandomPosterize((0, 7))
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Input is not within the required interval of (1 to 8)."
|
||||
# Test max < min
|
||||
try:
|
||||
_ = c_vision.RandomPosterize((8, 1))
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Input is not within the required interval of (1 to 8)."
|
||||
# Test wrong type (not uint8)
|
||||
try:
|
||||
_ = c_vision.RandomPosterize(1.1)
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Argument bits with value 1.1 is not of type (<class 'list'>, <class 'tuple'>, <class 'int'>)."
|
||||
# Test wrong number of bits
|
||||
try:
|
||||
_ = c_vision.RandomPosterize((1, 1, 1))
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_posterize_op_c(plot=True)
|
||||
test_random_posterize_op_fixed_point_c(plot=True)
|
||||
test_random_posterize_exception_bit()
|
Loading…
Reference in New Issue