[feat] [assistant] [I4S2FB] add new data operator write_jpeg

This commit is contained in:
deng jian 2022-08-23 16:58:46 +08:00 committed by dengjian
parent 06bb96bd11
commit 91f99ca70a
12 changed files with 440 additions and 3 deletions

View File

@ -0,0 +1,20 @@
mindspore.dataset.vision.write_jpeg
===================================
.. py:function:: mindspore.dataset.vision.write_jpeg(filename, image, quality=75)
将图像数据保存为JPEG文件。
参数:
- **filename** (str) - 要写入的文件的路径。
- **image** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的图像数据。
- **quality** (int, 可选) - 生成的JPEG文件的质量取值范围为[1, 100]。默认值: 75。
异常:
- **TypeError** - 如果 `filename` 不是str类型。
- **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。
- **TypeError** - 如果 `quality` 不是int类型。
- **RuntimeError** - 如果 `filename` 不存在或不是普通文件。
- **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。
- **RuntimeError** - 如果 `image` 的shape不是 <H, W> 或 <H, W, 1> 或 <H, W, 3>。
- **RuntimeError** - 如果 `quality` 小于1或大于100。

View File

@ -148,3 +148,4 @@ API样例中常用的导入模块如下
mindspore.dataset.vision.read_file
mindspore.dataset.vision.read_image
mindspore.dataset.vision.write_file
mindspore.dataset.vision.write_jpeg

View File

@ -98,3 +98,4 @@ Utilities
mindspore.dataset.vision.read_file
mindspore.dataset.vision.read_image
mindspore.dataset.vision.write_file
mindspore.dataset.vision.write_jpeg

View File

