forked from mindspore-Ecosystem/mindspore
add pil version cubic resize
This commit is contained in:
parent
24675b2081
commit
ecb8d6bc08
|
@ -117,6 +117,7 @@ PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) {
|
|||
.value("DE_INTER_CUBIC", InterpolationMode::kCubic)
|
||||
.value("DE_INTER_AREA", InterpolationMode::kArea)
|
||||
.value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour)
|
||||
.value("DE_INTER_PILCUBIC", InterpolationMode::kCubicPil)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ enum class ImageBatchFormat { kNHWC = 0, kNCHW = 1 };
|
|||
enum class ImageFormat { HWC = 0, CHW = 1, HW = 2 };
|
||||
|
||||
// Possible interpolation modes
|
||||
enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3 };
|
||||
enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3, kCubicPil = 4 };
|
||||
|
||||
// Possible JiebaMode modes
|
||||
enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 };
|
||||
|
|
|
@ -56,6 +56,7 @@ add_library(kernels-image OBJECT
|
|||
random_resize_with_bbox_op.cc
|
||||
random_color_op.cc
|
||||
rotate_op.cc
|
||||
resize_cubic_op.cc
|
||||
)
|
||||
if(ENABLE_ACL)
|
||||
add_dependencies(kernels-image kernels-soft-dvpp-image kernels-dvpp-image)
|
||||
|
|
|
@ -21,11 +21,12 @@
|
|||
#include <utility>
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/kernels/image/math_utils.h"
|
||||
#include "minddata/dataset/include/constants.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/include/constants.h"
|
||||
#include "minddata/dataset/kernels/image/math_utils.h"
|
||||
#include "minddata/dataset/kernels/image/resize_cubic_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
#define MAX_INT_PRECISION 16777216 // float int precision is 16777216
|
||||
|
@ -110,6 +111,19 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
RETURN_STATUS_UNEXPECTED("Resize: input tensor is not in shape of <H,W,C> or <H,W>");
|
||||
}
|
||||
|
||||
if (mode == InterpolationMode::kCubicPil) {
|
||||
LiteMat imIn, imOut;
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
TensorShape new_shape = TensorShape({output_height, output_width, 3});
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input_cv->type(), &output_tensor));
|
||||
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
|
||||
imOut.Init(output_width, output_height, input_cv->shape()[2], reinterpret_cast<void *>(buffer), LDataType::UINT8);
|
||||
imIn.Init(input_cv->shape()[1], input_cv->shape()[0], input_cv->shape()[2], input_cv->mat().data, LDataType::UINT8);
|
||||
ResizeCubic(imIn, imOut, output_width, output_height);
|
||||
*output = output_tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
cv::Mat in_image = input_cv->mat();
|
||||
// resize image too large or too small
|
||||
if (output_height > in_image.rows * 1000 || output_width > in_image.cols * 1000) {
|
||||
|
@ -569,6 +583,24 @@ Status CropAndResize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
|
|||
cv::Rect roi(x, y, crop_width, crop_height);
|
||||
auto cv_mode = GetCVInterpolationMode(mode);
|
||||
cv::Mat cv_in = input_cv->mat();
|
||||
|
||||
if (mode == InterpolationMode::kCubicPil) {
|
||||
cv::Mat input_roi = cv_in(roi);
|
||||
std::shared_ptr<CVTensor> input_image;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_roi, &input_image));
|
||||
LiteMat imIn, imOut;
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
TensorShape new_shape = TensorShape({target_height, target_width, 3});
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input_cv->type(), &output_tensor));
|
||||
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
|
||||
imOut.Init(target_width, target_height, input_cv->shape()[2], reinterpret_cast<void *>(buffer), LDataType::UINT8);
|
||||
imIn.Init(input_image->shape()[1], input_image->shape()[0], input_image->shape()[2], input_image->mat().data,
|
||||
LDataType::UINT8);
|
||||
ResizeCubic(imIn, imOut, target_width, target_height);
|
||||
*output = output_tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TensorShape shape{target_height, target_width};
|
||||
int num_channels = input_cv->shape()[2];
|
||||
if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels);
|
||||
|
|
|
@ -0,0 +1,272 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/image/resize_cubic_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// using 8 bits for result
|
||||
constexpr uint8_t PrecisionBits = 22;
|
||||
|
||||
// construct lookup table
|
||||
static const std::vector<uint8_t> _clip8_table = []() {
|
||||
std::vector<uint8_t> v1(896, 0);
|
||||
std::vector<uint8_t> v2(384, 255);
|
||||
for (int i = 0; i < 256; i++) {
|
||||
v1[i + 640] = i;
|
||||
}
|
||||
v1.insert(v1.end(), v2.begin(), v2.end());
|
||||
return v1;
|
||||
}();
|
||||
|
||||
static const uint8_t *clip8_table = &_clip8_table[640];
|
||||
|
||||
static inline uint8_t clip8(int input) { return clip8_table[input >> PrecisionBits]; }
|
||||
|
||||
static inline double cubic_interp(double x) {
|
||||
double a = -0.5;
|
||||
if (x < 0.0) {
|
||||
x = -x;
|
||||
}
|
||||
if (x < 1.0) {
|
||||
return ((a + 2.0) * x - (a + 3.0)) * x * x + 1;
|
||||
}
|
||||
if (x < 2.0) {
|
||||
return (((x - 5) * x + 8) * x - 4) * a;
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
struct interpolation {
|
||||
double (*interpolation)(double x);
|
||||
double threshold;
|
||||
};
|
||||
|
||||
int calc_coeff(int input_size, int out_size, int input0, int input1, struct interpolation *interp,
|
||||
std::vector<int> ®ions, std::vector<double> &coeffs_interp) {
|
||||
double threshold, scale, interp_scale;
|
||||
int kernel_size;
|
||||
|
||||
scale = static_cast<double>((input1 - input0)) / out_size;
|
||||
if (scale < 1.0) {
|
||||
interp_scale = 1.0;
|
||||
} else {
|
||||
interp_scale = scale;
|
||||
}
|
||||
|
||||
// obtain size
|
||||
threshold = interp->threshold * interp_scale;
|
||||
|
||||
// coefficients number
|
||||
kernel_size = static_cast<int>(ceil(threshold)) * 2 + 1;
|
||||
if (out_size > INT_MAX / (kernel_size * static_cast<int>(sizeof(double)))) {
|
||||
MS_LOG(WARNING) << "Unable to allocator memory as output Image size is so large.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
// coefficient array
|
||||
std::vector<double> coeffs(out_size * kernel_size, 0.0);
|
||||
std::vector<int> region(out_size * 2, 0);
|
||||
|
||||
for (int xx = 0; xx < out_size; xx++) {
|
||||
double center = input0 + (xx + 0.5) * scale;
|
||||
double mm = 0.0, ss = 1.0 / interp_scale;
|
||||
int x;
|
||||
// Round for x_min
|
||||
int x_min = static_cast<int>((center - threshold + 0.5));
|
||||
if (x_min < 0) {
|
||||
x_min = 0;
|
||||
}
|
||||
// Round for x_max
|
||||
int x_max = static_cast<int>((center + threshold + 0.5));
|
||||
if (x_max > input_size) {
|
||||
x_max = input_size;
|
||||
}
|
||||
x_max -= x_min;
|
||||
double *coeff = &coeffs[xx * kernel_size];
|
||||
for (x = 0; x < x_max; x++) {
|
||||
double m = interp->interpolation((x + x_min - center + 0.5) * ss);
|
||||
coeff[x] = m;
|
||||
mm += m;
|
||||
}
|
||||
for (x = 0; x < x_max; x++) {
|
||||
if (mm != 0.0) {
|
||||
coeff[x] /= mm;
|
||||
}
|
||||
}
|
||||
// Remaining values should stay empty if they are used despite of x_max.
|
||||
for (; x < kernel_size; x++) {
|
||||
coeff[x] = 0;
|
||||
}
|
||||
region[xx * 2 + 0] = x_min;
|
||||
region[xx * 2 + 1] = x_max;
|
||||
}
|
||||
|
||||
regions = std::move(region);
|
||||
coeffs_interp = std::move(coeffs);
|
||||
return kernel_size;
|
||||
}
|
||||
|
||||
void normalize_coeff(int out_size, int kernel_size, const std::vector<double> &prekk, std::vector<int> &kk) {
|
||||
for (int x = 0; x < out_size * kernel_size; x++) {
|
||||
if (prekk[x] < 0) {
|
||||
kk[x] = static_cast<int>((-0.5 + prekk[x] * (1 << PrecisionBits)));
|
||||
} else {
|
||||
kk[x] = static_cast<int>((0.5 + prekk[x] * (1 << PrecisionBits)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status ImagingHorizontalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size,
|
||||
const std::vector<int> ®ions, const std::vector<double> &prekk) {
|
||||
int ss0, ss1, ss2;
|
||||
int32_t *k;
|
||||
|
||||
// normalize previous calculated coefficients
|
||||
std::vector<int> kk(prekk.begin(), prekk.end());
|
||||
normalize_coeff(output.width_, kernel_size, prekk, kk);
|
||||
uint8_t *input_ptr = input;
|
||||
uint8_t *output_ptr = output;
|
||||
int32_t input_width = input.width_ * 3;
|
||||
int32_t output_width = output.width_ * 3;
|
||||
|
||||
for (int yy = 0; yy < output.height_; yy++) {
|
||||
// obtain the ptr of output, and put calculated value into it
|
||||
uint8_t *bgr_buf = output_ptr;
|
||||
for (int xx = 0; xx < output.width_; xx++) {
|
||||
int x_min = regions[xx * 2 + 0];
|
||||
int x_max = regions[xx * 2 + 1];
|
||||
k = &kk[xx * kernel_size];
|
||||
ss0 = ss1 = ss2 = 1 << (PrecisionBits - 1);
|
||||
for (int x = 0; x < x_max; x++) {
|
||||
ss0 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 0]) * k[x];
|
||||
ss1 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 1]) * k[x];
|
||||
ss2 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 2]) * k[x];
|
||||
}
|
||||
bgr_buf[0] = clip8(ss0);
|
||||
bgr_buf[1] = clip8(ss1);
|
||||
bgr_buf[2] = clip8(ss2);
|
||||
bgr_buf += 3;
|
||||
}
|
||||
output_ptr += output_width;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size,
|
||||
const std::vector<int> ®ions, const std::vector<double> &prekk) {
|
||||
int ss0, ss1, ss2;
|
||||
|
||||
// normalize previous calculated coefficients
|
||||
std::vector<int> kk(prekk.begin(), prekk.end());
|
||||
normalize_coeff(output.height_, kernel_size, prekk, kk);
|
||||
uint8_t *input_ptr = input;
|
||||
uint8_t *output_ptr = output;
|
||||
const int32_t input_width = input.width_ * 3;
|
||||
const int32_t output_width = output.width_ * 3;
|
||||
|
||||
for (int yy = 0; yy < output.height_; yy++) {
|
||||
// obtain the ptr of output, and put calculated value into it
|
||||
uint8_t *bgr_buf = output_ptr;
|
||||
int32_t *k = &kk[yy * kernel_size];
|
||||
int y_min = regions[yy * 2 + 0];
|
||||
int y_max = regions[yy * 2 + 1];
|
||||
for (int xx = 0; xx < output.width_; xx++) {
|
||||
ss0 = ss1 = ss2 = 1 << (PrecisionBits - 1);
|
||||
for (int y = 0; y < y_max; y++) {
|
||||
ss0 += (input_ptr[(y + y_min) * input_width + xx * 3 + 0]) * k[y];
|
||||
ss1 += (input_ptr[(y + y_min) * input_width + xx * 3 + 1]) * k[y];
|
||||
ss2 += (input_ptr[(y + y_min) * input_width + xx * 3 + 2]) * k[y];
|
||||
}
|
||||
bgr_buf[0] = clip8(ss0);
|
||||
bgr_buf[1] = clip8(ss1);
|
||||
bgr_buf[2] = clip8(ss2);
|
||||
bgr_buf += 3;
|
||||
}
|
||||
output_ptr += output_width;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ImageInterpolation(LiteMat input, LiteMat &output, int x_size, int y_size, struct interpolation *interp,
|
||||
int rect[4]) {
|
||||
int horizontal_interp, vertical_interp, horiz_kernel, vert_kernel, rect_y0, rect_y1;
|
||||
std::vector<int> horiz_region, vert_region;
|
||||
std::vector<double> horiz_coeff, vert_coeff;
|
||||
LiteMat temp;
|
||||
|
||||
horizontal_interp = x_size != input.width_ || rect[2] != x_size || rect[0];
|
||||
vertical_interp = y_size != input.height_ || rect[3] != y_size || rect[1];
|
||||
|
||||
horiz_kernel = calc_coeff(input.width_, x_size, rect[0], rect[2], interp, horiz_region, horiz_coeff);
|
||||
if (!horiz_kernel) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vert_kernel = calc_coeff(input.height_, y_size, rect[1], rect[3], interp, vert_region, vert_coeff);
|
||||
if (!vert_kernel) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// first and last used row in the input image
|
||||
rect_y0 = vert_region[0];
|
||||
rect_y1 = vert_region[y_size * 2 - 1] + vert_region[y_size * 2 - 2];
|
||||
|
||||
// two-direction resize, horizontal resize
|
||||
if (horizontal_interp) {
|
||||
// Shift region for vertical resize
|
||||
for (int i = 0; i < y_size; i++) {
|
||||
vert_region[i * 2] -= rect_y0;
|
||||
}
|
||||
temp.Init(x_size, rect_y1 - rect_y0, 3);
|
||||
|
||||
ImagingHorizontalInterp(temp, input, rect_y0, horiz_kernel, horiz_region, horiz_coeff);
|
||||
if (temp.IsEmpty()) {
|
||||
return false;
|
||||
}
|
||||
output = input = temp;
|
||||
}
|
||||
|
||||
/* vertical resize */
|
||||
if (vertical_interp) {
|
||||
output.Init(input.width_, y_size, 3);
|
||||
if (!output.IsEmpty()) {
|
||||
ImagingVerticalInterp(output, input, 0, vert_kernel, vert_region, vert_coeff);
|
||||
}
|
||||
if (output.IsEmpty()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResizeCubic(const LiteMat &input, LiteMat &dst, int dst_w, int dst_h) {
|
||||
if (input.data_type_ != LDataType::UINT8 || input.channel_ != 3) {
|
||||
MS_LOG(ERROR) << "Unsupported data type, only support input image of uint8 dtype and 3 channel.";
|
||||
return false;
|
||||
}
|
||||
int x_size = dst_w, y_size = dst_h;
|
||||
int rect[4] = {0, 0, input.width_, input.height_};
|
||||
LiteMat output;
|
||||
|
||||
struct interpolation interp = {cubic_interp, 2.0};
|
||||
bool res = ImageInterpolation(input, output, x_size, y_size, &interp, rect);
|
||||
|
||||
memcpy_s(dst.data_ptr_, output.size_, output.data_ptr_, output.size_);
|
||||
return res;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_
|
||||
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <limits.h>
|
||||
#include <string.h>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <random>
|
||||
#include "lite_cv/lite_mat.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief Calculate the coefficient for interpolation firstly
|
||||
int calc_coeff(int input_size, int out_size, int input0, int input1, struct interpolation *interp,
|
||||
std::vector<int> ®ions, std::vector<double> &coeffs_interp);
|
||||
|
||||
/// \brief Normalize the coefficient for interpolation
|
||||
void normalize_coeff(int out_size, int kernel_size, const std::vector<double> &prekk, std::vector<int> &kk);
|
||||
|
||||
/// \brief Apply horizontal interpolation on input image
|
||||
Status ImagingHorizontalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size,
|
||||
const std::vector<int> ®ions, const std::vector<double> &prekk);
|
||||
|
||||
/// \brief Apply Vertical interpolation on input image
|
||||
Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size,
|
||||
const std::vector<int> ®ions, const std::vector<double> &prekk);
|
||||
|
||||
/// \brief Mainly logic of Cubic interpolation
|
||||
bool ImageInterpolation(LiteMat input, LiteMat &output, int x_size, int y_size, struct interpolation *interp,
|
||||
int rect[4]);
|
||||
|
||||
/// \brief Apply cubic interpolation on input image and obtain the output image
|
||||
/// \param[in] input Input image
|
||||
/// \param[out] dst Output image
|
||||
/// \param[in] dst_w expected Output image width
|
||||
/// \param[in] dst_h expected Output image height
|
||||
bool ResizeCubic(const LiteMat &input, LiteMat &dst, int dst_w, int dst_h);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_
|
|
@ -84,7 +84,8 @@ DE_C_IMAGE_BATCH_FORMAT = {ImageBatchFormat.NHWC: cde.ImageBatchFormat.DE_IMAGE_
|
|||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||
Inter.CUBIC: cde.InterpolationMode.DE_INTER_CUBIC,
|
||||
Inter.AREA: cde.InterpolationMode.DE_INTER_AREA}
|
||||
Inter.AREA: cde.InterpolationMode.DE_INTER_AREA,
|
||||
Inter.PILCUBIC: cde.InterpolationMode.DE_INTER_PILCUBIC}
|
||||
|
||||
|
||||
def parse_padding(padding):
|
||||
|
@ -930,6 +931,10 @@ class RandomResizedCrop(ImageTensorOperation):
|
|||
|
||||
- Inter.BICUBIC, means interpolation method is bicubic interpolation.
|
||||
|
||||
- Inter.AREA, means interpolation method is pixel area interpolation.
|
||||
|
||||
- Inter.PILCUBIC, means interpolation method is bicubic interpolation like implemented in pillow.
|
||||
|
||||
max_attempts (int, optional): The maximum number of attempts to propose a valid
|
||||
crop_area (default=10). If exceeded, fall back to use center_crop instead.
|
||||
|
||||
|
@ -1314,6 +1319,8 @@ class Resize(ImageTensorOperation):
|
|||
|
||||
- Inter.AREA, means interpolation method is pixel area interpolation.
|
||||
|
||||
- Inter.PILCUBIC, means interpolation method is bicubic interpolation like implemented in pillow.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.vision import Inter
|
||||
>>> decode_op = c_vision.Decode()
|
||||
|
|
|
@ -23,6 +23,7 @@ class Inter(IntEnum):
|
|||
BILINEAR = LINEAR = 2
|
||||
BICUBIC = CUBIC = 3
|
||||
AREA = 4
|
||||
PILCUBIC = 5
|
||||
|
||||
|
||||
# Padding Mode, Border Type
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,14 +13,14 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <opencv2/imgproc/types_c.h>
|
||||
#include <fstream>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "lite_cv/lite_mat.h"
|
||||
#include "lite_cv/image_process.h"
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <opencv2/imgproc/types_c.h>
|
||||
|
||||
#include <fstream>
|
||||
#include "minddata/dataset/kernels/image/resize_cubic_op.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
class MindDataImageProcess : public UT::Common {
|
||||
|
@ -184,6 +184,24 @@ TEST_F(MindDataImageProcess, test3C) {
|
|||
CompareMat(cv_image, lite_norm_mat_cut);
|
||||
}
|
||||
|
||||
TEST_F(MindDataImageProcess, testCubic3C) {
|
||||
std::string filename = "data/dataset/apple.jpg";
|
||||
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
|
||||
cv::Mat rgb_mat;
|
||||
cv::cvtColor(image, rgb_mat, CV_BGR2RGB);
|
||||
|
||||
LiteMat imIn, imOut;
|
||||
int32_t output_width = 24;
|
||||
int32_t output_height = 24;
|
||||
imIn.Init(rgb_mat.cols, rgb_mat.rows, rgb_mat.channels(), rgb_mat.data, LDataType::UINT8);
|
||||
imOut.Init(output_width, output_height, 3, LDataType::UINT8);
|
||||
|
||||
bool ret = ResizeCubic(imIn, imOut, output_width, output_height);
|
||||
|
||||
ASSERT_TRUE(ret == true);
|
||||
return;
|
||||
}
|
||||
|
||||
bool ReadYUV(const char *filename, int w, int h, uint8_t **data) {
|
||||
FILE *f = fopen(filename, "rb");
|
||||
if (f == nullptr) {
|
||||
|
|
Loading…
Reference in New Issue