!30942 Add hwc_to_chw operation on lite

Merge pull request !30942 from shenwei41/hwc_to_chw
This commit is contained in:
i-robot 2022-03-10 01:21:54 +00:00 committed by Gitee
commit 3f1a7f45bb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 196 additions and 26 deletions

View File

@ -388,12 +388,14 @@ std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
HorizontalFlip::HorizontalFlip() = default;
std::shared_ptr<TensorOperation> HorizontalFlip::Parse() { return std::make_shared<HorizontalFlipOperation>(); }
#endif // not ENABLE_ANDROID
// HwcToChw Transform Operation.
HWC2CHW::HWC2CHW() = default;
std::shared_ptr<TensorOperation> HWC2CHW::Parse() { return std::make_shared<HwcToChwOperation>(); }
#ifndef ENABLE_ANDROID
// Invert Transform Operation.
Invert::Invert() = default;

View File

@ -331,28 +331,6 @@ class MS_API HorizontalFlip final : public TensorTransform {
std::shared_ptr<TensorOperation> Parse() override;
};
/// \brief Transpose the input image; shape (H, W, C) to shape (C, H, W).
class MS_API HWC2CHW final : public TensorTransform {
public:
/// \brief Constructor.
/// \par Example
/// \code
/// /* dataset is an instance of Dataset object */
/// dataset = dataset->Map({std::make_shared<vision::Decode>(),
/// std::make_shared<vision::HWC2CHW>()}, // operations
/// {"image"}); // input columns
/// \endcode
HWC2CHW();
/// \brief Destructor.
~HWC2CHW() = default;
protected:
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
};
/// \brief Apply invert on the input image in RGB mode.
class MS_API Invert final : public TensorTransform {
public:

View File

@ -208,6 +208,28 @@ class MS_API GaussianBlur final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Transpose the input image; shape (H, W, C) to shape (C, H, W).
class MS_API HWC2CHW final : public TensorTransform {
public:
/// \brief Constructor.
/// \par Example
/// \code
/// /* dataset is an instance of Dataset object */
/// dataset = dataset->Map({std::make_shared<vision::Decode>(),
/// std::make_shared<vision::HWC2CHW>()}, // operations
/// {"image"}); // input columns
/// \endcode
HWC2CHW();
/// \brief Destructor.
~HWC2CHW() = default;
protected:
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
};
/// \brief Normalize the input image with respect to mean and standard deviation.
class MS_API Normalize final : public TensorTransform {
public:

View File

@ -15,7 +15,11 @@
*/
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {

View File

@ -2030,5 +2030,47 @@ bool ResizePreserveARWithFiller(LiteMat &src, LiteMat &dst, int h, int w, float
return true;
}
template <typename T>
void HWC2CHWImpl(const T *src_ptr, T *dst_ptr, int height, int width, int channel) {
int stride = width * height;
for (int i = 0; i != stride; ++i) {
for (int c = 0; c != channel; ++c) {
dst_ptr[c * stride + i] = src_ptr[i * channel + c];
}
}
}
bool HWC2CHW(LiteMat &src, LiteMat &dst) {
if (src.IsEmpty()) {
return false;
}
if (dst.IsEmpty() || dst.width_ != src.height_ || dst.height_ != src.channel_ || dst.channel_ != src.width_ ||
dst.data_type_ != src.data_type_) {
dst.Init(src.height_, src.channel_, src.width_, src.data_type_);
}
if (src.data_type_ == LDataType::FLOAT32) {
HWC2CHWImpl<float>(src, dst, src.height_, src.width_, src.channel_);
} else if (src.data_type_ == LDataType::UINT8) {
HWC2CHWImpl<uint8_t>(src, dst, src.height_, src.width_, src.channel_);
} else if (src.data_type_ == LDataType::INT16) {
HWC2CHWImpl<int16_t>(src, dst, src.height_, src.width_, src.channel_);
} else if (src.data_type_ == LDataType::INT32) {
HWC2CHWImpl<int32_t>(src, dst, src.height_, src.width_, src.channel_);
} else if (src.data_type_ == LDataType::INT64) {
HWC2CHWImpl<int64_t>(src, dst, src.height_, src.width_, src.channel_);
} else if (src.data_type_ == LDataType::UINT16) {
HWC2CHWImpl<uint16_t>(src, dst, src.height_, src.width_, src.channel_);
} else if (src.data_type_ == LDataType::UINT32) {
HWC2CHWImpl<uint32_t>(src, dst, src.height_, src.width_, src.channel_);
} else if (src.data_type_ == LDataType::UINT64) {
HWC2CHWImpl<uint64_t>(src, dst, src.height_, src.width_, src.channel_);
} else if (src.data_type_ == LDataType::DOUBLE) {
HWC2CHWImpl<double>(src, dst, src.height_, src.width_, src.channel_);
} else {
return false;
}
return true;
}
} // namespace dataset
} // namespace mindspore

View File

@ -606,6 +606,22 @@ bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, Lit
bool ResizePreserveARWithFiller(LiteMat &src, LiteMat &dst, int h, int w, float (*ratioShiftWShiftH)[3],
float (*invM)[2][3], int img_orientation);
/// \brief Transpose the input image; shape (H, W, C) to shape (C, H, W).
/// \param[in] src Input image data.
/// \param[in] dst Output image data.
/// \par Example
/// \code
/// /* Assume p_rgb is a pointer that points to an image with shape (width, height, channel) */
/// LiteMat lite_mat_src;
/// lite_mat_src.Init(width, height, channel, p_rgb, LDataType::UINT8);
/// LiteMat lite_mat_dst;
///
/// HWC2CHW(lite_mat_src, lite_mat_dst);
/// std::cout << lite_mat_dst.width_ << " " << lite_mat_dst.height_ << " " << lite_mat_dst.channel_ << std::endl;
/// \endcode
/// \return Return true if transform successfully.
bool HWC2CHW(LiteMat &src, LiteMat &dst);
} // namespace dataset
} // namespace mindspore
#endif // IMAGE_PROCESS_H_

