[feat] [assidtant] [I40GZN] add new data ops RandomEqualize

This commit is contained in:
despicablemme 2021-09-13 10:54:39 +08:00
parent d0023355a3
commit 7063907bdc
14 changed files with 474 additions and 0 deletions

View File

@ -43,6 +43,7 @@
#include "minddata/dataset/kernels/ir/vision/random_crop_decode_resize_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_crop_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_equalize_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
@ -364,6 +365,16 @@ PYBIND_REGISTER(RandomCropWithBBoxOperation, 1, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(RandomEqualizeOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RandomEqualizeOperation, TensorOperation,
std::shared_ptr<vision::RandomEqualizeOperation>>(*m, "RandomEqualizeOperation")
.def(py::init([](float prob) {
auto random_equalize = std::make_shared<vision::RandomEqualizeOperation>(prob);
THROW_IF_ERROR(random_equalize->ValidateParams());
return random_equalize;
}));
}));
PYBIND_REGISTER(RandomHorizontalFlipOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RandomHorizontalFlipOperation, TensorOperation,
std::shared_ptr<vision::RandomHorizontalFlipOperation>>(

View File

@ -47,6 +47,7 @@
#include "minddata/dataset/kernels/ir/vision/random_crop_decode_resize_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_crop_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_equalize_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
@ -602,6 +603,18 @@ std::shared_ptr<TensorOperation> RandomCropWithBBox::Parse() {
data_->fill_value_, data_->padding_mode_);
}
// RandomEqualize Transform Operation.
struct RandomEqualize::Data {
explicit Data(float prob) : probability_(prob) {}
float probability_;
};
RandomEqualize::RandomEqualize(float prob) : data_(std::make_shared<Data>(prob)) {}
std::shared_ptr<TensorOperation> RandomEqualize::Parse() {
return std::make_shared<RandomEqualizeOperation>(data_->probability_);
}
// RandomHorizontalFlip.
struct RandomHorizontalFlip::Data {
explicit Data(float prob) : probability_(prob) {}

View File

@ -520,6 +520,27 @@ class RandomCropWithBBox final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Randomly apply histogram equalization on the input image with a given probability.
class RandomEqualize final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] prob A float representing the probability of equalization, which
/// must be in range of [0, 1] (default=0.5).
explicit RandomEqualize(float prob = 0.5);
/// \brief Destructor.
~RandomEqualize() = default;
protected:
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Randomly flip the input image horizontally with a given probability.
class RandomHorizontalFlip final : public TensorTransform {
public:

View File

@ -36,6 +36,7 @@ add_library(kernels-image OBJECT
random_crop_and_resize_op.cc
random_crop_op.cc
random_crop_with_bbox_op.cc
random_equalize_op.cc
random_horizontal_flip_op.cc
random_horizontal_flip_with_bbox_op.cc
random_invert_op.cc

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_equalize_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
const float RandomEqualizeOp::kDefProbability = 0.5;
Status RandomEqualizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
if (distribution_(rnd_)) {
return Equalize(input, output);
}
*output = input;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_EQUALIZE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_EQUALIZE_OP_H_
#include <memory>
#include <random>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RandomEqualizeOp : public TensorOp {
public:
// Default values, also used by python_bindings.cc
static const float kDefProbability;
explicit RandomEqualizeOp(float prob = kDefProbability) : distribution_(prob) {
is_deterministic_ = false;
rnd_.seed(GetSeed());
}
~RandomEqualizeOp() override = default;
// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const RandomEqualizeOp &so) {
so.Print(out);
return out;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kRandomEqualizeOp; }
private:
std::mt19937 rnd_;
std::bernoulli_distribution distribution_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_EQUALIZE_OP_H_

View File

@ -28,6 +28,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
random_crop_decode_resize_ir.cc
random_crop_ir.cc
random_crop_with_bbox_ir.cc
random_equalize_ir.cc
random_horizontal_flip_ir.cc
random_horizontal_flip_with_bbox_ir.cc
random_invert_ir.cc

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/vision/random_equalize_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/random_equalize_op.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// RandomEqualizeOperation
RandomEqualizeOperation::RandomEqualizeOperation(float prob) : TensorOperation(true), probability_(prob) {}
RandomEqualizeOperation::~RandomEqualizeOperation() = default;
std::string RandomEqualizeOperation::Name() const { return kRandomEqualizeOperation; }
Status RandomEqualizeOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomEqualize", probability_));
return Status::OK();
}
std::shared_ptr<TensorOp> RandomEqualizeOperation::Build() {
std::shared_ptr<RandomEqualizeOp> tensor_op = std::make_shared<RandomEqualizeOp>(probability_);
return tensor_op;
}
Status RandomEqualizeOperation::to_json(nlohmann::json *out_json) {
(*out_json)["prob"] = probability_;
return Status::OK();
}
Status RandomEqualizeOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("prob") != op_params.end(), "Failed to find prob");
float prob = op_params["prob"];
*operation = std::make_shared<vision::RandomEqualizeOperation>(prob);
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_EQUALIZE_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_EQUALIZE_IR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kRandomEqualizeOperation[] = "RandomEqualize";
class RandomEqualizeOperation : public TensorOperation {
public:
explicit RandomEqualizeOperation(float prob);
~RandomEqualizeOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
float probability_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_EQUALIZE_IR_H_

View File

@ -88,6 +88,7 @@ constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp
constexpr char kRandomCropDecodeResizeOp[] = "RandomCropDecodeResizeOp";
constexpr char kRandomCropOp[] = "RandomCropOp";
constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp";
constexpr char kRandomEqualizeOp[] = "RandomEqualizeOp";
constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp";
constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp";
constexpr char kRandomInvertOp[] = "RandomInvertOp";

View File

@ -1082,6 +1082,28 @@ class RandomCropWithBBox(ImageTensorOperation):
border_type)
class RandomEqualize(ImageTensorOperation):
"""
Apply histogram equalization on the input image with a given probability.
Args:
prob (float, optional): Probability of the image being equalized, which
must be in range of [0, 1] (default=0.5).
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.RandomEqualize(0.5)]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_prob
def __init__(self, prob=0.5):
self.prob = prob
def parse(self):
return cde.RandomEqualizeOperation(self.prob)
class RandomHorizontalFlip(ImageTensorOperation):
"""
Randomly flip the input image horizontally with a given probability.

