!12480 Adding Affine API

From: @ezphlow
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-26 21:35:05 +08:00 committed by Gitee
commit 5556b12da4
22 changed files with 595 additions and 187 deletions

View File

@ -84,8 +84,8 @@ PYBIND_REGISTER(CutOutOperation, 1, ([](const py::module *m) {
PYBIND_REGISTER(DecodeOperation, 1, ([](const py::module *m) { PYBIND_REGISTER(DecodeOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::DecodeOperation, TensorOperation, std::shared_ptr<vision::DecodeOperation>>( (void)py::class_<vision::DecodeOperation, TensorOperation, std::shared_ptr<vision::DecodeOperation>>(
*m, "DecodeOperation") *m, "DecodeOperation")
.def(py::init([]() { .def(py::init([](bool rgb) {
auto decode = std::make_shared<vision::DecodeOperation>(); auto decode = std::make_shared<vision::DecodeOperation>(rgb);
THROW_IF_ERROR(decode->ValidateParams()); THROW_IF_ERROR(decode->ValidateParams());
return decode; return decode;
})) }))

View File

@ -42,6 +42,19 @@ namespace vision {
// FUNCTIONS TO CREATE VISION TRANSFORM OPERATIONS // FUNCTIONS TO CREATE VISION TRANSFORM OPERATIONS
// (In alphabetical order) // (In alphabetical order)
Affine::Affine(float_t degrees, const std::vector<float> &translation, float scale, const std::vector<float> &shear,
InterpolationMode interpolation, const std::vector<uint8_t> &fill_value)
: degrees_(degrees),
translation_(translation),
scale_(scale),
shear_(shear),
interpolation_(interpolation),
fill_value_(fill_value) {}
std::shared_ptr<TensorOperation> Affine::Parse() {
return std::make_shared<AffineOperation>(degrees_, translation_, scale_, shear_, interpolation_, fill_value_);
}
// AutoContrast Transform Operation. // AutoContrast Transform Operation.
AutoContrast::AutoContrast(float cutoff, std::vector<uint32_t> ignore) : cutoff_(cutoff), ignore_(ignore) {} AutoContrast::AutoContrast(float cutoff, std::vector<uint32_t> ignore) : cutoff_(cutoff), ignore_(ignore) {}

View File

@ -35,7 +35,6 @@ class TensorOperation;
// Transform operations for performing computer vision. // Transform operations for performing computer vision.
namespace vision { namespace vision {
/// \brief AutoContrast TensorTransform. /// \brief AutoContrast TensorTransform.
/// \notes Apply automatic contrast on input image. /// \notes Apply automatic contrast on input image.
class AutoContrast : public TensorTransform { class AutoContrast : public TensorTransform {
@ -253,48 +252,6 @@ class Pad : public TensorTransform {
BorderType padding_mode_; BorderType padding_mode_;
}; };
/// \brief RandomAffine TensorTransform.
/// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode.
class RandomAffine : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] degrees A float vector of size 2, representing the starting and ending degree
/// \param[in] translate_range A float vector of size 2 or 4, representing percentages of translation on x and y axes.
/// if size is 2, (min_dx, max_dx, 0, 0)
/// if size is 4, (min_dx, max_dx, min_dy, max_dy)
/// all values are in range [-1, 1]
/// \param[in] scale_range A float vector of size 2, representing the starting and ending scales in the range.
/// \param[in] shear_ranges A float vector of size 2 or 4, representing the starting and ending shear degrees
/// vertically and horizontally.
/// if size is 2, (min_shear_x, max_shear_x, 0, 0)
/// if size is 4, (min_shear_x, max_shear_x, min_shear_y, max_shear_y)
/// \param[in] interpolation An enum for the mode of interpolation
/// \param[in] fill_value A vector representing the value to fill the area outside the transform
/// in the output image. If 1 value is provided, it is used for all RGB channels.
/// If 3 values are provided, it is used to fill R, G, B channels respectively.
explicit RandomAffine(const std::vector<float_t> &degrees,
const std::vector<float_t> &translate_range = {0.0, 0.0, 0.0, 0.0},
const std::vector<float_t> &scale_range = {1.0, 1.0},
const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0},
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
const std::vector<uint8_t> &fill_value = {0, 0, 0});
/// \brief Destructor.
~RandomAffine() = default;
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
std::vector<float_t> degrees_; // min_degree, max_degree
std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
std::vector<float_t> scale_range_; // min_scale, max_scale
std::vector<float_t> shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear
InterpolationMode interpolation_;
std::vector<uint8_t> fill_value_;
};
/// \brief Blends an image with its grayscale version with random weights /// \brief Blends an image with its grayscale version with random weights
/// t and 1 - t generated from a given range. If the range is trivial /// t and 1 - t generated from a given range. If the range is trivial
/// then the weights are determinate and t equals the bound of the interval /// then the weights are determinate and t equals the bound of the interval

View File

@ -35,6 +35,41 @@ namespace vision {
// Forward Declarations // Forward Declarations
class RotateOperation; class RotateOperation;
/// \brief Affine TensorTransform.
/// \notes Apply affine transform on input image.
class Affine : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] degrees The degrees to rotate the image by
/// \param[in] translation The value representing vertical and horizontal translation (default = {0.0, 0.0})
/// The first value represent the x axis translation while the second represents y axis translation.
/// \param[in] scale The scaling factor for the image (default = 0.0)
/// \param[in] shear A float vector of size 2, representing the shear degrees (default = {0.0, 0.0})
/// \param[in] interpolation An enum for the mode of interpolation
/// \param[in] fill_value A vector representing the value to fill the area outside the transform
/// in the output image. If 1 value is provided, it is used for all RGB channels.
/// If 3 values are provided, it is used to fill R, G, B channels respectively.
explicit Affine(float_t degrees, const std::vector<float> &translation = {0.0, 0.0}, float scale = 0.0,
const std::vector<float> &shear = {0.0, 0.0},
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
const std::vector<uint8_t> &fill_value = {0, 0, 0});
/// \brief Destructor.
~Affine() = default;
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
float degrees_;
std::vector<float> translation_;
float scale_;
std::vector<float> shear_;
InterpolationMode interpolation_;
std::vector<uint8_t> fill_value_;
};
/// \brief CenterCrop TensorTransform. /// \brief CenterCrop TensorTransform.
/// \notes Crops the input image at the center to the given size. /// \notes Crops the input image at the center to the given size.
class CenterCrop : public TensorTransform { class CenterCrop : public TensorTransform {
@ -125,6 +160,48 @@ class Normalize : public TensorTransform {
std::vector<float> std_; std::vector<float> std_;
}; };
/// \brief RandomAffine TensorTransform.
/// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode.
class RandomAffine : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] degrees A float vector of size 2, representing the starting and ending degree
/// \param[in] translate_range A float vector of size 2 or 4, representing percentages of translation on x and y axes.
/// if size is 2, (min_dx, max_dx, 0, 0)
/// if size is 4, (min_dx, max_dx, min_dy, max_dy)
/// all values are in range [-1, 1]
/// \param[in] scale_range A float vector of size 2, representing the starting and ending scales in the range.
/// \param[in] shear_ranges A float vector of size 2 or 4, representing the starting and ending shear degrees
/// vertically and horizontally.
/// if size is 2, (min_shear_x, max_shear_x, 0, 0)
/// if size is 4, (min_shear_x, max_shear_x, min_shear_y, max_shear_y)
/// \param[in] interpolation An enum for the mode of interpolation
/// \param[in] fill_value A vector representing the value to fill the area outside the transform
/// in the output image. If 1 value is provided, it is used for all RGB channels.
/// If 3 values are provided, it is used to fill R, G, B channels respectively.
explicit RandomAffine(const std::vector<float_t> &degrees,
const std::vector<float_t> &translate_range = {0.0, 0.0, 0.0, 0.0},
const std::vector<float_t> &scale_range = {1.0, 1.0},
const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0},
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
const std::vector<uint8_t> &fill_value = {0, 0, 0});
/// \brief Destructor.
~RandomAffine() = default;
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
std::vector<float_t> degrees_; // min_degree, max_degree
std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
std::vector<float_t> scale_range_; // min_scale, max_scale
std::vector<float_t> shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear
InterpolationMode interpolation_;
std::vector<uint8_t> fill_value_;
};
/// \brief Resize TensorTransform. /// \brief Resize TensorTransform.
/// \notes Resize the input image to the given size. /// \notes Resize the input image to the given size.
class Resize : public TensorTransform { class Resize : public TensorTransform {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,7 +19,11 @@
#include <vector> #include <vector>
#include "minddata/dataset/kernels/image/affine_op.h" #include "minddata/dataset/kernels/image/affine_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif
#include "minddata/dataset/kernels/image/math_utils.h" #include "minddata/dataset/kernels/image/math_utils.h"
#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/random.h"
@ -45,59 +49,46 @@ AffineOp::AffineOp(float_t degrees, const std::vector<float_t> &translation, flo
Status AffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { Status AffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output); IO_CHECK(input, output);
try { float_t translation_x = translation_[0];
float_t translation_x = translation_[0]; float_t translation_y = translation_[1];
float_t translation_y = translation_[1]; float_t degrees = 0.0;
float_t degrees = 0.0; DegreesToRadians(degrees_, &degrees);
DegreesToRadians(degrees_, &degrees); float_t shear_x = shear_[0];
float_t shear_x = shear_[0]; float_t shear_y = shear_[1];
float_t shear_y = shear_[1]; DegreesToRadians(shear_x, &shear_x);
DegreesToRadians(shear_x, &shear_x); DegreesToRadians(-1 * shear_y, &shear_y);
DegreesToRadians(-1 * shear_y, &shear_y);
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
// Apply Affine Transformation // Apply Affine Transformation
// T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] // T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
// C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] // C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
// RSS is rotation with scale and shear matrix // RSS is rotation with scale and shear matrix
// RSS(a, s, (sx, sy)) = // RSS(a, s, (sx, sy)) =
// = R(a) * S(s) * SHy(sy) * SHx(sx) // = R(a) * S(s) * SHy(sy) * SHx(sx)
// = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ] // = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
// [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ] // [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
// [ 0 , 0 , 1 ] // [ 0 , 0 , 1 ]
// //
// where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: // where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
// SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] // SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
// [0, 1 ] [-tan(s), 1] // [0, 1 ] [-tan(s), 1]
// //
// Thus, the affine matrix is M = T * C * RSS * C^-1 // Thus, the affine matrix is M = T * C * RSS * C^-1
float_t cx = ((input_cv->mat().cols - 1) / 2.0); // image is hwc, rows = shape()[0]
float_t cy = ((input_cv->mat().rows - 1) / 2.0); float_t cx = ((input->shape()[1] - 1) / 2.0);
// Calculate RSS float_t cy = ((input->shape()[0] - 1) / 2.0);
std::vector<float_t> matrix{ // Calculate RSS
static_cast<float>(scale_ * cos(degrees + shear_y) / cos(shear_y)), std::vector<float_t> matrix{
static_cast<float>(scale_ * (-1 * cos(degrees + shear_y) * tan(shear_x) / cos(shear_y) - sin(degrees))), static_cast<float>(scale_ * cos(degrees + shear_y) / cos(shear_y)),
0, static_cast<float>(scale_ * (-1 * cos(degrees + shear_y) * tan(shear_x) / cos(shear_y) - sin(degrees))),
static_cast<float>(scale_ * sin(degrees + shear_y) / cos(shear_y)), 0,
static_cast<float>(scale_ * (-1 * sin(degrees + shear_y) * tan(shear_x) / cos(shear_y) + cos(degrees))), static_cast<float>(scale_ * sin(degrees + shear_y) / cos(shear_y)),
0}; static_cast<float>(scale_ * (-1 * sin(degrees + shear_y) * tan(shear_x) / cos(shear_y) + cos(degrees))),
// Compute T * C * RSS * C^-1 0};
matrix[2] = (1 - matrix[0]) * cx - matrix[1] * cy + translation_x; // Compute T * C * RSS * C^-1
matrix[5] = (1 - matrix[4]) * cy - matrix[3] * cx + translation_y; matrix[2] = (1 - matrix[0]) * cx - matrix[1] * cy + translation_x;
cv::Mat affine_mat(matrix); matrix[5] = (1 - matrix[4]) * cy - matrix[3] * cx + translation_y;
affine_mat = affine_mat.reshape(1, {2, 3}); RETURN_IF_NOT_OK(Affine(input, output, matrix, interpolation_, fill_value_[0], fill_value_[1], fill_value_[2]));
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);
cv::warpAffine(input_cv->mat(), output_cv->mat(), affine_mat, input_cv->mat().size(),
GetCVInterpolationMode(interpolation_), cv::BORDER_CONSTANT,
cv::Scalar(fill_value_[0], fill_value_[1], fill_value_[2]));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Affine: " + std::string(e.what()));
}
return Status::OK(); return Status::OK();
} }
} // namespace dataset } // namespace dataset

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -21,7 +21,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
@ -49,10 +48,6 @@ class AffineOp : public TensorOp {
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
/// Member variables
private:
std::string kAffineOp = "AffineOp";
protected: protected:
float_t degrees_; float_t degrees_;
std::vector<float_t> translation_; // translation_x and translation_y std::vector<float_t> translation_; // translation_x and translation_y

View File

@ -1101,5 +1101,25 @@ Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, in
jpeg_destroy_decompress(&cinfo); jpeg_destroy_decompress(&cinfo);
return Status::OK(); return Status::OK();
} }
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
cv::Mat affine_mat(mat);
affine_mat = affine_mat.reshape(1, {2, 3});
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);
cv::warpAffine(input_cv->mat(), output_cv->mat(), affine_mat, input_cv->mat().size(),
GetCVInterpolationMode(interpolation), cv::BORDER_CONSTANT, cv::Scalar(fill_r, fill_g, fill_b));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Affine: " + std::string(e.what()));
}
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -299,6 +299,17 @@ Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
/// \param img_height: the jpeg image height /// \param img_height: the jpeg image height
Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height); Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height);
/// \brief Geometrically transform the input image
/// \param[in] input Input Tensor
/// \param[out] output Transformed Tensor
/// \param[in] mat The transformation matrix
/// \param[in] interpolation The interpolation mode
/// \param[in] fill_r Red fill value for pad
/// \param[in] fill_g Green fill value for pad
/// \param[in] fill_b Blue fill value for pad
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_

