add yuv420sp to bgr, split and extractchannel

This commit is contained in:
xulei2020 2020-10-14 17:52:42 +08:00
parent 1e678a84dc
commit 442cee15ad
6 changed files with 361 additions and 59 deletions

View File

@ -232,6 +232,17 @@ bool ResizeBilinear(const LiteMat &src, LiteMat &dst, int dst_w, int dst_h) {
return true;
}
static bool ConvertBGR(const unsigned char *data, LDataType data_type, int w, int h, LiteMat &mat) {
if (data_type == LDataType::UINT8) {
mat.Init(w, h, 3, LDataType::UINT8);
unsigned char *dst_ptr = mat;
(void)memcpy(dst_ptr, data, w * h * 3 * sizeof(unsigned char));
} else {
return false;
}
return true;
}
static bool ConvertRGBAToBGR(const unsigned char *data, LDataType data_type, int w, int h, LiteMat &mat) {
if (data_type == LDataType::UINT8) {
mat.Init(w, h, 3, LDataType::UINT8);
@ -272,6 +283,76 @@ static bool ConvertRGBAToRGB(const unsigned char *data, LDataType data_type, int
return true;
}
static bool ConvertYUV420SPToBGR(const uint8_t *data, LDataType data_type, bool flag, int w, int h, LiteMat &mat) {
if (data == nullptr || w <= 0 || h <= 0) {
return false;
}
if (data_type == LDataType::UINT8) {
mat.Init(w, h, 3, LDataType::UINT8);
const uint8_t *y_ptr = data;
const uint8_t *uv_ptr = y_ptr + w * h;
uint8_t *bgr_ptr = mat;
int bgr_stride = 3 * w;
for (int y = 0; y < h; ++y) {
uint8_t *bgr_buf = bgr_ptr;
const uint8_t *uv_buf = uv_ptr;
const uint8_t *y_buf = y_ptr;
uint8_t u;
uint8_t v;
for (int x = 0; x < w - 1; x += 2) {
if (flag) {
// NV21
u = uv_buf[1];
v = uv_buf[0];
} else {
// NV12
u = uv_buf[0];
v = uv_buf[1];
}
uint32_t tmp_y = (uint32_t)(y_buf[0] * YSCALE * YTOG) >> 16;
// b
bgr_buf[0] = std::clamp((int32_t)(-(u * UTOB) + tmp_y + BTOB) >> 6, 0, 255);
// g
bgr_buf[1] = std::clamp((int32_t)(-(u * UTOG + v * VTOG) + tmp_y + BTOG) >> 6, 0, 255);
// r
bgr_buf[2] = std::clamp((int32_t)(-(v * VTOR) + tmp_y + BTOR) >> 6, 0, 255);
tmp_y = (uint32_t)(y_buf[1] * YSCALE * YTOG) >> 16;
bgr_buf[3] = std::clamp((int32_t)(-(u * UTOB) + tmp_y + BTOB) >> 6, 0, 255);
bgr_buf[4] = std::clamp((int32_t)(-(u * UTOG + v * VTOG) + tmp_y + BTOG) >> 6, 0, 255);
bgr_buf[5] = std::clamp((int32_t)(-(v * VTOR) + tmp_y + BTOR) >> 6, 0, 255);
y_buf += 2;
uv_buf += 2;
bgr_buf += 6;
}
if (w & 1) {
if (flag) {
// NV21
u = uv_buf[1];
v = uv_buf[0];
} else {
// NV12
u = uv_buf[0];
v = uv_buf[1];
}
uint32_t tmp_y = (uint32_t)(y_buf[0] * YSCALE * YTOG) >> 16;
bgr_buf[0] = std::clamp((int32_t)(-(u * UTOB) + tmp_y + BTOB) >> 6, 0, 255);
bgr_buf[1] = std::clamp((int32_t)(-(u * UTOG + v * VTOG) + tmp_y + BTOG) >> 6, 0, 255);
bgr_buf[2] = std::clamp((int32_t)(-(v * VTOR) + tmp_y + BTOR) >> 6, 0, 255);
}
bgr_ptr += bgr_stride;
y_ptr += w;
if (y & 1) {
uv_ptr += w;
}
}
}
return true;
}
static bool ConvertRGBAToGRAY(const unsigned char *data, LDataType data_type, int w, int h, LiteMat &mat) {
if (data_type == LDataType::UINT8) {
mat.Init(w, h, 1, LDataType::UINT8);
@ -300,12 +381,24 @@ bool InitFromPixel(const unsigned char *data, LPixelType pixel_type, LDataType d
if (w <= 0 || h <= 0) {
return false;
}
if (data_type != LDataType::UINT8) {
return false;
}
if (pixel_type == LPixelType::RGBA2BGR) {
return ConvertRGBAToBGR(data, data_type, w, h, m);
} else if (pixel_type == LPixelType::RGBA2GRAY) {
return ConvertRGBAToGRAY(data, data_type, w, h, m);
} else if (pixel_type == LPixelType::RGBA2RGB) {
return ConvertRGBAToRGB(data, data_type, w, h, m);
} else if (pixel_type == LPixelType::NV212BGR) {
return ConvertYUV420SPToBGR(data, data_type, true, w, h, m);
} else if (pixel_type == LPixelType::NV122BGR) {
return ConvertYUV420SPToBGR(data, data_type, false, w, h, m);
} else if (pixel_type == LPixelType::BGR) {
return ConvertBGR(data, data_type, w, h, m);
} else if (pixel_type == LPixelType::RGB) {
return ConvertBGR(data, data_type, w, h, m);
} else {
return false;
}
@ -322,8 +415,8 @@ bool ConvertTo(const LiteMat &src, LiteMat &dst, double scale) {
float *dst_start_p = dst;
for (int h = 0; h < src.height_; h++) {
for (int w = 0; w < src.width_; w++) {
uint32_t index = (h * src.width_ + w) * src.channel_;
for (int c = 0; c < src.channel_; c++) {
int index = (h * src.width_ + w) * src.channel_;
dst_start_p[index + c] = (static_cast<float>(src_start_p[index + c] * scale));
}
}
@ -418,8 +511,9 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector<
if ((!mean.empty()) && std.empty()) {
for (int h = 0; h < src.height_; h++) {
for (int w = 0; w < src.width_; w++) {
uint32_t src_start = (h * src.width_ + w) * src.channel_;
for (int c = 0; c < src.channel_; c++) {
int index = (h * src.width_ + w) * src.channel_ + c;
uint32_t index = src_start + c;
dst_start_p[index] = src_start_p[index] - mean[c];
}
}
@ -427,8 +521,9 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector<
} else if (mean.empty() && (!std.empty())) {
for (int h = 0; h < src.height_; h++) {
for (int w = 0; w < src.width_; w++) {
uint32_t src_start = (h * src.width_ + w) * src.channel_;
for (int c = 0; c < src.channel_; c++) {
int index = (h * src.width_ + w) * src.channel_ + c;
uint32_t index = src_start + c;
dst_start_p[index] = src_start_p[index] / std[c];
}
}
@ -436,8 +531,9 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector<
} else if ((!mean.empty()) && (!std.empty())) {
for (int h = 0; h < src.height_; h++) {
for (int w = 0; w < src.width_; w++) {
uint32_t src_start = (h * src.width_ + w) * src.channel_;
for (int c = 0; c < src.channel_; c++) {
int index = (h * src.width_ + w) * src.channel_ + c;
uint32_t index = src_start + c;
dst_start_p[index] = (src_start_p[index] - mean[c]) / std[c];
}
}
@ -458,7 +554,7 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con
// padd top
for (int h = 0; h < top; h++) {
for (int w = 0; w < dst.width_; w++) {
int index = (h * dst.width_ + w) * dst.channel_;
uint32_t index = (h * dst.width_ + w) * dst.channel_;
if (dst.channel_ == 1) {
dst_start_p[index] = fill_b_or_gray;
} else if (dst.channel_ == 3) {
@ -472,7 +568,7 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con
// padd bottom
for (int h = dst.height_ - bottom; h < dst.height_; h++) {
for (int w = 0; w < dst.width_; w++) {
int index = (h * dst.width_ + w) * dst.channel_;
uint32_t index = (h * dst.width_ + w) * dst.channel_;
if (dst.channel_ == 1) {
dst_start_p[index] = fill_b_or_gray;
} else if (dst.channel_ == 3) {
@ -487,7 +583,7 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con
// padd left
for (int h = top; h < dst.height_ - bottom; h++) {
for (int w = 0; w < left; w++) {
int index = (h * dst.width_ + w) * dst.channel_;
uint32_t index = (h * dst.width_ + w) * dst.channel_;
if (dst.channel_ == 1) {
dst_start_p[index] = fill_b_or_gray;
} else if (dst.channel_ == 3) {
@ -502,7 +598,7 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con
// padd right
for (int h = top; h < dst.height_ - bottom; h++) {
for (int w = dst.width_ - right; w < dst.width_; w++) {
int index = (h * dst.width_ + w) * dst.channel_;
uint32_t index = (h * dst.width_ + w) * dst.channel_;
if (dst.channel_ == 1) {
dst_start_p[index] = fill_b_or_gray;
} else if (dst.channel_ == 3) {
@ -522,6 +618,86 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con
}
}
bool ExtractChannel(const LiteMat &src, LiteMat &dst, int col) {
if (src.IsEmpty() || col < 0 || col > src.dims_ - 1) {
return false;
}
(void)dst.Init(src.width_, src.height_, 1, src.data_type_);
if (src.data_type_ == LDataType::FLOAT32) {
const float *src_start_p = src;
float *dst_start_p = dst;
for (int h = 0; h < src.height_; h++) {
uint32_t src_start = h * src.width_ * src.channel_ + col;
uint32_t dst_start = h * dst.width_;
for (int w = 0; w < src.width_; w++) {
uint32_t src_index = src_start + w * src.channel_;
uint32_t dst_index = dst_start + w;
dst_start_p[dst_index] = src_start_p[src_index];
}
}
return true;
} else if (src.data_type_ == LDataType::UINT8) {
const uint8_t *src_start_p = src;
uint8_t *dst_start_p = dst;
for (int h = 0; h < src.height_; h++) {
uint32_t src_start = h * src.width_ * src.channel_ + col;
uint32_t dst_start = h * dst.width_;
for (int w = 0; w < src.width_; w++) {
uint32_t src_index = src_start + w * src.channel_;
uint32_t dst_index = dst_start + w;
dst_start_p[dst_index] = src_start_p[src_index];
}
}
return true;
} else {
return false;
}
return false;
}
bool Split(const LiteMat &src, std::vector<LiteMat> &mv) {
if (src.data_type_ == LDataType::FLOAT32) {
const float *src_start_p = src;
for (int c = 0; c < src.channel_; c++) {
LiteMat dst;
(void)dst.Init(src.width_, src.height_, 1, src.data_type_);
float *dst_start_p = dst;
for (int h = 0; h < src.height_; h++) {
uint32_t src_start = h * src.width_ * src.channel_;
uint32_t dst_start = h * dst.width_;
for (int w = 0; w < src.width_; w++) {
uint32_t src_index = src_start + w * src.channel_ + c;
uint32_t dst_index = dst_start + w;
dst_start_p[dst_index] = src_start_p[src_index];
}
}
mv.emplace_back(dst);
}
return true;
} else if (src.data_type_ == LDataType::UINT8) {
const uint8_t *src_start_p = src;
for (int c = 0; c < src.channel_; c++) {
LiteMat dst;
(void)dst.Init(src.width_, src.height_, 1, src.data_type_);
uint8_t *dst_start_p = dst;
for (int h = 0; h < src.height_; h++) {
uint32_t src_start = h * src.width_ * src.channel_;
uint32_t dst_start = h * dst.width_;
for (int w = 0; w < src.width_; w++) {
uint32_t src_index = src_start + w * src.channel_ + c;
uint32_t dst_index = dst_start + w;
dst_start_p[dst_index] = src_start_p[src_index];
}
}
mv.emplace_back(dst);
}
return true;
} else {
return false;
}
return false;
}
bool Pad(const LiteMat &src, LiteMat &dst, int top, int bottom, int left, int right, PaddBorderType pad_type,
uint8_t fill_b_or_gray, uint8_t fill_g, uint8_t fill_r) {
if (top <= 0 || bottom <= 0 || left <= 0 || right <= 0) {

View File

@ -35,6 +35,17 @@ namespace dataset {
#define B2GRAY 29
#define GRAYSHIFT 8
#define YSCALE 0x0101
#define UTOB (-128)
#define UTOG 25
#define VTOR (-102)
#define VTOG 52
#define YTOG 18997
#define YTOGB (-1160)
#define BTOB (UTOB * 128 + YTOGB)
#define BTOG (UTOG * 128 + VTOG * 128 + YTOGB)
#define BTOR (VTOR * 128 + YTOGB)
enum PaddBorderType { PADD_BORDER_CONSTANT = 0, PADD_BORDER_REPLICATE = 1 };
struct BoxesConfig {
@ -70,6 +81,10 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector<
bool Pad(const LiteMat &src, LiteMat &dst, int top, int bottom, int left, int right, PaddBorderType pad_type,
uint8_t fill_b_or_gray, uint8_t fill_g, uint8_t fill_r);
bool ExtractChannel(const LiteMat &src, LiteMat &dst, int col);
bool Split(const LiteMat &src, std::vector<LiteMat> &mv);
/// \brief Apply affine transformation for 1 channel image
bool Affine(LiteMat &src, LiteMat &out_img, double M[6], std::vector<size_t> dsize, UINT8_C1 borderValue);

View File

@ -92,6 +92,8 @@ enum LPixelType {
RGBA2GRAY = 3,
RGBA2BGR = 4,
RGBA2RGB = 5,
NV212BGR = 6,
NV122BGR = 7,
};
class LDataType {
@ -159,7 +161,6 @@ class LDataType {
class LiteMat {
// Class that represents a lite Mat of a Image.
// -# The pixel type of Lite Mat is RGBRGB...RGB.
public:
LiteMat();

View File

@ -19,7 +19,6 @@
#include "lite_cv/image_process.h"
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/types_c.h>
#include "utils/log_adapter.h"
#include <fstream>
@ -43,32 +42,22 @@ void CompareMat(cv::Mat cv_mat, LiteMat lite_mat) {
ASSERT_TRUE(cv_c == lite_c);
}
LiteMat Lite3CImageProcess(LiteMat &lite_mat_bgr) {
void Lite3CImageProcess(LiteMat &lite_mat_bgr, LiteMat &lite_norm_mat_cut) {
bool ret;
LiteMat lite_mat_resize;
ret = ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256);
if (!ret) {
MS_LOG(ERROR) << "ResizeBilinear error";
}
ASSERT_TRUE(ret == true);
LiteMat lite_mat_convert_float;
ret = ConvertTo(lite_mat_resize, lite_mat_convert_float, 1.0);
if (!ret) {
MS_LOG(ERROR) << "ConvertTo error";
}
ASSERT_TRUE(ret == true);
LiteMat lite_mat_crop;
ret = Crop(lite_mat_convert_float, lite_mat_crop, 16, 16, 224, 224);
if (!ret) {
MS_LOG(ERROR) << "Crop error";
}
ASSERT_TRUE(ret == true);
std::vector<float> means = {0.485, 0.456, 0.406};
std::vector<float> stds = {0.229, 0.224, 0.225};
LiteMat lite_norm_mat_cut;
SubStractMeanNormalize(lite_mat_crop, lite_norm_mat_cut, means, stds);
return lite_norm_mat_cut;
return;
}
cv::Mat cv3CImageProcess(cv::Mat &image) {
@ -103,11 +92,25 @@ cv::Mat cv3CImageProcess(cv::Mat &image) {
return imgR2;
}
TEST_F(MindDataImageProcess, testRGB) {
std::string filename = "data/dataset/apple.jpg";
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
cv::Mat rgba_mat;
cv::cvtColor(image, rgba_mat, CV_BGR2RGB);
bool ret = false;
LiteMat lite_mat_rgb;
ret = InitFromPixel(rgba_mat.data, LPixelType::RGB, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_rgb);
ASSERT_TRUE(ret == true);
cv::Mat dst_image(lite_mat_rgb.height_, lite_mat_rgb.width_, CV_8UC3, lite_mat_rgb.data_ptr_);
}
TEST_F(MindDataImageProcess, test3C) {
std::string filename = "data/dataset/apple.jpg";
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
cv::Mat cv_image = cv3CImageProcess(image);
// cv::imwrite("/home/xlei/test_3cv.jpg", cv_image);
// convert to RGBA for Android bitmap(rgba)
cv::Mat rgba_mat;
@ -117,34 +120,142 @@ TEST_F(MindDataImageProcess, test3C) {
LiteMat lite_mat_bgr;
ret =
InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr);
if (!ret) {
MS_LOG(ERROR) << "Init From RGBA error";
}
LiteMat lite_norm_mat_cut = Lite3CImageProcess(lite_mat_bgr);
ASSERT_TRUE(ret == true);
LiteMat lite_norm_mat_cut;
Lite3CImageProcess(lite_mat_bgr, lite_norm_mat_cut);
cv::Mat dst_image(lite_norm_mat_cut.height_, lite_norm_mat_cut.width_, CV_32FC3, lite_norm_mat_cut.data_ptr_);
// cv::imwrite("/home/xlei/test_3clite.jpg", dst_image);
CompareMat(cv_image, lite_norm_mat_cut);
}
LiteMat Lite1CImageProcess(LiteMat &lite_mat_bgr) {
bool ReadYUV(const char *filename, int w, int h, uint8_t **data) {
FILE *f = fopen(filename, "rb");
if (f == nullptr) {
return false;
}
fseek(f, 0, SEEK_END);
int size = ftell(f);
int expect_size = w * h + 2 * ((w + 1) / 2) * ((h + 1) / 2);
if (size != expect_size) {
fclose(f);
return false;
}
fseek(f, 0, SEEK_SET);
*data = (uint8_t *)malloc(size);
size_t re = fread(*data, 1, size, f);
if (re != size) {
fclose(f);
return false;
}
fclose(f);
return true;
}
TEST_F(MindDataImageProcess, testNV21ToBGR) {
// ffmpeg -i ./data/dataset/apple.jpg -s 1024*800 -pix_fmt nv21 ./data/dataset/yuv/test_nv21.yuv
const char *filename = "data/dataset/yuv/test_nv21.yuv";
int w = 1024;
int h = 800;
uint8_t *yuv_data = nullptr;
bool ret = ReadYUV(filename, w, h, &yuv_data);
ASSERT_TRUE(ret == true);
cv::Mat yuvimg(h * 3 / 2, w, CV_8UC1);
memcpy(yuvimg.data, yuv_data, w * h * 3 / 2);
cv::Mat rgbimage;
cv::cvtColor(yuvimg, rgbimage, cv::COLOR_YUV2BGR_NV21);
LiteMat lite_mat_bgr;
ret = InitFromPixel(yuv_data, LPixelType::NV212BGR, LDataType::UINT8, w, h, lite_mat_bgr);
ASSERT_TRUE(ret == true);
cv::Mat dst_image(lite_mat_bgr.height_, lite_mat_bgr.width_, CV_8UC3, lite_mat_bgr.data_ptr_);
}
TEST_F(MindDataImageProcess, testNV12ToBGR) {
// ffmpeg -i ./data/dataset/apple.jpg -s 1024*800 -pix_fmt nv12 ./data/dataset/yuv/test_nv12.yuv
const char *filename = "data/dataset/yuv/test_nv12.yuv";
int w = 1024;
int h = 800;
uint8_t *yuv_data = nullptr;
bool ret = ReadYUV(filename, w, h, &yuv_data);
ASSERT_TRUE(ret == true);
cv::Mat yuvimg(h * 3 / 2, w, CV_8UC1);
memcpy(yuvimg.data, yuv_data, w * h * 3 / 2);
cv::Mat rgbimage;
cv::cvtColor(yuvimg, rgbimage, cv::COLOR_YUV2BGR_NV12);
LiteMat lite_mat_bgr;
ret = InitFromPixel(yuv_data, LPixelType::NV122BGR, LDataType::UINT8, w, h, lite_mat_bgr);
ASSERT_TRUE(ret == true);
cv::Mat dst_image(lite_mat_bgr.height_, lite_mat_bgr.width_, CV_8UC3, lite_mat_bgr.data_ptr_);
}
TEST_F(MindDataImageProcess, testExtractChannel) {
std::string filename = "data/dataset/apple.jpg";
cv::Mat src_image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
cv::Mat dst_image;
cv::extractChannel(src_image, dst_image, 2);
// convert to RGBA for Android bitmap(rgba)
cv::Mat rgba_mat;
cv::cvtColor(src_image, rgba_mat, CV_BGR2RGBA);
bool ret = false;
LiteMat lite_mat_bgr;
ret =
InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr);
ASSERT_TRUE(ret == true);
LiteMat lite_B;
ret = ExtractChannel(lite_mat_bgr, lite_B, 0);
ASSERT_TRUE(ret == true);
LiteMat lite_R;
ret = ExtractChannel(lite_mat_bgr, lite_R, 2);
ASSERT_TRUE(ret == true);
cv::Mat dst_imageR(lite_R.height_, lite_R.width_, CV_8UC1, lite_R.data_ptr_);
// cv::imwrite("./test_lite_r.jpg", dst_imageR);
}
TEST_F(MindDataImageProcess, testSplit) {
std::string filename = "data/dataset/apple.jpg";
cv::Mat src_image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
std::vector<cv::Mat> dst_images;
cv::split(src_image, dst_images);
// convert to RGBA for Android bitmap(rgba)
cv::Mat rgba_mat;
cv::cvtColor(src_image, rgba_mat, CV_BGR2RGBA);
bool ret = false;
LiteMat lite_mat_bgr;
ret =
InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr);
ASSERT_TRUE(ret == true);
std::vector<LiteMat> lite_all;
ret = Split(lite_mat_bgr, lite_all);
ASSERT_TRUE(ret == true);
ASSERT_TRUE(lite_all.size() == 3);
LiteMat lite_r = lite_all[2];
cv::Mat dst_imageR(lite_r.height_, lite_r.width_, CV_8UC1, lite_r.data_ptr_);
}
void Lite1CImageProcess(LiteMat &lite_mat_bgr, LiteMat &lite_norm_mat_cut) {
LiteMat lite_mat_resize;
ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256);
int ret = ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256);
ASSERT_TRUE(ret == true);
LiteMat lite_mat_convert_float;
ConvertTo(lite_mat_resize, lite_mat_convert_float);
ret = ConvertTo(lite_mat_resize, lite_mat_convert_float);
ASSERT_TRUE(ret == true);
LiteMat lite_mat_cut;
Crop(lite_mat_convert_float, lite_mat_cut, 16, 16, 224, 224);
ret = Crop(lite_mat_convert_float, lite_mat_cut, 16, 16, 224, 224);
ASSERT_TRUE(ret == true);
std::vector<float> means = {0.485};
std::vector<float> stds = {0.229};
LiteMat lite_norm_mat_cut;
SubStractMeanNormalize(lite_mat_cut, lite_norm_mat_cut, means, stds);
return lite_norm_mat_cut;
ret = SubStractMeanNormalize(lite_mat_cut, lite_norm_mat_cut, means, stds);
ASSERT_TRUE(ret == true);
return;
}
cv::Mat cv1CImageProcess(cv::Mat &image) {
@ -183,18 +294,17 @@ TEST_F(MindDataImageProcess, test1C) {
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
cv::Mat cv_image = cv1CImageProcess(image);
// cv::imwrite("/home/xlei/test_c1v.jpg", cv_image);
// convert to RGBA for Android bitmap(rgba)
cv::Mat rgba_mat;
cv::cvtColor(image, rgba_mat, CV_BGR2RGBA);
LiteMat lite_mat_bgr;
InitFromPixel(rgba_mat.data, LPixelType::RGBA2GRAY, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr);
LiteMat lite_norm_mat_cut = Lite1CImageProcess(lite_mat_bgr);
bool ret =
InitFromPixel(rgba_mat.data, LPixelType::RGBA2GRAY, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr);
ASSERT_TRUE(ret == true);
LiteMat lite_norm_mat_cut;
Lite1CImageProcess(lite_mat_bgr, lite_norm_mat_cut);
cv::Mat dst_image(lite_norm_mat_cut.height_, lite_norm_mat_cut.width_, CV_32FC1, lite_norm_mat_cut.data_ptr_);
// cv::imwrite("/home/xlei/test_c1lite.jpg", dst_image);
CompareMat(cv_image, lite_norm_mat_cut);
}
@ -211,22 +321,20 @@ TEST_F(MindDataImageProcess, TestPadd) {
cv::Mat b_image;
cv::Scalar color = cv::Scalar(255, 255, 255);
cv::copyMakeBorder(resize_256_image, b_image, top, bottom, left, right, cv::BORDER_CONSTANT, color);
// cv::imwrite("/home/xlei/test_ccc.jpg", b_image);
cv::Mat rgba_mat;
cv::cvtColor(image, rgba_mat, CV_BGR2RGBA);
LiteMat lite_mat_bgr;
InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr);
bool ret =
InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr);
ASSERT_TRUE(ret == true);
LiteMat lite_mat_resize;
ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256);
ret = ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256);
ASSERT_TRUE(ret == true);
LiteMat makeborder;
Pad(lite_mat_resize, makeborder, top, bottom, left, right, PaddBorderType::PADD_BORDER_CONSTANT, 255, 255, 255);
ret = Pad(lite_mat_resize, makeborder, top, bottom, left, right, PaddBorderType::PADD_BORDER_CONSTANT, 255, 255, 255);
ASSERT_TRUE(ret == true);
cv::Mat dst_image(256 + top + bottom, 256 + left + right, CV_8UC3, makeborder.data_ptr_);
// cv::imwrite("/home/xlei/test_liteccc.jpg", dst_image);
}
TEST_F(MindDataImageProcess, TestGetDefaultBoxes) {

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long