Add ConvertColor Operation

This commit is contained in:
shenwei41 2021-08-23 20:32:19 +08:00
parent 28e63116f2
commit 1295744524
19 changed files with 673 additions and 3 deletions

View File

@ -138,5 +138,30 @@ PYBIND_REGISTER(SliceMode, 0, ([](const py::module *m) {
.export_values();
}));
PYBIND_REGISTER(ConvertMode, 0, ([](const py::module *m) {
(void)py::enum_<ConvertMode>(*m, "ConvertMode", py::arithmetic())
.value("DE_COLOR_BGR2BGRA", ConvertMode::COLOR_BGR2BGRA)
.value("DE_COLOR_RGB2RGBA", ConvertMode::COLOR_RGB2RGBA)
.value("DE_COLOR_BGRA2BGR", ConvertMode::COLOR_BGRA2BGR)
.value("DE_COLOR_RGBA2RGB", ConvertMode::COLOR_RGBA2RGB)
.value("DE_COLOR_BGR2RGBA", ConvertMode::COLOR_BGR2RGBA)
.value("DE_COLOR_RGB2BGRA", ConvertMode::COLOR_RGB2BGRA)
.value("DE_COLOR_RGBA2BGR", ConvertMode::COLOR_RGBA2BGR)
.value("DE_COLOR_BGRA2RGB", ConvertMode::COLOR_BGRA2RGB)
.value("DE_COLOR_BGR2RGB", ConvertMode::COLOR_BGR2RGB)
.value("DE_COLOR_RGB2BGR", ConvertMode::COLOR_RGB2BGR)
.value("DE_COLOR_BGRA2RGBA", ConvertMode::COLOR_BGRA2RGBA)
.value("DE_COLOR_RGBA2BGRA", ConvertMode::COLOR_RGBA2BGRA)
.value("DE_COLOR_BGR2GRAY", ConvertMode::COLOR_BGR2GRAY)
.value("DE_COLOR_RGB2GRAY", ConvertMode::COLOR_RGB2GRAY)
.value("DE_COLOR_GRAY2BGR", ConvertMode::COLOR_GRAY2BGR)
.value("DE_COLOR_GRAY2RGB", ConvertMode::COLOR_GRAY2RGB)
.value("DE_COLOR_GRAY2BGRA", ConvertMode::COLOR_GRAY2BGRA)
.value("DE_COLOR_GRAY2RGBA", ConvertMode::COLOR_GRAY2RGBA)
.value("DE_COLOR_BGRA2GRAY", ConvertMode::COLOR_BGRA2GRAY)
.value("DE_COLOR_RGBA2GRAY", ConvertMode::COLOR_RGBA2GRAY)
.export_values();
}));
} // namespace dataset
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/convert_color_ir.h"
#include "minddata/dataset/kernels/ir/vision/crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/cutmix_batch_ir.h"
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
@ -113,6 +114,17 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(
ConvertColorOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::ConvertColorOperation, TensorOperation, std::shared_ptr<vision::ConvertColorOperation>>(
*m, "ConvertColorOperation", "Tensor operation to change the color space of the image.")
.def(py::init([](ConvertMode convert_mode) {
auto convert = std::make_shared<vision::ConvertColorOperation>(convert_mode);
THROW_IF_ERROR(convert->ValidateParams());
return convert;
}));
}));
PYBIND_REGISTER(CropOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::CropOperation, TensorOperation, std::shared_ptr<vision::CropOperation>>(
*m, "CropOperation", "Tensor operation to crop images")

View File

@ -26,6 +26,7 @@
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/convert_color_ir.h"
#include "minddata/dataset/kernels/ir/vision/crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/cutmix_batch_ir.h"
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
@ -197,6 +198,20 @@ std::shared_ptr<TensorOperation> CenterCrop::Parse(const MapTargetDevice &env) {
return std::make_shared<CenterCropOperation>(data_->size_);
}
#ifndef ENABLE_ANDROID
// ConvertColor Transform Operation.
struct ConvertColor::Data {
explicit Data(ConvertMode convert_mode) : convert_mode_(convert_mode) {}
ConvertMode convert_mode_;
};
ConvertColor::ConvertColor(ConvertMode convert_mode) : data_(std::make_shared<Data>(convert_mode)) {}
std::shared_ptr<TensorOperation> ConvertColor::Parse() {
return std::make_shared<ConvertColorOperation>(data_->convert_mode_);
}
#endif // not ENABLE_ANDROID
// Crop Transform Operation.
struct Crop::Data {
Data(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size)

View File

@ -26,6 +26,30 @@ namespace dataset {
using uchar = unsigned char;
using dsize_t = int64_t;
/// \brief The color conversion code
enum class ConvertMode {
COLOR_BGR2BGRA = 0, ///< Add alpha channel to BGR image.
COLOR_RGB2RGBA = COLOR_BGR2BGRA, ///< Add alpha channel to RGB image.
COLOR_BGRA2BGR = 1, ///< Remove alpha channel to BGR image.
COLOR_RGBA2RGB = COLOR_BGRA2BGR, ///< Remove alpha channel to RGB image.
COLOR_BGR2RGBA = 2, ///< Convert BGR image to RGBA image.
COLOR_RGB2BGRA = COLOR_BGR2RGBA, ///< Convert RGB image to BGRA image.
COLOR_RGBA2BGR = 3, ///< Convert RGBA image to BGR image.
COLOR_BGRA2RGB = COLOR_RGBA2BGR, ///< Convert BGRA image to RGB image.
COLOR_BGR2RGB = 4, ///< Convert BGR image to RGB image.
COLOR_RGB2BGR = COLOR_BGR2RGB, ///< Convert RGB image to BGR image.
COLOR_BGRA2RGBA = 5, ///< Convert BGRA image to RGBA image.
COLOR_RGBA2BGRA = COLOR_BGRA2RGBA, ///< Convert RGBA image to BGRA image.
COLOR_BGR2GRAY = 6, ///< Convert BGR image to GRAY image.
COLOR_RGB2GRAY = 7, ///< Convert RGB image to GRAY image.
COLOR_GRAY2BGR = 8, ///< Convert GRAY image to BGR image.
COLOR_GRAY2RGB = COLOR_GRAY2BGR, ///< Convert GRAY image to RGB image.
COLOR_GRAY2BGRA = 9, ///< Convert GRAY image to BGRA image.
COLOR_GRAY2RGBA = COLOR_GRAY2BGRA, ///< Convert GRAY image to RGBA image.
COLOR_BGRA2GRAY = 10, ///< Convert BGRA image to GRAY image.
COLOR_RGBA2GRAY = 11 ///< Convert RGBA image to GRAY image.
};
/// \brief Target devices to perform map operation.
enum class MapTargetDevice {
kCpu, ///< CPU Device.

View File

@ -113,6 +113,26 @@ class BoundingBoxAugment final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Change the color space of the image.
class ConvertColor final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] convert_mode The mode of image channel conversion.
explicit ConvertColor(ConvertMode convert_mode);
/// \brief Destructor.
~ConvertColor() = default;
protected:
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Mask a random section of each image with the corresponding part of another randomly
/// selected image in that batch.
class CutMixBatch final : public TensorTransform {

View File

@ -11,6 +11,7 @@ add_library(kernels-image OBJECT
auto_contrast_op.cc
bounding_box.cc
center_crop_op.cc
convert_color_op.cc
crop_op.cc
cut_out_op.cc
cutmix_batch_op.cc

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <utility>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/image/convert_color_op.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
ConvertColorOp::ConvertColorOp(ConvertMode convert_mode) : convert_mode_(convert_mode) {}
Status ConvertColorOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
return ConvertColor(input, output, convert_mode_);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CONVERT_COLOR_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CONVERT_COLOR_OP_H_
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class ConvertColorOp : public TensorOp {
public:
explicit ConvertColorOp(ConvertMode convert_mode);
~ConvertColorOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kConvertColorOp; }
private:
ConvertMode convert_mode_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CONVERT_COLOR_OP_H_

View File

@ -63,6 +63,29 @@ int GetCVBorderType(BorderType type) {
}
}
Status GetConvertShape(ConvertMode convert_mode, const std::shared_ptr<CVTensor> &input_cv,
std::vector<dsize_t> *node) {
std::vector<ConvertMode> one_channels = {ConvertMode::COLOR_BGR2GRAY, ConvertMode::COLOR_RGB2GRAY,
ConvertMode::COLOR_BGRA2GRAY, ConvertMode::COLOR_RGBA2GRAY};
std::vector<ConvertMode> three_channels = {
ConvertMode::COLOR_BGRA2BGR, ConvertMode::COLOR_RGBA2RGB, ConvertMode::COLOR_RGBA2BGR, ConvertMode::COLOR_BGRA2RGB,
ConvertMode::COLOR_BGR2RGB, ConvertMode::COLOR_RGB2BGR, ConvertMode::COLOR_GRAY2BGR, ConvertMode::COLOR_GRAY2RGB};
std::vector<ConvertMode> four_channels = {ConvertMode::COLOR_BGR2BGRA, ConvertMode::COLOR_RGB2RGBA,
ConvertMode::COLOR_BGR2RGBA, ConvertMode::COLOR_RGB2BGRA,
ConvertMode::COLOR_BGRA2RGBA, ConvertMode::COLOR_RGBA2BGRA,
ConvertMode::COLOR_GRAY2BGRA, ConvertMode::COLOR_GRAY2RGBA};
if (std::find(three_channels.begin(), three_channels.end(), convert_mode) != three_channels.end()) {
*node = {input_cv->shape()[0], input_cv->shape()[1], 3};
} else if (std::find(four_channels.begin(), four_channels.end(), convert_mode) != four_channels.end()) {
*node = {input_cv->shape()[0], input_cv->shape()[1], 4};
} else if (std::find(one_channels.begin(), one_channels.end(), convert_mode) != one_channels.end()) {
*node = {input_cv->shape()[0], input_cv->shape()[1]};
} else {
RETURN_STATUS_UNEXPECTED("The mode of image channel conversion must be in ConvertMode.");
}
return Status::OK();
}
bool CheckTensorShape(const std::shared_ptr<Tensor> &tensor, const int &channel) {
bool rc = false;
if (tensor->Rank() != DEFAULT_IMAGE_RANK ||
@ -432,6 +455,35 @@ Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
}
}
Status ConvertColor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, ConvertMode convert_mode) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
int num_channels = input_cv->shape()[CHANNEL_INDEX];
if (input_cv->Rank() != DEFAULT_IMAGE_RANK) {
RETURN_STATUS_UNEXPECTED("ConvertColor: invalid image Shape, only support <H,W,C> or <H,W>");
}
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("ConvertColor: load image failed.");
}
if (num_channels != 1 && num_channels != 3 && num_channels != 4) {
RETURN_STATUS_UNEXPECTED("ConvertColor: number of channels of image should be 1, 3, 4");
}
std::vector<dsize_t> node;
RETURN_IF_NOT_OK(GetConvertShape(convert_mode, input_cv, &node));
if (node.empty()) {
RETURN_STATUS_UNEXPECTED("ConvertColor: convert mode must be in ConvertMode.");
}
TensorShape out_shape = TensorShape(node);
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(out_shape, input_cv->type(), &output_cv));
cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast<int>(convert_mode));
*output = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("ConvertColor: " + std::string(e.what()));
}
}
Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);

