!40711 【算子众智】【电子科技大学】【数据算子】【encode_png】encode a image as png

Merge pull request !40711 from dengjian/upstream_encode_png
This commit is contained in:
i-robot 2022-11-21 07:08:01 +00:00 committed by Gitee
commit 2c1ebbd63f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 408 additions and 12 deletions

View File

@ -7,10 +7,10 @@ mindspore.dataset.vision.encode_jpeg
参数:
- **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。
- **quality** (int, 可选) - 生成的JPEG数据的质量从1到100。默认值75。
- **quality** (int, 可选) - 生成的JPEG数据的质量取值范围为[1, 100]。默认值: 75。
返回:
- numpy:ndarray, 一维uint8类型数据。
- numpy.ndarray, 一维uint8类型数据。
异常:
- **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。

View File

@ -0,0 +1,20 @@
mindspore.dataset.vision.encode_png
===================================
.. py:function:: mindspore.dataset.vision.encode_png(image, compression_level=6)
将输入的图像编码为PNG数据。
参数:
- **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。
- **compression_level** (int, 可选) - 编码压缩因子,取值范围为[0, 9]。默认值: 6。
返回:
- numpy.ndarray, 一维uint8类型数据。
异常:
- **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。
- **TypeError** - 如果 `compression_level` 不是int类型。
- **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。
- **RuntimeError** - 如果 `image` 的shape不是 <H, W> 或 <H, W, 1> 或 <H, W, 3>。
- **RuntimeError** - 如果 `compression_level` 小于0或大于9。

View File

@ -143,6 +143,7 @@ API样例中常用的导入模块如下
mindspore.dataset.vision.Inter
mindspore.dataset.vision.SliceMode
mindspore.dataset.vision.encode_jpeg
mindspore.dataset.vision.encode_png
mindspore.dataset.vision.get_image_num_channels
mindspore.dataset.vision.get_image_size
mindspore.dataset.vision.read_file

View File

@ -93,6 +93,7 @@ Utilities
mindspore.dataset.vision.Inter
mindspore.dataset.vision.SliceMode
mindspore.dataset.vision.encode_jpeg
mindspore.dataset.vision.encode_png
mindspore.dataset.vision.get_image_num_channels
mindspore.dataset.vision.get_image_size
mindspore.dataset.vision.read_file

View File

@ -279,6 +279,14 @@ PYBIND_REGISTER(EncodeJpegOperation, 1, ([](py::module *m) {
}));
}));
PYBIND_REGISTER(EncodePNGOperation, 1, ([](py::module *m) {
(void)m->def("encode_png", ([](const std::shared_ptr<Tensor> &image, int compression_level) {
std::shared_ptr<Tensor> output;
THROW_IF_ERROR(EncodePng(image, &output, compression_level));
return output;
}));
}));
PYBIND_REGISTER(EqualizeOperation, 1, ([](const py::module *m) {
(void)
py::class_<vision::EqualizeOperation, TensorOperation, std::shared_ptr<vision::EqualizeOperation>>(

View File

@ -468,6 +468,7 @@ std::shared_ptr<TensorOperation> DvppDecodePng::Parse(const MapTargetDevice &env
}
#endif
#ifndef ENABLE_ANDROID
// EncodeJpeg Function.
Status EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output, int quality) {
RETURN_UNEXPECTED_IF_NULL(output);
@ -481,6 +482,19 @@ Status EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output,
return Status::OK();
}
// EncodePng Function.
Status EncodePng(const mindspore::MSTensor &image, mindspore::MSTensor *output, int compression_level) {
RETURN_UNEXPECTED_IF_NULL(output);
std::shared_ptr<dataset::Tensor> input;
RETURN_IF_NOT_OK(Tensor::CreateFromMSTensor(image, &input));
TensorPtr de_tensor;
RETURN_IF_NOT_OK(mindspore::dataset::EncodePng(input, &de_tensor, compression_level));
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(),
"EncodePng: get an empty tensor with shape " + de_tensor->shape().ToString());
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
return Status::OK();
}
// Equalize Transform Operation.
Equalize::Equalize() = default;

View File

@ -443,10 +443,17 @@ class DATASET_API CutOut final : public TensorTransform {
/// \brief Encode the image as JPEG data.
/// \param[in] image The image to be encoded.
/// \param[out] output The Tensor data.
/// \param[in] quality The quality for the output tensor from 1 to 100. Default: 75.
/// \param[in] quality The quality for the output tensor, in range of [1, 100]. Default: 75.
/// \return The status code.
Status DATASET_API EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output, int quality = 75);
/// \brief Encode the image as PNG data.
/// \param[in] image The image to be encoded.
/// \param[out] output The Tensor data.
/// \param[in] compression_level The compression_level for encoding, in range of [0, 9]. Default: 6.
/// \return The status code.
Status DATASET_API EncodePng(const mindspore::MSTensor &image, mindspore::MSTensor *output, int compression_level = 6);
/// \brief Apply histogram equalization on the input image.
class DATASET_API Equalize final : public TensorTransform {
public:

View File

@ -2230,8 +2230,9 @@ Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor>
}
if (quality < kMinJpegQuality || quality > kMaxJpegQuality) {
err_msg = "EncodeJpeg: Invalid quality " + std::to_string(quality) + ", should be from " +
std::to_string(kMinJpegQuality) + " to " + std::to_string(kMaxJpegQuality) + ".";
err_msg = "EncodeJpeg: Invalid quality " + std::to_string(quality) + ", should be in range of [" +
std::to_string(kMinJpegQuality) + ", " + std::to_string(kMaxJpegQuality) + "].";
RETURN_STATUS_UNEXPECTED(err_msg);
}
@ -2248,11 +2249,74 @@ Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor>
}
if (channels == kMinImageChannel) {
cv::imencode(".JPEG", image_matrix, buffer, params);
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_matrix, buffer, params),
"EncodeJpeg: Failed to encode image.");
} else {
cv::Mat image_bgr;
cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR);
cv::imencode(".JPEG", image_bgr, buffer, params);
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_bgr, buffer, params),
"EncodeJpeg: Failed to encode image.");
}
TensorShape tensor_shape = TensorShape({(long int)buffer.size()});
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(tensor_shape, DataType(DataType::DE_UINT8), buffer.data(), output));
return Status::OK();
}
Status EncodePng(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor> *output, int compression_level) {
RETURN_UNEXPECTED_IF_NULL(output);
std::string err_msg;
if (image->type() != DataType::DE_UINT8) {
err_msg = "EncodePng: The type of the image data should be UINT8, but got " + image->type().ToString() + ".";
RETURN_STATUS_UNEXPECTED(err_msg);
}
TensorShape shape = image->shape();
int rank = shape.Rank();
if (rank < kMinImageRank || rank > kDefaultImageRank) {
err_msg = "EncodePng: The image has invalid dimensions. It should have two or three dimensions, but got ";
err_msg += std::to_string(rank) + " dimensions.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
int channels;
if (rank == kDefaultImageRank) {
channels = shape[kMinImageRank];
if (channels != kMinImageChannel && channels != kDefaultImageChannel) {
err_msg = "EncodePng: The image has invalid channels. It should have 1 or 3 channels, but got ";
err_msg += std::to_string(channels) + " channels.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
} else {
channels = 1;
}
if (compression_level < kMinPngCompression || compression_level > kMaxPngCompression) {
err_msg = "EncodePng: Invalid compression_level " + std::to_string(compression_level) +
", should be in range of [" + std::to_string(kMinPngCompression) + ", " +
std::to_string(kMaxPngCompression) + "].";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::vector<int> params = {cv::IMWRITE_PNG_COMPRESSION, compression_level, cv::IMWRITE_PNG_STRATEGY,
cv::IMWRITE_PNG_STRATEGY_RLE};
std::vector<unsigned char> buffer;
cv::Mat image_matrix;
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(image);
image_matrix = input_cv->mat();
if (!image_matrix.data) {
RETURN_STATUS_UNEXPECTED("EncodePng: Load the image tensor failed.");
}
if (channels == kMinImageChannel) {
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".PNG", image_matrix, buffer, params),
"EncodePng: Failed to encode image.");
} else {
cv::Mat image_bgr;
cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR);
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".PNG", image_bgr, buffer, params), "EncodePng: Failed to encode image.");
}
TensorShape tensor_shape = TensorShape({(long int)buffer.size()});
@ -2296,7 +2360,7 @@ Status ReadImage(const std::string &filename, std::shared_ptr<Tensor> *output, I
int cv_mode = static_cast<int>(mode) - 1;
image = cv::imread(realpath.value(), cv_mode);
if (image.data == nullptr) {
RETURN_STATUS_UNEXPECTED("ReadImage: Can not read file " + filename);
RETURN_STATUS_UNEXPECTED("ReadImage: Failed to read file " + filename);
}
std::shared_ptr<CVTensor> output_cv;

View File

@ -57,6 +57,8 @@ constexpr dsize_t kHeightIndex = 0; // index of height of HWC images
constexpr dsize_t kWidthIndex = 1; // index of width of HWC images
constexpr dsize_t kMinJpegQuality = 1; // the minimum quality for JPEG
constexpr dsize_t kMaxJpegQuality = 100; // the maximum quality for JPEG
constexpr dsize_t kMinPngCompression = 0; // the minimum compression level for PNG
constexpr dsize_t kMaxPngCompression = 9; // the maximum compression level for PNG
void JpegErrorExitCustom(j_common_ptr cinfo);
@ -506,10 +508,17 @@ Status ApplyAugment(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
/// \brief Encode the image as JPEG data.
/// \param[in] image The image to be encoded.
/// \param[out] output The Tensor data.
/// \param[in] quality The quality for the output tensor from 1 to 100. Default: 75.
/// \param[in] quality The quality for the output tensor, in range of [1, 100]. Default: 75.
/// \return The status code.
Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor> *output, int quality = 75);
/// \brief Encode the image as PNG data.
/// \param[in] image The image to be encoded.
/// \param[out] output The Tensor data.
/// \param[in] compression_level The compression_level for encoding, in range of [0, 9]. Default: 6.
/// \return The status code.
Status EncodePng(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor> *output, int compression_level = 6);
/// \brief Reads a file in binary mode.
/// \param[in] filename The path to the file to be read.
/// \param[out] output The binary data.

View File

@ -85,4 +85,4 @@ from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue
ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, \
TrivialAugmentWide, UniformAugment, VerticalFlip, not_random
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, ImageReadMode, Inter, SliceMode, \
encode_jpeg, get_image_num_channels, get_image_size, read_file, read_image, write_file, write_jpeg
encode_jpeg, encode_png, get_image_num_channels, get_image_size, read_file, read_image, write_file, write_jpeg

View File

@ -379,7 +379,7 @@ def encode_jpeg(image, quality=75):
Args:
image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
quality (int, optional): Quality of the resulting JPEG data, from 1 to 100. Default: 75.
quality (int, optional): Quality of the resulting JPEG data, in range of [1, 100]. Default: 75.
Returns:
numpy.ndarray, one dimension uint8 data.
@ -410,6 +410,44 @@ def encode_jpeg(image, quality=75):
mindspore.Tensor, type(image)))
def encode_png(image, compression_level=6):
"""
Encode the input image as PNG data.
Args:
image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
compression_level (int, optional): The compression_level for encoding, in range of [0, 9]. Default: 6.
Returns:
numpy.ndarray, one dimension uint8 data.
Raises:
TypeError: If `image` is not of type numpy.ndarray or mindspore.Tensor.
TypeError: If `compression_level` is not of type int.
RuntimeError: If the data type of `image` is not uint8.
RuntimeError: If the shape of `image` is not <H, W> or <H, W, 1> or <H, W, 3>.
RuntimeError: If `compression_level` is less than 0 or greater than 9.
Supported Platforms:
``CPU``
Examples:
>>> import numpy as np
>>> # Generate a random image with height=120, width=340, channels=3
>>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
>>> png_data = vision.encode_png(image)
"""
if not isinstance(compression_level, int):
raise TypeError("Input compression_level is not of type {0}, but got: {1}.".format(int,
type(compression_level)))
if isinstance(image, np.ndarray):
return cde.encode_png(cde.Tensor(image), compression_level).as_array()
if isinstance(image, mindspore.Tensor):
return cde.encode_png(cde.Tensor(image.asnumpy()), compression_level).as_array()
raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray,
mindspore.Tensor, type(image)))
def get_image_num_channels(image):
"""
Get the number of input image channels.

82
tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc Normal file → Executable file
View File

@ -2831,3 +2831,85 @@ TEST_F(MindDataTestPipeline, TestEncodeJpegException) {
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output));
}
/// Feature: EncodePng
/// Description: Test EncodePng by encoding the image as PNG data according to the compression_level
/// Expectation: Output is equal to the expected output
TEST_F(MindDataTestPipeline, TestEncodePngNormal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesEncodePngNormal.";
mindspore::MSTensor output;
std::string folder_path = "./data/dataset/";
std::string filename;
const UINT8 *data;
filename = folder_path + "apple.jpg";
cv::Mat image_bgr = cv::imread(filename, cv::ImreadModes::IMREAD_UNCHANGED);
cv::Mat image;
cv::cvtColor(image_bgr, image, cv::COLOR_BGRA2RGB);
TensorShape tensor_shape = TensorShape({image.size[0], image.size[1], image.channels()});
DataType pixel_type = DataType(DataType::DE_UINT8);
std::shared_ptr<Tensor> input;
Tensor::CreateFromMemory(tensor_shape, pixel_type, image.data, &input);
auto input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
ASSERT_OK(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
data = (const UINT8 *) (output.Data().get());
EXPECT_EQ(data[0], 137);
EXPECT_EQ(data[1], 80);
EXPECT_EQ(data[2], 78);
EXPECT_EQ(data[3], 71);
int compression_level;
for (compression_level = 0; compression_level <= 9 ; compression_level++) {
ASSERT_OK(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output, compression_level));
data = (const UINT8 *) (output.Data().get());
EXPECT_EQ(data[1], 80);
}
}
/// Feature: EncodePng
/// Description: Test EncodePng with invalid parameter
/// Expectation: Error is caught when the parameter is invalid
TEST_F(MindDataTestPipeline, TestEncodePngException) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesEncodePngException.";
mindspore::MSTensor output;
std::string folder_path = "./data/dataset/";
std::string filename;
filename = folder_path + "apple.jpg";
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_UNCHANGED);
TensorShape tensor_shape = TensorShape({image.size[0], image.size[1], image.channels()});
DataType pixel_type = DataType(DataType::DE_UINT8);
std::shared_ptr<Tensor> input;
Tensor::CreateFromMemory(tensor_shape, pixel_type, image.data, &input);
auto input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
// Test with an invalid compression_level integer
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output, -1));
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output, 10));
// Test with an invalid image with the type of float
std::shared_ptr<Tensor> float32_de_tensor;
Tensor::CreateEmpty(TensorShape({5, 4, 3 }), DataType(DataType::DE_FLOAT32), &float32_de_tensor);
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(float32_de_tensor));
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
// Test with an invalid image with only one dimension
input->Reshape(TensorShape({image.size[0] * image.size[1] * image.channels()}));
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
// Test with an invalid image with four dimensions
input->Reshape(TensorShape({image.size[0] / 2, image.size[1], image.channels(), 2}));
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
// Test with an invalid image with two channels
input->Reshape(TensorShape({image.size[0] * image.channels() / 2, image.size[1], 2}));
input_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
ASSERT_ERROR(mindspore::dataset::vision::EncodePng(input_ms_tensor, &output));
}

View File

@ -0,0 +1,152 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing encode_png
"""
import cv2
import numpy
import pytest
import mindspore
def test_encode_png_three_channels():
"""
Feature: encode_png
Description: Test encode_png by encoding the three channels image as PNG data according to the compression_level
Expectation: Output is equal to the expected output
"""
filename = "../data/dataset/apple.jpg"
mode = cv2.IMREAD_UNCHANGED
image = cv2.imread(filename, mode)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Test with numpy:ndarray and default compression_level
encoded_png = mindspore.dataset.vision.encode_png(image_rgb)
assert encoded_png.dtype == numpy.uint8
assert encoded_png[0] == 137
assert encoded_png[1] == 80
assert encoded_png[2] == 78
assert encoded_png[3] == 71
# Test with Tensor and compression_level
input_tensor = mindspore.Tensor.from_numpy(image_rgb)
encoded_png_6 = mindspore.dataset.vision.encode_png(input_tensor, 6)
assert encoded_png_6[1] == 80
# Test with the minimum compression_level
encoded_png_0 = mindspore.dataset.vision.encode_png(input_tensor, 0)
assert encoded_png_0[1] == 80
# Test with the maximum compression_level
encoded_png_9 = mindspore.dataset.vision.encode_png(input_tensor, 9)
assert encoded_png_9[1] == 80
# Test with three channels 12*34*3 random uint8
image_random = numpy.ndarray(shape=(12, 34, 3), dtype=numpy.uint8)
encoded_png = mindspore.dataset.vision.encode_png(image_random)
assert encoded_png[1] == 80
encoded_png = mindspore.dataset.vision.encode_png(mindspore.Tensor.from_numpy(image_random))
assert encoded_png[1] == 80
def test_encode_png_one_channel():
"""
Feature: encode_png
Description: Test encode_png by encoding the one channel image as PNG data
Expectation: Output is equal to the expected output
"""
filename = "../data/dataset/apple.jpg"
mode = cv2.IMREAD_UNCHANGED
image = cv2.imread(filename, mode)
# Test with one channel image_grayscale
image_grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
encoded_png = mindspore.dataset.vision.encode_png(image_grayscale)
assert encoded_png[1] == 80
encoded_png = mindspore.dataset.vision.encode_png(mindspore.Tensor.from_numpy(image_grayscale))
assert encoded_png[1] == 80
# Test with one channel 12*34 random uint8
image_random = numpy.ndarray(shape=(12, 34), dtype=numpy.uint8)
encoded_png = mindspore.dataset.vision.encode_png(image_random)
assert encoded_png[1] == 80
encoded_png = mindspore.dataset.vision.encode_png(mindspore.Tensor.from_numpy(image_random))
assert encoded_png[1] == 80
# Test with one channel 12*34*1 random uint8
image_random = numpy.ndarray(shape=(12, 34, 1), dtype=numpy.uint8)
encoded_png = mindspore.dataset.vision.encode_png(image_random)
assert encoded_png[1] == 80
encoded_png = mindspore.dataset.vision.encode_png(mindspore.Tensor.from_numpy(image_random))
assert encoded_png[1] == 80
def test_encode_png_exception():
"""
Feature: encode_png
Description: Test encode_png with invalid parameter
Expectation: Error is caught when the parameter is invalid
"""
def test_invalid_param(image_param, compression_level_param, error, error_msg):
"""
a function used for checking correct error and message with invalid parameter
"""
with pytest.raises(error) as error_info:
mindspore.dataset.vision.encode_png(image_param, compression_level_param)
assert error_msg in str(error_info.value)
filename = "../data/dataset/apple.jpg"
mode = cv2.IMREAD_UNCHANGED
image = cv2.imread(filename, mode)
# Test with an invalid integer for the compression_level
error_message = "Invalid compression_level"
test_invalid_param(image, -1, RuntimeError, error_message)
test_invalid_param(image, 10, RuntimeError, error_message)
# Test with an invalid type for the compression_level
error_message = "Input compression_level is not of type"
test_invalid_param(image, 6.0, TypeError, error_message)
# Test with an invalid image containing the float elements
invalid_image = numpy.ndarray(shape=(10, 10, 3), dtype=float)
error_message = "The type of the image data"
test_invalid_param(invalid_image, 6, RuntimeError, error_message)
# Test with an invalid type for the image
error_message = "Input image is not of type"
test_invalid_param("invalid_image", 6, TypeError, error_message)
# Test with an invalid image with only one dimension
invalid_image = numpy.ndarray(shape=(10), dtype=numpy.uint8)
error_message = "The image has invalid dimensions"
test_invalid_param(invalid_image, 6, RuntimeError, error_message)
# Test with an invalid image with four dimensions
invalid_image = numpy.ndarray(shape=(10, 10, 10, 3), dtype=numpy.uint8)
test_invalid_param(invalid_image, 6, RuntimeError, error_message)
# Test with an invalid image with two channels
invalid_image = numpy.ndarray(shape=(10, 10, 2), dtype=numpy.uint8)
error_message = "The image has invalid channels"
test_invalid_param(invalid_image, 6, RuntimeError, error_message)
if __name__ == "__main__":
test_encode_png_three_channels()
test_encode_png_one_channel()
test_encode_png_exception()

View File

@ -147,7 +147,7 @@ def test_read_image_exception():
# Test with a not supported gif file
wrong_filename = "../data/dataset/testFormats/apple.gif"
error_message = "Can not read file " + wrong_filename
error_message = "Failed to read file " + wrong_filename
test_invalid_param(wrong_filename, ImageReadMode.COLOR, RuntimeError, error_message)
# Test with an invalid type for the filename