add HorizontalFlip and VerticalFlip APIs

This commit is contained in:
Xiao Tianci 2021-05-25 15:24:41 +08:00
parent 8d5147662b
commit 9c74d55a49
21 changed files with 772 additions and 8 deletions

View File

@ -26,6 +26,7 @@
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
#include "minddata/dataset/kernels/ir/vision/equalize_ir.h"
#include "minddata/dataset/kernels/ir/vision/gaussian_blur_ir.h"
#include "minddata/dataset/kernels/ir/vision/horizontal_flip_ir.h"
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
#include "minddata/dataset/kernels/ir/vision/invert_ir.h"
#include "minddata/dataset/kernels/ir/vision/mixup_batch_ir.h"
@ -57,6 +58,7 @@
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_random_crop_resize_jpeg_ir.h"
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.h"
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
namespace mindspore {
namespace dataset {
@ -154,6 +156,16 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(HorizontalFlipOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::HorizontalFlipOperation, TensorOperation,
std::shared_ptr<vision::HorizontalFlipOperation>>(*m, "HorizontalFlipOperation")
.def(py::init([]() {
auto horizontal_flip = std::make_shared<vision::HorizontalFlipOperation>();
THROW_IF_ERROR(horizontal_flip->ValidateParams());
return horizontal_flip;
}));
}));
PYBIND_REGISTER(HwcToChwOperation, 1, ([](const py::module *m) {
(void)
py::class_<vision::HwcToChwOperation, TensorOperation, std::shared_ptr<vision::HwcToChwOperation>>(
@ -539,5 +551,16 @@ PYBIND_REGISTER(
return uniform_aug;
}));
}));
PYBIND_REGISTER(
VerticalFlipOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::VerticalFlipOperation, TensorOperation, std::shared_ptr<vision::VerticalFlipOperation>>(
*m, "VerticalFlipOperation")
.def(py::init([]() {
auto vertical_flip = std::make_shared<vision::VerticalFlipOperation>();
THROW_IF_ERROR(vertical_flip->ValidateParams());
return vertical_flip;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -31,6 +31,7 @@
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
#include "minddata/dataset/kernels/ir/vision/equalize_ir.h"
#include "minddata/dataset/kernels/ir/vision/gaussian_blur_ir.h"
#include "minddata/dataset/kernels/ir/vision/horizontal_flip_ir.h"
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
#include "minddata/dataset/kernels/ir/vision/invert_ir.h"
#include "minddata/dataset/kernels/ir/vision/mixup_batch_ir.h"
@ -68,6 +69,7 @@
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.h"
#include "minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h"
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
#ifndef ENABLE_ANDROID
#include "utils/log_adapter.h"
@ -315,6 +317,11 @@ std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
}
#ifndef ENABLE_ANDROID
// HorizontalFlip Transform Operation.
HorizontalFlip::HorizontalFlip() {}
std::shared_ptr<TensorOperation> HorizontalFlip::Parse() { return std::make_shared<HorizontalFlipOperation>(); }
// HwcToChw Transform Operation.
HWC2CHW::HWC2CHW() {}
@ -917,6 +924,11 @@ UniformAugment::UniformAugment(const std::vector<std::reference_wrapper<TensorTr
std::shared_ptr<TensorOperation> UniformAugment::Parse() {
return std::make_shared<UniformAugOperation>(data_->transforms_, data_->num_ops_);
}
// VerticalFlip Transform Operation.
VerticalFlip::VerticalFlip() {}
std::shared_ptr<TensorOperation> VerticalFlip::Parse() { return std::make_shared<VerticalFlipOperation>(); }
#endif // not ENABLE_ANDROID
} // namespace vision

View File

@ -152,6 +152,22 @@ class Equalize final : public TensorTransform {
std::shared_ptr<TensorOperation> Parse() override;
};
/// \brief HorizontalFlip TensorTransform.
/// \note Flip the input image horizontally.
class HorizontalFlip final : public TensorTransform {
public:
/// \brief Constructor.
HorizontalFlip();
/// \brief Destructor.
~HorizontalFlip() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
};
/// \brief HwcToChw TensorTransform.
/// \note Transpose the input image; shape (H, W, C) to shape (C, H, W).
class HWC2CHW final : public TensorTransform {
@ -949,6 +965,22 @@ class UniformAugment final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief VerticalFlip TensorTransform.
/// \note Flip the input image Vertically.
class VerticalFlip final : public TensorTransform {
public:
/// \brief Constructor.
VerticalFlip();
/// \brief Destructor.
~VerticalFlip() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -16,6 +16,7 @@ add_library(kernels-image OBJECT
decode_op.cc
equalize_op.cc
gaussian_blur_op.cc
horizontal_flip_op.cc
hwc_to_chw_op.cc
image_utils.cc
invert_op.cc
@ -58,6 +59,7 @@ add_library(kernels-image OBJECT
random_color_op.cc
rotate_op.cc
resize_cubic_op.cc
vertical_flip_op.cc
)
if(ENABLE_ACL)
add_dependencies(kernels-image kernels-soft-dvpp-image kernels-dvpp-image)

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/horizontal_flip_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
Status HorizontalFlipOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
return HorizontalFlip(input, output);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_HORIZONTAL_FLIP_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_HORIZONTAL_FLIP_OP_H_
#include <memory>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class HorizontalFlipOp : public TensorOp {
public:
HorizontalFlipOp() {}
~HorizontalFlipOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kHorizontalFlipOp; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_HORIZONTAL_FLIP_OP_H_

View File

@ -1240,12 +1240,11 @@ Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
int32_t kernel_y, float sigma_x, float sigma_y) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);
cv::GaussianBlur(input_cv->mat(), output_cv->mat(), cv::Size(kernel_x, kernel_y), static_cast<double>(sigma_x),
cv::Mat output_cv_mat;
cv::GaussianBlur(input_cv->mat(), output_cv_mat, cv::Size(kernel_x, kernel_y), static_cast<double>(sigma_x),
static_cast<double>(sigma_y));
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_cv_mat, &output_cv));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/vertical_flip_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
Status VerticalFlipOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
return VerticalFlip(input, output);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_VERTICAL_FLIP_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_VERTICAL_FLIP_OP_H_
#include <memory>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class VerticalFlipOp : public TensorOp {
public:
VerticalFlipOp() {}
~VerticalFlipOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kVerticalFlipOp; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_VERTICAL_FLIP_OP_H_

View File

@ -12,6 +12,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
decode_ir.cc
equalize_ir.cc
gaussian_blur_ir.cc
horizontal_flip_ir.cc
hwc_to_chw_ir.cc
invert_ir.cc
mixup_batch_ir.cc
@ -49,6 +50,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
softdvpp_decode_resize_jpeg_ir.cc
swap_red_blue_ir.cc
uniform_aug_ir.cc
vertical_flip_ir.cc
)
if(ENABLE_ACL)

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/vision/horizontal_flip_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/horizontal_flip_op.h"
#endif
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// VerticalFlipOperation
HorizontalFlipOperation::HorizontalFlipOperation() {}
HorizontalFlipOperation::~HorizontalFlipOperation() = default;
std::string HorizontalFlipOperation::Name() const { return kHorizontalFlipOperation; }
Status HorizontalFlipOperation::ValidateParams() { return Status::OK(); }
std::shared_ptr<TensorOp> HorizontalFlipOperation::Build() {
std::shared_ptr<HorizontalFlipOp> tensor_op = std::make_shared<HorizontalFlipOp>();
return tensor_op;
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_HORIZONTAL_FLIP_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_HORIZONTAL_FLIP_IR_H_
#include <memory>
#include <string>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kHorizontalFlipOperation[] = "HorizontalFlip";
class HorizontalFlipOperation : public TensorOperation {
public:
HorizontalFlipOperation();
~HorizontalFlipOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_HORIZONTAL_FLIP_IR_H_

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/vertical_flip_op.h"
#endif
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// VerticalFlipOperation
VerticalFlipOperation::VerticalFlipOperation() {}
VerticalFlipOperation::~VerticalFlipOperation() = default;
std::string VerticalFlipOperation::Name() const { return kVerticalFlipOperation; }
Status VerticalFlipOperation::ValidateParams() { return Status::OK(); }
std::shared_ptr<TensorOp> VerticalFlipOperation::Build() {
std::shared_ptr<VerticalFlipOp> tensor_op = std::make_shared<VerticalFlipOp>();
return tensor_op;
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_VERTICAL_FLIP_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_VERTICAL_FLIP_IR_H_
#include <memory>
#include <string>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kVerticalFlipOperation[] = "VerticalFlip";
class VerticalFlipOperation : public TensorOperation {
public:
VerticalFlipOperation();
~VerticalFlipOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_VERTICAL_FLIP_IR_H_

View File

@ -70,6 +70,7 @@ constexpr char kDvppNormalizeOp[] = "DvppNormalizeOp";
constexpr char kDvppResizeJpegOp[] = "DvppResizeJpegOp";
constexpr char kEqualizeOp[] = "EqualizeOp";
constexpr char kGaussianBlurOp[] = "GaussianBlurOp";
constexpr char kHorizontalFlipOp[] = "HorizontalFlipOp";
constexpr char kHwcToChwOp[] = "HWC2CHWOp";
constexpr char kInvertOp[] = "InvertOp";
constexpr char kMixUpBatchOp[] = "MixUpBatchOp";
@ -78,6 +79,7 @@ constexpr char kNormalizePadOp[] = "NormalizePadOp";
constexpr char kPadOp[] = "PadOp";
constexpr char kRandomAffineOp[] = "RandomAffineOp";
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
constexpr char kRandomColorOp[] = "RandomColorOp";
constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp";
constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp";
constexpr char kRandomCropDecodeResizeOp[] = "RandomCropDecodeResizeOp";
@ -101,12 +103,12 @@ constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp";
constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp";
constexpr char kRgbToGrayOp[] = "RgbToGrayOp";
constexpr char kSharpnessOp[] = "SharpnessOp";
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp";
constexpr char kSolarizeOp[] = "SolarizeOp";
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
constexpr char kUniformAugOp[] = "UniformAugOp";
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp";
constexpr char kRandomColorOp[] = "RandomColorOp";
constexpr char kVerticalFlipOp[] = "VerticalFlipOp";
// text
constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp";

View File

@ -330,6 +330,20 @@ class GaussianBlur(ImageTensorOperation):
return cde.GaussianBlurOperation(self.kernel_size, self.sigma)
class HorizontalFlip(ImageTensorOperation):
"""
Flip the input image horizontally.
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.HorizontalFlip()]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
def parse(self):
return cde.HorizontalFlipOperation()
class HWC2CHW(ImageTensorOperation):
"""
Transpose the input image; shape (H, W, C) to shape (C, H, W).
@ -1531,3 +1545,17 @@ class UniformAugment(ImageTensorOperation):
else:
transforms.append(op)
return cde.UniformAugOperation(transforms, self.num_ops)
class VerticalFlip(ImageTensorOperation):
"""
Flip the input image vertically.
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.VerticalFlip()]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
def parse(self):
return cde.VerticalFlipOperation()

View File

@ -39,11 +39,13 @@ SET(DE_UT_SRCS
c_api_vision_a_to_q_test.cc
c_api_vision_affine_test.cc
c_api_vision_bounding_box_augment_test.cc
c_api_vision_horizontal_flip_test.cc
c_api_vision_random_subselect_policy_test.cc
c_api_vision_random_test.cc
c_api_vision_r_to_z_test.cc
c_api_vision_soft_dvpp_test.cc
c_api_vision_uniform_aug_test.cc
c_api_vision_vertical_flip_test.cc
celeba_op_test.cc
center_crop_op_test.cc
channel_swap_test.cc

View File

@ -0,0 +1,87 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/horizontal_flip_op.h"
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/execute.h"
#include "minddata/dataset/include/dataset/vision.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
class MindDataTestHorizontalFlip : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestHorizontalFlip, TestHorizontalFlipPipeline) {
MS_LOG(INFO) << "Doing MindDataTestHorizontalFlip-TestHorizontalFlipPipeline.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> horizontal_flip(new vision::HorizontalFlip());
// Create a Map operation on ds
ds = ds->Map({horizontal_flip});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestHorizontalFlip, TestHorizontalFlipEager) {
MS_LOG(INFO) << "Doing MindDataTestHorizontalFlip-TestHorizontalFlipEager.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto horizontal_flip = vision::HorizontalFlip();
auto transform = Execute({decode, horizontal_flip});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}

View File

@ -0,0 +1,87 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/vertical_flip_op.h"
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/execute.h"
#include "minddata/dataset/include/dataset/vision.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
class MindDataTestVerticalFlip : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestVerticalFlip, TestVerticalFlipPipeline) {
MS_LOG(INFO) << "Doing MindDataTestVerticalFlip-TestVerticalFlipPipeline.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> vertical_flip(new vision::VerticalFlip());
// Create a Map operation on ds
ds = ds->Map({vertical_flip});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestVerticalFlip, TestVerticalFlipEager) {
MS_LOG(INFO) << "Doing MindDataTestVerticalFlip-TestVerticalFlipEager.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto vertical_flip = vision::VerticalFlip();
auto transform = Execute({decode, vertical_flip});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
}

View File

@ -0,0 +1,79 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing HorizontalFlip Python API
"""
import cv2
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import log as logger
from util import visualize_image, diff_mse
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
IMAGE_FILE = "../data/dataset/apple.jpg"
def test_horizontal_flip_pipeline(plot=False):
"""
Test HorizontalFlip of c_transforms
"""
logger.info("test_horizontal_flip_pipeline")
# First dataset
dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
decode_op = c_vision.Decode()
horizontal_flip_op = c_vision.HorizontalFlip()
dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
dataset1 = dataset1.map(operations=horizontal_flip_op, input_columns=["image"])
# Second dataset
dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
if num_iter > 0:
break
horizontal_flip_ms = data1["image"]
original = data2["image"]
horizontal_flip_cv = cv2.flip(original, 1)
mse = diff_mse(horizontal_flip_ms, horizontal_flip_cv)
logger.info("gaussian_blur_{}, mse: {}".format(num_iter + 1, mse))
assert mse == 0
num_iter += 1
if plot:
visualize_image(original, horizontal_flip_ms, mse, horizontal_flip_cv)
def test_horizontal_flip_eager():
"""
Test HorizontalFlip with eager mode
"""
logger.info("test_horizontal_flip_eager")
img = cv2.imread(IMAGE_FILE)
img_ms = c_vision.HorizontalFlip()(img)
img_cv = cv2.flip(img, 1)
mse = diff_mse(img_ms, img_cv)
assert mse == 0
if __name__ == "__main__":
test_horizontal_flip_pipeline(plot=True)
test_horizontal_flip_eager()

View File

@ -0,0 +1,79 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing VerticalFlip Python API
"""
import cv2
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import log as logger
from util import visualize_image, diff_mse
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
IMAGE_FILE = "../data/dataset/apple.jpg"
def test_vertical_flip_pipeline(plot=False):
"""
Test VerticalFlip of c_transforms
"""
logger.info("test_vertical_flip_pipeline")
# First dataset
dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
decode_op = c_vision.Decode()
vertical_flip_op = c_vision.VerticalFlip()
dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
dataset1 = dataset1.map(operations=vertical_flip_op, input_columns=["image"])
# Second dataset
dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
if num_iter > 0:
break
vertical_flip_ms = data1["image"]
original = data2["image"]
vertical_flip_cv = cv2.flip(original, 0)
mse = diff_mse(vertical_flip_ms, vertical_flip_cv)
logger.info("gaussian_blur_{}, mse: {}".format(num_iter + 1, mse))
assert mse == 0
num_iter += 1
if plot:
visualize_image(original, vertical_flip_ms, mse, vertical_flip_cv)
def test_vertical_flip_eager():
"""
Test VerticalFlip with eager mode
"""
logger.info("test_vertical_flip_eager")
img = cv2.imread(IMAGE_FILE)
img_ms = c_vision.VerticalFlip()(img)
img_cv = cv2.flip(img, 0)
mse = diff_mse(img_ms, img_cv)
assert mse == 0
if __name__ == "__main__":
test_vertical_flip_pipeline(plot=True)
test_vertical_flip_eager()