View File

@ -133,6 +133,12 @@ Status Rescale(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *ou
/// \param output: Cropped image Tensor of shape <h,w,C> or <h,w> and same input type.
Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int x, int y, int w, int h);
/// \brief Change the color space of the image.
/// \param input: The input image.
/// \param output: The output image.
/// \param convert_mode: The mode of image channel conversion.
Status ConvertColor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, ConvertMode convert_mode);
/// \brief Swaps the channels in the image, i.e. converts HWC to CHW
/// \param input: Tensor of shape <H,W,C> or <H,W> and any OpenCv compatible type, see CVTensor.
/// \param output: Tensor of shape <C,H,W> or <H,W> and same input type.

View File

@ -7,6 +7,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
auto_contrast_ir.cc
bounding_box_augment_ir.cc
center_crop_ir.cc
convert_color_ir.cc
crop_ir.cc
cutmix_batch_ir.cc
cutout_ir.cc

View File

@ -0,0 +1,68 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "minddata/dataset/kernels/ir/vision/convert_color_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/convert_color_op.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// ConvertColorOperation
ConvertColorOperation::ConvertColorOperation(ConvertMode convert_mode) : convert_mode_(convert_mode) {}
ConvertColorOperation::~ConvertColorOperation() = default;
std::string ConvertColorOperation::Name() const { return kConvertColorOperation; }
Status ConvertColorOperation::ValidateParams() {
if (convert_mode_ < ConvertMode::COLOR_BGR2BGRA || convert_mode_ > ConvertMode::COLOR_RGBA2GRAY) {
std::string err_msg = "ConvertColorOperation: convert_mode must be in ConvertMode.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
std::shared_ptr<TensorOp> ConvertColorOperation::Build() {
std::shared_ptr<ConvertColorOp> tensor_op = std::make_shared<ConvertColorOp>(convert_mode_);
return tensor_op;
}
Status ConvertColorOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["convert_mode"] = convert_mode_;
*out_json = args;
return Status::OK();
}
Status ConvertColorOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("convert_mode") != op_params.end(), "Failed to find convert_mode");
ConvertMode convert_mode = static_cast<ConvertMode>(op_params["convert_mode"]);
*operation = std::make_shared<vision::ConvertColorOperation>(convert_mode);
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_CONVERT_COLOR_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_CONVERT_COLOR_IR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kConvertColorOperation[] = "ConvertColor";
class ConvertColorOperation : public TensorOperation {
public:
explicit ConvertColorOperation(ConvertMode convert_mode);
~ConvertColorOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
private:
ConvertMode convert_mode_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_CONVERT_COLOR_IR_H_

View File

@ -59,6 +59,7 @@ constexpr char kAutoContrastOp[] = "AutoContrastOp";
constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
constexpr char kDecodeOp[] = "DecodeOp";
constexpr char kCenterCropOp[] = "CenterCropOp";
constexpr char kConvertColorOp[] = "ConvertColorOp";
constexpr char kCutMixBatchOp[] = "CutMixBatchOp";
constexpr char kCutOutOp[] = "CutOutOp";
constexpr char kCropOp[] = "CropOp";

View File

@ -47,11 +47,11 @@ import numpy as np
from PIL import Image
import mindspore._c_dataengine as cde
from .utils import Inter, Border, ImageBatchFormat, SliceMode
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode
from .validators import check_prob, check_crop, check_center_crop, check_resize_interpolation, check_random_resize_crop, \
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \
check_uniform_augment_cpp, \
check_uniform_augment_cpp, check_convert_color, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches, check_adjust_gamma
@ -92,6 +92,28 @@ DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBO
DE_C_SLICE_MODE = {SliceMode.PAD: cde.SliceMode.DE_SLICE_PAD,
SliceMode.DROP: cde.SliceMode.DE_SLICE_DROP}
DE_C_CONVERTCOLOR_MODE = {ConvertMode.COLOR_BGR2BGRA: cde.ConvertMode.DE_COLOR_BGR2BGRA,
ConvertMode.COLOR_RGB2RGBA: cde.ConvertMode.DE_COLOR_RGB2RGBA,
ConvertMode.COLOR_BGRA2BGR: cde.ConvertMode.DE_COLOR_BGRA2BGR,
ConvertMode.COLOR_RGBA2RGB: cde.ConvertMode.DE_COLOR_RGBA2RGB,
ConvertMode.COLOR_BGR2RGBA: cde.ConvertMode.DE_COLOR_BGR2RGBA,
ConvertMode.COLOR_RGB2BGRA: cde.ConvertMode.DE_COLOR_RGB2BGRA,
ConvertMode.COLOR_RGBA2BGR: cde.ConvertMode.DE_COLOR_RGBA2BGR,
ConvertMode.COLOR_BGRA2RGB: cde.ConvertMode.DE_COLOR_BGRA2RGB,
ConvertMode.COLOR_BGR2RGB: cde.ConvertMode.DE_COLOR_BGR2RGB,
ConvertMode.COLOR_RGB2BGR: cde.ConvertMode.DE_COLOR_RGB2BGR,
ConvertMode.COLOR_BGRA2RGBA: cde.ConvertMode.DE_COLOR_BGRA2RGBA,
ConvertMode.COLOR_RGBA2BGRA: cde.ConvertMode.DE_COLOR_RGBA2BGRA,
ConvertMode.COLOR_BGR2GRAY: cde.ConvertMode.DE_COLOR_BGR2GRAY,
ConvertMode.COLOR_RGB2GRAY: cde.ConvertMode.DE_COLOR_RGB2GRAY,
ConvertMode.COLOR_GRAY2BGR: cde.ConvertMode.DE_COLOR_GRAY2BGR,
ConvertMode.COLOR_GRAY2RGB: cde.ConvertMode.DE_COLOR_GRAY2RGB,
ConvertMode.COLOR_GRAY2BGRA: cde.ConvertMode.DE_COLOR_GRAY2BGRA,
ConvertMode.COLOR_GRAY2RGBA: cde.ConvertMode.DE_COLOR_GRAY2RGBA,
ConvertMode.COLOR_BGRA2GRAY: cde.ConvertMode.DE_COLOR_BGRA2GRAY,
ConvertMode.COLOR_RGBA2GRAY: cde.ConvertMode.DE_COLOR_RGBA2GRAY,
}
def parse_padding(padding):
""" Parses and prepares the padding tuple"""
@ -231,6 +253,31 @@ class CenterCrop(ImageTensorOperation):
return cde.CenterCropOperation(self.size)
class ConvertColor(ImageTensorOperation):
"""
Change the color space of the image.
Args:
convert_mode (ConvertMode): The mode of image channel conversion.
Examples:
>>> # Convert RGB images to GRAY images
>>> convert_op = c_vision.ConvertColor(ConvertMode.COLOR_RGB2GRAY)
>>> image_folder_dataset = image_folder_dataset.map(operations=convert_op,
... input_columns=["image"])
>>> # Convert RGB images to BGR images
>>> convert_op = c_vision.ConvertColor(ConvertMode.COLOR_RGB2BGR)
>>> image_folder_dataset_1 = image_folder_dataset_1.map(operations=convert_op,
... input_columns=["image"])
"""
@check_convert_color
def __init__(self, convert_mode):
self.convert_mode = convert_mode
def parse(self):
return cde.ConvertColorOperation(DE_C_CONVERTCOLOR_MODE[self.convert_mode])
class Crop(ImageTensorOperation):
"""
Crop the input image at a specific location.

View File

@ -74,6 +74,30 @@ class ImageBatchFormat(IntEnum):
NCHW = 1
class ConvertMode(IntEnum):
"""The color conversion code"""
COLOR_BGR2BGRA = 0
COLOR_RGB2RGBA = COLOR_BGR2BGRA
COLOR_BGRA2BGR = 1
COLOR_RGBA2RGB = COLOR_BGRA2BGR
COLOR_BGR2RGBA = 2
COLOR_RGB2BGRA = COLOR_BGR2RGBA
COLOR_RGBA2BGR = 3
COLOR_BGRA2RGB = COLOR_RGBA2BGR
COLOR_BGR2RGB = 4
COLOR_RGB2BGR = COLOR_BGR2RGB
COLOR_BGRA2RGBA = 5
COLOR_RGBA2BGRA = COLOR_BGRA2RGBA
COLOR_BGR2GRAY = 6
COLOR_RGB2GRAY = 7
COLOR_GRAY2BGR = 8
COLOR_GRAY2RGB = COLOR_GRAY2BGR
COLOR_GRAY2BGRA = 9
COLOR_GRAY2RGBA = COLOR_GRAY2BGRA
COLOR_BGRA2GRAY = 10
COLOR_RGBA2GRAY = 11
class SliceMode(IntEnum):
"""
Mode to Slice Tensor into multiple parts.

View File

@ -23,7 +23,7 @@ from mindspore.dataset.core.validator_helpers import check_value, check_uint8, F
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32
from .utils import Inter, Border, ImageBatchFormat, SliceMode
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode
def check_crop_size(size):
@ -964,3 +964,15 @@ def check_gaussian_blur(method):
return method(self, *args, **kwargs)
return new_method
def check_convert_color(method):
"""Wrapper method to check the parameters of convertcolor."""
@wraps(method)
def new_method(self, *args, **kwargs):
[convert_mode], _ = parse_user_args(method, *args, **kwargs)
if convert_mode is not None:
type_check(convert_mode, (ConvertMode,), "convert_mode")
return method(self, *args, **kwargs)
return new_method

View File

@ -1137,3 +1137,136 @@ TEST_F(MindDataTestPipeline, TestPad) {
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestConvertColorSuccess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConvertColorSuccess1.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 1));
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> resize_op(new vision::Resize({500, 1000}));
std::shared_ptr<TensorTransform> convert(new mindspore::dataset::vision::ConvertColor(ConvertMode::COLOR_RGB2GRAY));
ds = ds->Map({resize_op, convert});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
EXPECT_EQ(image.Shape().size(), 2);
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 1);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestConvertColorSuccess2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConvertColorSuccess2.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 1));
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> resize_op(new vision::Resize({500, 1000}));
std::shared_ptr<TensorTransform> convert(new mindspore::dataset::vision::ConvertColor(ConvertMode::COLOR_RGB2BGR));
ds = ds->Map({resize_op, convert});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
EXPECT_EQ(image.Shape()[2], 3);
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 1);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestConvertColorSuccess3) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConvertColorSuccess3.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 1));
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> resize_op(new vision::Resize({500, 1000}));
std::shared_ptr<TensorTransform> convert(new mindspore::dataset::vision::ConvertColor(ConvertMode::COLOR_RGB2RGBA));
ds = ds->Map({resize_op, convert});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
EXPECT_EQ(image.Shape()[2], 4);
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 1);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestConvertColorFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConvertColorFail.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 1));
EXPECT_NE(ds, nullptr);
ConvertMode error_convert_mode = static_cast<ConvertMode>(50);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> resize_op(new vision::Resize({500, 1000}));
std::shared_ptr<TensorTransform> convert(new mindspore::dataset::vision::ConvertColor(error_convert_mode));
ds = ds->Map({resize_op, convert});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing ConvertColor op in DE
"""
import cv2
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.utils as mode
from mindspore import log as logger
from util import visualize_image, diff_mse
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
IMAGE_FILE = "../data/dataset/apple.jpg"
def convert_color(ms_convert, cv_convert, plot=False):
"""
ConvertColor with different mode.
"""
# First dataset
dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
decode_op = c_vision.Decode()
convertcolor_op = c_vision.ConvertColor(ms_convert)
dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
dataset1 = dataset1.map(operations=convertcolor_op, input_columns=["image"])
# Second dataset
dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
if num_iter > 0:
break
convertcolor_ms = data1["image"]
original = data2["image"]
convertcolor_cv = cv2.cvtColor(original, cv_convert)
mse = diff_mse(convertcolor_ms, convertcolor_cv)
logger.info("convertcolor_{}, mse: {}".format(num_iter + 1, mse))
assert mse == 0
num_iter += 1
if plot:
visualize_image(original, convertcolor_ms, mse, convertcolor_cv)
def test_convertcolor_pipeline(plot=False):
"""
Test ConvertColor of c_transforms
"""
logger.info("test_convertcolor_pipeline")
convert_color(mode.ConvertMode.COLOR_BGR2GRAY, cv2.COLOR_BGR2GRAY, plot)
convert_color(mode.ConvertMode.COLOR_BGR2RGB, cv2.COLOR_BGR2RGB, plot)
convert_color(mode.ConvertMode.COLOR_BGR2BGRA, cv2.COLOR_BGR2BGRA, plot)
def test_convertcolor_eager():
"""
Test ConvertColor with eager mode
"""
logger.info("test_convertcolor")
img = cv2.imread(IMAGE_FILE)
img_ms = c_vision.ConvertColor(mode.ConvertMode.COLOR_BGR2GRAY)(img)
img_expect = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
mse = diff_mse(img_ms, img_expect)
assert mse == 0
if __name__ == "__main__":
test_convertcolor_pipeline(plot=False)
test_convertcolor_eager()