@ -924,5 +924,12 @@ PYBIND_REGISTER(WriteFileOperation, 1, ([](py::module *m) {
THROW_IF_ERROR(WriteFile(filename, data));
}));
}));
PYBIND_REGISTER(WriteJPEGOperation, 1, ([](py::module *m) {
(void)m->def("write_jpeg",
([](const std::string &filename, const std::shared_ptr<Tensor> &image, int quality) {
THROW_IF_ERROR(WriteJpeg(filename, image, quality));
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -1457,6 +1457,14 @@ Status WriteFile(const std::string &filename, const mindspore::MSTensor &data) {
RETURN_IF_NOT_OK(mindspore::dataset::WriteFile(filename, de_tensor));
return Status::OK();
}
// WriteJpeg Function.
Status WriteJpeg(const std::string &filename, const mindspore::MSTensor &image, int quality) {
std::shared_ptr<dataset::Tensor> image_de_tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromMSTensor(image, &image_de_tensor));
RETURN_IF_NOT_OK(mindspore::dataset::WriteJpeg(filename, image_de_tensor, quality));
return Status::OK();
}
#endif // not ENABLE_ANDROID
} // namespace vision
} // namespace dataset

View File

@ -2046,6 +2046,13 @@ class DATASET_API VerticalFlip final : public TensorTransform {
/// \param[in] data The tensor data.
/// \return The status code.
Status DATASET_API WriteFile(const std::string &filename, const mindspore::MSTensor &data);
/// \brief Write the image data into a JPEG file.
/// \param[in] filename The path to the file to be written.
/// \param[in] image The data tensor.
/// \param[in] quality The quality for JPEG file, in range of [1, 100]. Default: 75.
/// \return The status code.
Status DATASET_API WriteJpeg(const std::string &filename, const mindspore::MSTensor &image, int quality = 75);
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -2244,7 +2244,7 @@ Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor>
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(image);
image_matrix = input_cv->mat();
if (!image_matrix.data) {
RETURN_STATUS_UNEXPECTED("[Internal ERROR] EncodeJpeg: load the image tensor failed.");
RETURN_STATUS_UNEXPECTED("EncodeJpeg: Load the image tensor failed.");
}
if (channels == kMinImageChannel) {
@ -2350,7 +2350,7 @@ Status WriteFile(const std::string &filename, const std::shared_ptr<Tensor> &dat
}
auto realpath = FileUtils::GetRealPath(filename.c_str());
if (!realpath.has_value()) {
RETURN_STATUS_UNEXPECTED("WriteFile: Invalid file path, " + filename + " can not get the real path.");
RETURN_STATUS_UNEXPECTED("WriteFile: Invalid file path, " + filename + " failed to get the real path.");
}
struct stat sb;
stat(realpath.value().c_str(), &sb);
@ -2372,5 +2372,87 @@ Status WriteFile(const std::string &filename, const std::shared_ptr<Tensor> &dat
fs.close();
return Status::OK();
}
Status WriteJpeg(const std::string &filename, const std::shared_ptr<Tensor> &image, int quality) {
std::string err_msg;
if (image->type() != DataType::DE_UINT8) {
err_msg = "WriteJpeg: The type of the elements of image should be UINT8, but got " + image->type().ToString() + ".";
RETURN_STATUS_UNEXPECTED(err_msg);
}
TensorShape shape = image->shape();
int rank = shape.Rank();
if (rank < kMinImageRank || rank > kDefaultImageRank) {
err_msg = "WriteJpeg: The image has invalid dimensions. It should have two or three dimensions, but got ";
err_msg += std::to_string(rank) + " dimensions.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
int channels;
if (rank == kDefaultImageRank) {
channels = shape[kMinImageRank];
if (channels != kMinImageChannel && channels != kDefaultImageChannel) {
err_msg = "WriteJpeg: The image has invalid channels. It should have 1 or 3 channels, but got ";
err_msg += std::to_string(channels) + " channels.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
} else {
channels = 1;
}
if (quality < kMinJpegQuality || quality > kMaxJpegQuality) {
err_msg = "WriteJpeg: Invalid quality " + std::to_string(quality) + ", should be in range of [" +
std::to_string(kMinJpegQuality) + ", " + std::to_string(kMaxJpegQuality) + "].";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::vector<int> params = {cv::IMWRITE_JPEG_QUALITY, quality, cv::IMWRITE_JPEG_PROGRESSIVE, 0,
cv::IMWRITE_JPEG_OPTIMIZE, 0, cv::IMWRITE_JPEG_RST_INTERVAL, 0};
std::vector<unsigned char> buffer;
cv::Mat image_matrix;
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(image);
image_matrix = input_cv->mat();
if (!image_matrix.data) {
RETURN_STATUS_UNEXPECTED("WriteJpeg: Load the image tensor failed.");
}
if (channels == kMinImageChannel) {
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_matrix, buffer, params),
"WriteJpeg: Failed to encode image.");
} else {
cv::Mat image_bgr;
cv::cvtColor(image_matrix, image_bgr, cv::COLOR_RGB2BGR);
CHECK_FAIL_RETURN_UNEXPECTED(cv::imencode(".JPEG", image_bgr, buffer, params),
"WriteJpeg: Failed to encode image.");
}
Path file(filename);
if (!file.Exists()) {
int file_descriptor;
RETURN_IF_NOT_OK(file.CreateFile(&file_descriptor));
RETURN_IF_NOT_OK(file.CloseFile(file_descriptor));
}
auto realpath = FileUtils::GetRealPath(filename.c_str());
if (!realpath.has_value()) {
RETURN_STATUS_UNEXPECTED("WriteJpeg: Invalid file path, " + filename + " failed to get the real path.");
}
struct stat sb;
stat(realpath.value().c_str(), &sb);
if (S_ISREG(sb.st_mode) == 0) {
RETURN_STATUS_UNEXPECTED("WriteJpeg: Invalid file path, " + filename + " is not a regular file.");
}
std::ofstream fs(realpath.value().c_str(), std::ios::out | std::ios::trunc | std::ios::binary);
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "WriteJpeg: Failed to open the file " + filename + " for writing.");
fs.write((const char *)buffer.data(), (long int)buffer.size());
if (fs.fail()) {
fs.close();
RETURN_STATUS_UNEXPECTED("WriteJpeg: Failed to write the file " + filename);
}
fs.close();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -534,6 +534,13 @@ Status ReadImage(const std::string &filename, std::shared_ptr<Tensor> *output,
/// \param[in] data The tensor data.
/// \return The status code.
Status WriteFile(const std::string &filename, const std::shared_ptr<Tensor> &data);
/// \brief Write the image data into a JPEG file.
/// \param[in] filename The path to the file to be written.
/// \param[in] image The data tensor.
/// \param[in] quality The quality for JPEG file, in range of [1, 100]. Default: 75.
/// \return Status code.
Status WriteJpeg(const std::string &filename, const std::shared_ptr<Tensor> &image, int quality = 75);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_

View File

@ -85,4 +85,4 @@ from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue
ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, \
TrivialAugmentWide, UniformAugment, VerticalFlip, not_random
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, ImageReadMode, Inter, SliceMode, \
encode_jpeg, get_image_num_channels, get_image_size, read_file, read_image, write_file
encode_jpeg, get_image_num_channels, get_image_size, read_file, read_image, write_file, write_jpeg

View File

@ -574,3 +574,42 @@ def write_file(filename, data):
return cde.write_file(filename, cde.Tensor(data.asnumpy()))
raise TypeError("Input data is not of type {0} or {1}, but got: {2}.".format(np.ndarray,
mindspore.Tensor, type(data)))
def write_jpeg(filename, image, quality=75):
"""
Write the image data into a JPEG file.
Args:
filename (str): The path to the file to be written.
image (Union[numpy.ndarray, mindspore.Tensor]): The image data to be written.
quality (int, optional): Quality of the resulting JPEG file, in range of [1, 100]. Default: 75.
Raises:
TypeError: If `filename` is not of type str.
TypeError: If `image` is not of type numpy.ndarray or mindspore.Tensor.
TypeError: If `quality` is not of type int.
RuntimeError: If the `filename` does not exist or not a common file.
RuntimeError: If the data type of `image` is not uint8.
RuntimeError: If the shape of `image` is not <H, W> or <H, W, 1> or <H, W, 3>.
RuntimeError: If `quality` is less than 1 or greater than 100.
Supported Platforms:
``CPU``
Examples:
>>> import numpy as np
>>> # Generate a random image with height=120, width=340, channels=3
>>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
>>> vision.write_jpeg("/path/to/file", image)
"""
if not isinstance(filename, str):
raise TypeError("Input filename is not of type {0}, but got: {1}.".format(str, type(filename)))
if not isinstance(quality, int):
raise TypeError("Input quality is not of type {0}, but got: {1}.".format(int, type(quality)))
if isinstance(image, np.ndarray):
return cde.write_jpeg(filename, cde.Tensor(image), quality)
if isinstance(image, mindspore.Tensor):
return cde.write_jpeg(filename, cde.Tensor(image.asnumpy()), quality)
raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray,
mindspore.Tensor, type(image)))

View File

@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <opencv2/opencv.hpp>
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/transforms.h"
@ -1348,3 +1350,89 @@ TEST_F(MindDataTestPipeline, TestWriteFileException) {
data_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
ASSERT_ERROR(mindspore::dataset::vision::WriteFile(filename_2, data_tensor));
}
/// Feature: WriteJpeg
/// Description: Test WriteJpeg by writing the image into a JPEG file
/// Expectation: The file should be written and removed
TEST_F(MindDataTestPipeline, TestWriteJpegNormal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesWriteJpegNormal.";
std::string folder_path = "./data/dataset/testFormats/";
std::string filename_1;
std::string filename_2;
cv::Mat image_1;
cv::Mat image_2;
filename_1 = folder_path + "apple.jpg";
filename_2 = filename_1 + ".test_write_jpeg.jpg";
cv::Mat image_bgr = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED);
cv::cvtColor(image_bgr, image_1, cv::COLOR_BGRA2RGB);
TensorShape img_tensor_shape = TensorShape({image_1.size[0], image_1.size[1], image_1.channels()});
DataType pixel_type = DataType(DataType::DE_UINT8);
std::shared_ptr<Tensor> image_de_tensor;
Tensor::CreateFromMemory(img_tensor_shape, pixel_type, image_1.data, &image_de_tensor);
auto image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
int quality;
for (quality = 20; quality <= 100 ; quality += 40) {
ASSERT_OK(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, quality));
image_2 = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED);
remove(filename_2.c_str());
EXPECT_EQ(image_1.total(), image_2.total());
}
}
/// Feature: WriteJpeg
/// Description: Test WriteJpeg with invalid parameter
/// Expectation: Error is caught when the parameter is invalid
TEST_F(MindDataTestPipeline, TestWriteJpegException) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesWriteJpegException.";
std::string folder_path = "./data/dataset/testFormats/";
std::string filename_1;
std::string filename_2;
cv::Mat image_1;
filename_1 = folder_path + "apple.jpg";
filename_2 = filename_1 + ".test_write_jpeg.jpg";
image_1 = cv::imread(filename_1, cv::ImreadModes::IMREAD_UNCHANGED);
TensorShape img_tensor_shape = TensorShape({image_1.size[0], image_1.size[1], image_1.channels()});
DataType pixel_type = DataType(DataType::DE_UINT8);
std::shared_ptr<Tensor> image_de_tensor;
Tensor::CreateFromMemory(img_tensor_shape, pixel_type, image_1.data, &image_de_tensor);
auto image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
// Test with invalid quality 0, 101
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, 0));
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor, 101));
// Test with an invalid filename
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg("/dev/cdrom/0", image_ms_tensor));
// Test with a directory name
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg("./data/dataset/", image_ms_tensor));
// Test with an invalid image containing float elements
std::shared_ptr<Tensor> float32_cde_tensor;
Tensor::CreateEmpty(TensorShape({5, 4, 3 }), DataType(DataType::DE_FLOAT32), &float32_cde_tensor);
image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(float32_cde_tensor));
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor));
// Test with an invalid image with only one dimension
image_de_tensor->Reshape(TensorShape({image_1.size[0] * image_1.size[1] * image_1.channels()}));
image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor));
// Test with an invalid image with four dimensions
image_de_tensor->Reshape(TensorShape({image_1.size[0] / 2, image_1.size[1], image_1.channels(), 2}));
image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor));
// Test with an invalid image with two channels
image_de_tensor->Reshape(TensorShape({image_1.size[0] * image_1.channels() / 2, image_1.size[1], 2}));
image_ms_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(image_de_tensor));
ASSERT_ERROR(mindspore::dataset::vision::WriteJpeg(filename_2, image_ms_tensor));
}