View File

@ -621,5 +621,46 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
return RotateAngleWithMirror(input, output, orientation); return RotateAngleWithMirror(input, output, orientation);
} }
} }
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) {
try {
if (interpolation != InterpolationMode::kLinear) {
MS_LOG(WARNING) << "Only Bilinear interpolation supported for now";
}
int height = 0;
int width = 0;
double M[6] = {};
for (int i = 0; i < mat.size(); i++) {
M[i] = static_cast<double>(mat[i]);
}
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2],
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
GetLiteCVDataType(input->type()));
height = lite_mat_rgb.height_;
width = lite_mat_rgb.width_;
std::vector<size_t> dsize;
dsize.push_back(width);
dsize.push_back(height);
LiteMat lite_mat_affine;
std::shared_ptr<Tensor> output_tensor;
TensorShape new_shape = TensorShape({height, width, input->shape()[2]});
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), &output_tensor));
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
lite_mat_affine.Init(width, height, lite_mat_rgb.channel_, reinterpret_cast<void *>(buffer),
GetLiteCVDataType(input->type()));
bool ret = Affine(lite_mat_rgb, lite_mat_affine, M, dsize, UINT8_C3(fill_r, fill_g, fill_b));
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Affine: affine failed.");
*output = output_tensor;
return Status::OK();
} catch (std::runtime_error &e) {
RETURN_STATUS_UNEXPECTED("Affine: " + std::string(e.what()));
}
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -52,70 +52,81 @@ Status JpegCropAndDecode(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
int w = 0, int h = 0); int w = 0, int h = 0);
/// \brief Returns cropped ROI of an image /// \brief Returns cropped ROI of an image
/// \param input: Tensor of shape <H,W,C> or <H,W> and any OpenCv compatible type, see CVTensor. /// \param[in] input: Tensor of shape <H,W,C> or <H,W> and any OpenCv compatible type, see CVTensor.
/// \param x: starting horizontal position of ROI /// \param[in] x Starting horizontal position of ROI
/// \param y: starting vertical position of ROI /// \param[in] y Starting vertical position of ROI
/// \param w: width of the ROI /// \param[in] w Width of the ROI
/// \param h: height of the ROI /// \param[in] h Height of the ROI
/// \param output: Cropped image Tensor of shape <h,w,C> or <h,w> and same input type. /// \param[out] output: Cropped image Tensor of shape <h,w,C> or <h,w> and same input type.
Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int x, int y, int w, int h); Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int x, int y, int w, int h);
/// \brief Returns Decoded image /// \brief Returns Decoded image
/// Supported images: /// Supported images:
/// BMP JPEG JPG PNG TIFF /// BMP JPEG JPG PNG TIFF
/// supported by opencv, if user need more image analysis capabilities, please compile opencv particularlly. /// supported by opencv, if user need more image analysis capabilities, please compile opencv particularlly.
/// \param input: CVTensor containing the not decoded image 1D bytes /// \param[in] input CVTensor containing the not decoded image 1D bytes
/// \param output: Decoded image Tensor of shape <H,W,C> and type DE_UINT8. Pixel order is RGB /// \param[out] output Decoded image Tensor of shape <H,W,C> and type DE_UINT8. Pixel order is RGB
Status Decode(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output); Status Decode(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
/// \brief Get jpeg image width and height /// \brief Get jpeg image width and height
/// \param input: CVTensor containing the not decoded image 1D bytes /// \param[in] input CVTensor containing the not decoded image 1D bytes
/// \param img_width: the jpeg image width /// \param[in] img_width The jpeg image width
/// \param img_height: the jpeg image height /// \param[in] img_height The jpeg image height
Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height); Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height);
/// \brief Returns Normalized image /// \brief Returns Normalized image
/// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor. /// \param[in] input Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor.
/// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order /// \param[in] mean Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
/// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order /// \param[in] std Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
/// \param output: Normalized image Tensor of same input shape and type DE_FLOAT32 /// \param[out] output Normalized image Tensor of same input shape and type DE_FLOAT32
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std); const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std);
/// \brief Returns Resized image. /// \brief Returns Resized image.
/// \param input/output: Tensor of shape <H,W,C> or <H,W> and any OpenCv compatible type, see CVTensor. /// \param[in] input
/// \param output_height: height of output /// \param[in] output_height Height of output
/// \param output_width: width of output /// \param[in] output_width Width of output
/// \param fx: horizontal scale /// \param[in] fx Horizontal scale
/// \param fy: vertical scale /// \param[in] fy Vertical scale
/// \param InterpolationMode: the interpolation mode /// \param[in] InterpolationMode The interpolation mode
/// \param output: Resized image of shape <outputHeight,outputWidth,C> or <outputHeight,outputWidth> /// \param[out] output Resized image of shape <outputHeight,outputWidth,C> or <outputHeight,outputWidth>
/// and same type as input /// and same type as input
Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t output_height, Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t output_height,
int32_t output_width, double fx = 0.0, double fy = 0.0, int32_t output_width, double fx = 0.0, double fy = 0.0,
InterpolationMode mode = InterpolationMode::kLinear); InterpolationMode mode = InterpolationMode::kLinear);
/// \brief Pads the input image and puts the padded image in the output /// \brief Pads the input image and puts the padded image in the output
/// \param input: input Tensor /// \param[in] input: input Tensor
/// \param output: padded Tensor /// \param[out] output: padded Tensor
/// \param pad_top: amount of padding done in top /// \param[in] pad_top Amount of padding done in top
/// \param pad_bottom: amount of padding done in bottom /// \param[in] pad_bottom Amount of padding done in bottom
/// \param pad_left: amount of padding done in left /// \param[in] pad_left Amount of padding done in left
/// \param pad_right: amount of padding done in right /// \param[in] pad_right Amount of padding done in right
/// \param border_types: the interpolation to be done in the border /// \param[in] border_types The interpolation to be done in the border
/// \param fill_r: red fill value for pad /// \param[in] fill_r Red fill value for pad
/// \param fill_g: green fill value for pad /// \param[in] fill_g Green fill value for pad
/// \param fill_b: blue fill value for pad. /// \param[in] fill_b Blue fill value for pad
Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const int32_t &pad_top, Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const int32_t &pad_top,
const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types,
uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
/// \brief Rotate the input image by orientation /// \brief Rotate the input image by orientation
/// \param input: input Tensor /// \param[in] input Input Tensor
/// \param output: padded Tensor /// \param[out] output Rotated Tensor
/// \param orientation: the orientation of EXIF /// \param[in] orientation The orientation of EXIF
Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const uint64_t orientation); Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const uint64_t orientation);
/// \brief Geometrically transform the input image
/// \param[in] input Input Tensor
/// \param[out] output Transformed Tensor
/// \param[in] mat The transformation matrix
/// \param[in] interpolation The interpolation mode, support only bilinear for now
/// \param[in] fill_r Red fill value for pad
/// \param[in] fill_g Green fill value for pad
/// \param[in] fill_b Blue fill value for pad
Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::vector<float_t> &mat,
InterpolationMode interpolation, uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_

