From 91f99ca70ac8dce1c3574059918410b7a276c751 Mon Sep 17 00:00:00 2001 From: deng jian Date: Tue, 23 Aug 2022 16:58:46 +0800 Subject: [PATCH] [feat] [assistant] [I4S2FB] add new data operator write_jpeg --- .../mindspore.dataset.vision.write_jpeg.rst | 20 ++ .../api_python/mindspore.dataset.vision.rst | 1 + .../mindspore.dataset.vision.rst | 1 + .../dataset/kernels/ir/image/bindings.cc | 7 + .../ccsrc/minddata/dataset/api/vision.cc | 8 + .../minddata/dataset/include/dataset/vision.h | 7 + .../dataset/kernels/image/image_utils.cc | 86 ++++++++- .../dataset/kernels/image/image_utils.h | 7 + .../mindspore/dataset/vision/__init__.py | 2 +- .../python/mindspore/dataset/vision/utils.py | 39 ++++ .../cpp/dataset/c_api_vision_r_to_z_test.cc | 88 +++++++++ tests/ut/python/dataset/test_write_jpeg.py | 177 ++++++++++++++++++ 12 files changed, 440 insertions(+), 3 deletions(-) create mode 100755 docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_jpeg.rst create mode 100755 tests/ut/python/dataset/test_write_jpeg.py diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_jpeg.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_jpeg.rst new file mode 100755 index 00000000000..bf6c14bd8b0 --- /dev/null +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_jpeg.rst @@ -0,0 +1,20 @@ +mindspore.dataset.vision.write_jpeg +=================================== + +.. py:function:: mindspore.dataset.vision.write_jpeg(filename, image, quality=75) + + 将图像数据保存为JPEG文件。 + + 参数: + - **filename** (str) - 要写入的文件的路径。 + - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的图像数据。 + - **quality** (int, 可选) - 生成的JPEG文件的质量,取值范围为[1, 100]。默认值: 75。 + + 异常: + - **TypeError** - 如果 `filename` 不是str类型。 + - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 + - **TypeError** - 如果 `quality` 不是int类型。 + - **RuntimeError** - 如果 `filename` 不存在或不是普通文件。 + - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 + - **RuntimeError** - 如果 `image` 的shape不是 。 + - **RuntimeError** - 如果 `quality` 小于1或大于100。 diff --git a/docs/api/api_python/mindspore.dataset.vision.rst b/docs/api/api_python/mindspore.dataset.vision.rst index e0ce5152daf..cbcf9eaf1db 100755 --- a/docs/api/api_python/mindspore.dataset.vision.rst +++ b/docs/api/api_python/mindspore.dataset.vision.rst @@ -148,3 +148,4 @@ API样例中常用的导入模块如下: mindspore.dataset.vision.read_file mindspore.dataset.vision.read_image mindspore.dataset.vision.write_file + mindspore.dataset.vision.write_jpeg diff --git a/docs/api/api_python_en/mindspore.dataset.vision.rst b/docs/api/api_python_en/mindspore.dataset.vision.rst index 884f0164ea4..34ed6d96a05 100755 --- a/docs/api/api_python_en/mindspore.dataset.vision.rst +++ b/docs/api/api_python_en/mindspore.dataset.vision.rst @@ -98,3 +98,4 @@ Utilities mindspore.dataset.vision.read_file mindspore.dataset.vision.read_image mindspore.dataset.vision.write_file + mindspore.dataset.vision.write_jpeg diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc index a5e4ed9ac8a..31aa3e36ab7 100755 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc @@ -924,5 +924,12 @@ PYBIND_REGISTER(WriteFileOperation, 1, ([](py::module *m) { THROW_IF_ERROR(WriteFile(filename, data)); })); })); + +PYBIND_REGISTER(WriteJPEGOperation, 1, ([](py::module *m) { + (void)m->def("write_jpeg", + ([](const std::string &filename, const std::shared_ptr &image, int quality) { + THROW_IF_ERROR(WriteJpeg(filename, image, quality)); + })); + })); } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index 8ee638b285e..45da3296cb6 100755 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -1457,6 +1457,14 @@ Status WriteFile(const std::string &filename, const mindspore::MSTensor &data) { RETURN_IF_NOT_OK(mindspore::dataset::WriteFile(filename, de_tensor)); return Status::OK(); } + +// WriteJpeg Function. +Status WriteJpeg(const std::string &filename, const mindspore::MSTensor &image, int quality) { + std::shared_ptr image_de_tensor; + RETURN_IF_NOT_OK(Tensor::CreateFromMSTensor(image, &image_de_tensor)); + RETURN_IF_NOT_OK(mindspore::dataset::WriteJpeg(filename, image_de_tensor, quality)); + return Status::OK(); +} #endif // not ENABLE_ANDROID } // namespace vision } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h index 03a2b9d49c6..de77dbb07fc 100755 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h @@ -2046,6 +2046,13 @@ class DATASET_API VerticalFlip final : public TensorTransform { /// \param[in] data The tensor data. /// \return The status code. Status DATASET_API WriteFile(const std::string &filename, const mindspore::MSTensor &data); + +/// \brief Write the image data into a JPEG file. +/// \param[in] filename The path to the file to be written. +/// \param[in] image The data tensor. +/// \param[in] quality The quality for JPEG file, in range of [1, 100]. Default: 75. +/// \return The status code. +Status DATASET_API WriteJpeg(const std::string &filename, const mindspore::MSTensor &image, int quality = 75); } // namespace vision } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index cb8aa14d793..9331d565293 100755 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -2244,7 +2244,7 @@ Status EncodeJpeg(const std::shared_ptr &image, std::shared_ptr std::shared_ptr input_cv = CVTensor::AsCVTensor(image); image_matrix = input_cv->mat(); if (!image_matrix.data) { - RETURN_STATUS_UNEXPECTED("[Internal ERROR] EncodeJpeg: load the image tensor failed."); + RETURN_STATUS_UNEXPECTED("EncodeJpeg: Load the image tensor failed."); } if (channels == kMinImageChannel) { @@ -2350,7 +2350,7 @@ Status WriteFile(const std::string &filename, const std::shared_ptr &dat } auto realpath = FileUtils::GetRealPath(filename.c_str()); if (!realpath.has_value()) { - RETURN_STATUS_UNEXPECTED("WriteFile: Invalid file path, " + filename + " can not get the real path."); + RETURN_STATUS_UNEXPECTED("WriteFile: Invalid file path, " + filename + " failed to get the real path."); } struct stat sb; stat(realpath.value().c_str(), &sb); @@ -2372,5 +2372,87 @@ Status WriteFile(const std::string &filename, const std::shared_ptr &dat fs.close(); return Status::OK(); } + +Status WriteJpeg(const std::string &filename, const std::shared_ptr &image, int quality) { + std::string err_msg; + + if (image->type() != DataType::DE_UINT8) { + err_msg = "WriteJpeg: The type of the elements of image should be UINT8, but got " + image->type().ToString() + "."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + TensorShape shape = image->shape(); + int rank = shape.Rank(); + if (rank < kMinImageRank || rank > kDefaultImageRank) { + err_msg = "WriteJpeg: The image has invalid dimensions. It should have two or three dimensions, but got "; + err_msg += std::to_string(rank) + " dimensions."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + int channels; + if (rank == kDefaultImageRank) { + channels = shape[kMinImageRank]; + if (channels != kMinImageChannel && channels != kDefaultImageChannel) { + err_msg = "WriteJpeg: The image has invalid channels. It should have 1 or 3 channels, but got "; + err_msg += std::to_string(channels) + " channels."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } else { + channels = 1; + } + + if (quality < kMinJpegQuality || quality > kMaxJpegQuality) { + err_msg = "WriteJpeg: Invalid quality " + std::to_string(quality) + ", should be in range of [" + + std::to_string(kMinJpegQuality) + ", " + std::to_string(kMaxJpegQuality) + "]."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::vector params = {cv::IMWRITE_JPEG_QUALITY, quality, cv::IMWRITE_JPEG_PROGRESSIVE, 0, + cv::IMWRITE_JPEG_OPTIMIZE, 0, cv::IMWRITE_JPEG_RST_INTERVAL, 0}; + + std::vector buffer; + cv::Mat image_matrix; + + std::shared_ptr input_cv = CVTensor::AsCVTensor(image); + image_matrix = input_cv->mat(); + if (!image_matrix.data) { + RETURN_STATUS_UNEXPECTED("WriteJpeg: Load the image tensor failed."); + } + + if (channels == kMinImageChannel) { + CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_matrix, buffer, params), + "WriteJpeg: Failed to encode image."); + } else { + cv::Mat image_bgr; + cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR); + CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_bgr, buffer, params), + "WriteJpeg: Failed to encode image."); + } + + Path file(filename); + if (!file.Exists()) { + int file_descriptor; + RETURN_IF_NOT_OK(file.CreateFile(&file_descriptor)); + RETURN_IF_NOT_OK(file.CloseFile(file_descriptor)); + } + auto realpath = FileUtils::GetRealPath(filename.c_str()); + if (!realpath.has_value()) { + RETURN_STATUS_UNEXPECTED("WriteJpeg: Invalid file path, " + filename + " failed to get the real path."); + } + struct stat sb; + stat(realpath.value().c_str(), &sb); + if (S_ISREG(sb.st_mode) == 0) { + RETURN_STATUS_UNEXPECTED("WriteJpeg: Invalid file path, " + filename + " is not a regular file."); + } + + std::ofstream fs(realpath.value().c_str(), std::ios::out | std::ios::trunc | std::ios::binary); + CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "WriteJpeg: Failed to open the file " + filename + " for writing."); + + fs.write((const char *)buffer.data(), (long int)buffer.size()); + if (fs.fail()) { + fs.close(); + RETURN_STATUS_UNEXPECTED("WriteJpeg: Failed to write the file " + filename); + } + fs.close(); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index dcf9509d9db..1da65e36b64 100755 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -534,6 +534,13 @@ Status ReadImage(const std::string &filename, std::shared_ptr *output, /// \param[in] data The tensor data. /// \return The status code. Status WriteFile(const std::string &filename, const std::shared_ptr &data); + +/// \brief Write the image data into a JPEG file. +/// \param[in] filename The path to the file to be written. +/// \param[in] image The data tensor. +/// \param[in] quality The quality for JPEG file, in range of [1, 100]. Default: 75. +/// \return Status code. +Status WriteJpeg(const std::string &filename, const std::shared_ptr &image, int quality = 75); } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/python/mindspore/dataset/vision/__init__.py b/mindspore/python/mindspore/dataset/vision/__init__.py index eeaf1251dee..e883052a281 100755 --- a/mindspore/python/mindspore/dataset/vision/__init__.py +++ b/mindspore/python/mindspore/dataset/vision/__init__.py @@ -85,4 +85,4 @@ from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, \ TrivialAugmentWide, UniformAugment, VerticalFlip, not_random from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, ImageReadMode, Inter, SliceMode, \ - encode_jpeg, get_image_num_channels, get_image_size, read_file, read_image, write_file + encode_jpeg, get_image_num_channels, get_image_size, read_file, read_image, write_file, write_jpeg diff --git a/mindspore/python/mindspore/dataset/vision/utils.py b/mindspore/python/mindspore/dataset/vision/utils.py index 0879dff8ce3..78d7c52425a 100755 --- a/mindspore/python/mindspore/dataset/vision/utils.py +++ b/mindspore/python/mindspore/dataset/vision/utils.py @@ -574,3 +574,42 @@ def write_file(filename, data): return cde.write_file(filename, cde.Tensor(data.asnumpy())) raise TypeError("Input data is not of type {0} or {1}, but got: {2}.".format(np.ndarray, mindspore.Tensor, type(data))) + + +def write_jpeg(filename, image, quality=75): + """ + Write the image data into a JPEG file. + + Args: + filename (str): The path to the file to be written. + image (Union[numpy.ndarray, mindspore.Tensor]): The image data to be written. + quality (int, optional): Quality of the resulting JPEG file, in range of [1, 100]. Default: 75. + + Raises: + TypeError: If `filename` is not of type str. + TypeError: If `image` is not of type numpy.ndarray or mindspore.Tensor. + TypeError: If `quality` is not of type int. + RuntimeError: If the `filename` does not exist or not a common file. + RuntimeError: If the data type of `image` is not uint8. + RuntimeError: If the shape of `image` is not or or . + RuntimeError: If `quality` is less than 1 or greater than 100. + + Supported Platforms: + ``CPU`` + + Examples: + >>> import numpy as np + >>> # Generate a random image with height=120, width=340, channels=3 + >>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8) + >>> vision.write_jpeg("/path/to/file", image) + """ + if not isinstance(filename, str): + raise TypeError("Input filename is not of type {0}, but got: {1}.".format(str, type(filename))) + if not isinstance(quality, int): + raise TypeError("Input quality is not of type {0}, but got: {1}.".format(int, type(quality))) + if isinstance(image, np.ndarray): + return cde.write_jpeg(filename, cde.Tensor(image), quality) + if isinstance(image, mindspore.Tensor): + return cde.write_jpeg(filename, cde.Tensor(image.asnumpy()), quality) + raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray, + mindspore.Tensor, type(image))) diff --git a/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc b/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc index a407b086d35..def2b6445dc 100755 --- a/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include + #include "common/common.h" #include "minddata/dataset/include/dataset/datasets.h" #include "minddata/dataset/include/dataset/transforms.h" @@ -1348,3 +1350,89 @@ TEST_F(MindDataTestPipeline, TestWriteFileException) { data_tensor = mindspore::MSTensor(std::make_shared(input)); ASSERT_ERROR(mindspore::dataset::vision::WriteFile(filename_2, data_tensor)); } + +/// Feature: WriteJpeg +/// Description: Test WriteJpeg by writing the image into a JPEG file +/// Expectation: The file should be written and removed +TEST_F(MindDataTestPipeline, TestWriteJpegNormal) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TesWriteJpegNormal."; + std::string folder_path = "./data/dataset/testFormats/"; + std::string filename_1; + std::string filename_2; + cv::Mat image_1; + cv::Mat image_2; + + filename_1 = folder_path + "apple.jpg"; + filename_2 = filename_1 + ".test_write_jpeg.jpg"; + + cv::Mat image_bgr = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED); + cv::cvtColor(image_bgr, image_1, cv::COLOR_BGRA2RGB); + + TensorShape img_tensor_shape = TensorShape({image_1.size[0], image_1.size[1], image_1.channels()}); + DataType pixel_type = DataType(DataType::DE_UINT8); + + std::shared_ptr image_de_tensor; + Tensor::CreateFromMemory(img_tensor_shape, pixel_type, image_1.data, &image_de_tensor); + auto image_ms_tensor = mindspore::MSTensor(std::make_shared(image_de_tensor)); + + int quality; + for (quality = 20; quality <= 100 ; quality += 40) { + ASSERT_OK(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, quality)); + image_2 = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED); + remove(filename_2.c_str()); + EXPECT_EQ(image_1.total(), image_2.total()); + } +} + +/// Feature: WriteJpeg +/// Description: Test WriteJpeg with invalid parameter +/// Expectation: Error is caught when the parameter is invalid +TEST_F(MindDataTestPipeline, TestWriteJpegException) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TesWriteJpegException."; + std::string folder_path = "./data/dataset/testFormats/"; + std::string filename_1; + std::string filename_2; + cv::Mat image_1; + + filename_1 = folder_path + "apple.jpg"; + filename_2 = filename_1 + ".test_write_jpeg.jpg"; + image_1 = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED); + + TensorShape img_tensor_shape = TensorShape({image_1.size[0], image_1.size[1], image_1.channels()}); + DataType pixel_type = DataType(DataType::DE_UINT8); + + std::shared_ptr image_de_tensor; + Tensor::CreateFromMemory(img_tensor_shape, pixel_type, image_1.data, &image_de_tensor); + auto image_ms_tensor = mindspore::MSTensor(std::make_shared(image_de_tensor)); + + // Test with invalid quality 0, 101 + ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, 0)); + ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, 101)); + + // Test with an invalid filename + ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg("/dev/cdrom/0", image_ms_tensor)); + + // Test with a directory name + ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg("./data/dataset/", image_ms_tensor)); + + // Test with an invalid image containing float elements + std::shared_ptr float32_cde_tensor; + Tensor::CreateEmpty(TensorShape({5, 4, 3 }), DataType(DataType::DE_FLOAT32), &float32_cde_tensor); + image_ms_tensor = mindspore::MSTensor(std::make_shared(float32_cde_tensor)); + ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor)); + + // Test with an invalid image with only one dimension + image_de_tensor->Reshape(TensorShape({image_1.size[0] * image_1.size[1] * image_1.channels()})); + image_ms_tensor = mindspore::MSTensor(std::make_shared(image_de_tensor)); + ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor)); + + // Test with an invalid image with four dimensions + image_de_tensor->Reshape(TensorShape({image_1.size[0] / 2, image_1.size[1], image_1.channels(), 2})); + image_ms_tensor = mindspore::MSTensor(std::make_shared(image_de_tensor)); + ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor)); + + // Test with an invalid image with two channels + image_de_tensor->Reshape(TensorShape({image_1.size[0] * image_1.channels() / 2, image_1.size[1], 2})); + image_ms_tensor = mindspore::MSTensor(std::make_shared(image_de_tensor)); + ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor)); +} diff --git a/tests/ut/python/dataset/test_write_jpeg.py b/tests/ut/python/dataset/test_write_jpeg.py new file mode 100755 index 00000000000..7be591bb963 --- /dev/null +++ b/tests/ut/python/dataset/test_write_jpeg.py @@ -0,0 +1,177 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing write_jpeg +""" +import os +import cv2 +import numpy +import pytest + +from mindspore import Tensor +from mindspore.dataset import vision + + +def test_write_jpeg_three_channels(): + """ + Feature: write_jpeg + Description: Write the image containing three channels into a JPEG file + Expectation: The file should be written and removed + """ + + def write_jpeg_three_channels(filename_param, image_param, quality_param=75): + """ + a function used for writing with three channels image + """ + vision.write_jpeg(filename_param, image_param, quality_param) + image_2_numpy = cv2.imread(filename_param, cv2.IMREAD_UNCHANGED) + os.remove(filename_param) + assert image_2_numpy.shape == (2268, 4032, 3) + + filename_1 = "../data/dataset/apple.jpg" + mode = cv2.IMREAD_UNCHANGED + image_bgr = cv2.imread(filename_1, mode) + image_1_numpy = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + image_1_tensor = Tensor.from_numpy(image_1_numpy) + filename_2 = filename_1 + ".test_write_jpeg.jpg" + + # Test writing numpy.ndarray + write_jpeg_three_channels(filename_2, image_1_numpy) + + # Test writing Tensor and quality 1, 75, 100 + for quality in (1, 75, 100): + write_jpeg_three_channels(filename_2, image_1_tensor, quality) + + # Test with three channels 2268*4032*3 random uint8, the quality is 50 + image_random = numpy.ndarray(shape=(2268, 4032, 3), dtype=numpy.uint8) + write_jpeg_three_channels(filename_2, image_random, 50) + + +def test_write_jpeg_one_channel(): + """ + Feature: write_jpeg + Description: Write the grayscale image into a JPEG file + Expectation: The file should be written and removed + """ + + def write_jpeg_one_channel(filename_param, image_param, quality_param=75): + """ + a function used for writing with three channels image + """ + vision.write_jpeg(filename_param, image_param, quality_param) + image_2_numpy = cv2.imread(filename_param, cv2.IMREAD_UNCHANGED) + os.remove(filename_param) + assert image_2_numpy.shape == (2268, 4032) + + filename_1 = "../data/dataset/apple.jpg" + mode = cv2.IMREAD_UNCHANGED + image_1_numpy = cv2.imread(filename_1, mode) + filename_2 = filename_1 + ".test_write_jpeg.jpg" + image_grayscale = cv2.cvtColor(image_1_numpy, cv2.COLOR_BGR2GRAY) + image_grayscale_tensor = Tensor.from_numpy(image_grayscale) + + # Test writing numpy.ndarray + write_jpeg_one_channel(filename_2, image_grayscale) + + # Test writing Tensor and quality 1, 75, 100 + for quality in (1, 75, 100): + write_jpeg_one_channel(filename_2, image_grayscale_tensor, quality) + + # Test with three channels 2268*4032 random uint8 + image_random = numpy.ndarray(shape=(2268, 4032), dtype=numpy.uint8) + write_jpeg_one_channel(filename_2, image_random) + + # Test with one channels 2268*4032*1 random uint8, the quality is 50 + image_random = numpy.ndarray(shape=(2268, 4032, 1), dtype=numpy.uint8) + write_jpeg_one_channel(filename_2, image_random, 50) + + +def test_write_jpeg_exception(): + """ + Feature: write_jpeg + Description: Test write_jpeg with an invalid parameter + Expectation: Error is caught when the parameter is invalid + """ + + def test_invalid_param(filename_param, image_param, quality_param, error, error_msg): + """ + a function used for checking correct error and message with invalid parameter + """ + with pytest.raises(error) as error_info: + vision.write_jpeg(filename_param, image_param, quality_param) + assert error_msg in str(error_info.value) + + filename_1 = "../data/dataset/apple.jpg" + mode = cv2.IMREAD_UNCHANGED + image_1_numpy = cv2.imread(filename_1, mode) + image_1_tensor = Tensor.from_numpy(image_1_numpy) + + # Test with a directory name + wrong_filename = "../data/dataset/" + error_message = "Invalid file path, " + wrong_filename + " is not a regular file." + test_invalid_param(wrong_filename, image_1_numpy, 75, RuntimeError, error_message) + + # Test with an invalid filename + wrong_filename = "/dev/cdrom/0" + error_message = "No such file or directory" + test_invalid_param(wrong_filename, image_1_tensor, 75, RuntimeError, error_message) + + # Test with an invalid type for the filename + error_message = "Input filename is not of type" + test_invalid_param(0, image_1_numpy, 75, TypeError, error_message) + + # Test with an invalid type for the data + filename_2 = filename_1 + ".test_write_jpeg.jpg" + error_message = "Input image is not of type" + test_invalid_param(filename_2, 0, 75, TypeError, error_message) + + # Test with invalid float elements + invalid_data = numpy.ndarray(shape=(10, 10), dtype=float) + error_message = "The type of the elements of image should be UINT8" + test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message) + + # Test with invalid image with only one dimension + invalid_data = numpy.ndarray(shape=(10), dtype=numpy.uint8) + error_message = "The image has invalid dimensions" + test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message) + + # Test with invalid image with four dimensions + invalid_data = numpy.ndarray(shape=(1, 2, 3, 4), dtype=numpy.uint8) + test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message) + + # Test with invalid image with two channels + invalid_data = numpy.ndarray(shape=(2, 3, 2), dtype=numpy.uint8) + error_message = "The image has invalid channels" + test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message) + + # Test with invalid quality + invalid_data = numpy.ndarray(shape=(2, 3, 2), dtype=numpy.uint8) + error_message = "The image has invalid channels" + test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message) + + # Test with an invalid integer for the quality 0, 101 + error_message = "Invalid quality" + test_invalid_param(filename_2, image_1_numpy, 0, RuntimeError, error_message) + test_invalid_param(filename_2, image_1_numpy, 101, RuntimeError, error_message) + + # Test with an invalid type for the quality + error_message = "Input quality is not of type" + test_invalid_param(filename_2, image_1_numpy, 75.0, TypeError, error_message) + + +if __name__ == "__main__": + test_write_jpeg_three_channels() + test_write_jpeg_one_channel() + test_write_jpeg_exception()