[feat] [assistant] [I4S2FB] add new data operator write_jpeg
This commit is contained in:
parent
06bb96bd11
commit
91f99ca70a
|
@ -0,0 +1,20 @@
|
|||
mindspore.dataset.vision.write_jpeg
|
||||
===================================
|
||||
|
||||
.. py:function:: mindspore.dataset.vision.write_jpeg(filename, image, quality=75)
|
||||
|
||||
将图像数据保存为JPEG文件。
|
||||
|
||||
参数:
|
||||
- **filename** (str) - 要写入的文件的路径。
|
||||
- **image** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的图像数据。
|
||||
- **quality** (int, 可选) - 生成的JPEG文件的质量,取值范围为[1, 100]。默认值: 75。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `filename` 不是str类型。
|
||||
- **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。
|
||||
- **TypeError** - 如果 `quality` 不是int类型。
|
||||
- **RuntimeError** - 如果 `filename` 不存在或不是普通文件。
|
||||
- **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。
|
||||
- **RuntimeError** - 如果 `image` 的shape不是 <H, W> 或 <H, W, 1> 或 <H, W, 3>。
|
||||
- **RuntimeError** - 如果 `quality` 小于1或大于100。
|
|
@ -148,3 +148,4 @@ API样例中常用的导入模块如下:
|
|||
mindspore.dataset.vision.read_file
|
||||
mindspore.dataset.vision.read_image
|
||||
mindspore.dataset.vision.write_file
|
||||
mindspore.dataset.vision.write_jpeg
|
||||
|
|
|
@ -98,3 +98,4 @@ Utilities
|
|||
mindspore.dataset.vision.read_file
|
||||
mindspore.dataset.vision.read_image
|
||||
mindspore.dataset.vision.write_file
|
||||
mindspore.dataset.vision.write_jpeg
|
||||
|
|
|
@ -924,5 +924,12 @@ PYBIND_REGISTER(WriteFileOperation, 1, ([](py::module *m) {
|
|||
THROW_IF_ERROR(WriteFile(filename, data));
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(WriteJPEGOperation, 1, ([](py::module *m) {
|
||||
(void)m->def("write_jpeg",
|
||||
([](const std::string &filename, const std::shared_ptr<Tensor> &image, int quality) {
|
||||
THROW_IF_ERROR(WriteJpeg(filename, image, quality));
|
||||
}));
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1457,6 +1457,14 @@ Status WriteFile(const std::string &filename, const mindspore::MSTensor &data) {
|
|||
RETURN_IF_NOT_OK(mindspore::dataset::WriteFile(filename, de_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// WriteJpeg Function.
|
||||
Status WriteJpeg(const std::string &filename, const mindspore::MSTensor &image, int quality) {
|
||||
std::shared_ptr<dataset::Tensor> image_de_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMSTensor(image, &image_de_tensor));
|
||||
RETURN_IF_NOT_OK(mindspore::dataset::WriteJpeg(filename, image_de_tensor, quality));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif // not ENABLE_ANDROID
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
|
|
|
@ -2046,6 +2046,13 @@ class DATASET_API VerticalFlip final : public TensorTransform {
|
|||
/// \param[in] data The tensor data.
|
||||
/// \return The status code.
|
||||
Status DATASET_API WriteFile(const std::string &filename, const mindspore::MSTensor &data);
|
||||
|
||||
/// \brief Write the image data into a JPEG file.
|
||||
/// \param[in] filename The path to the file to be written.
|
||||
/// \param[in] image The data tensor.
|
||||
/// \param[in] quality The quality for JPEG file, in range of [1, 100]. Default: 75.
|
||||
/// \return The status code.
|
||||
Status DATASET_API WriteJpeg(const std::string &filename, const mindspore::MSTensor &image, int quality = 75);
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -2244,7 +2244,7 @@ Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor>
|
|||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(image);
|
||||
image_matrix = input_cv->mat();
|
||||
if (!image_matrix.data) {
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] EncodeJpeg: load the image tensor failed.");
|
||||
RETURN_STATUS_UNEXPECTED("EncodeJpeg: Load the image tensor failed.");
|
||||
}
|
||||
|
||||
if (channels == kMinImageChannel) {
|
||||
|
@ -2350,7 +2350,7 @@ Status WriteFile(const std::string &filename, const std::shared_ptr<Tensor> &dat
|
|||
}
|
||||
auto realpath = FileUtils::GetRealPath(filename.c_str());
|
||||
if (!realpath.has_value()) {
|
||||
RETURN_STATUS_UNEXPECTED("WriteFile: Invalid file path, " + filename + " can not get the real path.");
|
||||
RETURN_STATUS_UNEXPECTED("WriteFile: Invalid file path, " + filename + " failed to get the real path.");
|
||||
}
|
||||
struct stat sb;
|
||||
stat(realpath.value().c_str(), &sb);
|
||||
|
@ -2372,5 +2372,87 @@ Status WriteFile(const std::string &filename, const std::shared_ptr<Tensor> &dat
|
|||
fs.close();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WriteJpeg(const std::string &filename, const std::shared_ptr<Tensor> &image, int quality) {
|
||||
std::string err_msg;
|
||||
|
||||
if (image->type() != DataType::DE_UINT8) {
|
||||
err_msg = "WriteJpeg: The type of the elements of image should be UINT8, but got " + image->type().ToString() + ".";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
TensorShape shape = image->shape();
|
||||
int rank = shape.Rank();
|
||||
if (rank < kMinImageRank || rank > kDefaultImageRank) {
|
||||
err_msg = "WriteJpeg: The image has invalid dimensions. It should have two or three dimensions, but got ";
|
||||
err_msg += std::to_string(rank) + " dimensions.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
int channels;
|
||||
if (rank == kDefaultImageRank) {
|
||||
channels = shape[kMinImageRank];
|
||||
if (channels != kMinImageChannel && channels != kDefaultImageChannel) {
|
||||
err_msg = "WriteJpeg: The image has invalid channels. It should have 1 or 3 channels, but got ";
|
||||
err_msg += std::to_string(channels) + " channels.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
} else {
|
||||
channels = 1;
|
||||
}
|
||||
|
||||
if (quality < kMinJpegQuality || quality > kMaxJpegQuality) {
|
||||
err_msg = "WriteJpeg: Invalid quality " + std::to_string(quality) + ", should be in range of [" +
|
||||
std::to_string(kMinJpegQuality) + ", " + std::to_string(kMaxJpegQuality) + "].";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
std::vector<int> params = {cv::IMWRITE_JPEG_QUALITY, quality, cv::IMWRITE_JPEG_PROGRESSIVE, 0,
|
||||
cv::IMWRITE_JPEG_OPTIMIZE, 0, cv::IMWRITE_JPEG_RST_INTERVAL, 0};
|
||||
|
||||
std::vector<unsigned char> buffer;
|
||||
cv::Mat image_matrix;
|
||||
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(image);
|
||||
image_matrix = input_cv->mat();
|
||||
if (!image_matrix.data) {
|
||||
RETURN_STATUS_UNEXPECTED("WriteJpeg: Load the image tensor failed.");
|
||||
}
|
||||
|
||||
if (channels == kMinImageChannel) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_matrix, buffer, params),
|
||||
"WriteJpeg: Failed to encode image.");
|
||||
} else {
|
||||
cv::Mat image_bgr;
|
||||
cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_bgr, buffer, params),
|
||||
"WriteJpeg: Failed to encode image.");
|
||||
}
|
||||
|
||||
Path file(filename);
|
||||
if (!file.Exists()) {
|
||||
int file_descriptor;
|
||||
RETURN_IF_NOT_OK(file.CreateFile(&file_descriptor));
|
||||
RETURN_IF_NOT_OK(file.CloseFile(file_descriptor));
|
||||
}
|
||||
auto realpath = FileUtils::GetRealPath(filename.c_str());
|
||||
if (!realpath.has_value()) {
|
||||
RETURN_STATUS_UNEXPECTED("WriteJpeg: Invalid file path, " + filename + " failed to get the real path.");
|
||||
}
|
||||
struct stat sb;
|
||||
stat(realpath.value().c_str(), &sb);
|
||||
if (S_ISREG(sb.st_mode) == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("WriteJpeg: Invalid file path, " + filename + " is not a regular file.");
|
||||
}
|
||||
|
||||
std::ofstream fs(realpath.value().c_str(), std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "WriteJpeg: Failed to open the file " + filename + " for writing.");
|
||||
|
||||
fs.write((const char *)buffer.data(), (long int)buffer.size());
|
||||
if (fs.fail()) {
|
||||
fs.close();
|
||||
RETURN_STATUS_UNEXPECTED("WriteJpeg: Failed to write the file " + filename);
|
||||
}
|
||||
fs.close();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -534,6 +534,13 @@ Status ReadImage(const std::string &filename, std::shared_ptr<Tensor> *output,
|
|||
/// \param[in] data The tensor data.
|
||||
/// \return The status code.
|
||||
Status WriteFile(const std::string &filename, const std::shared_ptr<Tensor> &data);
|
||||
|
||||
/// \brief Write the image data into a JPEG file.
|
||||
/// \param[in] filename The path to the file to be written.
|
||||
/// \param[in] image The data tensor.
|
||||
/// \param[in] quality The quality for JPEG file, in range of [1, 100]. Default: 75.
|
||||
/// \return Status code.
|
||||
Status WriteJpeg(const std::string &filename, const std::shared_ptr<Tensor> &image, int quality = 75);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_
|
||||
|
|
|
@ -85,4 +85,4 @@ from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue
|
|||
ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, \
|
||||
TrivialAugmentWide, UniformAugment, VerticalFlip, not_random
|
||||
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, ImageReadMode, Inter, SliceMode, \
|
||||
encode_jpeg, get_image_num_channels, get_image_size, read_file, read_image, write_file
|
||||
encode_jpeg, get_image_num_channels, get_image_size, read_file, read_image, write_file, write_jpeg
|
||||
|
|
|
@ -574,3 +574,42 @@ def write_file(filename, data):
|
|||
return cde.write_file(filename, cde.Tensor(data.asnumpy()))
|
||||
raise TypeError("Input data is not of type {0} or {1}, but got: {2}.".format(np.ndarray,
|
||||
mindspore.Tensor, type(data)))
|
||||
|
||||
|
||||
def write_jpeg(filename, image, quality=75):
|
||||
"""
|
||||
Write the image data into a JPEG file.
|
||||
|
||||
Args:
|
||||
filename (str): The path to the file to be written.
|
||||
image (Union[numpy.ndarray, mindspore.Tensor]): The image data to be written.
|
||||
quality (int, optional): Quality of the resulting JPEG file, in range of [1, 100]. Default: 75.
|
||||
|
||||
Raises:
|
||||
TypeError: If `filename` is not of type str.
|
||||
TypeError: If `image` is not of type numpy.ndarray or mindspore.Tensor.
|
||||
TypeError: If `quality` is not of type int.
|
||||
RuntimeError: If the `filename` does not exist or not a common file.
|
||||
RuntimeError: If the data type of `image` is not uint8.
|
||||
RuntimeError: If the shape of `image` is not <H, W> or <H, W, 1> or <H, W, 3>.
|
||||
RuntimeError: If `quality` is less than 1 or greater than 100.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> # Generate a random image with height=120, width=340, channels=3
|
||||
>>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
|
||||
>>> vision.write_jpeg("/path/to/file", image)
|
||||
"""
|
||||
if not isinstance(filename, str):
|
||||
raise TypeError("Input filename is not of type {0}, but got: {1}.".format(str, type(filename)))
|
||||
if not isinstance(quality, int):
|
||||
raise TypeError("Input quality is not of type {0}, but got: {1}.".format(int, type(quality)))
|
||||
if isinstance(image, np.ndarray):
|
||||
return cde.write_jpeg(filename, cde.Tensor(image), quality)
|
||||
if isinstance(image, mindspore.Tensor):
|
||||
return cde.write_jpeg(filename, cde.Tensor(image.asnumpy()), quality)
|
||||
raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray,
|
||||
mindspore.Tensor, type(image)))
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <opencv2/opencv.hpp>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
|
@ -1348,3 +1350,89 @@ TEST_F(MindDataTestPipeline, TestWriteFileException) {
|
|||
data_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteFile(filename_2, data_tensor));
|
||||
}
|
||||
|
||||
/// Feature: WriteJpeg
|
||||
/// Description: Test WriteJpeg by writing the image into a JPEG file
|
||||
/// Expectation: The file should be written and removed
|
||||
TEST_F(MindDataTestPipeline, TestWriteJpegNormal) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesWriteJpegNormal.";
|
||||
std::string folder_path = "./data/dataset/testFormats/";
|
||||
std::string filename_1;
|
||||
std::string filename_2;
|
||||
cv::Mat image_1;
|
||||
cv::Mat image_2;
|
||||
|
||||
filename_1 = folder_path + "apple.jpg";
|
||||
filename_2 = filename_1 + ".test_write_jpeg.jpg";
|
||||
|
||||
cv::Mat image_bgr = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED);
|
||||
cv::cvtColor(image_bgr, image_1, cv::COLOR_BGRA2RGB);
|
||||
|
||||
TensorShape img_tensor_shape = TensorShape({image_1.size[0], image_1.size[1], image_1.channels()});
|
||||
DataType pixel_type = DataType(DataType::DE_UINT8);
|
||||
|
||||
std::shared_ptr<Tensor> image_de_tensor;
|
||||
Tensor::CreateFromMemory(img_tensor_shape, pixel_type, image_1.data, &image_de_tensor);
|
||||
auto image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
|
||||
|
||||
int quality;
|
||||
for (quality = 20; quality <= 100 ; quality += 40) {
|
||||
ASSERT_OK(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, quality));
|
||||
image_2 = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED);
|
||||
remove(filename_2.c_str());
|
||||
EXPECT_EQ(image_1.total(), image_2.total());
|
||||
}
|
||||
}
|
||||
|
||||
/// Feature: WriteJpeg
|
||||
/// Description: Test WriteJpeg with invalid parameter
|
||||
/// Expectation: Error is caught when the parameter is invalid
|
||||
TEST_F(MindDataTestPipeline, TestWriteJpegException) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesWriteJpegException.";
|
||||
std::string folder_path = "./data/dataset/testFormats/";
|
||||
std::string filename_1;
|
||||
std::string filename_2;
|
||||
cv::Mat image_1;
|
||||
|
||||
filename_1 = folder_path + "apple.jpg";
|
||||
filename_2 = filename_1 + ".test_write_jpeg.jpg";
|
||||
image_1 = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED);
|
||||
|
||||
TensorShape img_tensor_shape = TensorShape({image_1.size[0], image_1.size[1], image_1.channels()});
|
||||
DataType pixel_type = DataType(DataType::DE_UINT8);
|
||||
|
||||
std::shared_ptr<Tensor> image_de_tensor;
|
||||
Tensor::CreateFromMemory(img_tensor_shape, pixel_type, image_1.data, &image_de_tensor);
|
||||
auto image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
|
||||
|
||||
// Test with invalid quality 0, 101
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, 0));
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, 101));
|
||||
|
||||
// Test with an invalid filename
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg("/dev/cdrom/0", image_ms_tensor));
|
||||
|
||||
// Test with a directory name
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg("./data/dataset/", image_ms_tensor));
|
||||
|
||||
// Test with an invalid image containing float elements
|
||||
std::shared_ptr<Tensor> float32_cde_tensor;
|
||||
Tensor::CreateEmpty(TensorShape({5, 4, 3 }), DataType(DataType::DE_FLOAT32), &float32_cde_tensor);
|
||||
image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(float32_cde_tensor));
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor));
|
||||
|
||||
// Test with an invalid image with only one dimension
|
||||
image_de_tensor->Reshape(TensorShape({image_1.size[0] * image_1.size[1] * image_1.channels()}));
|
||||
image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor));
|
||||
|
||||
// Test with an invalid image with four dimensions
|
||||
image_de_tensor->Reshape(TensorShape({image_1.size[0] / 2, image_1.size[1], image_1.channels(), 2}));
|
||||
image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor));
|
||||
|
||||
// Test with an invalid image with two channels
|
||||
image_de_tensor->Reshape(TensorShape({image_1.size[0] * image_1.channels() / 2, image_1.size[1], 2}));
|
||||
image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
|
||||
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing write_jpeg
|
||||
"""
|
||||
import os
|
||||
import cv2
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.dataset import vision
|
||||
|
||||
|
||||
def test_write_jpeg_three_channels():
|
||||
"""
|
||||
Feature: write_jpeg
|
||||
Description: Write the image containing three channels into a JPEG file
|
||||
Expectation: The file should be written and removed
|
||||
"""
|
||||
|
||||
def write_jpeg_three_channels(filename_param, image_param, quality_param=75):
|
||||
"""
|
||||
a function used for writing with three channels image
|
||||
"""
|
||||
vision.write_jpeg(filename_param, image_param, quality_param)
|
||||
image_2_numpy = cv2.imread(filename_param, cv2.IMREAD_UNCHANGED)
|
||||
os.remove(filename_param)
|
||||
assert image_2_numpy.shape == (2268, 4032, 3)
|
||||
|
||||
filename_1 = "../data/dataset/apple.jpg"
|
||||
mode = cv2.IMREAD_UNCHANGED
|
||||
image_bgr = cv2.imread(filename_1, mode)
|
||||
image_1_numpy = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
image_1_tensor = Tensor.from_numpy(image_1_numpy)
|
||||
filename_2 = filename_1 + ".test_write_jpeg.jpg"
|
||||
|
||||
# Test writing numpy.ndarray
|
||||
write_jpeg_three_channels(filename_2, image_1_numpy)
|
||||
|
||||
# Test writing Tensor and quality 1, 75, 100
|
||||
for quality in (1, 75, 100):
|
||||
write_jpeg_three_channels(filename_2, image_1_tensor, quality)
|
||||
|
||||
# Test with three channels 2268*4032*3 random uint8, the quality is 50
|
||||
image_random = numpy.ndarray(shape=(2268, 4032, 3), dtype=numpy.uint8)
|
||||
write_jpeg_three_channels(filename_2, image_random, 50)
|
||||
|
||||
|
||||
def test_write_jpeg_one_channel():
|
||||
"""
|
||||
Feature: write_jpeg
|
||||
Description: Write the grayscale image into a JPEG file
|
||||
Expectation: The file should be written and removed
|
||||
"""
|
||||
|
||||
def write_jpeg_one_channel(filename_param, image_param, quality_param=75):
|
||||
"""
|
||||
a function used for writing with three channels image
|
||||
"""
|
||||
vision.write_jpeg(filename_param, image_param, quality_param)
|
||||
image_2_numpy = cv2.imread(filename_param, cv2.IMREAD_UNCHANGED)
|
||||
os.remove(filename_param)
|
||||
assert image_2_numpy.shape == (2268, 4032)
|
||||
|
||||
filename_1 = "../data/dataset/apple.jpg"
|
||||
mode = cv2.IMREAD_UNCHANGED
|
||||
image_1_numpy = cv2.imread(filename_1, mode)
|
||||
filename_2 = filename_1 + ".test_write_jpeg.jpg"
|
||||
image_grayscale = cv2.cvtColor(image_1_numpy, cv2.COLOR_BGR2GRAY)
|
||||
image_grayscale_tensor = Tensor.from_numpy(image_grayscale)
|
||||
|
||||
# Test writing numpy.ndarray
|
||||
write_jpeg_one_channel(filename_2, image_grayscale)
|
||||
|
||||
# Test writing Tensor and quality 1, 75, 100
|
||||
for quality in (1, 75, 100):
|
||||
write_jpeg_one_channel(filename_2, image_grayscale_tensor, quality)
|
||||
|
||||
# Test with three channels 2268*4032 random uint8
|
||||
image_random = numpy.ndarray(shape=(2268, 4032), dtype=numpy.uint8)
|
||||
write_jpeg_one_channel(filename_2, image_random)
|
||||
|
||||
# Test with one channels 2268*4032*1 random uint8, the quality is 50
|
||||
image_random = numpy.ndarray(shape=(2268, 4032, 1), dtype=numpy.uint8)
|
||||
write_jpeg_one_channel(filename_2, image_random, 50)
|
||||
|
||||
|
||||
def test_write_jpeg_exception():
|
||||
"""
|
||||
Feature: write_jpeg
|
||||
Description: Test write_jpeg with an invalid parameter
|
||||
Expectation: Error is caught when the parameter is invalid
|
||||
"""
|
||||
|
||||
def test_invalid_param(filename_param, image_param, quality_param, error, error_msg):
|
||||
"""
|
||||
a function used for checking correct error and message with invalid parameter
|
||||
"""
|
||||
with pytest.raises(error) as error_info:
|
||||
vision.write_jpeg(filename_param, image_param, quality_param)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
filename_1 = "../data/dataset/apple.jpg"
|
||||
mode = cv2.IMREAD_UNCHANGED
|
||||
image_1_numpy = cv2.imread(filename_1, mode)
|
||||
image_1_tensor = Tensor.from_numpy(image_1_numpy)
|
||||
|
||||
# Test with a directory name
|
||||
wrong_filename = "../data/dataset/"
|
||||
error_message = "Invalid file path, " + wrong_filename + " is not a regular file."
|
||||
test_invalid_param(wrong_filename, image_1_numpy, 75, RuntimeError, error_message)
|
||||
|
||||
# Test with an invalid filename
|
||||
wrong_filename = "/dev/cdrom/0"
|
||||
error_message = "No such file or directory"
|
||||
test_invalid_param(wrong_filename, image_1_tensor, 75, RuntimeError, error_message)
|
||||
|
||||
# Test with an invalid type for the filename
|
||||
error_message = "Input filename is not of type"
|
||||
test_invalid_param(0, image_1_numpy, 75, TypeError, error_message)
|
||||
|
||||
# Test with an invalid type for the data
|
||||
filename_2 = filename_1 + ".test_write_jpeg.jpg"
|
||||
error_message = "Input image is not of type"
|
||||
test_invalid_param(filename_2, 0, 75, TypeError, error_message)
|
||||
|
||||
# Test with invalid float elements
|
||||
invalid_data = numpy.ndarray(shape=(10, 10), dtype=float)
|
||||
error_message = "The type of the elements of image should be UINT8"
|
||||
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
|
||||
|
||||
# Test with invalid image with only one dimension
|
||||
invalid_data = numpy.ndarray(shape=(10), dtype=numpy.uint8)
|
||||
error_message = "The image has invalid dimensions"
|
||||
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
|
||||
|
||||
# Test with invalid image with four dimensions
|
||||
invalid_data = numpy.ndarray(shape=(1, 2, 3, 4), dtype=numpy.uint8)
|
||||
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
|
||||
|
||||
# Test with invalid image with two channels
|
||||
invalid_data = numpy.ndarray(shape=(2, 3, 2), dtype=numpy.uint8)
|
||||
error_message = "The image has invalid channels"
|
||||
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
|
||||
|
||||
# Test with invalid quality
|
||||
invalid_data = numpy.ndarray(shape=(2, 3, 2), dtype=numpy.uint8)
|
||||
error_message = "The image has invalid channels"
|
||||
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
|
||||
|
||||
# Test with an invalid integer for the quality 0, 101
|
||||
error_message = "Invalid quality"
|
||||
test_invalid_param(filename_2, image_1_numpy, 0, RuntimeError, error_message)
|
||||
test_invalid_param(filename_2, image_1_numpy, 101, RuntimeError, error_message)
|
||||
|
||||
# Test with an invalid type for the quality
|
||||
error_message = "Input quality is not of type"
|
||||
test_invalid_param(filename_2, image_1_numpy, 75.0, TypeError, error_message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_write_jpeg_three_channels()
|
||||
test_write_jpeg_one_channel()
|
||||
test_write_jpeg_exception()
|
Loading…
Reference in New Issue