View File

@ -16,8 +16,6 @@
#include "minddata/dataset/kernels/image/math_utils.h" #include "minddata/dataset/kernels/image/math_utils.h"
#include <opencv2/imgproc/types_c.h>
#include <algorithm> #include <algorithm>
#include <string> #include <string>

View File

@ -21,6 +21,8 @@
#include <vector> #include <vector>
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#define CV_PI 3.1415926535897932384626433832795
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {

View File

@ -19,7 +19,11 @@
#include <vector> #include <vector>
#include "minddata/dataset/kernels/image/random_affine_op.h" #include "minddata/dataset/kernels/image/random_affine_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif
#include "minddata/dataset/kernels/image/math_utils.h" #include "minddata/dataset/kernels/image/math_utils.h"
#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/random.h"

View File

@ -21,7 +21,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/image/affine_op.h" #include "minddata/dataset/kernels/image/affine_op.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
@ -51,7 +50,6 @@ class RandomAffineOp : public AffineOp {
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
private: private:
std::string kRandomAffineOp = "RandomAffineOp";
std::vector<float_t> degrees_range_; // min_degree, max_degree std::vector<float_t> degrees_range_; // min_degree, max_degree
std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
std::vector<float_t> scale_range_; // min_scale, max_scale std::vector<float_t> scale_range_; // min_scale, max_scale

View File

@ -21,6 +21,7 @@
#include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/image_utils.h"
#endif #endif
// Kernel image headers (in alphabetical order) // Kernel image headers (in alphabetical order)
#include "minddata/dataset/kernels/image/affine_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/auto_contrast_op.h" #include "minddata/dataset/kernels/image/auto_contrast_op.h"
#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" #include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
@ -42,7 +43,9 @@
#include "minddata/dataset/kernels/image/normalize_pad_op.h" #include "minddata/dataset/kernels/image/normalize_pad_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/pad_op.h"
#endif
#include "minddata/dataset/kernels/image/random_affine_op.h" #include "minddata/dataset/kernels/image/random_affine_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/random_color_op.h" #include "minddata/dataset/kernels/image/random_color_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
@ -88,6 +91,59 @@ namespace vision {
/* ####################################### Derived TensorOperation classes ################################# */ /* ####################################### Derived TensorOperation classes ################################# */
// (In alphabetical order) // (In alphabetical order)
// AffineOperation
AffineOperation::AffineOperation(float_t degrees, const std::vector<float> &translation, float scale,
const std::vector<float> &shear, InterpolationMode interpolation,
const std::vector<uint8_t> &fill_value)
: degrees_(degrees),
translation_(translation),
scale_(scale),
shear_(shear),
interpolation_(interpolation),
fill_value_(fill_value) {}
Status AffineOperation::ValidateParams() {
// Translate
if (translation_.size() != 2) {
std::string err_msg =
"Affine: translate expecting size 2, got: translation.size() = " + std::to_string(translation_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateScalar("Affine", "translate", translation_[0], {-1, 1}, false, false));
RETURN_IF_NOT_OK(ValidateScalar("Affine", "translate", translation_[1], {-1, 1}, false, false));
// Shear
if (shear_.size() != 2) {
std::string err_msg = "Affine: shear_ranges expecting size 2, got: shear.size() = " + std::to_string(shear_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// Fill Value
RETURN_IF_NOT_OK(ValidateVectorFillvalue("Affine", fill_value_));
return Status::OK();
}
std::shared_ptr<TensorOp> AffineOperation::Build() {
std::shared_ptr<AffineOp> tensor_op =
std::make_shared<AffineOp>(degrees_, translation_, scale_, shear_, interpolation_, fill_value_);
return tensor_op;
}
Status AffineOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["degrees"] = degrees_;
args["translate"] = translation_;
args["scale"] = scale_;
args["shear"] = shear_;
args["resample"] = interpolation_;
args["fill_value"] = fill_value_;
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// AutoContrastOperation // AutoContrastOperation
@ -257,6 +313,7 @@ Status CutOutOperation::to_json(nlohmann::json *out_json) {
*out_json = args; *out_json = args;
return Status::OK(); return Status::OK();
} }
#endif
// DecodeOperation // DecodeOperation
DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {} DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {}
@ -269,6 +326,7 @@ Status DecodeOperation::to_json(nlohmann::json *out_json) {
(*out_json)["rgb"] = rgb_; (*out_json)["rgb"] = rgb_;
return Status::OK(); return Status::OK();
} }
#ifndef ENABLE_ANDROID
// EqualizeOperation // EqualizeOperation
Status EqualizeOperation::ValidateParams() { return Status::OK(); } Status EqualizeOperation::ValidateParams() { return Status::OK(); }

View File

@ -35,6 +35,7 @@ namespace dataset {
namespace vision { namespace vision {
// Char arrays storing name of corresponding classes (in alphabetical order) // Char arrays storing name of corresponding classes (in alphabetical order)
constexpr char kAffineOperation[] = "Affine";
constexpr char kAutoContrastOperation[] = "AutoContrast"; constexpr char kAutoContrastOperation[] = "AutoContrast";
constexpr char kBoundingBoxAugmentOperation[] = "BoundingBoxAugment"; constexpr char kBoundingBoxAugmentOperation[] = "BoundingBoxAugment";
constexpr char kCenterCropOperation[] = "CenterCrop"; constexpr char kCenterCropOperation[] = "CenterCrop";
@ -81,9 +82,34 @@ constexpr char kUniformAugOperation[] = "UniformAug";
/* ####################################### Derived TensorOperation classes ################################# */ /* ####################################### Derived TensorOperation classes ################################# */
class AffineOperation : public TensorOperation {
public:
explicit AffineOperation(float_t degrees, const std::vector<float> &translation, float scale,
const std::vector<float> &shear, InterpolationMode interpolation,
const std::vector<uint8_t> &fill_value);
~AffineOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kAffineOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float degrees_;
std::vector<float> translation_;
float scale_;
std::vector<float> shear_;
InterpolationMode interpolation_;
std::vector<uint8_t> fill_value_;
};
class AutoContrastOperation : public TensorOperation { class AutoContrastOperation : public TensorOperation {
public: public:
explicit AutoContrastOperation(float cutoff = 0.0, std::vector<uint32_t> ignore = {}); explicit AutoContrastOperation(float cutoff, std::vector<uint32_t> ignore);
~AutoContrastOperation() = default; ~AutoContrastOperation() = default;
@ -102,7 +128,7 @@ class AutoContrastOperation : public TensorOperation {
class BoundingBoxAugmentOperation : public TensorOperation { class BoundingBoxAugmentOperation : public TensorOperation {
public: public:
explicit BoundingBoxAugmentOperation(std::shared_ptr<TensorOperation> transform, float ratio = 0.3); explicit BoundingBoxAugmentOperation(std::shared_ptr<TensorOperation> transform, float ratio);
~BoundingBoxAugmentOperation() = default; ~BoundingBoxAugmentOperation() = default;
@ -156,7 +182,7 @@ class CropOperation : public TensorOperation {
class CutMixBatchOperation : public TensorOperation { class CutMixBatchOperation : public TensorOperation {
public: public:
explicit CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha = 1.0, float prob = 1.0); explicit CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha, float prob);
~CutMixBatchOperation() = default; ~CutMixBatchOperation() = default;
@ -176,7 +202,7 @@ class CutMixBatchOperation : public TensorOperation {
class CutOutOperation : public TensorOperation { class CutOutOperation : public TensorOperation {
public: public:
explicit CutOutOperation(int32_t length, int32_t num_patches = 1); explicit CutOutOperation(int32_t length, int32_t num_patches);
~CutOutOperation() = default; ~CutOutOperation() = default;
@ -195,7 +221,7 @@ class CutOutOperation : public TensorOperation {
class DecodeOperation : public TensorOperation { class DecodeOperation : public TensorOperation {
public: public:
explicit DecodeOperation(bool rgb = true); explicit DecodeOperation(bool rgb);
~DecodeOperation() = default; ~DecodeOperation() = default;
@ -246,7 +272,7 @@ class InvertOperation : public TensorOperation {
class MixUpBatchOperation : public TensorOperation { class MixUpBatchOperation : public TensorOperation {
public: public:
explicit MixUpBatchOperation(float alpha = 1); explicit MixUpBatchOperation(float alpha);
~MixUpBatchOperation() = default; ~MixUpBatchOperation() = default;
@ -283,8 +309,7 @@ class NormalizeOperation : public TensorOperation {
class NormalizePadOperation : public TensorOperation { class NormalizePadOperation : public TensorOperation {
public: public:
NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, const std::string &dtype);
const std::string &dtype = "float32");
~NormalizePadOperation() = default; ~NormalizePadOperation() = default;
@ -304,8 +329,7 @@ class NormalizePadOperation : public TensorOperation {
class PadOperation : public TensorOperation { class PadOperation : public TensorOperation {
public: public:
PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0}, PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode);
BorderType padding_mode = BorderType::kConstant);
~PadOperation() = default; ~PadOperation() = default;
@ -325,11 +349,9 @@ class PadOperation : public TensorOperation {
class RandomAffineOperation : public TensorOperation { class RandomAffineOperation : public TensorOperation {
public: public:
RandomAffineOperation(const std::vector<float_t> &degrees, const std::vector<float_t> &translate_range = {0.0, 0.0}, RandomAffineOperation(const std::vector<float_t> &degrees, const std::vector<float_t> &translate_range,
const std::vector<float_t> &scale_range = {1.0, 1.0}, const std::vector<float_t> &scale_range, const std::vector<float_t> &shear_ranges,
const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0}, InterpolationMode interpolation, const std::vector<uint8_t> &fill_value);
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
const std::vector<uint8_t> &fill_value = {0, 0, 0});
~RandomAffineOperation() = default; ~RandomAffineOperation() = default;
@ -371,8 +393,8 @@ class RandomColorOperation : public TensorOperation {
class RandomColorAdjustOperation : public TensorOperation { class RandomColorAdjustOperation : public TensorOperation {
public: public:
RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0}, RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast, std::vector<float> saturation,
std::vector<float> saturation = {1.0, 1.0}, std::vector<float> hue = {0.0, 0.0}); std::vector<float> hue);
~RandomColorAdjustOperation() = default; ~RandomColorAdjustOperation() = default;
@ -393,9 +415,8 @@ class RandomColorAdjustOperation : public TensorOperation {
class RandomCropOperation : public TensorOperation { class RandomCropOperation : public TensorOperation {
public: public:
RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0}, RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0}, std::vector<uint8_t> fill_value, BorderType padding_mode);
BorderType padding_mode = BorderType::kConstant);
~RandomCropOperation() = default; ~RandomCropOperation() = default;
@ -417,10 +438,8 @@ class RandomCropOperation : public TensorOperation {
class RandomResizedCropOperation : public TensorOperation { class RandomResizedCropOperation : public TensorOperation {
public: public:
RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale, std::vector<float> ratio,
std::vector<float> ratio = {3. / 4., 4. / 3.}, InterpolationMode interpolation, int32_t max_attempts);
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
int32_t max_attempts = 10);
/// \brief default copy constructor /// \brief default copy constructor
explicit RandomResizedCropOperation(const RandomResizedCropOperation &) = default; explicit RandomResizedCropOperation(const RandomResizedCropOperation &) = default;
@ -461,9 +480,8 @@ class RandomCropDecodeResizeOperation : public RandomResizedCropOperation {
class RandomCropWithBBoxOperation : public TensorOperation { class RandomCropWithBBoxOperation : public TensorOperation {
public: public:
RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0}, RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0}, std::vector<uint8_t> fill_value, BorderType padding_mode);
BorderType padding_mode = BorderType::kConstant);
~RandomCropWithBBoxOperation() = default; ~RandomCropWithBBoxOperation() = default;
@ -485,7 +503,7 @@ class RandomCropWithBBoxOperation : public TensorOperation {
class RandomHorizontalFlipOperation : public TensorOperation { class RandomHorizontalFlipOperation : public TensorOperation {
public: public:
explicit RandomHorizontalFlipOperation(float probability = 0.5); explicit RandomHorizontalFlipOperation(float probability);
~RandomHorizontalFlipOperation() = default; ~RandomHorizontalFlipOperation() = default;
@ -503,7 +521,7 @@ class RandomHorizontalFlipOperation : public TensorOperation {
class RandomHorizontalFlipWithBBoxOperation : public TensorOperation { class RandomHorizontalFlipWithBBoxOperation : public TensorOperation {
public: public:
explicit RandomHorizontalFlipWithBBoxOperation(float probability = 0.5); explicit RandomHorizontalFlipWithBBoxOperation(float probability);
~RandomHorizontalFlipWithBBoxOperation() = default; ~RandomHorizontalFlipWithBBoxOperation() = default;
@ -521,7 +539,7 @@ class RandomHorizontalFlipWithBBoxOperation : public TensorOperation {
class RandomPosterizeOperation : public TensorOperation { class RandomPosterizeOperation : public TensorOperation {
public: public:
explicit RandomPosterizeOperation(const std::vector<uint8_t> &bit_range = {4, 8}); explicit RandomPosterizeOperation(const std::vector<uint8_t> &bit_range);
~RandomPosterizeOperation() = default; ~RandomPosterizeOperation() = default;
@ -575,10 +593,9 @@ class RandomResizeWithBBoxOperation : public TensorOperation {
class RandomResizedCropWithBBoxOperation : public TensorOperation { class RandomResizedCropWithBBoxOperation : public TensorOperation {
public: public:
explicit RandomResizedCropWithBBoxOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, explicit RandomResizedCropWithBBoxOperation(std::vector<int32_t> size, std::vector<float> scale,
std::vector<float> ratio = {3. / 4., 4. / 3.}, std::vector<float> ratio, InterpolationMode interpolation,
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, int32_t max_attempts);
int32_t max_attempts = 10);
~RandomResizedCropWithBBoxOperation() = default; ~RandomResizedCropWithBBoxOperation() = default;
@ -642,7 +659,7 @@ class RandomSelectSubpolicyOperation : public TensorOperation {
class RandomSharpnessOperation : public TensorOperation { class RandomSharpnessOperation : public TensorOperation {
public: public:
explicit RandomSharpnessOperation(std::vector<float> degrees = {0.1, 1.9}); explicit RandomSharpnessOperation(std::vector<float> degrees);
~RandomSharpnessOperation() = default; ~RandomSharpnessOperation() = default;
@ -678,7 +695,7 @@ class RandomSolarizeOperation : public TensorOperation {
class RandomVerticalFlipOperation : public TensorOperation { class RandomVerticalFlipOperation : public TensorOperation {
public: public:
explicit RandomVerticalFlipOperation(float probability = 0.5); explicit RandomVerticalFlipOperation(float probability);
~RandomVerticalFlipOperation() = default; ~RandomVerticalFlipOperation() = default;
@ -696,7 +713,7 @@ class RandomVerticalFlipOperation : public TensorOperation {
class RandomVerticalFlipWithBBoxOperation : public TensorOperation { class RandomVerticalFlipWithBBoxOperation : public TensorOperation {
public: public:
explicit RandomVerticalFlipWithBBoxOperation(float probability = 0.5); explicit RandomVerticalFlipWithBBoxOperation(float probability);
~RandomVerticalFlipWithBBoxOperation() = default; ~RandomVerticalFlipWithBBoxOperation() = default;
@ -733,8 +750,7 @@ class RescaleOperation : public TensorOperation {
class ResizeOperation : public TensorOperation { class ResizeOperation : public TensorOperation {
public: public:
explicit ResizeOperation(std::vector<int32_t> size, explicit ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation_mode);
InterpolationMode interpolation_mode = InterpolationMode::kLinear);
~ResizeOperation() = default; ~ResizeOperation() = default;
@ -753,8 +769,7 @@ class ResizeOperation : public TensorOperation {
class ResizeWithBBoxOperation : public TensorOperation { class ResizeWithBBoxOperation : public TensorOperation {
public: public:
explicit ResizeWithBBoxOperation(std::vector<int32_t> size, explicit ResizeWithBBoxOperation(std::vector<int32_t> size, InterpolationMode interpolation_mode);
InterpolationMode interpolation_mode = InterpolationMode::kLinear);
~ResizeWithBBoxOperation() = default; ~ResizeWithBBoxOperation() = default;
@ -870,7 +885,7 @@ class SwapRedBlueOperation : public TensorOperation {
class UniformAugOperation : public TensorOperation { class UniformAugOperation : public TensorOperation {
public: public:
explicit UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops = 2); explicit UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops);
~UniformAugOperation() = default; ~UniformAugOperation() = default;

View File

@ -53,6 +53,7 @@ namespace dataset {
constexpr char kTensorOp[] = "TensorOp"; constexpr char kTensorOp[] = "TensorOp";
// image // image
constexpr char kAffineOp[] = "AffineOp";
constexpr char kAutoContrastOp[] = "AutoContrastOp"; constexpr char kAutoContrastOp[] = "AutoContrastOp";
constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
constexpr char kDecodeOp[] = "DecodeOp"; constexpr char kDecodeOp[] = "DecodeOp";
@ -73,6 +74,7 @@ constexpr char kMixUpBatchOp[] = "MixUpBatchOp";
constexpr char kNormalizeOp[] = "NormalizeOp"; constexpr char kNormalizeOp[] = "NormalizeOp";
constexpr char kNormalizePadOp[] = "NormalizePadOp"; constexpr char kNormalizePadOp[] = "NormalizePadOp";
constexpr char kPadOp[] = "PadOp"; constexpr char kPadOp[] = "PadOp";
constexpr char kRandomAffineOp[] = "RandomAffineOp";
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp";
constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp"; constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp";

View File

@ -229,7 +229,7 @@ def set_auto_num_workers(enable):
If turned on, the num_parallel_workers in each op will be adjusted automatically, possibly overwriting the If turned on, the num_parallel_workers in each op will be adjusted automatically, possibly overwriting the
num_parallel_workers passed in by user or the default value (if user doesn't pass anything) set by num_parallel_workers passed in by user or the default value (if user doesn't pass anything) set by
ds.config.set_num_parallel_workers(). ds.config.set_num_parallel_workers().
For now, this function is only optimized for Yolo3 dataset with per_batch_map (running map in batch). For now, this function is only optimized for YoloV3 dataset with per_batch_map (running map in batch).
This feature aims to provide a baseline for optimized num_workers assignment for each op. This feature aims to provide a baseline for optimized num_workers assignment for each op.
Op whose num_parallel_workers is adjusted to a new value will be logged. Op whose num_parallel_workers is adjusted to a new value will be logged.

View File

@ -192,9 +192,13 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/kernels/image/lite_image_utils.cc ${MINDDATA_DIR}/kernels/image/lite_image_utils.cc
${MINDDATA_DIR}/kernels/image/center_crop_op.cc ${MINDDATA_DIR}/kernels/image/center_crop_op.cc
${MINDDATA_DIR}/kernels/image/crop_op.cc ${MINDDATA_DIR}/kernels/image/crop_op.cc
${MINDDATA_DIR}/kernels/image/decode_op.cc
${MINDDATA_DIR}/kernels/image/normalize_op.cc ${MINDDATA_DIR}/kernels/image/normalize_op.cc
${MINDDATA_DIR}/kernels/image/affine_op.cc
${MINDDATA_DIR}/kernels/image/resize_op.cc ${MINDDATA_DIR}/kernels/image/resize_op.cc
${MINDDATA_DIR}/kernels/image/rotate_op.cc ${MINDDATA_DIR}/kernels/image/rotate_op.cc
${MINDDATA_DIR}/kernels/image/random_affine_op.cc
${MINDDATA_DIR}/kernels/image/math_utils.cc
${MINDDATA_DIR}/kernels/data/compose_op.cc ${MINDDATA_DIR}/kernels/data/compose_op.cc
${MINDDATA_DIR}/kernels/data/duplicate_op.cc ${MINDDATA_DIR}/kernels/data/duplicate_op.cc
${MINDDATA_DIR}/kernels/data/one_hot_op.cc ${MINDDATA_DIR}/kernels/data/one_hot_op.cc
@ -350,7 +354,6 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
"${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc" "${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc"
"${MINDDATA_DIR}/kernels/image/image_utils.cc" "${MINDDATA_DIR}/kernels/image/image_utils.cc"
"${MINDDATA_DIR}/kernels/image/invert_op.cc" "${MINDDATA_DIR}/kernels/image/invert_op.cc"
"${MINDDATA_DIR}/kernels/image/math_utils.cc"
"${MINDDATA_DIR}/kernels/image/mixup_batch_op.cc" "${MINDDATA_DIR}/kernels/image/mixup_batch_op.cc"
"${MINDDATA_DIR}/kernels/image/pad_op.cc" "${MINDDATA_DIR}/kernels/image/pad_op.cc"
"${MINDDATA_DIR}/kernels/image/posterize_op.cc" "${MINDDATA_DIR}/kernels/image/posterize_op.cc"

View File

@ -1,6 +1,7 @@
include(GoogleTest) include(GoogleTest)
SET(DE_UT_SRCS SET(DE_UT_SRCS
affine_op_test.cc
execute_test.cc execute_test.cc
album_op_test.cc album_op_test.cc
arena_test.cc arena_test.cc
@ -11,6 +12,7 @@ SET(DE_UT_SRCS
btree_test.cc btree_test.cc
buddy_test.cc buddy_test.cc
build_vocab_test.cc build_vocab_test.cc
c_api_affine_test.cc
c_api_cache_test.cc c_api_cache_test.cc
c_api_dataset_album_test.cc c_api_dataset_album_test.cc
c_api_dataset_cifar_test.cc c_api_dataset_cifar_test.cc

View File

@ -0,0 +1,113 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/cvop_common.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/kernels/image/affine_op.h"
#include "minddata/dataset/kernels/image/math_utils.h"
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/types_c.h>
#include "lite_cv/lite_mat.h"
#include "lite_cv/image_process.h"
using namespace mindspore::dataset;
using mindspore::dataset::InterpolationMode;
class MindDataTestAffineOp : public UT::CVOP::CVOpCommon {
public:
MindDataTestAffineOp() : CVOpCommon() {}
};
// Helper function, consider moving this to helper class for UT
double Mse(cv::Mat img1, cv::Mat img2) {
// clone to get around open cv optimization
cv::Mat output1 = img1.clone();
cv::Mat output2 = img2.clone();
// input check
if (output1.rows < 0 || output1.rows != output2.rows || output1.cols < 0 || output1.cols != output2.cols) {
return 10000.0;
}
return cv::norm(output1, output2, cv::NORM_L1);
}
// helper function to generate corresponding affine matrix
std::vector<double> GenerateMatrix(const std::shared_ptr<Tensor> &input, float_t degrees,
const std::vector<float_t> &translation, float_t scale,
const std::vector<float_t> &shear) {
float_t translation_x = translation[0];
float_t translation_y = translation[1];
DegreesToRadians(degrees, &degrees);
float_t shear_x = shear[0];
float_t shear_y = shear[1];
DegreesToRadians(shear_x, &shear_x);
DegreesToRadians(-1 * shear_y, &shear_y);
float_t cx = ((input->shape()[1] - 1) / 2.0);
float_t cy = ((input->shape()[0] - 1) / 2.0);
// Calculate RSS
std::vector<double> matrix{
static_cast<double>(scale * cos(degrees + shear_y) / cos(shear_y)),
static_cast<double>(scale * (-1 * cos(degrees + shear_y) * tan(shear_x) / cos(shear_y) - sin(degrees))),
0,
static_cast<double>(scale * sin(degrees + shear_y) / cos(shear_y)),
static_cast<double>(scale * (-1 * sin(degrees + shear_y) * tan(shear_x) / cos(shear_y) + cos(degrees))),
0};
// Compute T * C * RSS * C^-1
matrix[2] = (1 - matrix[0]) * cx - matrix[1] * cy + translation_x;
matrix[5] = (1 - matrix[4]) * cy - matrix[3] * cx + translation_y;
return matrix;
}
TEST_F(MindDataTestAffineOp, TestAffineLite) {
MS_LOG(INFO) << "Doing MindDataTestAffine-TestAffineLite.";
// create input tensor and
float degree = 0.0;
std::vector<float> translation = {0.0, 0.0};
float scale = 0.0;
std::vector<float> shear = {0.0, 0.0};
// Create affine object with default values
std::shared_ptr<AffineOp> op(new AffineOp(degree, translation, scale, shear, InterpolationMode::kLinear));
// output tensor
std::shared_ptr<Tensor> output_tensor;
// output
LiteMat dst;
LiteMat lite_mat_rgb(input_tensor_->shape()[1], input_tensor_->shape()[0], input_tensor_->shape()[2],
const_cast<void *>(reinterpret_cast<const void *>(input_tensor_->GetBuffer())),
LDataType::UINT8);
std::vector<double> matrix = GenerateMatrix(input_tensor_, degree, translation, scale, shear);
int height = lite_mat_rgb.height_;
int width = lite_mat_rgb.width_;
std::vector<size_t> dsize;
dsize.push_back(width);
dsize.push_back(height);
double M[6] = {};
for (int i = 0; i < matrix.size(); i++) {
M[i] = static_cast<double>(matrix[i]);
}
EXPECT_TRUE(Affine(lite_mat_rgb, dst, M, dsize, UINT8_C3(0, 0, 0)));
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());
// output tensor is a cv tenosr, we can compare mat values
cv::Mat lite_cv_out(dst.height_, dst.width_, CV_8UC3, dst.data_ptr_);
double mse = Mse(lite_cv_out, CVTensor(output_tensor).mat());
MS_LOG(INFO) << "mse: " << std::to_string(mse) << std::endl;
EXPECT_LT(mse, 1); // predetermined magic number
}

View File

@ -0,0 +1,97 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/vision.h"
using namespace mindspore::dataset;
using mindspore::dataset::InterpolationMode;
using mindspore::dataset::Tensor;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestAffineAPI) {
MS_LOG(INFO) << "Doing MindDataTestAffine-TestAffineAPI.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 5));
// Create a Repeat operation on ds
int32_t repeat_num = 3;
ds = ds->Repeat(repeat_num);
// Create auto contrast object with default values
std::shared_ptr<TensorTransform> crop(new vision::RandomCrop({256, 256}));
std::shared_ptr<TensorTransform> affine(
new vision::Affine(0.0, {0.0, 0.0}, 0.0, {0.0, 0.0}, InterpolationMode::kLinear));
// Create a Map operation on ds
ds = ds->Map({crop, affine});
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
// auto image = row["image"];
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
// EXPECT_EQ(row["image"].Shape()[0], 256);
}
EXPECT_EQ(i, 15);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestAffineAPIFail) {
MS_LOG(INFO) << "Doing MindDataTestAffine-TestAffineAPI.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 5));
// Create a Repeat operation on ds
int32_t repeat_num = 3;
ds = ds->Repeat(repeat_num);
// Create auto contrast object with default values
std::shared_ptr<TensorTransform> crop(new vision::RandomCrop({256, 256}));
std::shared_ptr<TensorTransform> affine(
new vision::Affine(0.0, {2.0, -1.0}, 0.0, {0.0, 0.0}, InterpolationMode::kLinear));
// Create a Map operation on ds
ds = ds->Map({crop, affine});
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}