View File

@ -0,0 +1,177 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing write_jpeg
"""
import os
import cv2
import numpy
import pytest
from mindspore import Tensor
from mindspore.dataset import vision
def test_write_jpeg_three_channels():
"""
Feature: write_jpeg
Description: Write the image containing three channels into a JPEG file
Expectation: The file should be written and removed
"""
def write_jpeg_three_channels(filename_param, image_param, quality_param=75):
"""
a function used for writing with three channels image
"""
vision.write_jpeg(filename_param, image_param, quality_param)
image_2_numpy = cv2.imread(filename_param, cv2.IMREAD_UNCHANGED)
os.remove(filename_param)
assert image_2_numpy.shape == (2268, 4032, 3)
filename_1 = "../data/dataset/apple.jpg"
mode = cv2.IMREAD_UNCHANGED
image_bgr = cv2.imread(filename_1, mode)
image_1_numpy = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
image_1_tensor = Tensor.from_numpy(image_1_numpy)
filename_2 = filename_1 + ".test_write_jpeg.jpg"
# Test writing numpy.ndarray
write_jpeg_three_channels(filename_2, image_1_numpy)
# Test writing Tensor and quality 1, 75, 100
for quality in (1, 75, 100):
write_jpeg_three_channels(filename_2, image_1_tensor, quality)
# Test with three channels 2268*4032*3 random uint8, the quality is 50
image_random = numpy.ndarray(shape=(2268, 4032, 3), dtype=numpy.uint8)
write_jpeg_three_channels(filename_2, image_random, 50)
def test_write_jpeg_one_channel():
"""
Feature: write_jpeg
Description: Write the grayscale image into a JPEG file
Expectation: The file should be written and removed
"""
def write_jpeg_one_channel(filename_param, image_param, quality_param=75):
"""
a function used for writing with three channels image
"""
vision.write_jpeg(filename_param, image_param, quality_param)
image_2_numpy = cv2.imread(filename_param, cv2.IMREAD_UNCHANGED)
os.remove(filename_param)
assert image_2_numpy.shape == (2268, 4032)
filename_1 = "../data/dataset/apple.jpg"
mode = cv2.IMREAD_UNCHANGED
image_1_numpy = cv2.imread(filename_1, mode)
filename_2 = filename_1 + ".test_write_jpeg.jpg"
image_grayscale = cv2.cvtColor(image_1_numpy, cv2.COLOR_BGR2GRAY)
image_grayscale_tensor = Tensor.from_numpy(image_grayscale)
# Test writing numpy.ndarray
write_jpeg_one_channel(filename_2, image_grayscale)
# Test writing Tensor and quality 1, 75, 100
for quality in (1, 75, 100):
write_jpeg_one_channel(filename_2, image_grayscale_tensor, quality)
# Test with three channels 2268*4032 random uint8
image_random = numpy.ndarray(shape=(2268, 4032), dtype=numpy.uint8)
write_jpeg_one_channel(filename_2, image_random)
# Test with one channels 2268*4032*1 random uint8, the quality is 50
image_random = numpy.ndarray(shape=(2268, 4032, 1), dtype=numpy.uint8)
write_jpeg_one_channel(filename_2, image_random, 50)
def test_write_jpeg_exception():
"""
Feature: write_jpeg
Description: Test write_jpeg with an invalid parameter
Expectation: Error is caught when the parameter is invalid
"""
def test_invalid_param(filename_param, image_param, quality_param, error, error_msg):
"""
a function used for checking correct error and message with invalid parameter
"""
with pytest.raises(error) as error_info:
vision.write_jpeg(filename_param, image_param, quality_param)
assert error_msg in str(error_info.value)
filename_1 = "../data/dataset/apple.jpg"
mode = cv2.IMREAD_UNCHANGED
image_1_numpy = cv2.imread(filename_1, mode)
image_1_tensor = Tensor.from_numpy(image_1_numpy)
# Test with a directory name
wrong_filename = "../data/dataset/"
error_message = "Invalid file path, " + wrong_filename + " is not a regular file."
test_invalid_param(wrong_filename, image_1_numpy, 75, RuntimeError, error_message)
# Test with an invalid filename
wrong_filename = "/dev/cdrom/0"
error_message = "No such file or directory"
test_invalid_param(wrong_filename, image_1_tensor, 75, RuntimeError, error_message)
# Test with an invalid type for the filename
error_message = "Input filename is not of type"
test_invalid_param(0, image_1_numpy, 75, TypeError, error_message)
# Test with an invalid type for the data
filename_2 = filename_1 + ".test_write_jpeg.jpg"
error_message = "Input image is not of type"
test_invalid_param(filename_2, 0, 75, TypeError, error_message)
# Test with invalid float elements
invalid_data = numpy.ndarray(shape=(10, 10), dtype=float)
error_message = "The type of the elements of image should be UINT8"
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
# Test with invalid image with only one dimension
invalid_data = numpy.ndarray(shape=(10), dtype=numpy.uint8)
error_message = "The image has invalid dimensions"
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
# Test with invalid image with four dimensions
invalid_data = numpy.ndarray(shape=(1, 2, 3, 4), dtype=numpy.uint8)
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
# Test with invalid image with two channels
invalid_data = numpy.ndarray(shape=(2, 3, 2), dtype=numpy.uint8)
error_message = "The image has invalid channels"
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
# Test with invalid quality
invalid_data = numpy.ndarray(shape=(2, 3, 2), dtype=numpy.uint8)
error_message = "The image has invalid channels"
test_invalid_param(filename_2, invalid_data, 75, RuntimeError, error_message)
# Test with an invalid integer for the quality 0, 101
error_message = "Invalid quality"
test_invalid_param(filename_2, image_1_numpy, 0, RuntimeError, error_message)
test_invalid_param(filename_2, image_1_numpy, 101, RuntimeError, error_message)
# Test with an invalid type for the quality
error_message = "Input quality is not of type"
test_invalid_param(filename_2, image_1_numpy, 75.0, TypeError, error_message)
if __name__ == "__main__":
test_write_jpeg_three_channels()
test_write_jpeg_one_channel()
test_write_jpeg_exception()