add slice patches in python and c++

This commit is contained in:
liyong 2021-06-10 19:26:42 +08:00
parent 1e54401388
commit 9a1de0af65
20 changed files with 768 additions and 4 deletions

View File

@ -131,5 +131,12 @@ PYBIND_REGISTER(ImageBatchFormat, 0, ([](const py::module *m) {
.export_values();
}));
PYBIND_REGISTER(SliceMode, 0, ([](const py::module *m) {
(void)py::enum_<SliceMode>(*m, "SliceMode", py::arithmetic())
.value("DE_SLICE_PAD", SliceMode::kPad)
.value("DE_SLICE_DROP", SliceMode::kDrop)
.export_values();
}));
} // namespace dataset
} // namespace mindspore

View File

@ -58,6 +58,7 @@
#include "minddata/dataset/kernels/ir/vision/resize_with_bbox_ir.h"
#include "minddata/dataset/kernels/ir/vision/rgb_to_bgr_ir.h"
#include "minddata/dataset/kernels/ir/vision/rotate_ir.h"
#include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_random_crop_resize_jpeg_ir.h"
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.h"
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
@ -553,6 +554,18 @@ PYBIND_REGISTER(RotateOperation, 1, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(
SlicePatchesOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::SlicePatchesOperation, TensorOperation, std::shared_ptr<vision::SlicePatchesOperation>>(
*m, "SlicePatchesOperation")
.def(py::init([](int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value) {
auto slice_patches =
std::make_shared<vision::SlicePatchesOperation>(num_height, num_width, slice_mode, fill_value);
THROW_IF_ERROR(slice_patches->ValidateParams());
return slice_patches;
}));
}));
PYBIND_REGISTER(SoftDvppDecodeRandomCropResizeJpegOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::SoftDvppDecodeRandomCropResizeJpegOperation, TensorOperation,
std::shared_ptr<vision::SoftDvppDecodeRandomCropResizeJpegOperation>>(

View File

@ -66,6 +66,7 @@
#include "minddata/dataset/kernels/ir/vision/rgba_to_bgr_ir.h"
#include "minddata/dataset/kernels/ir/vision/rgba_to_rgb_ir.h"
#include "minddata/dataset/kernels/ir/vision/rotate_ir.h"
#include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_random_crop_resize_jpeg_ir.h"
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.h"
#include "minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h"
@ -877,6 +878,24 @@ RGBA2RGB::RGBA2RGB() {}
std::shared_ptr<TensorOperation> RGBA2RGB::Parse() { return std::make_shared<RgbaToRgbOperation>(); }
// SlicePatches Transform Operation.
struct SlicePatches::Data {
Data(int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value)
: num_height_(num_height), num_width_(num_width), slice_mode_(slice_mode), fill_value_(fill_value) {}
int32_t num_height_;
int32_t num_width_;
SliceMode slice_mode_;
uint8_t fill_value_;
};
SlicePatches::SlicePatches(int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value)
: data_(std::make_shared<Data>(num_height, num_width, slice_mode, fill_value)) {}
std::shared_ptr<TensorOperation> SlicePatches::Parse() {
return std::make_shared<SlicePatchesOperation>(data_->num_height_, data_->num_width_, data_->slice_mode_,
data_->fill_value_);
}
// SoftDvppDecodeRandomCropResizeJpeg Transform Operation.
struct SoftDvppDecodeRandomCropResizeJpeg::Data {
Data(const std::vector<int32_t> &size, const std::vector<float> &scale, const std::vector<float> &ratio,

View File

@ -125,6 +125,12 @@ enum class RelationalOp {
kGreaterEqual, ///< equal to `>=`
};
/// \brief Possible modes for slice patches.
enum class SliceMode {
kPad = 0, ///< Pad some pixels before slice to patches.
kDrop = 1, ///< Drop remainder pixels before slice to patches.
};
/// \brief Possible options for SamplingStrategy.
enum class SamplingStrategy {
kRandom = 0, ///< Random sampling with replacement.

View File

@ -823,6 +823,33 @@ class RGBA2RGB final : public TensorTransform {
std::shared_ptr<TensorOperation> Parse() override;
};
/// \note Slice the tensor to multiple patches in horizontal and vertical directions.
class SlicePatches final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] num_height The number of patches in vertical direction (default=1).
/// \param[in] num_width The number of patches in horizontal direction (default=1).
/// \param[in] slice_mode An enum for the mode of slice (default=SliceMode::kPad).
/// \param[in] fill_value A value representing the pixel to fill the padding area in right and
/// bottom border if slice_mode is kPad. Then padded tensor could be just sliced to multiple patches (default=0).
/// \note The usage scenerio is suitable to tensor with large height and width. The tensor will keep the same
/// if set both num_height and num_width to 1. And the number of output tensors is equal to num_height*num_width.
SlicePatches(int32_t num_height = 1, int32_t num_width = 1, SliceMode slice_mode = SliceMode::kPad,
uint8_t fill_value = 0);
/// \brief Destructor.
~SlicePatches() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Decode, randomly crop and resize a JPEG image using the simulation algorithm of
/// Ascend series chip DVPP module. The application scenario is consistent with SoftDvppDecodeResizeJpeg.
/// The input image size should be in range [32*32, 8192*8192].

View File

@ -52,6 +52,7 @@ add_library(kernels-image OBJECT
rgba_to_bgr_op.cc
rgba_to_rgb_op.cc
sharpness_op.cc
slice_patches_op.cc
solarize_op.cc
swap_red_blue_op.cc
uniform_aug_op.cc

View File

@ -18,7 +18,6 @@
#include <algorithm>
#include <vector>
#include <stdexcept>
#include <utility>
#include <opencv2/imgcodecs.hpp>
#include "utils/ms_utils.h"
#include "minddata/dataset/core/cv_tensor.h"
@ -29,6 +28,9 @@
#include "minddata/dataset/kernels/image/resize_cubic_op.h"
const int32_t MAX_INT_PRECISION = 16777216; // float int precision is 16777216
const int32_t DEFAULT_NUM_HEIGHT = 1;
const int32_t DEFAULT_NUM_WIDTH = 1;
namespace mindspore {
namespace dataset {
int GetCVInterpolationMode(InterpolationMode mode) {
@ -1281,5 +1283,81 @@ Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
RETURN_STATUS_UNEXPECTED("GaussianBlur: " + std::string(e.what()));
}
}
Status ComputePatchSize(const std::shared_ptr<CVTensor> &input_cv,
std::shared_ptr<std::pair<int32_t, int32_t>> *patch_size, int32_t num_height, int32_t num_width,
SliceMode slice_mode) {
if (input_cv->mat().data == nullptr) {
RETURN_STATUS_UNEXPECTED("SlicePatches: Tensor could not convert to CV Tensor.");
}
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("SlicePatches: image shape is not <H,W,C> or <H,W>.");
}
cv::Mat in_img = input_cv->mat();
cv::Size s = in_img.size();
if (num_height == 0 || num_height > s.height) {
RETURN_STATUS_UNEXPECTED("SlicePatches: The number of patches on height axis equals 0 or is greater than height.");
}
if (num_width == 0 || num_width > s.width) {
RETURN_STATUS_UNEXPECTED("SlicePatches: The number of patches on width axis equals 0 or is greater than width.");
}
int32_t patch_h = s.height / num_height;
if (s.height % num_height != 0) {
if (slice_mode == SliceMode::kPad) {
patch_h += 1; // patch_h * num_height - s.height
}
}
int32_t patch_w = s.width / num_width;
if (s.width % num_width != 0) {
if (slice_mode == SliceMode::kPad) {
patch_w += 1; // patch_w * num_width - s.width
}
}
(*patch_size)->first = patch_h;
(*patch_size)->second = patch_w;
return Status::OK();
}
Status SlicePatches(const std::shared_ptr<Tensor> &input, std::vector<std::shared_ptr<Tensor>> *output,
int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value) {
if (num_height == DEFAULT_NUM_HEIGHT && num_width == DEFAULT_NUM_WIDTH) {
(*output).push_back(input);
return Status::OK();
}
auto patch_size = std::make_shared<std::pair<int32_t, int32_t>>(0, 0);
int32_t patch_h = 0;
int32_t patch_w = 0;
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
RETURN_IF_NOT_OK(ComputePatchSize(input_cv, &patch_size, num_height, num_width, slice_mode));
std::tie(patch_h, patch_w) = *patch_size;
cv::Mat in_img = input_cv->mat();
cv::Size s = in_img.size();
try {
cv::Mat out_img;
if (slice_mode == SliceMode::kPad) { // padding on right and bottom directions
auto padding_h = patch_h * num_height - s.height;
auto padding_w = patch_w * num_width - s.width;
out_img = cv::Mat(s.height + padding_h, s.width + padding_w, in_img.type(), cv::Scalar::all(fill_value));
in_img.copyTo(out_img(cv::Rect(0, 0, s.width, s.height)));
} else {
out_img = in_img;
}
for (int i = 0; i < num_height; ++i) {
for (int j = 0; j < num_width; ++j) {
std::shared_ptr<CVTensor> patch_cv;
cv::Rect patch(j * patch_w, i * patch_h, patch_w, patch_h);
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_img(patch), &patch_cv));
(*output).push_back(std::static_pointer_cast<Tensor>(patch_cv));
}
}
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("SlicePatches: " + std::string(e.what()));
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#if defined(_WIN32) || defined(_WIN64)
#undef HAVE_STDDEF_H
@ -338,6 +339,25 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t kernel_size_x,
int32_t kernel_size_y, float sigma_x, float sigma_y);
/// \brief Slice tensor to multiple patches.
/// \param[in] input Input Tensor
/// \param[out] output Vector of Output Tensor
/// \param[in] num_height Number of patches in vertical direction.
/// \param[in] num_width Number of patches in horizontal direction.
/// \param[in] slice_mode Mode represents padding or drop.
/// \param[in] fill_value The value of filled pixel in right and bottom border when padding.
Status SlicePatches(const std::shared_ptr<Tensor> &input, std::vector<std::shared_ptr<Tensor>> *output,
int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value);
/// \brief Compute patch height and width.
/// \param[in] input Input CVTensor
/// \param[out] patch_size Size of patch
/// \param[in] num_height Number of patches in vertical direction.
/// \param[in] num_width Number of patches in horizontal direction.
/// \param[in] slice_mode Mode represents padding or drop.
Status ComputePatchSize(const std::shared_ptr<CVTensor> &input_cv,
std::shared_ptr<std::pair<int32_t, int32_t>> *patch_size, int32_t num_height, int32_t num_width,
SliceMode slice_mode);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_

View File

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/slice_patches_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
const int32_t SlicePatchesOp::kDefNumH = 1;
const int32_t SlicePatchesOp::kDefNumW = 1;
const uint8_t SlicePatchesOp::kDefFillV = 0;
const SliceMode SlicePatchesOp::kDefSliceMode = SliceMode::kPad;
SlicePatchesOp::SlicePatchesOp(int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value)
: num_height_(num_height), num_width_(num_width), slice_mode_(slice_mode), fill_value_(fill_value) {}
Status SlicePatchesOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input tensor size should be 1.");
auto in_tensor = input[0];
auto in_type = in_tensor->type();
auto in_shape = in_tensor->shape();
CHECK_FAIL_RETURN_UNEXPECTED(in_type.IsNumeric(), "Input Tensor type should be numeric.");
CHECK_FAIL_RETURN_UNEXPECTED(in_shape.Rank() >= 2, "Input Tensor rank should be greater than 2.");
std::vector<std::shared_ptr<Tensor>> out;
RETURN_IF_NOT_OK(SlicePatches(in_tensor, &out, num_height_, num_width_, slice_mode_, fill_value_));
(void)std::copy(out.begin(), out.end(), std::back_inserter(*output));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SLICE_PATCHES_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SLICE_PATCHES_OP_H_
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class SlicePatchesOp : public TensorOp {
public:
// Default values, also used by python_bindings.cc
static const int32_t kDefNumH;
static const int32_t kDefNumW;
static const uint8_t kDefFillV;
static const SliceMode kDefSliceMode;
SlicePatchesOp(int32_t num_height = kDefNumH, int32_t num_width = kDefNumW, SliceMode slice_mode = kDefSliceMode,
uint8_t fill_value = kDefFillV);
~SlicePatchesOp() override = default;
void Print(std::ostream &out) const override {
out << Name() << " patches number on height: " << num_height_ << ", patches number on width: " << num_width_;
}
Status Compute(const TensorRow &input, TensorRow *output) override;
std::string Name() const override { return kSlicePatchesOp; }
protected:
int32_t num_height_; // number of patches on height axis
int32_t num_width_; // number of patches on width axis
SliceMode slice_mode_; // PadModel, DropModel
uint8_t fill_value_; // border width in number of pixels in right and bottom direction
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SLICE_PATCHES_OP_H_

View File

@ -47,6 +47,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
rgba_to_bgr_ir.cc
rgba_to_rgb_ir.cc
rotate_ir.cc
slice_patches_ir.cc
softdvpp_decode_random_crop_resize_jpeg_ir.cc
softdvpp_decode_resize_jpeg_ir.cc
swap_red_blue_ir.cc

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
#include "minddata/dataset/kernels/image/slice_patches_op.h"
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
// SlicePatchesOperation
SlicePatchesOperation::SlicePatchesOperation(int32_t num_height, int32_t num_width, SliceMode slice_mode,
uint8_t fill_value)
: TensorOperation(),
num_height_(num_height),
num_width_(num_width),
slice_mode_(slice_mode),
fill_value_(fill_value) {}
SlicePatchesOperation::~SlicePatchesOperation() = default;
std::string SlicePatchesOperation::Name() const { return kSlicePatchesOperation; }
Status SlicePatchesOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateIntScalarPositive("SlicePatches", "num_height", num_height_));
RETURN_IF_NOT_OK(ValidateIntScalarPositive("SlicePatches", "num_width", num_width_));
return Status::OK();
}
std::shared_ptr<TensorOp> SlicePatchesOperation::Build() {
auto tensor_op = std::make_shared<SlicePatchesOp>(num_height_, num_width_, slice_mode_, fill_value_);
return tensor_op;
}
Status SlicePatchesOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_height"] = num_height_;
args["num_width"] = num_width_;
args["slice_mode"] = slice_mode_;
args["fill_value"] = fill_value_;
*out_json = args;
return Status::OK();
}
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SLICE_PATCHES_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SLICE_PATCHES_IR_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kSlicePatchesOperation[] = "SlicePatches";
class SlicePatchesOperation : public TensorOperation {
public:
SlicePatchesOperation(int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value);
~SlicePatchesOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
private:
int32_t num_height_;
int32_t num_width_;
SliceMode slice_mode_;
uint8_t fill_value_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_SLICE_PATCHES_IR_H_

View File

@ -105,6 +105,7 @@ constexpr char kRgbToBgrOp[] = "RgbToBgrOp";
constexpr char kRgbToGrayOp[] = "RgbToGrayOp";
constexpr char kRotateOp[] = "RotateOp";
constexpr char kSharpnessOp[] = "SharpnessOp";
constexpr char kSlicePatchesOp[] = "SlicePatchesOp";
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp";
constexpr char kSolarizeOp[] = "SolarizeOp";

View File

@ -47,14 +47,14 @@ import numpy as np
from PIL import Image
import mindspore._c_dataengine as cde
from .utils import Inter, Border, ImageBatchFormat
from .utils import Inter, Border, ImageBatchFormat, SliceMode
from .validators import check_prob, check_crop, check_center_crop, check_resize_interpolation, check_random_resize_crop, \
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \
check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches
from ..transforms.c_transforms import TensorOperation
@ -87,6 +87,8 @@ DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBO
Inter.AREA: cde.InterpolationMode.DE_INTER_AREA,
Inter.PILCUBIC: cde.InterpolationMode.DE_INTER_PILCUBIC}
DE_C_SLICE_MODE = {SliceMode.PAD: cde.SliceMode.DE_SLICE_PAD,
SliceMode.DROP: cde.SliceMode.DE_SLICE_DROP}
def parse_padding(padding):
""" Parses and prepares the padding tuple"""
@ -1541,6 +1543,43 @@ class Rotate(ImageTensorOperation):
return cde.RotateOperation(degrees, interpolation, expand, center, fill_value)
class SlicePatches(ImageTensorOperation):
"""
Slice Tensor to multiple patches in horizontal and vertical directions.
The usage scenerio is suitable to large height and width Tensor. The Tensor
will keep the same if set both num_height and num_width to 1. And the
number of output tensors is equal to num_height*num_width.
Args:
num_height (int, optional): The number of patches in vertical direction (default=1).
num_height (int, optional): The number of patches in horizontal direction (default=1).
slice_mode (Inter mode, optional): An mode represents pad or drop (default=SliceMode.PAD).
It can be any of [SliceMode.PAD, SliceMode.DROP].
fill_value (int, optional): The border width in number of pixels in
right and bottom direction if slice_mode is set to be SliceMode.PAD (default=0).
Examples:
>>> # default padding mode
>>> slice_patches_op = c_vision.SlicePatches(num_h, num_w)
>>> cols = ['img' + str(x) for x in range(num_h*num_w)]
>>> dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
>>> dataset1 = dataset1.map(operations=resize_op, input_columns=["image"])
>>> dataset1 = dataset1.map(operations=slice_patches_op, input_columns=[
... "image"], output_columns=cols, column_order=cols)
"""
@check_slice_patches
def __init__(self, num_height=1, num_width=1, slice_mode=SliceMode.PAD, fill_value=0):
self.num_height = num_height
self.num_width = num_width
self.slice_mode = slice_mode
self.fill_value = fill_value
def parse(self):
return cde.SlicePatchesOperation(self.num_height, self.num_width,
DE_C_SLICE_MODE[self.slice_mode], self.fill_value)
class SoftDvppDecodeRandomCropResizeJpeg(ImageTensorOperation):
"""
A combination of `Crop`, `Decode` and `Resize` using the simulation algorithm of Ascend series chip DVPP module.

View File

@ -42,3 +42,7 @@ class ImageBatchFormat(IntEnum):
"""Image Batch Format"""
NHWC = 0
NCHW = 1
class SliceMode(IntEnum):
PAD = 0
DROP = 1

View File

@ -22,7 +22,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff, check_value_ratio, check_odd
from .utils import Inter, Border, ImageBatchFormat
from .utils import Inter, Border, ImageBatchFormat, SliceMode
def check_crop_size(size):
@ -513,6 +513,27 @@ def check_pad(method):
return new_method
def check_slice_patches(method):
"""Wrapper method to check the parameters of slice patches."""
@wraps(method)
def new_method(self, *args, **kwargs):
[num_height, num_width, slice_mode, fill_value], _ = parse_user_args(method, *args, **kwargs)
if num_height is not None:
type_check(num_height, (int,), "num_height")
check_value(num_height, (1, INT32_MAX), "num_height")
if num_width is not None:
type_check(num_width, (int,), "num_width")
check_value(num_width, (1, INT32_MAX), "num_width")
if slice_mode is not None:
type_check(slice_mode, (SliceMode,), "slice_mode")
if fill_value is not None:
type_check(fill_value, (int,), "fill_value")
check_value(fill_value, [0, 255], "fill_value")
return method(self, *args, **kwargs)
return new_method
def check_random_perspective(method):
"""Wrapper method to check the parameters of random perspective."""

View File

@ -43,6 +43,7 @@ SET(DE_UT_SRCS
c_api_vision_random_subselect_policy_test.cc
c_api_vision_random_test.cc
c_api_vision_r_to_z_test.cc
c_api_vision_slice_patches_test.cc
c_api_vision_soft_dvpp_test.cc
c_api_vision_uniform_aug_test.cc
c_api_vision_vertical_flip_test.cc

View File

@ -0,0 +1,117 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/execute.h"
#include "minddata/dataset/include/dataset/vision.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
class MindDataTestSlicePatches : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestSlicePatches, TestSlicePacthesParamCheck) {
MS_LOG(INFO) << "Doing TestSlicePatchesParamCheck with invalid parameters.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Case 1: num_height is not positive
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> slice_patches_1(new vision::SlicePatches(-1));
auto ds1 = ds->Map({slice_patches_1});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid num_height for SlicePatches
EXPECT_EQ(iter1, nullptr);
// Case 2: num_width is not positive
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> slice_patches_2(new vision::SlicePatches(1, 0));
auto ds2 = ds->Map({slice_patches_2});
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
// Expect failure: invalid num_height for SlicePatches
EXPECT_EQ(iter2, nullptr);
}
TEST_F(MindDataTestSlicePatches, TestSlicePatchesPipeline) {
MS_LOG(INFO) << "Doing TestGaussianBlurPipeline.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> slice_patches(new vision::SlicePatches(2, 2));
// Create a Map operation on ds
ds = ds->Map({slice_patches}, {"image"}, {"img0", "img1", "img2", "img3"}, {"img0", "img1", "img2", "img3"});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (!row.empty()) {
i++;
ASSERT_EQ(row.size(), 4);
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestSlicePatches, TestSlicePatchesEager) {
MS_LOG(INFO) << "Doing TestGaussianBlurEager.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
std::vector<mindspore::MSTensor> input{image};
std::vector<mindspore::MSTensor> output;
// Transform params
auto decode = vision::Decode();
auto slice_patches = vision::SlicePatches(2, 2);
auto transform = Execute({decode, slice_patches});
Status rc = transform(input, &output);
EXPECT_EQ(rc, Status::OK());
}

View File

@ -0,0 +1,177 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing SlicePatches Python API
"""
import functools
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.utils as mode
from mindspore import log as logger
from util import diff_mse, visualize_list
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_slice_patches_01(plot=False):
"""
slice rgb image(100, 200) to 4 patches
"""
slice_to_patches([100, 200], 2, 2, True, plot=plot)
def test_slice_patches_02(plot=False):
"""
no op
"""
slice_to_patches([100, 200], 1, 1, True, plot=plot)
def test_slice_patches_03(plot=False):
"""
slice rgb image(99, 199) to 4 patches in pad mode
"""
slice_to_patches([99, 199], 2, 2, True, plot=plot)
def test_slice_patches_04(plot=False):
"""
slice rgb image(99, 199) to 4 patches in drop mode
"""
slice_to_patches([99, 199], 2, 2, False, plot=plot)
def test_slice_patches_05(plot=False):
"""
slice rgb image(99, 199) to 4 patches in pad mode
"""
slice_to_patches([99, 199], 2, 2, True, 255, plot=plot)
def slice_to_patches(ori_size, num_h, num_w, pad_or_drop, fill_value=0, plot=False):
"""
Tool function for slice patches
"""
logger.info("test_slice_patches_pipeline")
cols = ['img' + str(x) for x in range(num_h*num_w)]
# First dataset
dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
decode_op = c_vision.Decode()
resize_op = c_vision.Resize(ori_size) # H, W
slice_patches_op = c_vision.SlicePatches(
num_h, num_w, mode.SliceMode.PAD, fill_value)
if not pad_or_drop:
slice_patches_op = c_vision.SlicePatches(
num_h, num_w, mode.SliceMode.DROP)
dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
dataset1 = dataset1.map(operations=resize_op, input_columns=["image"])
dataset1 = dataset1.map(operations=slice_patches_op,
input_columns=["image"], output_columns=cols, column_order=cols)
# Second dataset
dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
dataset2 = dataset2.map(operations=resize_op, input_columns=["image"])
func_slice_patches = functools.partial(
slice_patches, num_h=num_h, num_w=num_w, pad_or_drop=pad_or_drop, fill_value=fill_value)
dataset2 = dataset2.map(operations=func_slice_patches,
input_columns=["image"], output_columns=cols, column_order=cols)
num_iter = 0
patches_c = []
patches_py = []
for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
for x in range(num_h*num_w):
col = "img" + str(x)
mse = diff_mse(data1[col], data2[col])
logger.info("slice_patches_{}, mse: {}".format(num_iter + 1, mse))
assert mse == 0
patches_c.append(data1[col])
patches_py.append(data2[col])
num_iter += 1
if plot:
visualize_list(patches_py, patches_c)
def test_slice_patches_exception_01():
"""
Test SlicePatches with invalid parameters
"""
logger.info("test_Slice_Patches_exception")
try:
_ = c_vision.SlicePatches(0, 2)
except ValueError as e:
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
assert "Input num_height is not within" in str(e)
try:
_ = c_vision.SlicePatches(2, 0)
except ValueError as e:
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
assert "Input num_width is not within" in str(e)
try:
_ = c_vision.SlicePatches(2, 2, 1)
except TypeError as e:
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
assert "Argument slice_mode with value" in str(e)
try:
_ = c_vision.SlicePatches(2, 2, mode.SliceMode.PAD, -1)
except ValueError as e:
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
assert "Input fill_value is not within" in str(e)
def slice_patches(image, num_h, num_w, pad_or_drop, fill_value):
""" help function which slice patches with numpy """
if num_h == 1 and num_w == 1:
return image
# (H, W, C)
H, W, C = image.shape
patch_h = H // num_h
patch_w = W // num_w
if H % num_h != 0:
if pad_or_drop:
patch_h += 1
if W % num_w != 0:
if pad_or_drop:
patch_w += 1
img = image[:, :, :]
if pad_or_drop:
img = np.full([patch_h*num_h, patch_w*num_w, C], fill_value, dtype=np.uint8)
img[:H, :W] = image[:, :, :]
patches = []
for top in range(num_h):
for left in range(num_w):
patches.append(img[top*patch_h:(top+1)*patch_h,
left*patch_w:(left+1)*patch_w, :])
return (*patches,)
if __name__ == "__main__":
test_slice_patches_01(plot=True)
test_slice_patches_02(plot=True)
test_slice_patches_03(plot=True)
test_slice_patches_04(plot=True)
test_slice_patches_05(plot=True)
test_slice_patches_exception_01()