!15672 [assistant][RGB2BGR]
Merge pull request !15672 from QingfengLi/rgb2bgr
This commit is contained in:
commit
dca5504fd4
|
@ -56,6 +56,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/rescale_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/resize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/resize_with_bbox_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rgb_to_bgr_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rotate_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_random_crop_resize_jpeg_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.h"
|
||||
|
@ -529,6 +530,17 @@ PYBIND_REGISTER(ResizeWithBBoxOperation, 1, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RgbToBgrOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
py::class_<vision::RgbToBgrOperation, TensorOperation, std::shared_ptr<vision::RgbToBgrOperation>>(
|
||||
*m, "RgbToBgrOperation")
|
||||
.def(py::init([]() {
|
||||
auto rgb2bgr = std::make_shared<vision::RgbToBgrOperation>();
|
||||
THROW_IF_ERROR(rgb2bgr->ValidateParams());
|
||||
return rgb2bgr;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RotateOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::RotateOperation, TensorOperation, std::shared_ptr<vision::RotateOperation>>(
|
||||
*m, "RotateOperation")
|
||||
|
|
|
@ -61,9 +61,10 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/resize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/resize_preserve_ar_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/resize_with_bbox_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rgb_to_bgr_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rgb_to_gray_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rgba_to_bgr_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rgba_to_rgb_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rgb_to_gray_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/rotate_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_random_crop_resize_jpeg_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.h"
|
||||
|
@ -181,9 +182,6 @@ std::shared_ptr<TensorOperation> CenterCrop::Parse(const MapTargetDevice &env) {
|
|||
return std::make_shared<CenterCropOperation>(data_->size_);
|
||||
}
|
||||
|
||||
// RGB2GRAY Transform Operation.
|
||||
std::shared_ptr<TensorOperation> RGB2GRAY::Parse() { return std::make_shared<RgbToGrayOperation>(); }
|
||||
|
||||
// Crop Transform Operation.
|
||||
struct Crop::Data {
|
||||
Data(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size)
|
||||
|
@ -863,6 +861,12 @@ std::shared_ptr<TensorOperation> ResizeWithBBox::Parse() {
|
|||
return std::make_shared<ResizeWithBBoxOperation>(data_->size_, data_->interpolation_);
|
||||
}
|
||||
|
||||
// RGB2BGR Transform Operation.
|
||||
std::shared_ptr<TensorOperation> RGB2BGR::Parse() { return std::make_shared<RgbToBgrOperation>(); }
|
||||
|
||||
// RGB2GRAY Transform Operation.
|
||||
std::shared_ptr<TensorOperation> RGB2GRAY::Parse() { return std::make_shared<RgbToGrayOperation>(); }
|
||||
|
||||
// RgbaToBgr Transform Operation.
|
||||
RGBA2BGR::RGBA2BGR() {}
|
||||
|
||||
|
|
|
@ -91,6 +91,22 @@ class CenterCrop final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief RGB2BGR TensorTransform.
|
||||
/// \notes Convert RGB image to BGR image
|
||||
class RGB2BGR final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
RGB2BGR() = default;
|
||||
|
||||
/// \brief Destructor.
|
||||
~RGB2BGR() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
};
|
||||
|
||||
/// \brief RGB2GRAY TensorTransform.
|
||||
/// \note Convert RGB image or color image to grayscale image.
|
||||
class RGB2GRAY final : public TensorTransform {
|
||||
|
|
|
@ -47,6 +47,7 @@ add_library(kernels-image OBJECT
|
|||
rescale_op.cc
|
||||
resize_op.cc
|
||||
resize_preserve_ar_op.cc
|
||||
rgb_to_bgr_op.cc
|
||||
rgb_to_gray_op.cc
|
||||
rgba_to_bgr_op.cc
|
||||
rgba_to_rgb_op.cc
|
||||
|
|
|
@ -1189,6 +1189,23 @@ Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
}
|
||||
}
|
||||
|
||||
Status RgbToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
|
||||
if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("RgbToBgr: image shape is not <H,W,C> or channel is not 3.");
|
||||
}
|
||||
TensorShape out_shape = TensorShape({input_cv->shape()[0], input_cv->shape()[1], 3});
|
||||
std::shared_ptr<CVTensor> output_cv;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(out_shape, input_cv->type(), &output_cv));
|
||||
cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast<int>(cv::COLOR_RGB2BGR));
|
||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||
return Status::OK();
|
||||
} catch (const cv::Exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("RgbToBgr: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -299,6 +299,12 @@ Status RgbaToRgb(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
/// \return Status code
|
||||
Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
/// \brief Take in a 3 channel image in RBG to BGR
|
||||
/// \param[in] input The input image
|
||||
/// \param[out] output The output image
|
||||
/// \return Status code
|
||||
Status RgbToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
/// \brief Take in a 3 channel image in RBG to GRAY
|
||||
/// \param[in] input The input image
|
||||
/// \param[out] output The output image
|
||||
|
|
|
@ -1602,6 +1602,38 @@ bool GetAffineTransform(std::vector<Point> src_point, std::vector<Point> dst_poi
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConvertRgbToBgr(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat) {
|
||||
if (data_type == LDataType::UINT8) {
|
||||
if (src.IsEmpty()) {
|
||||
return false;
|
||||
}
|
||||
if (mat.IsEmpty()) {
|
||||
mat.Init(w, h, 3, LDataType::UINT8);
|
||||
}
|
||||
if (mat.channel_ != 3) {
|
||||
return false;
|
||||
}
|
||||
if ((src.width_ != w) || (src.height_ != h)) {
|
||||
return false;
|
||||
}
|
||||
unsigned char *ptr = mat;
|
||||
const unsigned char *data_ptr = src;
|
||||
for (int y = 0; y < h; y++) {
|
||||
for (int x = 0; x < w; x++) {
|
||||
ptr[0] = data_ptr[2];
|
||||
ptr[1] = data_ptr[1];
|
||||
ptr[2] = data_ptr[0];
|
||||
|
||||
ptr += 3;
|
||||
data_ptr += 3;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat) {
|
||||
if (data_type == LDataType::UINT8) {
|
||||
if (src.IsEmpty()) {
|
||||
|
|
|
@ -151,6 +151,9 @@ bool ConvRowCol(const LiteMat &src, const LiteMat &kx, const LiteMat &ky, LiteMa
|
|||
bool Sobel(const LiteMat &src, LiteMat &dst, int flag_x, int flag_y, int ksize = 3, double scale = 1.0,
|
||||
PaddBorderType pad_type = PaddBorderType::PADD_BORDER_DEFAULT);
|
||||
|
||||
/// \brief Convert RGB image or color image to BGR image
|
||||
bool ConvertRgbToBgr(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat);
|
||||
|
||||
/// \brief Convert RGB image or color image to grayscale image
|
||||
bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat);
|
||||
|
||||
|
|
|
@ -444,6 +444,40 @@ Status ResizePreserve(const TensorRow &inputs, int32_t height, int32_t width, in
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RgbToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
if (input->Rank() != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("RgbToBgr: input image is not in shape of <H,W,C>");
|
||||
}
|
||||
if (input->type() != DataType::DE_UINT8) {
|
||||
RETURN_STATUS_UNEXPECTED("RgbToBgr: image datatype is not uint8.");
|
||||
}
|
||||
|
||||
try {
|
||||
int output_height = input->shape()[0];
|
||||
int output_width = input->shape()[1];
|
||||
|
||||
LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2],
|
||||
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
|
||||
GetLiteCVDataType(input->type()));
|
||||
LiteMat lite_mat_convert;
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
TensorShape new_shape = TensorShape({output_height, output_width, 3});
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), &output_tensor));
|
||||
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
|
||||
lite_mat_convert.Init(output_width, output_height, 3, reinterpret_cast<void *>(buffer),
|
||||
GetLiteCVDataType(input->type()));
|
||||
|
||||
bool ret =
|
||||
ConvertRgbToBgr(lite_mat_rgb, GetLiteCVDataType(input->type()), output_width, output_height, lite_mat_convert);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret, "RgbToBgr: RGBToBGR failed.");
|
||||
|
||||
*output = output_tensor;
|
||||
} catch (std::runtime_error &e) {
|
||||
RETURN_STATUS_UNEXPECTED("RgbToBgr: " + std::string(e.what()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
if (input->Rank() != 3) {
|
||||
RETURN_STATUS_UNEXPECTED("RgbToGray: input image is not in shape of <H,W,C>");
|
||||
|
|
|
@ -106,6 +106,12 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
Status ResizePreserve(const TensorRow &inputs, int32_t height, int32_t width, int32_t img_orientation,
|
||||
TensorRow *outputs);
|
||||
|
||||
/// \brief Take in a 3 channel image in RBG to BGR
|
||||
/// \param[in] input The input image
|
||||
/// \param[out] output The output image
|
||||
/// \return Status code
|
||||
Status RgbToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
/// \brief Take in a 3 channel image in RBG to GRAY
|
||||
/// \param[in] input The input image
|
||||
/// \param[out] output The output image
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/image/rgb_to_bgr_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#else
|
||||
#include "minddata/dataset/kernels/image/lite_image_utils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status RgbToBgrOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
return RgbToBgr(input, output);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_BGR_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_BGR_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class RgbToBgrOp : public TensorOp {
|
||||
public:
|
||||
RgbToBgrOp() = default;
|
||||
|
||||
~RgbToBgrOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRgbToBgrOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_BGR_OP_H_
|
|
@ -42,9 +42,10 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
|
|||
resize_ir.cc
|
||||
resize_preserve_ar_ir.cc
|
||||
resize_with_bbox_ir.cc
|
||||
rgb_to_bgr_ir.cc
|
||||
rgb_to_gray_ir.cc
|
||||
rgba_to_bgr_ir.cc
|
||||
rgba_to_rgb_ir.cc
|
||||
rgb_to_gray_ir.cc
|
||||
rotate_ir.cc
|
||||
softdvpp_decode_random_crop_resize_jpeg_ir.cc
|
||||
softdvpp_decode_resize_jpeg_ir.cc
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
|
||||
#include "minddata/dataset/kernels/ir/vision/rgb_to_bgr_ir.h"
|
||||
|
||||
#include "minddata/dataset/kernels/image/rgb_to_bgr_op.h"
|
||||
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace vision {
|
||||
|
||||
RgbToBgrOperation::RgbToBgrOperation() = default;
|
||||
|
||||
// RGB2BGROperation
|
||||
RgbToBgrOperation::~RgbToBgrOperation() = default;
|
||||
|
||||
std::string RgbToBgrOperation::Name() const { return kRgbToBgrOperation; }
|
||||
|
||||
Status RgbToBgrOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> RgbToBgrOperation::Build() { return std::make_shared<RgbToBgrOp>(); }
|
||||
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RGB_TO_BGR_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RGB_TO_BGR_IR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace vision {
|
||||
|
||||
constexpr char kRgbToBgrOperation[] = "RgbToBgr";
|
||||
|
||||
class RgbToBgrOperation : public TensorOperation {
|
||||
public:
|
||||
RgbToBgrOperation();
|
||||
|
||||
~RgbToBgrOperation();
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RGB_TO_BGR_IR_H_
|
|
@ -101,6 +101,7 @@ constexpr char kResizePreserveAROp[] = "ResizePreserveAROp";
|
|||
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
|
||||
constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp";
|
||||
constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp";
|
||||
constexpr char kRgbToBgrOp[] = "RgbToBgrOp";
|
||||
constexpr char kRgbToGrayOp[] = "RgbToGrayOp";
|
||||
constexpr char kRotateOp[] = "RotateOp";
|
||||
constexpr char kSharpnessOp[] = "SharpnessOp";
|
||||
|
|
|
@ -1462,6 +1462,23 @@ class ResizeWithBBox(ImageTensorOperation):
|
|||
return cde.ResizeWithBBoxOperation(size, DE_C_INTER_MODE[self.interpolation])
|
||||
|
||||
|
||||
class RgbToBgr(ImageTensorOperation):
|
||||
"""
|
||||
Convert RGB image to BGR.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.vision import Inter
|
||||
>>> decode_op = c_vision.Decode()
|
||||
>>> rgb2bgr_op = c_vision.RgbToBgr()
|
||||
>>> transforms_list = [decode_op, rgb2bgr_op]
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
def parse(self):
|
||||
return cde.RgbToBgrOperation()
|
||||
|
||||
|
||||
class Rotate(ImageTensorOperation):
|
||||
"""
|
||||
Rotate the input image by specified degrees.
|
||||
|
|
|
@ -31,7 +31,7 @@ from .validators import check_prob, check_center_crop, check_five_crop, check_re
|
|||
check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||
check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, \
|
||||
check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \
|
||||
check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast
|
||||
check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast, check_rgb_to_bgr
|
||||
from .utils import Inter, Border
|
||||
from .py_transforms_util import is_pil
|
||||
|
||||
|
@ -1337,6 +1337,45 @@ class MixUp:
|
|||
return util.mix_up_muti(self, self.batch_size, image, label, self.alpha)
|
||||
|
||||
|
||||
class RgbToBgr:
|
||||
"""
|
||||
Convert a NumPy RGB image or a batch of NumPy RGB images to BGR images.
|
||||
|
||||
Args:
|
||||
is_hwc (bool): The flag of image shape, (H, W, C) or (N, H, W, C) if True
|
||||
and (C, H, W) or (N, C, H, W) if False (default=False).
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.transforms.py_transforms import Compose
|
||||
>>> transforms_list = Compose([py_vision.Decode(),
|
||||
... py_vision.CenterCrop(20),
|
||||
... py_vision.ToTensor(),
|
||||
... py_vision.RgbToBgr()])
|
||||
>>> # apply the transform to dataset through map function
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns="image")
|
||||
"""
|
||||
|
||||
@check_rgb_to_bgr
|
||||
def __init__(self, is_hwc=False):
|
||||
self.is_hwc = is_hwc
|
||||
self.random = False
|
||||
|
||||
def __call__(self, rgb_imgs):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
|
||||
or (C, H, W) or (N, C, H, W) to be converted.
|
||||
|
||||
Returns:
|
||||
bgr_img (numpy.ndarray), NumPy BGR images array with same shape of rgb_imgs.
|
||||
"""
|
||||
return util.rgb_to_bgrs(rgb_imgs, self.is_hwc)
|
||||
|
||||
|
||||
|
||||
class RgbToHsv:
|
||||
"""
|
||||
Convert a NumPy RGB image or a batch of NumPy RGB images to HSV images.
|
||||
|
|
|
@ -1223,6 +1223,71 @@ def mix_up_muti(tmp, batch_size, img, label, alpha=0.2):
|
|||
return mix_img, mix_label
|
||||
|
||||
|
||||
def rgb_to_bgr(np_rgb_img, is_hwc):
|
||||
"""
|
||||
Convert RGB img to BGR img.
|
||||
|
||||
Args:
|
||||
np_rgb_img (numpy.ndarray): NumPy RGB image array of shape (H, W, C) or (C, H, W) to be converted.
|
||||
is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
|
||||
|
||||
Returns:
|
||||
np_bgr_img (numpy.ndarray), NumPy BGR image with same type of np_rgb_img.
|
||||
"""
|
||||
if is_hwc:
|
||||
np_bgr_img = np_rgb_img[:, :, ::-1]
|
||||
else:
|
||||
np_bgr_img = np_rgb_img[::-1, :, :]
|
||||
return np_bgr_img
|
||||
|
||||
def rgb_to_bgrs(np_rgb_imgs, is_hwc):
|
||||
"""
|
||||
Convert RGB imgs to BGR imgs.
|
||||
|
||||
Args:
|
||||
np_rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
|
||||
or (C, H, W) or (N, C, H, W) to be converted.
|
||||
is_hwc (Bool): If True, the shape of np_rgb_imgs is (H, W, C) or (N, H, W, C);
|
||||
If False, the shape of np_rgb_imgs is (C, H, W) or (N, C, H, W).
|
||||
|
||||
Returns:
|
||||
np_bgr_imgs (numpy.ndarray), NumPy BGR images with same type of np_rgb_imgs.
|
||||
"""
|
||||
if not is_numpy(np_rgb_imgs):
|
||||
raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
|
||||
|
||||
if not isinstance(is_hwc, bool):
|
||||
raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
|
||||
|
||||
shape_size = len(np_rgb_imgs.shape)
|
||||
|
||||
if shape_size == 2:
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
|
||||
"Got (H, W).")
|
||||
|
||||
if not shape_size in (3, 4):
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
|
||||
"Got {}.".format(np_rgb_imgs.shape))
|
||||
|
||||
if shape_size == 3:
|
||||
batch_size = 0
|
||||
if is_hwc:
|
||||
num_channels = np_rgb_imgs.shape[2]
|
||||
else:
|
||||
num_channels = np_rgb_imgs.shape[0]
|
||||
else:
|
||||
batch_size = np_rgb_imgs.shape[0]
|
||||
if is_hwc:
|
||||
num_channels = np_rgb_imgs.shape[3]
|
||||
else:
|
||||
num_channels = np_rgb_imgs.shape[1]
|
||||
|
||||
if num_channels != 3:
|
||||
raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
|
||||
if batch_size == 0:
|
||||
return rgb_to_bgr(np_rgb_imgs, is_hwc)
|
||||
return np.array([rgb_to_bgr(img, is_hwc) for img in np_rgb_imgs])
|
||||
|
||||
def rgb_to_hsv(np_rgb_img, is_hwc):
|
||||
"""
|
||||
Convert RGB img to HSV img.
|
||||
|
|
|
@ -547,6 +547,17 @@ def check_mix_up(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_rgb_to_bgr(method):
|
||||
"""Wrapper method to check the parameters of rgb_to_bgr."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(is_hwc, (bool,), "is_hwc")
|
||||
return method(self, *args, **kwargs)
|
||||
return new_method
|
||||
|
||||
|
||||
def check_rgb_to_hsv(method):
|
||||
"""Wrapper method to check the parameters of rgb_to_hsv."""
|
||||
|
||||
|
|
|
@ -212,6 +212,7 @@ if(BUILD_MINDDATA STREQUAL "full")
|
|||
${MINDDATA_DIR}/kernels/image/normalize_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/resize_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/rgb_to_bgr_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/rotate_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/random_affine_op.cc
|
||||
|
@ -264,9 +265,10 @@ if(BUILD_MINDDATA STREQUAL "full")
|
|||
${MINDDATA_DIR}/kernels/ir/vision/resize_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/resize_preserve_ar_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/resize_with_bbox_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/rgb_to_bgr_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/rgb_to_gray_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/rgba_to_bgr_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/rgba_to_rgb_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/rgb_to_gray_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/rotate_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/softdvpp_decode_random_crop_resize_jpeg_ir.cc
|
||||
${MINDDATA_DIR}/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.cc
|
||||
|
|
|
@ -1795,6 +1795,45 @@ TEST_F(MindDataImageProcess, TestSobelFlag) {
|
|||
EXPECT_EQ(distance_x, 0.0f);
|
||||
}
|
||||
|
||||
TEST_F(MindDataImageProcess, testConvertRgbToBgr) {
|
||||
std::string filename = "data/dataset/apple.jpg";
|
||||
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
|
||||
cv::Mat rgb_mat1;
|
||||
|
||||
cv::cvtColor(image, rgb_mat1, CV_BGR2RGB);
|
||||
|
||||
LiteMat lite_mat_rgb;
|
||||
lite_mat_rgb.Init(rgb_mat1.cols, rgb_mat1.rows, rgb_mat1.channels(), rgb_mat1.data, LDataType::UINT8);
|
||||
LiteMat lite_mat_bgr;
|
||||
bool ret = ConvertRgbToBgr(lite_mat_rgb, LDataType::UINT8, image.cols, image.rows, lite_mat_bgr);
|
||||
ASSERT_TRUE(ret == true);
|
||||
|
||||
cv::Mat dst_image(lite_mat_bgr.height_, lite_mat_bgr.width_, CV_8UC1, lite_mat_bgr.data_ptr_);
|
||||
cv::imwrite("./mindspore_image.jpg", dst_image);
|
||||
CompareMat(image, lite_mat_bgr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataImageProcess, testConvertRgbToBgrFail) {
|
||||
std::string filename = "data/dataset/apple.jpg";
|
||||
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
|
||||
cv::Mat rgb_mat1;
|
||||
|
||||
cv::cvtColor(image, rgb_mat1, CV_BGR2RGB);
|
||||
|
||||
// The width and height of the output image is different from the original image.
|
||||
LiteMat lite_mat_rgb;
|
||||
lite_mat_rgb.Init(rgb_mat1.cols, rgb_mat1.rows, rgb_mat1.channels(), rgb_mat1.data, LDataType::UINT8);
|
||||
LiteMat lite_mat_bgr;
|
||||
bool ret = ConvertRgbToBgr(lite_mat_rgb, LDataType::UINT8, 1000, 1000, lite_mat_bgr);
|
||||
ASSERT_TRUE(ret == false);
|
||||
|
||||
// The input lite_mat_rgb object is null.
|
||||
LiteMat lite_mat_rgb1;
|
||||
LiteMat lite_mat_bgr1;
|
||||
bool ret1 = ConvertRgbToBgr(lite_mat_rgb1, LDataType::UINT8, image.cols, image.rows, lite_mat_bgr1);
|
||||
ASSERT_TRUE(ret1 == false);
|
||||
}
|
||||
|
||||
TEST_F(MindDataImageProcess, testConvertRgbToGray) {
|
||||
std::string filename = "data/dataset/apple.jpg";
|
||||
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "include/dataset/datasets.h"
|
||||
#include "include/dataset/transforms.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/kernels/image/rgb_to_bgr_op.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::CVTensor;
|
||||
using mindspore::dataset::BorderType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
|
||||
class MindDataTestRgbToBgrOp : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
|
||||
TEST_F(MindDataTestRgbToBgrOp, TestOp1) {
|
||||
// Eager
|
||||
MS_LOG(INFO) << "Doing MindDataTestGaussianBlur-TestGaussianBlurEager.";
|
||||
|
||||
// Read images
|
||||
auto image = ReadFileToTensor("data/dataset/apple.jpg");
|
||||
|
||||
// Transform params
|
||||
auto decode = vision::Decode();
|
||||
auto rgb2bgr_op = vision::RGB2BGR();
|
||||
|
||||
auto transform = Execute({decode, rgb2bgr_op});
|
||||
Status rc = transform(image, &image);
|
||||
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MindDataTestRgbToBgrOp, TestOp2) {
|
||||
// pipeline
|
||||
MS_LOG(INFO) << "Basic Function Test.";
|
||||
// create two imagenet dataset
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
auto rgb2bgr_op = vision::RGB2BGR();
|
||||
|
||||
ds1 = ds1->Map({rgb2bgr_op});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
EXPECT_NE(iter1, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row1;
|
||||
iter1->GetNextRow(&row1);
|
||||
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
EXPECT_NE(iter2, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row2;
|
||||
iter2->GetNextRow(&row2);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row1.size() != 0) {
|
||||
i++;
|
||||
auto image =row1["image"];
|
||||
iter1->GetNextRow(&row1);
|
||||
iter2->GetNextRow(&row2);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter1->Stop();
|
||||
iter2->Stop();
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing RgbToBgr op in DE
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.py_transforms
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
import mindspore.dataset.vision.py_transforms_util as util
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def generate_numpy_random_rgb(shape):
|
||||
# Only generate floating points that are fractions like n / 256, since they
|
||||
# are RGB pixels. Some low-precision floating point types in this test can't
|
||||
# handle arbitrary precision floating points well.
|
||||
return np.random.randint(0, 256, shape) / 255.
|
||||
|
||||
|
||||
def test_rgb_bgr_hwc_py():
|
||||
# Eager
|
||||
rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
|
||||
rgb_np = rgb_flat.reshape((8, 8, 3))
|
||||
|
||||
bgr_np_pred = util.rgb_to_bgrs(rgb_np, True)
|
||||
r, g, b = rgb_np[:, :, 0], rgb_np[:, :, 1], rgb_np[:, :, 2]
|
||||
bgr_np_gt = np.stack((b, g, r), axis=2)
|
||||
assert bgr_np_pred.shape == rgb_np.shape
|
||||
assert_allclose(bgr_np_pred.flatten(),
|
||||
bgr_np_gt.flatten(),
|
||||
rtol=1e-5,
|
||||
atol=0)
|
||||
|
||||
|
||||
def test_rgb_bgr_hwc_c():
|
||||
# Eager
|
||||
rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
|
||||
rgb_np = rgb_flat.reshape((8, 8, 3))
|
||||
|
||||
rgb2bgr_op = vision.RgbToBgr()
|
||||
bgr_np_pred = rgb2bgr_op(rgb_np)
|
||||
r, g, b = rgb_np[:, :, 0], rgb_np[:, :, 1], rgb_np[:, :, 2]
|
||||
bgr_np_gt = np.stack((b, g, r), axis=2)
|
||||
assert bgr_np_pred.shape == rgb_np.shape
|
||||
assert_allclose(bgr_np_pred.flatten(),
|
||||
bgr_np_gt.flatten(),
|
||||
rtol=1e-5,
|
||||
atol=0)
|
||||
|
||||
|
||||
def test_rgb_bgr_chw_py():
|
||||
rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
|
||||
rgb_np = rgb_flat.reshape((3, 8, 8))
|
||||
|
||||
rgb_np_pred = util.rgb_to_bgrs(rgb_np, False)
|
||||
rgb_np_gt = rgb_np[::-1, :, :]
|
||||
assert rgb_np_pred.shape == rgb_np.shape
|
||||
assert_allclose(rgb_np_pred.flatten(),
|
||||
rgb_np_gt.flatten(),
|
||||
rtol=1e-5,
|
||||
atol=0)
|
||||
|
||||
|
||||
def test_rgb_bgr_pipeline_py():
|
||||
# First dataset
|
||||
transforms1 = [py_vision.Decode(), py_vision.Resize([64, 64]), py_vision.ToTensor()]
|
||||
transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
|
||||
transforms1)
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR,
|
||||
SCHEMA_DIR,
|
||||
columns_list=["image"],
|
||||
shuffle=False)
|
||||
ds1 = ds1.map(operations=transforms1, input_columns=["image"])
|
||||
|
||||
# Second dataset
|
||||
transforms2 = [
|
||||
py_vision.Decode(),
|
||||
py_vision.Resize([64, 64]),
|
||||
py_vision.ToTensor(),
|
||||
py_vision.RgbToBgr()
|
||||
]
|
||||
transforms2 = mindspore.dataset.transforms.py_transforms.Compose(
|
||||
transforms2)
|
||||
ds2 = ds.TFRecordDataset(DATA_DIR,
|
||||
SCHEMA_DIR,
|
||||
columns_list=["image"],
|
||||
shuffle=False)
|
||||
ds2 = ds2.map(operations=transforms2, input_columns=["image"])
|
||||
|
||||
num_iter = 0
|
||||
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
|
||||
ds2.create_dict_iterator(num_epochs=1)):
|
||||
num_iter += 1
|
||||
ori_img = data1["image"].asnumpy()
|
||||
cvt_img = data2["image"].asnumpy()
|
||||
cvt_img_gt = ori_img[::-1, :, :]
|
||||
assert_allclose(cvt_img_gt.flatten(),
|
||||
cvt_img.flatten(),
|
||||
rtol=1e-5,
|
||||
atol=0)
|
||||
assert ori_img.shape == cvt_img.shape
|
||||
|
||||
|
||||
def test_rgb_bgr_pipeline_c():
|
||||
# First dataset
|
||||
transforms1 = [
|
||||
vision.Decode(),
|
||||
vision.Resize([64, 64])
|
||||
]
|
||||
transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
|
||||
transforms1)
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR,
|
||||
SCHEMA_DIR,
|
||||
columns_list=["image"],
|
||||
shuffle=False)
|
||||
ds1 = ds1.map(operations=transforms1, input_columns=["image"])
|
||||
|
||||
# Second dataset
|
||||
transforms2 = [
|
||||
vision.Decode(),
|
||||
vision.Resize([64, 64]),
|
||||
vision.RgbToBgr()
|
||||
]
|
||||
transforms2 = mindspore.dataset.transforms.py_transforms.Compose(
|
||||
transforms2)
|
||||
ds2 = ds.TFRecordDataset(DATA_DIR,
|
||||
SCHEMA_DIR,
|
||||
columns_list=["image"],
|
||||
shuffle=False)
|
||||
ds2 = ds2.map(operations=transforms2, input_columns=["image"])
|
||||
|
||||
num_iter = 0
|
||||
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
|
||||
ds2.create_dict_iterator(num_epochs=1)):
|
||||
num_iter += 1
|
||||
ori_img = data1["image"].asnumpy()
|
||||
cvt_img = data2["image"].asnumpy()
|
||||
cvt_img_gt = ori_img[:, :, ::-1]
|
||||
assert_allclose(cvt_img_gt.flatten(),
|
||||
cvt_img.flatten(),
|
||||
rtol=1e-5,
|
||||
atol=0)
|
||||
assert ori_img.shape == cvt_img.shape
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rgb_bgr_hwc_py()
|
||||
test_rgb_bgr_hwc_c()
|
||||
test_rgb_bgr_chw_py()
|
||||
test_rgb_bgr_pipeline_py()
|
||||
test_rgb_bgr_pipeline_c()
|
Loading…
Reference in New Issue