!40711 【算子众智】【电子科技大学】【数据算子】【encode_png】encode a image as png
Merge pull request !40711 from dengjian/upstream_encode_png
This commit is contained in:
commit
2c1ebbd63f
|
@ -7,10 +7,10 @@ mindspore.dataset.vision.encode_jpeg
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。
|
- **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。
|
||||||
- **quality** (int, 可选) - 生成的JPEG数据的质量,从1到100。默认值75。
|
- **quality** (int, 可选) - 生成的JPEG数据的质量,取值范围为[1, 100]。默认值: 75。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
- numpy:ndarray, 一维uint8类型数据。
|
- numpy.ndarray, 一维uint8类型数据。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。
|
- **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
mindspore.dataset.vision.encode_png
|
||||||
|
===================================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.dataset.vision.encode_png(image, compression_level=6)
|
||||||
|
|
||||||
|
将输入的图像编码为PNG数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。
|
||||||
|
- **compression_level** (int, 可选) - 编码压缩因子,取值范围为[0, 9]。默认值: 6。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- numpy.ndarray, 一维uint8类型数据。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。
|
||||||
|
- **TypeError** - 如果 `compression_level` 不是int类型。
|
||||||
|
- **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。
|
||||||
|
- **RuntimeError** - 如果 `image` 的shape不是 <H, W> 或 <H, W, 1> 或 <H, W, 3>。
|
||||||
|
- **RuntimeError** - 如果 `compression_level` 小于0或大于9。
|
|
@ -143,6 +143,7 @@ API样例中常用的导入模块如下:
|
||||||
mindspore.dataset.vision.Inter
|
mindspore.dataset.vision.Inter
|
||||||
mindspore.dataset.vision.SliceMode
|
mindspore.dataset.vision.SliceMode
|
||||||
mindspore.dataset.vision.encode_jpeg
|
mindspore.dataset.vision.encode_jpeg
|
||||||
|
mindspore.dataset.vision.encode_png
|
||||||
mindspore.dataset.vision.get_image_num_channels
|
mindspore.dataset.vision.get_image_num_channels
|
||||||
mindspore.dataset.vision.get_image_size
|
mindspore.dataset.vision.get_image_size
|
||||||
mindspore.dataset.vision.read_file
|
mindspore.dataset.vision.read_file
|
||||||
|
|
|
@ -93,6 +93,7 @@ Utilities
|
||||||
mindspore.dataset.vision.Inter
|
mindspore.dataset.vision.Inter
|
||||||
mindspore.dataset.vision.SliceMode
|
mindspore.dataset.vision.SliceMode
|
||||||
mindspore.dataset.vision.encode_jpeg
|
mindspore.dataset.vision.encode_jpeg
|
||||||
|
mindspore.dataset.vision.encode_png
|
||||||
mindspore.dataset.vision.get_image_num_channels
|
mindspore.dataset.vision.get_image_num_channels
|
||||||
mindspore.dataset.vision.get_image_size
|
mindspore.dataset.vision.get_image_size
|
||||||
mindspore.dataset.vision.read_file
|
mindspore.dataset.vision.read_file
|
||||||
|
|
|
@ -279,6 +279,14 @@ PYBIND_REGISTER(EncodeJpegOperation, 1, ([](py::module *m) {
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(EncodePNGOperation, 1, ([](py::module *m) {
|
||||||
|
(void)m->def("encode_png", ([](const std::shared_ptr<Tensor> &image, int compression_level) {
|
||||||
|
std::shared_ptr<Tensor> output;
|
||||||
|
THROW_IF_ERROR(EncodePng(image, &output, compression_level));
|
||||||
|
return output;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(EqualizeOperation, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(EqualizeOperation, 1, ([](const py::module *m) {
|
||||||
(void)
|
(void)
|
||||||
py::class_<vision::EqualizeOperation, TensorOperation, std::shared_ptr<vision::EqualizeOperation>>(
|
py::class_<vision::EqualizeOperation, TensorOperation, std::shared_ptr<vision::EqualizeOperation>>(
|
||||||
|
|
|
@ -468,6 +468,7 @@ std::shared_ptr<TensorOperation> DvppDecodePng::Parse(const MapTargetDevice &env
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
|
|
||||||
// EncodeJpeg Function.
|
// EncodeJpeg Function.
|
||||||
Status EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output, int quality) {
|
Status EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output, int quality) {
|
||||||
RETURN_UNEXPECTED_IF_NULL(output);
|
RETURN_UNEXPECTED_IF_NULL(output);
|
||||||
|
@ -481,6 +482,19 @@ Status EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EncodePng Function.
|
||||||
|
Status EncodePng(const mindspore::MSTensor &image, mindspore::MSTensor *output, int compression_level) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(output);
|
||||||
|
std::shared_ptr<dataset::Tensor> input;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromMSTensor(image, &input));
|
||||||
|
TensorPtr de_tensor;
|
||||||
|
RETURN_IF_NOT_OK(mindspore::dataset::EncodePng(input, &de_tensor, compression_level));
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(),
|
||||||
|
"EncodePng: get an empty tensor with shape " + de_tensor->shape().ToString());
|
||||||
|
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Equalize Transform Operation.
|
// Equalize Transform Operation.
|
||||||
Equalize::Equalize() = default;
|
Equalize::Equalize() = default;
|
||||||
|
|
||||||
|
|
|
@ -443,10 +443,17 @@ class DATASET_API CutOut final : public TensorTransform {
|
||||||
/// \brief Encode the image as JPEG data.
|
/// \brief Encode the image as JPEG data.
|
||||||
/// \param[in] image The image to be encoded.
|
/// \param[in] image The image to be encoded.
|
||||||
/// \param[out] output The Tensor data.
|
/// \param[out] output The Tensor data.
|
||||||
/// \param[in] quality The quality for the output tensor from 1 to 100. Default: 75.
|
/// \param[in] quality The quality for the output tensor, in range of [1, 100]. Default: 75.
|
||||||
/// \return The status code.
|
/// \return The status code.
|
||||||
Status DATASET_API EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output, int quality = 75);
|
Status DATASET_API EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output, int quality = 75);
|
||||||
|
|
||||||
|
/// \brief Encode the image as PNG data.
|
||||||
|
/// \param[in] image The image to be encoded.
|
||||||
|
/// \param[out] output The Tensor data.
|
||||||
|
/// \param[in] compression_level The compression_level for encoding, in range of [0, 9]. Default: 6.
|
||||||
|
/// \return The status code.
|
||||||
|
Status DATASET_API EncodePng(const mindspore::MSTensor &image, mindspore::MSTensor *output, int compression_level = 6);
|
||||||
|
|
||||||
/// \brief Apply histogram equalization on the input image.
|
/// \brief Apply histogram equalization on the input image.
|
||||||
class DATASET_API Equalize final : public TensorTransform {
|
class DATASET_API Equalize final : public TensorTransform {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -2230,8 +2230,9 @@ Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor>
|
||||||
}
|
}
|
||||||
|
|
||||||
if (quality < kMinJpegQuality || quality > kMaxJpegQuality) {
|
if (quality < kMinJpegQuality || quality > kMaxJpegQuality) {
|
||||||
err_msg = "EncodeJpeg: Invalid quality " + std::to_string(quality) + ", should be from " +
|
err_msg = "EncodeJpeg: Invalid quality " + std::to_string(quality) + ", should be in range of [" +
|
||||||
std::to_string(kMinJpegQuality) + " to " + std::to_string(kMaxJpegQuality) + ".";
|
std::to_string(kMinJpegQuality) + ", " + std::to_string(kMaxJpegQuality) + "].";
|
||||||
|
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2248,11 +2249,74 @@ Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor>
|
||||||
}
|
}
|
||||||
|
|
||||||
if (channels == kMinImageChannel) {
|
if (channels == kMinImageChannel) {
|
||||||
cv::imencode(".JPEG", image_matrix, buffer, params);
|
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_matrix, buffer, params),
|
||||||
|
"EncodeJpeg: Failed to encode image.");
|
||||||
} else {
|
} else {
|
||||||
cv::Mat image_bgr;
|
cv::Mat image_bgr;
|
||||||
cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR);
|
cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR);
|
||||||
cv::imencode(".JPEG", image_bgr, buffer, params);
|
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_bgr, buffer, params),
|
||||||
|
"EncodeJpeg: Failed to encode image.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape tensor_shape = TensorShape({(long int)buffer.size()});
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(tensor_shape, DataType(DataType::DE_UINT8), buffer.data(), output));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EncodePng(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor> *output, int compression_level) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(output);
|
||||||
|
|
||||||
|
std::string err_msg;
|
||||||
|
if (image->type() != DataType::DE_UINT8) {
|
||||||
|
err_msg = "EncodePng: The type of the image data should be UINT8, but got " + image->type().ToString() + ".";
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape shape = image->shape();
|
||||||
|
int rank = shape.Rank();
|
||||||
|
if (rank < kMinImageRank || rank > kDefaultImageRank) {
|
||||||
|
err_msg = "EncodePng: The image has invalid dimensions. It should have two or three dimensions, but got ";
|
||||||
|
err_msg += std::to_string(rank) + " dimensions.";
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
}
|
||||||
|
int channels;
|
||||||
|
if (rank == kDefaultImageRank) {
|
||||||
|
channels = shape[kMinImageRank];
|
||||||
|
if (channels != kMinImageChannel && channels != kDefaultImageChannel) {
|
||||||
|
err_msg = "EncodePng: The image has invalid channels. It should have 1 or 3 channels, but got ";
|
||||||
|
err_msg += std::to_string(channels) + " channels.";
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
channels = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (compression_level < kMinPngCompression || compression_level > kMaxPngCompression) {
|
||||||
|
err_msg = "EncodePng: Invalid compression_level " + std::to_string(compression_level) +
|
||||||
|
", should be in range of [" + std::to_string(kMinPngCompression) + ", " +
|
||||||
|
std::to_string(kMaxPngCompression) + "].";
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> params = {cv::IMWRITE_PNG_COMPRESSION, compression_level, cv::IMWRITE_PNG_STRATEGY,
|
||||||
|
cv::IMWRITE_PNG_STRATEGY_RLE};
|
||||||
|
std::vector<unsigned char> buffer;
|
||||||
|
cv::Mat image_matrix;
|
||||||
|
|
||||||
|
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(image);
|
||||||
|
image_matrix = input_cv->mat();
|
||||||
|
if (!image_matrix.data) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("EncodePng: Load the image tensor failed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (channels == kMinImageChannel) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".PNG", image_matrix, buffer, params),
|
||||||
|
"EncodePng: Failed to encode image.");
|
||||||
|
} else {
|
||||||
|
cv::Mat image_bgr;
|
||||||
|
cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".PNG", image_bgr, buffer, params), "EncodePng: Failed to encode image.");
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShape tensor_shape = TensorShape({(long int)buffer.size()});
|
TensorShape tensor_shape = TensorShape({(long int)buffer.size()});
|
||||||
|
@ -2296,7 +2360,7 @@ Status ReadImage(const std::string &filename, std::shared_ptr<Tensor> *output, I
|
||||||
int cv_mode = static_cast<int>(mode) - 1;
|
int cv_mode = static_cast<int>(mode) - 1;
|
||||||
image = cv::imread(realpath.value(), cv_mode);
|
image = cv::imread(realpath.value(), cv_mode);
|
||||||
if (image.data == nullptr) {
|
if (image.data == nullptr) {
|
||||||
RETURN_STATUS_UNEXPECTED("ReadImage: Can not read file " + filename);
|
RETURN_STATUS_UNEXPECTED("ReadImage: Failed to read file " + filename);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<CVTensor> output_cv;
|
std::shared_ptr<CVTensor> output_cv;
|
||||||
|
|
|
@ -57,6 +57,8 @@ constexpr dsize_t kHeightIndex = 0; // index of height of HWC images
|
||||||
constexpr dsize_t kWidthIndex = 1; // index of width of HWC images
|
constexpr dsize_t kWidthIndex = 1; // index of width of HWC images
|
||||||
constexpr dsize_t kMinJpegQuality = 1; // the minimum quality for JPEG
|
constexpr dsize_t kMinJpegQuality = 1; // the minimum quality for JPEG
|
||||||
constexpr dsize_t kMaxJpegQuality = 100; // the maximum quality for JPEG
|
constexpr dsize_t kMaxJpegQuality = 100; // the maximum quality for JPEG
|
||||||
|
constexpr dsize_t kMinPngCompression = 0; // the minimum compression level for PNG
|
||||||
|
constexpr dsize_t kMaxPngCompression = 9; // the maximum compression level for PNG
|
||||||
|
|
||||||
void JpegErrorExitCustom(j_common_ptr cinfo);
|
void JpegErrorExitCustom(j_common_ptr cinfo);
|
||||||
|
|
||||||
|
@ -506,10 +508,17 @@ Status ApplyAugment(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
||||||
/// \brief Encode the image as JPEG data.
|
/// \brief Encode the image as JPEG data.
|
||||||
/// \param[in] image The image to be encoded.
|
/// \param[in] image The image to be encoded.
|
||||||
/// \param[out] output The Tensor data.
|
/// \param[out] output The Tensor data.
|
||||||
/// \param[in] quality The quality for the output tensor from 1 to 100. Default: 75.
|
/// \param[in] quality The quality for the output tensor, in range of [1, 100]. Default: 75.
|
||||||
/// \return The status code.
|
/// \return The status code.
|
||||||
Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor> *output, int quality = 75);
|
Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor> *output, int quality = 75);
|
||||||
|
|
||||||
|
/// \brief Encode the image as PNG data.
|
||||||
|
/// \param[in] image The image to be encoded.
|
||||||
|
/// \param[out] output The Tensor data.
|
||||||
|
/// \param[in] compression_level The compression_level for encoding, in range of [0, 9]. Default: 6.
|
||||||
|
/// \return The status code.
|
||||||
|
Status EncodePng(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor> *output, int compression_level = 6);
|
||||||
|
|
||||||
/// \brief Reads a file in binary mode.
|
/// \brief Reads a file in binary mode.
|
||||||
/// \param[in] filename The path to the file to be read.
|
/// \param[in] filename The path to the file to be read.
|
||||||
/// \param[out] output The binary data.
|
/// \param[out] output The binary data.
|
||||||
|
|
|
@ -85,4 +85,4 @@ from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue
|
||||||
ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, \
|
ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, \
|
||||||
TrivialAugmentWide, UniformAugment, VerticalFlip, not_random
|
TrivialAugmentWide, UniformAugment, VerticalFlip, not_random
|
||||||
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, ImageReadMode, Inter, SliceMode, \
|
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, ImageReadMode, Inter, SliceMode, \
|
||||||
encode_jpeg, get_image_num_channels, get_image_size, read_file, read_image, write_file, write_jpeg
|
encode_jpeg, encode_png, get_image_num_channels, get_image_size, read_file, read_image, write_file, write_jpeg
|
||||||
|
|
|
@ -379,7 +379,7 @@ def encode_jpeg(image, quality=75):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
|
image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
|
||||||
quality (int, optional): Quality of the resulting JPEG data, from 1 to 100. Default: 75.
|
quality (int, optional): Quality of the resulting JPEG data, in range of [1, 100]. Default: 75.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray, one dimension uint8 data.
|
numpy.ndarray, one dimension uint8 data.
|
||||||
|
@ -410,6 +410,44 @@ def encode_jpeg(image, quality=75):
|
||||||
mindspore.Tensor, type(image)))
|
mindspore.Tensor, type(image)))
|
||||||
|
|
||||||
|
|
||||||
|
def encode_png(image, compression_level=6):
|
||||||
|
"""
|
||||||
|
Encode the input image as PNG data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
|
||||||
|
compression_level (int, optional): The compression_level for encoding, in range of [0, 9]. Default: 6.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray, one dimension uint8 data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `image` is not of type numpy.ndarray or mindspore.Tensor.
|
||||||
|
TypeError: If `compression_level` is not of type int.
|
||||||
|
RuntimeError: If the data type of `image` is not uint8.
|
||||||
|
RuntimeError: If the shape of `image` is not <H, W> or <H, W, 1> or <H, W, 3>.
|
||||||
|
RuntimeError: If `compression_level` is less than 0 or greater than 9.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> # Generate a random image with height=120, width=340, channels=3
|
||||||
|
>>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
|
||||||
|
>>> png_data = vision.encode_png(image)
|
||||||
|
"""
|
||||||
|
if not isinstance(compression_level, int):
|
||||||
|
raise TypeError("Input compression_level is not of type {0}, but got: {1}.".format(int,
|
||||||
|
type(compression_level)))
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
return cde.encode_png(cde.Tensor(image), compression_level).as_array()
|
||||||
|
if isinstance(image, mindspore.Tensor):
|
||||||
|
return cde.encode_png(cde.Tensor(image.asnumpy()), compression_level).as_array()
|
||||||
|
raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray,
|
||||||
|
mindspore.Tensor, type(image)))
|
||||||
|
|
||||||
|
|
||||||
def get_image_num_channels(image):
|
def get_image_num_channels(image):
|
||||||
"""
|
"""
|
||||||
Get the number of input image channels.
|
Get the number of input image channels.
|
||||||
|
|
|
@ -2831,3 +2831,85 @@ TEST_F(MindDataTestPipeline, TestEncodeJpegException) {
|
||||||
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||||
ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output));
|
ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Feature: EncodePng
|
||||||
|
/// Description: Test EncodePng by encoding the image as PNG data according to the compression_level
|
||||||
|
/// Expectation: Output is equal to the expected output
|
||||||
|
TEST_F(MindDataTestPipeline, TestEncodePngNormal) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesEncodePngNormal.";
|
||||||
|
mindspore::MSTensor output;
|
||||||
|
std::string folder_path = "./data/dataset/";
|
||||||
|
std::string filename;
|
||||||
|
const UINT8 *data;
|
||||||
|
|
||||||
|
filename = folder_path + "apple.jpg";
|
||||||
|
cv::Mat image_bgr = cv::imread(filename, cv::ImreadModes::IMREAD_UNCHANGED);
|
||||||
|
cv::Mat image;
|
||||||
|
cv::cvtColor(image_bgr, image, cv::COLOR_BGRA2RGB);
|
||||||
|
|
||||||
|
TensorShape tensor_shape = TensorShape({image.size[0], image.size[1], image.channels()});
|
||||||
|
DataType pixel_type = DataType(DataType::DE_UINT8);
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> input;
|
||||||
|
Tensor::CreateFromMemory(tensor_shape, pixel_type, image.data, &input);
|
||||||
|
auto input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||||
|
|
||||||
|
ASSERT_OK(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
|
||||||
|
data = (const UINT8 *) (output.Data().get());
|
||||||
|
EXPECT_EQ(data[0], 137);
|
||||||
|
EXPECT_EQ(data[1], 80);
|
||||||
|
EXPECT_EQ(data[2], 78);
|
||||||
|
EXPECT_EQ(data[3], 71);
|
||||||
|
|
||||||
|
int compression_level;
|
||||||
|
for (compression_level = 0; compression_level <= 9 ; compression_level++) {
|
||||||
|
ASSERT_OK(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output, compression_level));
|
||||||
|
data = (const UINT8 *) (output.Data().get());
|
||||||
|
EXPECT_EQ(data[1], 80);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: EncodePng
|
||||||
|
/// Description: Test EncodePng with invalid parameter
|
||||||
|
/// Expectation: Error is caught when the parameter is invalid
|
||||||
|
TEST_F(MindDataTestPipeline, TestEncodePngException) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesEncodePngException.";
|
||||||
|
mindspore::MSTensor output;
|
||||||
|
std::string folder_path = "./data/dataset/";
|
||||||
|
std::string filename;
|
||||||
|
|
||||||
|
filename = folder_path + "apple.jpg";
|
||||||
|
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_UNCHANGED);
|
||||||
|
|
||||||
|
TensorShape tensor_shape = TensorShape({image.size[0], image.size[1], image.channels()});
|
||||||
|
DataType pixel_type = DataType(DataType::DE_UINT8);
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> input;
|
||||||
|
Tensor::CreateFromMemory(tensor_shape, pixel_type, image.data, &input);
|
||||||
|
auto input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||||
|
|
||||||
|
// Test with an invalid compression_level integer
|
||||||
|
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output, -1));
|
||||||
|
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output, 10));
|
||||||
|
|
||||||
|
// Test with an invalid image with the type of float
|
||||||
|
std::shared_ptr<Tensor> float32_de_tensor;
|
||||||
|
Tensor::CreateEmpty(TensorShape({5, 4, 3 }), DataType(DataType::DE_FLOAT32), &float32_de_tensor);
|
||||||
|
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(float32_de_tensor));
|
||||||
|
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
|
||||||
|
|
||||||
|
// Test with an invalid image with only one dimension
|
||||||
|
input->Reshape(TensorShape({image.size[0] * image.size[1] * image.channels()}));
|
||||||
|
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||||
|
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
|
||||||
|
|
||||||
|
// Test with an invalid image with four dimensions
|
||||||
|
input->Reshape(TensorShape({image.size[0] / 2, image.size[1], image.channels(), 2}));
|
||||||
|
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||||
|
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
|
||||||
|
|
||||||
|
// Test with an invalid image with two channels
|
||||||
|
input->Reshape(TensorShape({image.size[0] * image.channels() / 2, image.size[1], 2}));
|
||||||
|
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||||
|
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing encode_png
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
import numpy
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_png_three_channels():
|
||||||
|
"""
|
||||||
|
Feature: encode_png
|
||||||
|
Description: Test encode_png by encoding the three channels image as PNG data according to the compression_level
|
||||||
|
Expectation: Output is equal to the expected output
|
||||||
|
"""
|
||||||
|
filename = "../data/dataset/apple.jpg"
|
||||||
|
mode = cv2.IMREAD_UNCHANGED
|
||||||
|
image = cv2.imread(filename, mode)
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Test with numpy:ndarray and default compression_level
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(image_rgb)
|
||||||
|
assert encoded_png.dtype == numpy.uint8
|
||||||
|
assert encoded_png[0] == 137
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
assert encoded_png[2] == 78
|
||||||
|
assert encoded_png[3] == 71
|
||||||
|
|
||||||
|
# Test with Tensor and compression_level
|
||||||
|
input_tensor = mindspore.Tensor.from_numpy(image_rgb)
|
||||||
|
encoded_png_6 = mindspore.dataset.vision.encode_png(input_tensor, 6)
|
||||||
|
assert encoded_png_6[1] == 80
|
||||||
|
|
||||||
|
# Test with the minimum compression_level
|
||||||
|
encoded_png_0 = mindspore.dataset.vision.encode_png(input_tensor, 0)
|
||||||
|
assert encoded_png_0[1] == 80
|
||||||
|
|
||||||
|
# Test with the maximum compression_level
|
||||||
|
encoded_png_9 = mindspore.dataset.vision.encode_png(input_tensor, 9)
|
||||||
|
assert encoded_png_9[1] == 80
|
||||||
|
|
||||||
|
# Test with three channels 12*34*3 random uint8
|
||||||
|
image_random = numpy.ndarray(shape=(12, 34, 3), dtype=numpy.uint8)
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(image_random)
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(mindspore.Tensor.from_numpy(image_random))
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_png_one_channel():
|
||||||
|
"""
|
||||||
|
Feature: encode_png
|
||||||
|
Description: Test encode_png by encoding the one channel image as PNG data
|
||||||
|
Expectation: Output is equal to the expected output
|
||||||
|
"""
|
||||||
|
filename = "../data/dataset/apple.jpg"
|
||||||
|
mode = cv2.IMREAD_UNCHANGED
|
||||||
|
image = cv2.imread(filename, mode)
|
||||||
|
|
||||||
|
# Test with one channel image_grayscale
|
||||||
|
image_grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(image_grayscale)
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(mindspore.Tensor.from_numpy(image_grayscale))
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
|
||||||
|
# Test with one channel 12*34 random uint8
|
||||||
|
image_random = numpy.ndarray(shape=(12, 34), dtype=numpy.uint8)
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(image_random)
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(mindspore.Tensor.from_numpy(image_random))
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
|
||||||
|
# Test with one channel 12*34*1 random uint8
|
||||||
|
image_random = numpy.ndarray(shape=(12, 34, 1), dtype=numpy.uint8)
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(image_random)
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
encoded_png = mindspore.dataset.vision.encode_png(mindspore.Tensor.from_numpy(image_random))
|
||||||
|
assert encoded_png[1] == 80
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_png_exception():
|
||||||
|
"""
|
||||||
|
Feature: encode_png
|
||||||
|
Description: Test encode_png with invalid parameter
|
||||||
|
Expectation: Error is caught when the parameter is invalid
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_invalid_param(image_param, compression_level_param, error, error_msg):
|
||||||
|
"""
|
||||||
|
a function used for checking correct error and message with invalid parameter
|
||||||
|
"""
|
||||||
|
with pytest.raises(error) as error_info:
|
||||||
|
mindspore.dataset.vision.encode_png(image_param, compression_level_param)
|
||||||
|
assert error_msg in str(error_info.value)
|
||||||
|
|
||||||
|
filename = "../data/dataset/apple.jpg"
|
||||||
|
mode = cv2.IMREAD_UNCHANGED
|
||||||
|
image = cv2.imread(filename, mode)
|
||||||
|
|
||||||
|
# Test with an invalid integer for the compression_level
|
||||||
|
error_message = "Invalid compression_level"
|
||||||
|
test_invalid_param(image, -1, RuntimeError, error_message)
|
||||||
|
test_invalid_param(image, 10, RuntimeError, error_message)
|
||||||
|
|
||||||
|
# Test with an invalid type for the compression_level
|
||||||
|
error_message = "Input compression_level is not of type"
|
||||||
|
test_invalid_param(image, 6.0, TypeError, error_message)
|
||||||
|
|
||||||
|
# Test with an invalid image containing the float elements
|
||||||
|
invalid_image = numpy.ndarray(shape=(10, 10, 3), dtype=float)
|
||||||
|
error_message = "The type of the image data"
|
||||||
|
test_invalid_param(invalid_image, 6, RuntimeError, error_message)
|
||||||
|
|
||||||
|
# Test with an invalid type for the image
|
||||||
|
error_message = "Input image is not of type"
|
||||||
|
test_invalid_param("invalid_image", 6, TypeError, error_message)
|
||||||
|
|
||||||
|
# Test with an invalid image with only one dimension
|
||||||
|
invalid_image = numpy.ndarray(shape=(10), dtype=numpy.uint8)
|
||||||
|
error_message = "The image has invalid dimensions"
|
||||||
|
test_invalid_param(invalid_image, 6, RuntimeError, error_message)
|
||||||
|
|
||||||
|
# Test with an invalid image with four dimensions
|
||||||
|
invalid_image = numpy.ndarray(shape=(10, 10, 10, 3), dtype=numpy.uint8)
|
||||||
|
test_invalid_param(invalid_image, 6, RuntimeError, error_message)
|
||||||
|
|
||||||
|
# Test with an invalid image with two channels
|
||||||
|
invalid_image = numpy.ndarray(shape=(10, 10, 2), dtype=numpy.uint8)
|
||||||
|
error_message = "The image has invalid channels"
|
||||||
|
test_invalid_param(invalid_image, 6, RuntimeError, error_message)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_encode_png_three_channels()
|
||||||
|
test_encode_png_one_channel()
|
||||||
|
test_encode_png_exception()
|
|
@ -147,7 +147,7 @@ def test_read_image_exception():
|
||||||
|
|
||||||
# Test with a not supported gif file
|
# Test with a not supported gif file
|
||||||
wrong_filename = "../data/dataset/testFormats/apple.gif"
|
wrong_filename = "../data/dataset/testFormats/apple.gif"
|
||||||
error_message = "Can not read file " + wrong_filename
|
error_message = "Failed to read file " + wrong_filename
|
||||||
test_invalid_param(wrong_filename, ImageReadMode.COLOR, RuntimeError, error_message)
|
test_invalid_param(wrong_filename, ImageReadMode.COLOR, RuntimeError, error_message)
|
||||||
|
|
||||||
# Test with an invalid type for the filename
|
# Test with an invalid type for the filename
|
||||||
|
|
Loading…
Reference in New Issue