View File

@ -820,5 +820,33 @@ Status ValidateImageRank(const std::string &op_name, int32_t rank) {
}
return Status::OK();
}
Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
try {
if (input->Rank() <= 3) {
int output_height = input->shape()[0];
int output_width = input->shape()[1];
int output_channel = input->shape()[2];
LiteMat lite_mat_hwc(input->shape()[1], input->shape()[0], input->shape()[2],
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
GetLiteCVDataType(input->type()));
LiteMat lite_mat_chw;
std::shared_ptr<Tensor> output_tensor;
TensorShape new_shape = TensorShape({output_channel, output_height, output_width});
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), &output_tensor));
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
lite_mat_chw.Init(output_height, output_channel, output_width, reinterpret_cast<void *>(buffer),
GetLiteCVDataType(input->type()));
bool ret = HWC2CHW(lite_mat_hwc, lite_mat_chw);
CHECK_FAIL_RETURN_UNEXPECTED(ret, "HwcToChw: HwcToChw failed.");
*output = output_tensor;
} else {
RETURN_STATUS_UNEXPECTED("HwcToChw: input image is not in shape of <H,W,C> or <H,W>");
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED("HwcToChw: " + std::string(e.what()));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -165,6 +165,11 @@ Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
/// \param[in] op_name operator name.
/// \param[in] rank refers to the rank of input image shape.
Status ValidateImageRank(const std::string &op_name, int32_t rank);
/// \brief Swaps the channels in the image, i.e. converts HWC to CHW
/// \param input: Tensor of shape <H,W,C> or <H,W> and any OpenCv compatible type, see CVTensor.
/// \param output: Tensor of shape <C,H,W> or <H,W> and same input type.
Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_LITE_IMAGE_UTILS_H_

View File

@ -17,16 +17,13 @@
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// HwcToChwOperation
HwcToChwOperation::~HwcToChwOperation() = default;
@ -40,7 +37,6 @@ Status HwcToChwOperation::from_json(nlohmann::json op_params, std::shared_ptr<Te
*operation = std::make_shared<vision::HwcToChwOperation>();
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -218,6 +218,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
${MINDDATA_DIR}/kernels/image/crop_op.cc
${MINDDATA_DIR}/kernels/image/decode_op.cc
${MINDDATA_DIR}/kernels/image/gaussian_blur_op.cc
${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc
${MINDDATA_DIR}/kernels/image/normalize_op.cc
${MINDDATA_DIR}/kernels/image/resize_op.cc
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc
@ -364,6 +365,7 @@ elseif(MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper")
${MINDDATA_DIR}/kernels/image/lite_image_utils.cc
${MINDDATA_DIR}/kernels/image/center_crop_op.cc
${MINDDATA_DIR}/kernels/image/crop_op.cc
${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc
${MINDDATA_DIR}/kernels/image/normalize_op.cc
${MINDDATA_DIR}/kernels/image/resize_op.cc
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc

View File

@ -42,6 +42,31 @@ void CompareMat(cv::Mat cv_mat, LiteMat lite_mat) {
ASSERT_TRUE(cv_c == lite_c);
}
template <typename T>
bool compare_mat(LiteMat &expect_value, LiteMat &calculate_value) {
int rows = expect_value.height_;
int cols = expect_value.width_;
int stride = rows * cols;
for (int i = 0; i < stride; i++) {
for (int j = 0; j < expect_value.channel_; j++) {
T value = reinterpret_cast<T *>(calculate_value.data_ptr_)[i * calculate_value.channel_ + j];
T value_expect = reinterpret_cast<T *>(expect_value.data_ptr_)[i * expect_value.channel_ + j];
if (value != value_expect) {
return false;
}
}
}
return true;
}
bool compare_mat_shape(LiteMat &src, LiteMat &calculate_value) {
if (calculate_value.width_ != src.height_ || calculate_value.height_ != src.channel_ ||
calculate_value.channel_ != src.width_) {
return false;
}
return true;
}
void Lite3CImageProcess(LiteMat &lite_mat_bgr, LiteMat &lite_norm_mat_cut) {
bool ret;
LiteMat lite_mat_resize;
@ -1951,3 +1976,53 @@ TEST_F(MindDataImageProcess, testResizePreserveARWithFillervFail) {
bool ret3 = ResizePreserveARWithFiller(lite_mat_rgb3, lite_mat_resize3, h3, w3, &ratioShiftWShiftH3, &invM3, 0);
ASSERT_TRUE(ret3 == false);
}
/// Feature: Test HWCTOCHW Operation successfully.
/// Description: The input is a three-dimensional int array.
/// Expectation: success and The final result should be consistent with expect_value_arr.
TEST_F(MindDataImageProcess, TestHWCTOCHW) {
std::vector<int32_t> a = {1, 2, 3, 4, 5, 6};
LiteMat src(1, 2, 3, a.data(), LDataType(LDataType::INT32));
LiteMat dst;
std::vector<int32_t> expect_value_arr = {1, 4, 2, 5, 3, 6};
LiteMat expect_value(2, 3, 1, expect_value_arr.data(), LDataType(LDataType::INT32));
HWC2CHW(src, dst);
bool res = compare_mat<int32_t>(expect_value, dst);
ASSERT_TRUE(res == true);
bool ret = compare_mat_shape(src, expect_value);
ASSERT_TRUE(ret == true);
}
/// Feature: Test HWCTOCHW Operation successfully.
/// Description: The input is a three channel picture.
/// Expectation: success and The final result should be consistent with expect_value.
TEST_F(MindDataImageProcess, TestHWCTOCHWImage) {
std::string filename = "data/dataset/apple.jpg";
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
cv::Mat image_resize;
cv::resize(image, image_resize, cv::Size(10, 10));
cv::Mat rgb_mat;
cv::cvtColor(image_resize, rgb_mat, CV_BGR2RGB);
// Implements hwc conversion chw by opencv
std::vector<uint8_t> dst_data;
std::vector<cv::Mat> bgrChannels(3);
cv::split(rgb_mat, bgrChannels);
for (auto i = 0; i < bgrChannels.size(); i++) {
std::vector<uint8_t> data = std::vector<uint8_t>(bgrChannels[i].reshape(1, 1));
dst_data.insert(dst_data.end(), data.begin(), data.end());
}
// HWC2CHW operation implementation.
LiteMat expect_value(rgb_mat.rows, rgb_mat.channels(), rgb_mat.cols, dst_data.data(), LDataType(LDataType::UINT8));
LiteMat src(rgb_mat.cols, rgb_mat.rows, rgb_mat.channels(), rgb_mat.data, LDataType(LDataType::UINT8));
LiteMat dst;
HWC2CHW(src, dst);
bool res = compare_mat<uint8_t>(expect_value, dst);
ASSERT_TRUE(res == true);
bool ret = compare_mat_shape(src, expect_value);
ASSERT_TRUE(ret == true);
}