View File

@ -350,3 +350,49 @@ TEST_F(MindDataTestPipeline, TestRGB2BGR) {
iter1->Stop();
iter2->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomEqualize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomEqualize.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
auto random_equalize_op = vision::RandomEqualize(0.5);
ds = ds->Map({random_equalize_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 2);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomEqualizeInvalidProb) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomEqualizeInvalidProb.";
std::string MindDataPath = "data/dataset";
std::string folder_path = MindDataPath + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
auto random_equalize_op = vision::RandomEqualize(1.5);
ds = ds->Map({random_equalize_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}

View File

@ -1126,3 +1126,17 @@ TEST_F(MindDataTestExecute, TestRandomAutoContrastEager) {
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}
TEST_F(MindDataTestExecute, TestRandomEqualizeEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRandomEqualizeEager.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto random_equalize_op = vision::RandomEqualize(0.6);
auto transform = Execute({decode, random_equalize_op});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}

View File

@ -0,0 +1,129 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RandomEqualize op in DE
"""
import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.vision.c_transforms import Decode, Resize, RandomEqualize, Equalize
from mindspore import log as logger
from util import visualize_list, visualize_image, diff_mse
image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg"
data_dir = "../data/dataset/testImageNetData/train/"
def test_random_equalize_pipeline(plot=False):
"""
Test RandomEqualize pipeline
"""
logger.info("Test RandomEqualize pipeline")
# Original Images
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
transforms_original = [Decode(), Resize(size=[224, 224])]
ds_original = data_set.map(operations=transforms_original, input_columns="image")
ds_original = ds_original.batch(512)
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image.asnumpy()
else:
images_original = np.append(images_original,
image.asnumpy(),
axis=0)
# Randomly Equalized Images
data_set1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
transform_random_equalize = [Decode(), Resize(size=[224, 224]), RandomEqualize(0.6)]
ds_random_equalize = data_set1.map(operations=transform_random_equalize, input_columns="image")
ds_random_equalize = ds_random_equalize.batch(512)
for idx, (image, _) in enumerate(ds_random_equalize):
if idx == 0:
images_random_equalize = image.asnumpy()
else:
images_random_equalize = np.append(images_random_equalize,
image.asnumpy(),
axis=0)
if plot:
visualize_list(images_original, images_random_equalize)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_random_equalize[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_random_equalize_eager():
"""
Test RandomEqualize eager.
"""
img = np.fromfile(image_file, dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = Decode()(img)
img_equalized = Equalize()(img)
img_random_equalized = RandomEqualize(1.0)(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img_random_equalized), img_random_equalized.shape))
assert img_random_equalized.all() == img_equalized.all()
def test_random_equalize_comp(plot=False):
"""
Test RandomEqualize op compared with Equalize op.
"""
random_equalize_op = RandomEqualize(prob=1.0)
equalize_op = Equalize()
dataset1 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
for item in dataset1.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item['image']
dataset1.map(operations=random_equalize_op, input_columns=['image'])
dataset2 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
dataset2.map(operations=equalize_op, input_columns=['image'])
for item1, item2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_random_equalized = item1['image']
image_equalized = item2['image']
mse = diff_mse(image_equalized, image_random_equalized)
assert mse == 0
logger.info("mse: {}".format(mse))
if plot:
visualize_image(image, image_random_equalized, mse, image_equalized)
def test_random_equalize_invalid_prob():
"""
Test eager. prob out of range.
"""
logger.info("test_random_equalize_invalid_prob")
dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
try:
random_equalize_op = RandomEqualize(1.5)
dataset = dataset.map(operations=random_equalize_op, input_columns=['image'])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
if __name__ == "__main__":
test_random_equalize_pipeline(plot=True)
test_random_equalize_eager()
test_random_equalize_comp(plot=True)
test_random_equalize_invalid_prob()