forked from mindspore-Ecosystem/mindspore
!8380 [MD]remove opencv api with lite_cv in lite train
From: @xulei2020 Reviewed-by: @heleiwang Signed-off-by:
This commit is contained in:
commit
bc12c647a5
|
@ -23,8 +23,8 @@
|
|||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/auto_contrast_op.h"
|
||||
#include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
|
||||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
#endif
|
||||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/crop_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
|
||||
|
@ -94,6 +94,7 @@ std::shared_ptr<BoundingBoxAugmentOperation> BoundingBoxAugment(std::shared_ptr<
|
|||
// Input validation
|
||||
return op->ValidateParams() ? op : nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Function to create CenterCropOperation.
|
||||
std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size) {
|
||||
|
@ -101,7 +102,6 @@ std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size) {
|
|||
// Input validation
|
||||
return op->ValidateParams() ? op : nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Function to create CropOperation.
|
||||
std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size) {
|
||||
|
@ -519,6 +519,7 @@ std::shared_ptr<TensorOp> BoundingBoxAugmentOperation::Build() {
|
|||
std::shared_ptr<BoundingBoxAugmentOp> tensor_op = std::make_shared<BoundingBoxAugmentOp>(transform_->Build(), ratio_);
|
||||
return tensor_op;
|
||||
}
|
||||
#endif
|
||||
|
||||
// CenterCropOperation
|
||||
CenterCropOperation::CenterCropOperation(std::vector<int32_t> size) : size_(size) {}
|
||||
|
@ -558,7 +559,6 @@ std::shared_ptr<TensorOp> CenterCropOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
#endif
|
||||
// CropOperation.
|
||||
CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size)
|
||||
: coordinates_(coordinates), size_(size) {}
|
||||
|
|
|
@ -35,8 +35,8 @@ namespace vision {
|
|||
#ifndef ENABLE_ANDROID
|
||||
class AutoContrastOperation;
|
||||
class BoundingBoxAugmentOperation;
|
||||
class CenterCropOperation;
|
||||
#endif
|
||||
class CenterCropOperation;
|
||||
class CropOperation;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class CutMixBatchOperation;
|
||||
|
@ -96,6 +96,7 @@ std::shared_ptr<AutoContrastOperation> AutoContrast(float cutoff = 0.0, std::vec
|
|||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<BoundingBoxAugmentOperation> BoundingBoxAugment(std::shared_ptr<TensorOperation> transform,
|
||||
float ratio = 0.3);
|
||||
#endif
|
||||
|
||||
/// \brief Function to create a CenterCrop TensorOperation.
|
||||
/// \notes Crops the input image at the center to the given size.
|
||||
|
@ -104,7 +105,7 @@ std::shared_ptr<BoundingBoxAugmentOperation> BoundingBoxAugment(std::shared_ptr<
|
|||
/// If size has 2 values, it should be (height, width).
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size);
|
||||
#endif
|
||||
|
||||
/// \brief Function to create a Crop TensorOp
|
||||
/// \notes Crop an image based on location and crop size
|
||||
/// \param[in] coordinates Starting location of crop. Must be a vector of two values, in the form of {x_coor, y_coor}
|
||||
|
@ -502,6 +503,8 @@ class BoundingBoxAugmentOperation : public TensorOperation {
|
|||
float ratio_;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
class CenterCropOperation : public TensorOperation {
|
||||
public:
|
||||
explicit CenterCropOperation(std::vector<int32_t> size);
|
||||
|
@ -515,7 +518,7 @@ class CenterCropOperation : public TensorOperation {
|
|||
private:
|
||||
std::vector<int32_t> size_;
|
||||
};
|
||||
#endif
|
||||
|
||||
class CropOperation : public TensorOperation {
|
||||
public:
|
||||
CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size);
|
||||
|
|
|
@ -16,8 +16,13 @@
|
|||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
#include <string>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#else
|
||||
#include "minddata/dataset/kernels/image/lite_image_utils.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -73,6 +73,20 @@ LiteMat::LiteMat(int width, int height, int channel, LDataType data_type) {
|
|||
Init(width, height, channel, data_type);
|
||||
}
|
||||
|
||||
LiteMat::LiteMat(int width, int height, int channel, void *p_data, LDataType data_type) {
|
||||
data_type_ = data_type;
|
||||
InitElemSize(data_type);
|
||||
width_ = width;
|
||||
height_ = height;
|
||||
dims_ = 3;
|
||||
channel_ = channel;
|
||||
c_step_ = height_ * width_;
|
||||
size_ = c_step_ * channel_ * elem_size_;
|
||||
data_ptr_ = p_data;
|
||||
ref_count_ = new int[1];
|
||||
*ref_count_ = 0;
|
||||
}
|
||||
|
||||
LiteMat::~LiteMat() { Release(); }
|
||||
|
||||
int LiteMat::addRef(int *p, int value) {
|
||||
|
|
|
@ -195,6 +195,8 @@ class LiteMat {
|
|||
|
||||
LiteMat(int width, int height, int channel, LDataType data_type = LDataType::UINT8);
|
||||
|
||||
LiteMat(int width, int height, int channel, void *p_data, LDataType data_type = LDataType::UINT8);
|
||||
|
||||
~LiteMat();
|
||||
|
||||
LiteMat(const LiteMat &m);
|
||||
|
|
|
@ -229,6 +229,11 @@ Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
|||
if (input->Rank() != 3 && input->Rank() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>");
|
||||
}
|
||||
|
||||
if (input->type() != DataType::DE_FLOAT32 || input->type() != DataType::DE_UINT8) {
|
||||
RETURN_STATUS_UNEXPECTED("Only float32, uint8 support in Crop");
|
||||
}
|
||||
|
||||
// account for integer overflow
|
||||
if (y < 0 || (y + h) > input->shape()[0] || (y + h) < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid y coordinate value for crop");
|
||||
|
@ -237,18 +242,17 @@ Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
|||
if (x < 0 || (x + w) > input->shape()[1] || (x + w) < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid x coordinate value for crop");
|
||||
}
|
||||
// convert to lite Mat
|
||||
LiteMat lite_mat_rgb;
|
||||
// rows = height, this constructor takes: cols,rows
|
||||
bool ret = InitFromPixel(input->GetBuffer(), LPixelType::RGB, LDataType::UINT8, input->shape()[1], input->shape()[0],
|
||||
lite_mat_rgb);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Creation of lite cv failed");
|
||||
|
||||
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], 3,
|
||||
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())), LDataType::UINT8);
|
||||
|
||||
try {
|
||||
TensorShape shape{h, w};
|
||||
int num_channels = input->shape()[2];
|
||||
if (input->Rank() == 3) shape = shape.AppendDim(num_channels);
|
||||
LiteMat lite_mat_cut;
|
||||
ret = Crop(lite_mat_rgb, lite_mat_cut, x, y, x + w, y + h);
|
||||
|
||||
bool ret = Crop(lite_mat_rgb, lite_mat_cut, x, y, x + w, y + h);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Crop failed in lite cv");
|
||||
// create output Tensor based off of lite_mat_cut
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
|
@ -287,14 +291,17 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
if (input->Rank() != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("Input tensor rank isn't 3");
|
||||
}
|
||||
LiteMat lite_mat_rgb;
|
||||
// rows = height, this constructor takes: cols,rows
|
||||
bool ret = InitFromPixel(input->GetBuffer(), LPixelType::RGB, LDataType::UINT8, input->shape()[1], input->shape()[0],
|
||||
lite_mat_rgb);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Creation of lite cv failed");
|
||||
|
||||
if (input->type() != DataType::DE_UINT8) {
|
||||
RETURN_STATUS_UNEXPECTED("Only uint8 support in Normalize");
|
||||
}
|
||||
|
||||
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], 3,
|
||||
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())), LDataType::UINT8);
|
||||
|
||||
LiteMat lite_mat_float;
|
||||
// change input to float
|
||||
ret = ConvertTo(lite_mat_rgb, lite_mat_float, 1.0);
|
||||
bool ret = ConvertTo(lite_mat_rgb, lite_mat_float, 1.0);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Conversion of lite cv to float failed");
|
||||
|
||||
mean->Squeeze();
|
||||
|
@ -337,6 +344,9 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
if (input->Rank() != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of <H,W,C>");
|
||||
}
|
||||
if (input->type() != DataType::DE_UINT8) {
|
||||
RETURN_STATUS_UNEXPECTED("Only uint8 support in Resize");
|
||||
}
|
||||
// resize image too large or too small
|
||||
if (output_height == 0 || output_height > input->shape()[0] * 1000 || output_width == 0 ||
|
||||
output_width > input->shape()[1] * 1000) {
|
||||
|
@ -345,10 +355,8 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
"1000 times the original image; 2) can not be 0.";
|
||||
return Status(StatusCode::kShapeMisMatch, err_msg);
|
||||
}
|
||||
LiteMat lite_mat_rgb;
|
||||
bool ret = InitFromPixel(input->GetBuffer(), LPixelType::RGB, LDataType::UINT8, input->shape()[1], input->shape()[0],
|
||||
lite_mat_rgb);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Creation of lite cv failed");
|
||||
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], 3,
|
||||
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())), LDataType::UINT8);
|
||||
|
||||
try {
|
||||
TensorShape shape{output_height, output_width};
|
||||
|
@ -356,7 +364,7 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
if (input->Rank() == 3) shape = shape.AppendDim(num_channels);
|
||||
|
||||
LiteMat lite_mat_resize;
|
||||
ret = ResizeBilinear(lite_mat_rgb, lite_mat_resize, output_width, output_height);
|
||||
bool ret = ResizeBilinear(lite_mat_rgb, lite_mat_resize, output_width, output_height);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Resize failed in lite cv");
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -368,5 +376,39 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const int32_t &pad_top,
|
||||
const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types,
|
||||
uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) {
|
||||
if (input->Rank() != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of <H,W,C>");
|
||||
}
|
||||
|
||||
if (input->type() != DataType::DE_FLOAT32 || input->type() != DataType::DE_UINT8) {
|
||||
RETURN_STATUS_UNEXPECTED("Only float32, uint8 support in Pad");
|
||||
}
|
||||
|
||||
if (pad_top <= 0 || pad_bottom <= 0 || pad_left <= 0 || pad_right <= 0) {
|
||||
RETURN_STATUS_UNEXPECTED("The pad, top, bottom, left, right must be greater than 0");
|
||||
}
|
||||
|
||||
try {
|
||||
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], 3,
|
||||
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())), LDataType::UINT8);
|
||||
|
||||
LiteMat lite_mat_pad;
|
||||
bool ret = Pad(lite_mat_rgb, lite_mat_pad, pad_top, pad_bottom, pad_left, pad_right,
|
||||
PaddBorderType::PADD_BORDER_CONSTANT, fill_r, fill_g, fill_b);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Pad failed in lite cv");
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(input->shape(), DataType(DataType::DE_FLOAT32),
|
||||
static_cast<uchar *>(lite_mat_pad.data_ptr_), &output_tensor));
|
||||
*output = output_tensor;
|
||||
} catch (std::runtime_error &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Error in image Pad.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -95,6 +95,20 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
int32_t output_width, double fx = 0.0, double fy = 0.0,
|
||||
InterpolationMode mode = InterpolationMode::kLinear);
|
||||
|
||||
/// \brief Pads the input image and puts the padded image in the output
|
||||
/// \param input: input Tensor
|
||||
/// \param output: padded Tensor
|
||||
/// \param pad_top: amount of padding done in top
|
||||
/// \param pad_bottom: amount of padding done in bottom
|
||||
/// \param pad_left: amount of padding done in left
|
||||
/// \param pad_right: amount of padding done in right
|
||||
/// \param border_types: the interpolation to be done in the border
|
||||
/// \param fill_r: red fill value for pad
|
||||
/// \param fill_g: green fill value for pad
|
||||
/// \param fill_b: blue fill value for pad.
|
||||
Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const int32_t &pad_top,
|
||||
const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types,
|
||||
uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_
|
||||
|
|
|
@ -144,7 +144,6 @@ if (BUILD_MINDDATA STREQUAL "full")
|
|||
"${MINDDATA_DIR}/kernels/image/auto_contrast_op.cc"
|
||||
"${MINDDATA_DIR}/kernels/image/bounding_box_op.cc"
|
||||
"${MINDDATA_DIR}/kernels/image/bounding_box_augment_op.cc"
|
||||
"${MINDDATA_DIR}/kernels/image/center_crop_op.cc"
|
||||
"${MINDDATA_DIR}/kernels/image/concatenate_op.cc"
|
||||
"${MINDDATA_DIR}/kernels/image/cut_out_op.cc"
|
||||
"${MINDDATA_DIR}/kernels/image/cutmix_batch_op.cc"
|
||||
|
|
|
@ -107,6 +107,53 @@ TEST_F(MindDataImageProcess, testRGB) {
|
|||
cv::Mat dst_image(lite_mat_rgb.height_, lite_mat_rgb.width_, CV_8UC3, lite_mat_rgb.data_ptr_);
|
||||
}
|
||||
|
||||
TEST_F(MindDataImageProcess, testLoadByMemPtr) {
|
||||
std::string filename = "data/dataset/apple.jpg";
|
||||
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
|
||||
|
||||
cv::Mat rgba_mat;
|
||||
cv::cvtColor(image, rgba_mat, CV_BGR2RGB);
|
||||
|
||||
bool ret = false;
|
||||
int width = rgba_mat.cols;
|
||||
int height = rgba_mat.rows;
|
||||
uchar *p_rgb = (uchar *)malloc(width * height * 3 * sizeof(uchar));
|
||||
for (int i = 0; i < height; i++) {
|
||||
const uchar *current = rgba_mat.ptr<uchar>(i);
|
||||
for (int j = 0; j < width; j++) {
|
||||
p_rgb[i * width * 3 + 3 * j + 0] = current[3 * j + 0];
|
||||
p_rgb[i * width * 3 + 3 * j + 1] = current[3 * j + 1];
|
||||
p_rgb[i * width * 3 + 3 * j + 2] = current[3 * j + 2];
|
||||
}
|
||||
}
|
||||
|
||||
LiteMat lite_mat_rgb(width, height, 3, (void *)p_rgb, LDataType::UINT8);
|
||||
LiteMat lite_mat_resize;
|
||||
ret = ResizeBilinear(lite_mat_rgb, lite_mat_resize, 256, 256);
|
||||
ASSERT_TRUE(ret == true);
|
||||
LiteMat lite_mat_convert_float;
|
||||
ret = ConvertTo(lite_mat_resize, lite_mat_convert_float, 1.0);
|
||||
ASSERT_TRUE(ret == true);
|
||||
|
||||
LiteMat lite_mat_crop;
|
||||
ret = Crop(lite_mat_convert_float, lite_mat_crop, 16, 16, 224, 224);
|
||||
ASSERT_TRUE(ret == true);
|
||||
std::vector<float> means = {0.485, 0.456, 0.406};
|
||||
std::vector<float> stds = {0.229, 0.224, 0.225};
|
||||
LiteMat lite_norm_mat_cut;
|
||||
ret = SubStractMeanNormalize(lite_mat_crop, lite_norm_mat_cut, means, stds);
|
||||
|
||||
int pad_width = lite_norm_mat_cut.width_ + 20;
|
||||
int pad_height = lite_norm_mat_cut.height_ + 20;
|
||||
float *p_rgb_pad = (float *)malloc(pad_width * pad_height * 3 * sizeof(float));
|
||||
|
||||
LiteMat makeborder(pad_width, pad_height, 3, (void *)p_rgb_pad, LDataType::FLOAT32);
|
||||
ret = Pad(lite_norm_mat_cut, makeborder, 10, 30, 40, 10, PaddBorderType::PADD_BORDER_CONSTANT, 255, 255, 255);
|
||||
cv::Mat dst_image(pad_height, pad_width, CV_8UC3, p_rgb_pad);
|
||||
free(p_rgb);
|
||||
free(p_rgb_pad);
|
||||
}
|
||||
|
||||
TEST_F(MindDataImageProcess, test3C) {
|
||||
std::string filename = "data/dataset/apple.jpg";
|
||||
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
|
||||
|
@ -512,8 +559,7 @@ TEST_F(MindDataImageProcess, TestSubtractInt8) {
|
|||
LiteMat dst_int8;
|
||||
EXPECT_TRUE(Subtract(src1_int8, src2_int8, dst_int8));
|
||||
for (size_t i = 0; i < cols; i++) {
|
||||
EXPECT_EQ(static_cast<INT8_C1 *>(expect_int8.data_ptr_)[i].c1,
|
||||
static_cast<INT8_C1 *>(dst_int8.data_ptr_)[i].c1);
|
||||
EXPECT_EQ(static_cast<INT8_C1 *>(expect_int8.data_ptr_)[i].c1, static_cast<INT8_C1 *>(dst_int8.data_ptr_)[i].c1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -645,8 +691,7 @@ TEST_F(MindDataImageProcess, TestDivideInt8) {
|
|||
LiteMat dst_int8;
|
||||
EXPECT_TRUE(Divide(src1_int8, src2_int8, dst_int8));
|
||||
for (size_t i = 0; i < cols; i++) {
|
||||
EXPECT_EQ(static_cast<INT8_C1 *>(expect_int8.data_ptr_)[i].c1,
|
||||
static_cast<INT8_C1 *>(dst_int8.data_ptr_)[i].c1);
|
||||
EXPECT_EQ(static_cast<INT8_C1 *>(expect_int8.data_ptr_)[i].c1, static_cast<INT8_C1 *>(dst_int8.data_ptr_)[i].c1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue