diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_jpeg.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_jpeg.rst new file mode 100755 index 00000000000..b308bc6711c --- /dev/null +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_jpeg.rst @@ -0,0 +1,20 @@ +mindspore.dataset.vision.encode_jpeg +==================================== + +.. py:function:: mindspore.dataset.vision.encode_jpeg(image, quality=75) + + 将输入的图像编码为JPEG数据。 + + 参数: + - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。 + - **quality** (int, 可选) - 生成的JPEG数据的质量,从1到100。默认值75。 + + 返回: + - numpy:ndarray, 一维uint8类型数据。 + + 异常: + - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 + - **TypeError** - 如果 `quality` 不是int类型。 + - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 + - **RuntimeError** - 如果 `image` 的shape不是 。 + - **RuntimeError** - 如果 `quality` 小于1或大于100。 diff --git a/docs/api/api_python/mindspore.dataset.vision.rst b/docs/api/api_python/mindspore.dataset.vision.rst index 3764e8c9961..18fbe80a460 100644 --- a/docs/api/api_python/mindspore.dataset.vision.rst +++ b/docs/api/api_python/mindspore.dataset.vision.rst @@ -141,5 +141,6 @@ API样例中常用的导入模块如下: mindspore.dataset.vision.ImageBatchFormat mindspore.dataset.vision.Inter mindspore.dataset.vision.SliceMode + mindspore.dataset.vision.encode_jpeg mindspore.dataset.vision.get_image_num_channels mindspore.dataset.vision.get_image_size diff --git a/docs/api/api_python_en/mindspore.dataset.vision.rst b/docs/api/api_python_en/mindspore.dataset.vision.rst index f8a8c5fb73b..473d2ed0af3 100644 --- a/docs/api/api_python_en/mindspore.dataset.vision.rst +++ b/docs/api/api_python_en/mindspore.dataset.vision.rst @@ -91,5 +91,6 @@ Utilities mindspore.dataset.vision.ImageBatchFormat mindspore.dataset.vision.Inter mindspore.dataset.vision.SliceMode + mindspore.dataset.vision.encode_jpeg mindspore.dataset.vision.get_image_num_channels mindspore.dataset.vision.get_image_size diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc index e22753383a5..a851734974f 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc @@ -271,6 +271,14 @@ PYBIND_REGISTER(DecodeOperation, 1, ([](const py::module *m) { })); })); +PYBIND_REGISTER(EncodeJpegOperation, 1, ([](py::module *m) { + (void)m->def("encode_jpeg", ([](const std::shared_ptr &image, int quality) { + std::shared_ptr output; + THROW_IF_ERROR(EncodeJpeg(image, &output, quality)); + return output; + })); + })); + PYBIND_REGISTER(EqualizeOperation, 1, ([](const py::module *m) { (void) py::class_>( diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index 2165b2d22bc..5ea18ce8622 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -465,6 +465,19 @@ std::shared_ptr DvppDecodePng::Parse(const MapTargetDevice &env } #endif #ifndef ENABLE_ANDROID +// EncodeJpeg Function. +Status EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output, int quality) { + RETURN_UNEXPECTED_IF_NULL(output); + std::shared_ptr input; + RETURN_IF_NOT_OK(Tensor::CreateFromMSTensor(image, &input)); + std::shared_ptr de_tensor; + RETURN_IF_NOT_OK(mindspore::dataset::EncodeJpeg(input, &de_tensor, quality)); + CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), + "EncodeJpeg: get an empty tensor with shape " + de_tensor->shape().ToString()); + *output = mindspore::MSTensor(std::make_shared(de_tensor)); + return Status::OK(); +} + // Equalize Transform Operation. Equalize::Equalize() = default; diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h index f837e40d74b..7a0bde694d4 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h @@ -440,6 +440,13 @@ class DATASET_API CutOut final : public TensorTransform { std::shared_ptr data_; }; +/// \brief Encode the image as JPEG data. +/// \param[in] image The image to be encoded. +/// \param[out] output The Tensor data. +/// \param[in] quality The quality for the output tensor from 1 to 100. Default: 75. +/// \return The status code. +Status DATASET_API EncodeJpeg(const mindspore::MSTensor &image, mindspore::MSTensor *output, int quality = 75); + /// \brief Apply histogram equalization on the input image. class DATASET_API Equalize final : public TensorTransform { public: diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index b3508d55b53..c3419e0f5cc 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -2198,5 +2198,65 @@ Status ApplyAugment(const std::shared_ptr &input, std::shared_ptr &image, std::shared_ptr *output, int quality) { + RETURN_UNEXPECTED_IF_NULL(output); + + std::string err_msg; + if (image->type() != DataType::DE_UINT8) { + err_msg = "EncodeJpeg: The type of the image data should be UINT8, but got " + image->type().ToString() + "."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + TensorShape shape = image->shape(); + int rank = shape.Rank(); + if (rank < kMinImageRank || rank > kDefaultImageRank) { + err_msg = "EncodeJpeg: The image has invalid dimensions. It should have two or three dimensions, but got "; + err_msg += std::to_string(rank) + " dimensions."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + int channels; + if (rank == kDefaultImageRank) { + channels = shape[kMinImageRank]; + if (channels != kMinImageChannel && channels != kDefaultImageChannel) { + err_msg = "EncodeJpeg: The image has invalid channels. It should have 1 or 3 channels, but got "; + err_msg += std::to_string(channels) + " channels."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } else { + channels = 1; + } + + if (quality < kMinJpegQuality || quality > kMaxJpegQuality) { + err_msg = "EncodeJpeg: Invalid quality " + std::to_string(quality) + ", should be from " + + std::to_string(kMinJpegQuality) + " to " + std::to_string(kMaxJpegQuality) + "."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::vector params = {cv::IMWRITE_JPEG_QUALITY, quality, cv::IMWRITE_JPEG_PROGRESSIVE, 0, + cv::IMWRITE_JPEG_OPTIMIZE, 0, cv::IMWRITE_JPEG_RST_INTERVAL, 0}; + + std::vector buffer; + cv::Mat image_matrix; + + std::shared_ptr input_cv = CVTensor::AsCVTensor(image); + image_matrix = input_cv->mat(); + if (!image_matrix.data) { + RETURN_STATUS_UNEXPECTED("[Internal ERROR] EncodeJpeg: load the image tensor failed."); + } + + if (channels == kMinImageChannel) { + cv::imencode(".JPEG", image_matrix, buffer, params); + } else { + cv::Mat image_bgr; + cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR); + cv::imencode(".JPEG", image_bgr, buffer, params); + } + + TensorShape tensor_shape = TensorShape({(long int)buffer.size()}); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(tensor_shape, DataType(DataType::DE_UINT8), buffer.data(), output)); + + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index 2920b6a0868..c3725d8c196 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -55,6 +55,8 @@ constexpr dsize_t kGIndex = 1; // index of green channel in RGB fo constexpr dsize_t kBIndex = 2; // index of blue channel in RGB format constexpr dsize_t kHeightIndex = 0; // index of height of HWC images constexpr dsize_t kWidthIndex = 1; // index of width of HWC images +constexpr dsize_t kMinJpegQuality = 1; // the minimum quality for JPEG +constexpr dsize_t kMaxJpegQuality = 100; // the maximum quality for JPEG void JpegErrorExitCustom(j_common_ptr cinfo); @@ -500,6 +502,13 @@ float Round(float value); /// \param[in] fill_value Values used to fill. Status ApplyAugment(const std::shared_ptr &input, std::shared_ptr *output, const std::string &op_name, float magnitude, InterpolationMode interpolation, const std::vector &fill_value); + +/// \brief Encode the image as JPEG data. +/// \param[in] image The image to be encoded. +/// \param[out] output The Tensor data. +/// \param[in] quality The quality for the output tensor from 1 to 100. Default: 75. +/// \return The status code. +Status EncodeJpeg(const std::shared_ptr &image, std::shared_ptr *output, int quality = 75); } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/python/mindspore/dataset/vision/__init__.py b/mindspore/python/mindspore/dataset/vision/__init__.py index c5b44ab4413..8aa6b37062b 100644 --- a/mindspore/python/mindspore/dataset/vision/__init__.py +++ b/mindspore/python/mindspore/dataset/vision/__init__.py @@ -84,5 +84,5 @@ from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue RandomSharpness, RandomSolarize, RandomVerticalFlip, RandomVerticalFlipWithBBox, Rescale, Resize, ResizedCrop, \ ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, \ TrivialAugmentWide, UniformAugment, VerticalFlip, not_random -from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, get_image_num_channels, \ - get_image_size +from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, encode_jpeg, \ + get_image_num_channels, get_image_size diff --git a/mindspore/python/mindspore/dataset/vision/utils.py b/mindspore/python/mindspore/dataset/vision/utils.py index 0dbeb90944f..ae32bb1d423 100644 --- a/mindspore/python/mindspore/dataset/vision/utils.py +++ b/mindspore/python/mindspore/dataset/vision/utils.py @@ -20,6 +20,7 @@ import numbers import numpy as np from PIL import Image +import mindspore import mindspore._c_dataengine as cde @@ -333,6 +334,41 @@ class SliceMode(IntEnum): return c_values.get(mode) +def encode_jpeg(image, quality=75): + """ + Encode the input image as JPEG data. + + Args: + image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded. + quality (int, optional): Quality of the resulting JPEG data, from 1 to 100. Default: 75. + + Returns: + numpy.ndarray, one dimension uint8 data. + + Raises: + TypeError: If `image` is not of type numpy.ndarray or mindspore.Tensor. + TypeError: If `quality` is not of type int. + RuntimeError: If the data type of `image` is not uint8. + RuntimeError: If the shape of `image` is not or or . + RuntimeError: If `quality` is less than 1 or greater than 100. + + Examples: + >>> import numpy as np + >>> from mindspore.dataset import vision + >>> # Generate a random image with height=120, width=340, channels=3 + >>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8) + >>> jpeg_data = vision.encode_jpeg(image) + """ + if not isinstance(quality, int): + raise TypeError("Input quality is not of type {0}, but got: {1}.".format(int, type(quality))) + if isinstance(image, np.ndarray): + return cde.encode_jpeg(cde.Tensor(image), quality).as_array() + if isinstance(image, mindspore.Tensor): + return cde.encode_jpeg(cde.Tensor(image.asnumpy()), quality).as_array() + raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray, + mindspore.Tensor, type(image))) + + def get_image_num_channels(image): """ Get the number of input image channels. diff --git a/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc b/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc index a4499fbb06d..6e62f02304a 100644 --- a/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include + #include "common/common.h" #include "minddata/dataset/include/dataset/datasets.h" #include "minddata/dataset/include/dataset/transforms.h" @@ -2661,7 +2663,6 @@ TEST_F(MindDataTestPipeline, TestAdjustContrastParamCheck) { EXPECT_EQ(iter1, nullptr); } - /// Feature: Perspective /// Description: Test Perspective pipeline /// Expectation: The returned result is as expected @@ -2749,3 +2750,84 @@ TEST_F(MindDataTestPipeline, TestPerspectiveParamCheck) { // Expect failure: invalid value of Perspective EXPECT_EQ(iter1, nullptr); } + +/// Feature: EncodeJpeg +/// Description: Test EncodeJpeg by encoding the image as JPEG data according to the quality +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestEncodeJpegNormal) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TesEncodeJpegNormal."; + mindspore::MSTensor output; + std::string folder_path = "./data/dataset/"; + std::string filename; + const UINT8 *data; + + filename = folder_path + "apple.jpg"; + cv::Mat image_bgr = cv::imread(filename, cv::ImreadModes::IMREAD_UNCHANGED); + cv::Mat image; + cv::cvtColor(image_bgr, image, cv::COLOR_BGRA2RGB); + + TensorShape img_tensor_shape = TensorShape({image.size[0], image.size[1], image.channels()}); + DataType pixel_type = DataType(DataType::DE_UINT8); + + std::shared_ptr input; + Tensor::CreateFromMemory(img_tensor_shape, pixel_type, image.data, &input); + auto input_ms_tensor = mindspore::MSTensor(std::make_shared(input)); + + ASSERT_OK(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output)); + data = (const UINT8 *) (output.Data().get()); + EXPECT_EQ(data[0], 255); + EXPECT_EQ(data[1], 216); + EXPECT_EQ(data[2], 255); + + int quality; + for (quality = 20; quality <= 100 ; quality+= 40) { + ASSERT_OK(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output, quality)); + data = (const UINT8 *) (output.Data().get()); + EXPECT_EQ(data[1], 216); + } +} + +/// Feature: EncodeJpeg +/// Description: Test EncodeJpeg with invalid parameter +/// Expectation: Error is caught when the parameter is invalid +TEST_F(MindDataTestPipeline, TestEncodeJpegException) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TesEncodeJpegException."; + mindspore::MSTensor output; + std::string folder_path = "./data/dataset/"; + std::string filename; + + filename = folder_path + "apple.jpg"; + cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_UNCHANGED); + + TensorShape img_tensor_shape = TensorShape({image.size[0], image.size[1], image.channels()}); + DataType pixel_type = DataType(DataType::DE_UINT8); + + std::shared_ptr input; + Tensor::CreateFromMemory(img_tensor_shape, pixel_type, image.data, &input); + auto input_ms_tensor = mindspore::MSTensor(std::make_shared(input)); + + // Test with an invalid integer for the quality + ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output, 0)); + ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output, 101)); + + // Test with an invalid image containing float32 elements + std::shared_ptr float32_de_tensor; + Tensor::CreateEmpty(TensorShape({5, 4, 3 }), DataType(DataType::DE_FLOAT32), &float32_de_tensor); + input_ms_tensor = mindspore::MSTensor(std::make_shared(float32_de_tensor)); + ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output)); + + // Test with an invalid image with only one dimension + input->Reshape(TensorShape({image.size[0] * image.size[1] * image.channels()})); + input_ms_tensor = mindspore::MSTensor(std::make_shared(input)); + ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output)); + + // Test with an invalid image with four dimensions + input->Reshape(TensorShape({image.size[0] / 2, image.size[1], image.channels(), 2})); + input_ms_tensor = mindspore::MSTensor(std::make_shared(input)); + ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output)); + + // Test with an invalid image with two channels + input->Reshape(TensorShape({image.size[0] * image.channels() / 2, image.size[1], 2})); + input_ms_tensor = mindspore::MSTensor(std::make_shared(input)); + ASSERT_ERROR(mindspore::dataset::vision::EncodeJpeg(input_ms_tensor, &output)); +} diff --git a/tests/ut/python/dataset/test_encode_jpeg.py b/tests/ut/python/dataset/test_encode_jpeg.py new file mode 100755 index 00000000000..d2fc93e5366 --- /dev/null +++ b/tests/ut/python/dataset/test_encode_jpeg.py @@ -0,0 +1,152 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing encode_jpeg +""" +import cv2 +import numpy +import pytest + +from mindspore import Tensor +from mindspore.dataset import vision + + +def test_encode_jpeg_three_channels(): + """ + Feature: encode_jpeg + Description: Test encode_jpeg by encoding the three channels image as JPEG data according to the quality + Expectation: Output is equal to the expected output + """ + filename = "../data/dataset/apple.jpg" + mode = cv2.IMREAD_UNCHANGED + image = cv2.imread(filename, mode) + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Test with numpy:ndarray and default quality + encoded_jpeg = vision.encode_jpeg(image_rgb) + assert encoded_jpeg.dtype == numpy.uint8 + assert encoded_jpeg[0] == 255 + assert encoded_jpeg[1] == 216 + assert encoded_jpeg[2] == 255 + + # Test with Tensor and quality + input_tensor = Tensor.from_numpy(image_rgb) + encoded_jpeg_75 = vision.encode_jpeg(input_tensor, 75) + assert encoded_jpeg_75[1] == 216 + + # Test with the minimum quality + encoded_jpeg_0 = vision.encode_jpeg(input_tensor, 1) + assert encoded_jpeg_0[1] == 216 + + # Test with the maximum quality + encoded_jpeg_100 = vision.encode_jpeg(input_tensor, 100) + assert encoded_jpeg_100[1] == 216 + + # Test with three channels 12*34*3 random uint8 + image_random = numpy.ndarray(shape=(12, 34, 3), dtype=numpy.uint8) + encoded_jpeg = vision.encode_jpeg(image_random) + assert encoded_jpeg[1] == 216 + encoded_jpeg = vision.encode_jpeg(Tensor.from_numpy(image_random)) + assert encoded_jpeg[1] == 216 + + +def test_encode_jpeg_one_channel(): + """ + Feature: encode_jpeg + Description: Test encode_jpeg by encoding the one channel image as JPEG data + Expectation: Output is equal to the expected output + """ + filename = "../data/dataset/apple.jpg" + mode = cv2.IMREAD_UNCHANGED + image = cv2.imread(filename, mode) + + # Test with one channel image_grayscale + image_grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + encoded_jpeg = vision.encode_jpeg(image_grayscale) + assert encoded_jpeg[1] == 216 + encoded_jpeg = vision.encode_jpeg(Tensor.from_numpy(image_grayscale)) + assert encoded_jpeg[1] == 216 + + # Test with one channel 12*34 random uint8 + image_random = numpy.ndarray(shape=(12, 34), dtype=numpy.uint8) + encoded_jpeg = vision.encode_jpeg(image_random) + assert encoded_jpeg[1] == 216 + encoded_jpeg = vision.encode_jpeg(Tensor.from_numpy(image_random)) + assert encoded_jpeg[1] == 216 + + # Test with one channel 12*34*1 random uint8 + image_random = numpy.ndarray(shape=(12, 34, 1), dtype=numpy.uint8) + encoded_jpeg = vision.encode_jpeg(image_random) + assert encoded_jpeg[1] == 216 + encoded_jpeg = vision.encode_jpeg(Tensor.from_numpy(image_random)) + assert encoded_jpeg[1] == 216 + + +def test_encode_jpeg_exception(): + """ + Feature: encode_jpeg + Description: Test encode_jpeg with invalid parameter + Expectation: Error is caught when the parameter is invalid + """ + + def test_invalid_param(image_param, quality_param, error, error_msg): + """ + a function used for checking correct error and message with invalid parameter + """ + with pytest.raises(error) as error_info: + vision.encode_jpeg(image_param, quality_param) + assert error_msg in str(error_info.value) + + filename = "../data/dataset/apple.jpg" + mode = cv2.IMREAD_UNCHANGED + image = cv2.imread(filename, mode) + + # Test with an invalid integer for the quality + error_message = "Invalid quality" + test_invalid_param(image, 0, RuntimeError, error_message) + test_invalid_param(image, 101, RuntimeError, error_message) + + # Test with an invalid type for the quality + error_message = "Input quality is not of type" + test_invalid_param(image, 75.0, TypeError, error_message) + + # Test with an invalid image containing the float elements + invalid_image = numpy.ndarray(shape=(10, 10, 3), dtype=float) + error_message = "The type of the image data" + test_invalid_param(invalid_image, 75, RuntimeError, error_message) + + # Test with an invalid type for the image + error_message = "Input image is not of type" + test_invalid_param("invalid_image", 75, TypeError, error_message) + + # Test with an invalid image with only one dimension + invalid_image = numpy.ndarray(shape=(10), dtype=numpy.uint8) + error_message = "The image has invalid dimensions" + test_invalid_param(invalid_image, 75, RuntimeError, error_message) + + # Test with an invalid image with four dimensions + invalid_image = numpy.ndarray(shape=(10, 10, 10, 3), dtype=numpy.uint8) + test_invalid_param(invalid_image, 75, RuntimeError, error_message) + + # Test with an invalid image with two channels + invalid_image = numpy.ndarray(shape=(10, 10, 2), dtype=numpy.uint8) + error_message = "The image has invalid channels" + test_invalid_param(invalid_image, 75, RuntimeError, error_message) + + +if __name__ == "__main__": + test_encode_jpeg_three_channels() + test_encode_jpeg_one_channel() + test_encode_jpeg_exception()