forked from mindspore-Ecosystem/mindspore
!30942 Add hwc_to_chw operation on lite
Merge pull request !30942 from shenwei41/hwc_to_chw
This commit is contained in:
commit
3f1a7f45bb
|
@ -388,12 +388,14 @@ std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
|
|||
HorizontalFlip::HorizontalFlip() = default;
|
||||
|
||||
std::shared_ptr<TensorOperation> HorizontalFlip::Parse() { return std::make_shared<HorizontalFlipOperation>(); }
|
||||
#endif // not ENABLE_ANDROID
|
||||
|
||||
// HwcToChw Transform Operation.
|
||||
HWC2CHW::HWC2CHW() = default;
|
||||
|
||||
std::shared_ptr<TensorOperation> HWC2CHW::Parse() { return std::make_shared<HwcToChwOperation>(); }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Invert Transform Operation.
|
||||
Invert::Invert() = default;
|
||||
|
||||
|
|
|
@ -331,28 +331,6 @@ class MS_API HorizontalFlip final : public TensorTransform {
|
|||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
};
|
||||
|
||||
/// \brief Transpose the input image; shape (H, W, C) to shape (C, H, W).
|
||||
class MS_API HWC2CHW final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* dataset is an instance of Dataset object */
|
||||
/// dataset = dataset->Map({std::make_shared<vision::Decode>(),
|
||||
/// std::make_shared<vision::HWC2CHW>()}, // operations
|
||||
/// {"image"}); // input columns
|
||||
/// \endcode
|
||||
HWC2CHW();
|
||||
|
||||
/// \brief Destructor.
|
||||
~HWC2CHW() = default;
|
||||
|
||||
protected:
|
||||
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
};
|
||||
|
||||
/// \brief Apply invert on the input image in RGB mode.
|
||||
class MS_API Invert final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -208,6 +208,28 @@ class MS_API GaussianBlur final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Transpose the input image; shape (H, W, C) to shape (C, H, W).
|
||||
class MS_API HWC2CHW final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* dataset is an instance of Dataset object */
|
||||
/// dataset = dataset->Map({std::make_shared<vision::Decode>(),
|
||||
/// std::make_shared<vision::HWC2CHW>()}, // operations
|
||||
/// {"image"}); // input columns
|
||||
/// \endcode
|
||||
HWC2CHW();
|
||||
|
||||
/// \brief Destructor.
|
||||
~HWC2CHW() = default;
|
||||
|
||||
protected:
|
||||
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
};
|
||||
|
||||
/// \brief Normalize the input image with respect to mean and standard deviation.
|
||||
class MS_API Normalize final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -15,7 +15,11 @@
|
|||
*/
|
||||
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#else
|
||||
#include "minddata/dataset/kernels/image/lite_image_utils.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -2030,5 +2030,47 @@ bool ResizePreserveARWithFiller(LiteMat &src, LiteMat &dst, int h, int w, float
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void HWC2CHWImpl(const T *src_ptr, T *dst_ptr, int height, int width, int channel) {
|
||||
int stride = width * height;
|
||||
for (int i = 0; i != stride; ++i) {
|
||||
for (int c = 0; c != channel; ++c) {
|
||||
dst_ptr[c * stride + i] = src_ptr[i * channel + c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool HWC2CHW(LiteMat &src, LiteMat &dst) {
|
||||
if (src.IsEmpty()) {
|
||||
return false;
|
||||
}
|
||||
if (dst.IsEmpty() || dst.width_ != src.height_ || dst.height_ != src.channel_ || dst.channel_ != src.width_ ||
|
||||
dst.data_type_ != src.data_type_) {
|
||||
dst.Init(src.height_, src.channel_, src.width_, src.data_type_);
|
||||
}
|
||||
if (src.data_type_ == LDataType::FLOAT32) {
|
||||
HWC2CHWImpl<float>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else if (src.data_type_ == LDataType::UINT8) {
|
||||
HWC2CHWImpl<uint8_t>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else if (src.data_type_ == LDataType::INT16) {
|
||||
HWC2CHWImpl<int16_t>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else if (src.data_type_ == LDataType::INT32) {
|
||||
HWC2CHWImpl<int32_t>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else if (src.data_type_ == LDataType::INT64) {
|
||||
HWC2CHWImpl<int64_t>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else if (src.data_type_ == LDataType::UINT16) {
|
||||
HWC2CHWImpl<uint16_t>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else if (src.data_type_ == LDataType::UINT32) {
|
||||
HWC2CHWImpl<uint32_t>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else if (src.data_type_ == LDataType::UINT64) {
|
||||
HWC2CHWImpl<uint64_t>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else if (src.data_type_ == LDataType::DOUBLE) {
|
||||
HWC2CHWImpl<double>(src, dst, src.height_, src.width_, src.channel_);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -606,6 +606,22 @@ bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, Lit
|
|||
bool ResizePreserveARWithFiller(LiteMat &src, LiteMat &dst, int h, int w, float (*ratioShiftWShiftH)[3],
|
||||
float (*invM)[2][3], int img_orientation);
|
||||
|
||||
/// \brief Transpose the input image; shape (H, W, C) to shape (C, H, W).
|
||||
/// \param[in] src Input image data.
|
||||
/// \param[in] dst Output image data.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Assume p_rgb is a pointer that points to an image with shape (width, height, channel) */
|
||||
/// LiteMat lite_mat_src;
|
||||
/// lite_mat_src.Init(width, height, channel, p_rgb, LDataType::UINT8);
|
||||
/// LiteMat lite_mat_dst;
|
||||
///
|
||||
/// HWC2CHW(lite_mat_src, lite_mat_dst);
|
||||
/// std::cout << lite_mat_dst.width_ << " " << lite_mat_dst.height_ << " " << lite_mat_dst.channel_ << std::endl;
|
||||
/// \endcode
|
||||
/// \return Return true if transform successfully.
|
||||
bool HWC2CHW(LiteMat &src, LiteMat &dst);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // IMAGE_PROCESS_H_
|
||||
|
|
|
@ -820,5 +820,33 @@ Status ValidateImageRank(const std::string &op_name, int32_t rank) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
|
||||
try {
|
||||
if (input->Rank() <= 3) {
|
||||
int output_height = input->shape()[0];
|
||||
int output_width = input->shape()[1];
|
||||
int output_channel = input->shape()[2];
|
||||
LiteMat lite_mat_hwc(input->shape()[1], input->shape()[0], input->shape()[2],
|
||||
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
|
||||
GetLiteCVDataType(input->type()));
|
||||
LiteMat lite_mat_chw;
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
TensorShape new_shape = TensorShape({output_channel, output_height, output_width});
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), &output_tensor));
|
||||
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
|
||||
lite_mat_chw.Init(output_height, output_channel, output_width, reinterpret_cast<void *>(buffer),
|
||||
GetLiteCVDataType(input->type()));
|
||||
bool ret = HWC2CHW(lite_mat_hwc, lite_mat_chw);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "HwcToChw: HwcToChw failed.");
|
||||
*output = output_tensor;
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("HwcToChw: input image is not in shape of <H,W,C> or <H,W>");
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("HwcToChw: " + std::string(e.what()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -165,6 +165,11 @@ Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
|||
/// \param[in] op_name operator name.
|
||||
/// \param[in] rank refers to the rank of input image shape.
|
||||
Status ValidateImageRank(const std::string &op_name, int32_t rank);
|
||||
|
||||
/// \brief Swaps the channels in the image, i.e. converts HWC to CHW
|
||||
/// \param input: Tensor of shape <H,W,C> or <H,W> and any OpenCv compatible type, see CVTensor.
|
||||
/// \param output: Tensor of shape <C,H,W> or <H,W> and same input type.
|
||||
Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_LITE_IMAGE_UTILS_H_
|
||||
|
|
|
@ -17,16 +17,13 @@
|
|||
|
||||
#include "minddata/dataset/kernels/ir/vision/hwc_to_chw_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
#ifndef ENABLE_ANDROID
|
||||
// HwcToChwOperation
|
||||
HwcToChwOperation::~HwcToChwOperation() = default;
|
||||
|
||||
|
@ -40,7 +37,6 @@ Status HwcToChwOperation::from_json(nlohmann::json op_params, std::shared_ptr<Te
|
|||
*operation = std::make_shared<vision::HwcToChwOperation>();
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -218,6 +218,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
${MINDDATA_DIR}/kernels/image/crop_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/decode_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/gaussian_blur_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/normalize_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/resize_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc
|
||||
|
@ -364,6 +365,7 @@ elseif(MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper")
|
|||
${MINDDATA_DIR}/kernels/image/lite_image_utils.cc
|
||||
${MINDDATA_DIR}/kernels/image/center_crop_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/crop_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/hwc_to_chw_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/normalize_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/resize_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc
|
||||
|
|
|
@ -42,6 +42,31 @@ void CompareMat(cv::Mat cv_mat, LiteMat lite_mat) {
|
|||
ASSERT_TRUE(cv_c == lite_c);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool compare_mat(LiteMat &expect_value, LiteMat &calculate_value) {
|
||||
int rows = expect_value.height_;
|
||||
int cols = expect_value.width_;
|
||||
int stride = rows * cols;
|
||||
for (int i = 0; i < stride; i++) {
|
||||
for (int j = 0; j < expect_value.channel_; j++) {
|
||||
T value = reinterpret_cast<T *>(calculate_value.data_ptr_)[i * calculate_value.channel_ + j];
|
||||
T value_expect = reinterpret_cast<T *>(expect_value.data_ptr_)[i * expect_value.channel_ + j];
|
||||
if (value != value_expect) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compare_mat_shape(LiteMat &src, LiteMat &calculate_value) {
|
||||
if (calculate_value.width_ != src.height_ || calculate_value.height_ != src.channel_ ||
|
||||
calculate_value.channel_ != src.width_) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Lite3CImageProcess(LiteMat &lite_mat_bgr, LiteMat &lite_norm_mat_cut) {
|
||||
bool ret;
|
||||
LiteMat lite_mat_resize;
|
||||
|
@ -1951,3 +1976,53 @@ TEST_F(MindDataImageProcess, testResizePreserveARWithFillervFail) {
|
|||
bool ret3 = ResizePreserveARWithFiller(lite_mat_rgb3, lite_mat_resize3, h3, w3, &ratioShiftWShiftH3, &invM3, 0);
|
||||
ASSERT_TRUE(ret3 == false);
|
||||
}
|
||||
|
||||
/// Feature: Test HWCTOCHW Operation successfully.
|
||||
/// Description: The input is a three-dimensional int array.
|
||||
/// Expectation: success and The final result should be consistent with expect_value_arr.
|
||||
TEST_F(MindDataImageProcess, TestHWCTOCHW) {
|
||||
std::vector<int32_t> a = {1, 2, 3, 4, 5, 6};
|
||||
LiteMat src(1, 2, 3, a.data(), LDataType(LDataType::INT32));
|
||||
LiteMat dst;
|
||||
std::vector<int32_t> expect_value_arr = {1, 4, 2, 5, 3, 6};
|
||||
LiteMat expect_value(2, 3, 1, expect_value_arr.data(), LDataType(LDataType::INT32));
|
||||
HWC2CHW(src, dst);
|
||||
|
||||
bool res = compare_mat<int32_t>(expect_value, dst);
|
||||
ASSERT_TRUE(res == true);
|
||||
|
||||
bool ret = compare_mat_shape(src, expect_value);
|
||||
ASSERT_TRUE(ret == true);
|
||||
}
|
||||
|
||||
/// Feature: Test HWCTOCHW Operation successfully.
|
||||
/// Description: The input is a three channel picture.
|
||||
/// Expectation: success and The final result should be consistent with expect_value.
|
||||
TEST_F(MindDataImageProcess, TestHWCTOCHWImage) {
|
||||
std::string filename = "data/dataset/apple.jpg";
|
||||
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
|
||||
cv::Mat image_resize;
|
||||
cv::resize(image, image_resize, cv::Size(10, 10));
|
||||
cv::Mat rgb_mat;
|
||||
cv::cvtColor(image_resize, rgb_mat, CV_BGR2RGB);
|
||||
|
||||
// Implements hwc conversion chw by opencv
|
||||
std::vector<uint8_t> dst_data;
|
||||
std::vector<cv::Mat> bgrChannels(3);
|
||||
cv::split(rgb_mat, bgrChannels);
|
||||
for (auto i = 0; i < bgrChannels.size(); i++) {
|
||||
std::vector<uint8_t> data = std::vector<uint8_t>(bgrChannels[i].reshape(1, 1));
|
||||
dst_data.insert(dst_data.end(), data.begin(), data.end());
|
||||
}
|
||||
|
||||
// HWC2CHW operation implementation.
|
||||
LiteMat expect_value(rgb_mat.rows, rgb_mat.channels(), rgb_mat.cols, dst_data.data(), LDataType(LDataType::UINT8));
|
||||
LiteMat src(rgb_mat.cols, rgb_mat.rows, rgb_mat.channels(), rgb_mat.data, LDataType(LDataType::UINT8));
|
||||
LiteMat dst;
|
||||
HWC2CHW(src, dst);
|
||||
bool res = compare_mat<uint8_t>(expect_value, dst);
|
||||
ASSERT_TRUE(res == true);
|
||||
|
||||
bool ret = compare_mat_shape(src, expect_value);
|
||||
ASSERT_TRUE(ret == true);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue