!40746 【算子众智】【电子科技大学】【数据算子】【write_file】write a file using binary mode

Merge pull request !40746 from dengjian/upstream_write_file
This commit is contained in:
i-robot 2022-10-29 02:42:53 +00:00 committed by Gitee
commit df5f138d03
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 303 additions and 2 deletions

View File

@ -0,0 +1,17 @@
mindspore.dataset.vision.write_file
===================================
.. py:function:: mindspore.dataset.vision.write_file(filename, data)
使用二进制模式将一维uint8类型数据数组写到文件。
参数:
- **filename** (str) - 要写入的文件的路径。
- **data** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的一维uint8数据。
异常:
- **TypeError** - 如果 `filename` 不是str类型。
- **TypeError** - 如果 `data` 不是numpy.ndarray或mindspore.Tensor类型。
- **RuntimeError** - 如果 `filename` 路径不是普通文件。
- **RuntimeError** - 如果 `data` 的数据类型不是uint8类型。
- **RuntimeError** - 如果 `data` 的shape不是一维数组。

View File

@ -144,3 +144,4 @@ API样例中常用的导入模块如下
mindspore.dataset.vision.encode_jpeg
mindspore.dataset.vision.get_image_num_channels
mindspore.dataset.vision.get_image_size
mindspore.dataset.vision.write_file

View File

@ -94,3 +94,4 @@ Utilities
mindspore.dataset.vision.encode_jpeg
mindspore.dataset.vision.get_image_num_channels
mindspore.dataset.vision.get_image_size
mindspore.dataset.vision.write_file

View File

@ -902,5 +902,11 @@ PYBIND_REGISTER(
return vertical_flip;
}));
}));
PYBIND_REGISTER(WriteFileOperation, 1, ([](py::module *m) {
(void)m->def("write_file", ([](const std::string &filename, const std::shared_ptr<Tensor> &data) {
THROW_IF_ERROR(WriteFile(filename, data));
}));
}));
} // namespace dataset
} // namespace mindspore

8
mindspore/ccsrc/minddata/dataset/api/vision.cc Normal file → Executable file
View File

@ -1424,6 +1424,14 @@ std::shared_ptr<TensorOperation> UniformAugment::Parse() {
VerticalFlip::VerticalFlip() = default;
std::shared_ptr<TensorOperation> VerticalFlip::Parse() { return std::make_shared<VerticalFlipOperation>(); }
// WriteFile Function.
Status WriteFile(const std::string &filename, const mindspore::MSTensor &data) {
std::shared_ptr<dataset::Tensor> de_tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromMSTensor(data, &de_tensor));
RETURN_IF_NOT_OK(mindspore::dataset::WriteFile(filename, de_tensor));
return Status::OK();
}
#endif // not ENABLE_ANDROID
} // namespace vision
} // namespace dataset

View File

@ -2021,6 +2021,12 @@ class DATASET_API VerticalFlip final : public TensorTransform {
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
};
/// \brief Write the one dimension uint8 data into a file using binary mode.
/// \param[in] filename The path to the file to be written.
/// \param[in] data The tensor data.
/// \return The status code.
Status DATASET_API WriteFile(const std::string &filename, const mindspore::MSTensor &data);
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -16,11 +16,13 @@
#include "minddata/dataset/kernels/image/image_utils.h"
#include <opencv2/imgproc/types_c.h>
#include <algorithm>
#include <fstream>
#include <limits>
#include <string>
#include <vector>
#include <stdexcept>
#include <opencv2/imgcodecs.hpp>
#include "utils/file_utils.h"
#include "utils/ms_utils.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor.h"
@ -2258,5 +2260,66 @@ Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor>
return Status::OK();
}
Status WriteFile(const std::string &filename, const std::shared_ptr<Tensor> &data) {
std::string err_msg;
if (data->type() != DataType::DE_UINT8) {
err_msg = "WriteFile: The type of the elements of data should be UINT8, but got " + data->type().ToString() + ".";
RETURN_STATUS_UNEXPECTED(err_msg);
}
long int data_size = data->Size();
const char *data_buffer;
if (data_size >= kDeMaxDim || data_size < 0) {
err_msg = "WriteFile: Invalid data->Size() , should be >= 0 && < " + std::to_string(kDeMaxDim);
err_msg += " , but got " + std::to_string(data_size) + " for " + filename;
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (data_size > 0) {
data_buffer = (const char *)data->GetBuffer();
if (data_buffer == nullptr) {
err_msg = "WriteFile: Invalid data->GetBufferSize() , should not be nullptr.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
TensorShape shape = data->shape();
int rank = shape.Rank();
if (rank != kMinImageChannel) {
err_msg = "WriteFile: The data has invalid dimensions. It should have only one dimension, but got ";
err_msg += std::to_string(rank) + " dimensions.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
Path file(filename);
if (!file.Exists()) {
int file_descriptor;
RETURN_IF_NOT_OK(file.CreateFile(&file_descriptor));
RETURN_IF_NOT_OK(file.CloseFile(file_descriptor));
}
auto realpath = FileUtils::GetRealPath(filename.c_str());
if (!realpath.has_value()) {
RETURN_STATUS_UNEXPECTED("WriteFile: Invalid file path, " + filename + " can not get the real path.");
}
struct stat sb;
stat(realpath.value().c_str(), &sb);
if (S_ISREG(sb.st_mode) == 0) {
RETURN_STATUS_UNEXPECTED("WriteFile: Invalid file path, " + filename + " is not a regular file.");
}
std::ofstream fs(realpath.value().c_str(), std::ios::out | std::ios::trunc | std::ios::binary);
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "WriteFile: Failed to open the file: " + filename + " for writing.");
if (data_size > 0) {
fs.write(data_buffer, data_size);
if (fs.fail()) {
err_msg = "WriteFile: Failed to write the file " + filename;
fs.close();
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
fs.close();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -509,6 +509,12 @@ Status ApplyAugment(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
/// \param[in] quality The quality for the output tensor from 1 to 100. Default: 75.
/// \return The status code.
Status EncodeJpeg(const std::shared_ptr<Tensor> &image, std::shared_ptr<Tensor> *output, int quality = 75);
/// \brief Write the one dimension uint8 data into a file using binary mode.
/// \param[in] filename The path to the file to be written.
/// \param[in] data The tensor data.
/// \return The status code.
Status WriteFile(const std::string &filename, const std::shared_ptr<Tensor> &data);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_

2
mindspore/python/mindspore/dataset/vision/__init__.py Normal file → Executable file
View File

@ -85,4 +85,4 @@ from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue
ResizeWithBBox, RgbToHsv, Rotate, SlicePatches, Solarize, TenCrop, ToNumpy, ToPIL, ToTensor, ToType, \
TrivialAugmentWide, UniformAugment, VerticalFlip, not_random
from .utils import AutoAugmentPolicy, Border, ConvertMode, ImageBatchFormat, Inter, SliceMode, encode_jpeg, \
get_image_num_channels, get_image_size
get_image_num_channels, get_image_size, write_file

29
mindspore/python/mindspore/dataset/vision/utils.py Normal file → Executable file
View File

@ -445,3 +445,32 @@ def parse_padding(padding):
if isinstance(padding, list):
padding = tuple(padding)
return padding
def write_file(filename, data):
"""
Write the one dimension uint8 data into a file using binary mode.
Args:
filename (str): The path to the file to be written.
data (Union[numpy.ndarray, mindspore.Tensor]): The one dimension uint8 data to be written.
Raises:
TypeError: If `filename` is not of type str.
TypeError: If `data` is not of type numpy.ndarray or mindspore.Tensor.
RuntimeError: If the `filename` path is not a common file.
RuntimeError: If the data type of `data` is not uint8.
RuntimeError: If the shape of `data` is not a one-dimensional array.
Examples:
>>> from mindspore.dataset import vision
>>> vision.write_file("/path/to/file", data)
"""
if not isinstance(filename, str):
raise TypeError("Input filename is not of type {0}, but got: {1}.".format(str, type(filename)))
if isinstance(data, np.ndarray):
return cde.write_file(filename, cde.Tensor(data))
if isinstance(data, mindspore.Tensor):
return cde.write_file(filename, cde.Tensor(data.asnumpy()))
raise TypeError("Input data is not of type {0} or {1}, but got: {2}.".format(np.ndarray,
mindspore.Tensor, type(data)))

50
tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc Normal file → Executable file
View File

@ -821,6 +821,7 @@ TEST_F(MindDataTestPipeline, TestToTensorOpInvalidInput) {
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_ERROR(iter->GetNextRow(&row));
}
/// Feature: ResizedCrop op
/// Description: Test ResizedCrop pipeline
/// Expectation: Input is processed as expected and all rows iterated correctly
@ -900,7 +901,6 @@ TEST_F(MindDataTestPipeline, TestResizedCropParamCheck) {
EXPECT_EQ(iter3, nullptr);
}
/// Feature: RandAugment
/// Description: test RandAugment pipeline
/// Expectation: create an ImageFolder dataset then do rand augmentation on it
@ -1131,3 +1131,51 @@ TEST_F(MindDataTestPipeline, TestRandAugmentMagGreNMBError) {
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
/// Feature: WriteFile
/// Description: Test WriteFile by writing the data into a file using binary mode
/// Expectation: The file should be writeen and removed successfully
TEST_F(MindDataTestPipeline, TestWriteFileNormal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesWriteFileNormal.";
std::string folder_path = "./data/dataset/";
std::string filename_1, filename_2;
filename_1 = folder_path + "apple.jpg";
filename_2 = filename_1 + ".test_write_file";
std::shared_ptr<Tensor> de_tensor_1, de_tensor_2;
Tensor::CreateFromFile(filename_1, &de_tensor_1);
auto data_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor_1));
ASSERT_OK(mindspore::dataset::vision::WriteFile(filename_2, data_tensor));
Tensor::CreateFromFile(filename_2, &de_tensor_2);
EXPECT_EQ(de_tensor_1->shape(), de_tensor_2->shape());
remove(filename_2.c_str());
}
/// Feature: WriteFile
/// Description: Test WriFile with invalid parameter
/// Expectation: Error is caught when the parameter is invalid
TEST_F(MindDataTestPipeline, TestWriteFileException) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TesWriteFileException.";
std::string folder_path = "./data/dataset/";
std::string filename_1, filename_2;
filename_1 = folder_path + "apple.jpg";
filename_2 = filename_1 + ".test_write_file";
std::shared_ptr<Tensor> de_tensor_1;
Tensor::CreateFromFile(filename_1, &de_tensor_1);
auto data_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor_1));
// Test with a directory name
ASSERT_ERROR(mindspore::dataset::vision::WriteFile(folder_path, data_tensor));
// Test with an invalid filename
ASSERT_ERROR(mindspore::dataset::vision::WriteFile("/dev/cdrom/0", data_tensor));
// Test with invalid float elements
std::shared_ptr<Tensor> input;
std::vector<float> float_vector = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.10, 11.11, 12.12};
Tensor::CreateFromVector(float_vector, TensorShape({12}), &input);
data_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
ASSERT_ERROR(mindspore::dataset::vision::WriteFile(filename_2, data_tensor));
}

View File

@ -0,0 +1,116 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing write_file
"""
import os
import numpy
import pytest
from mindspore import Tensor
from mindspore.dataset import vision
def test_write_file_normal():
"""
Feature: write_file
Description: Test the write_file by writing the data into a file using binary mode
Expectation: The file should be writeen and removed
"""
filename_1 = "../data/dataset/apple.jpg"
data_1_numpy = numpy.fromfile(filename_1, dtype=numpy.uint8)
data_1_tensor = Tensor.from_numpy(data_1_numpy)
filename_2 = filename_1 + ".test_write_file"
# Test writing numpy.ndarray
vision.write_file(filename_2, data_1_numpy)
data_2_numpy = numpy.fromfile(filename_2, dtype=numpy.uint8)
os.remove(filename_2)
assert data_2_numpy.shape == (159109,)
# Test writing Tensor
vision.write_file(filename_2, data_1_tensor)
data_2_numpy = numpy.fromfile(filename_2, dtype=numpy.uint8)
os.remove(filename_2)
assert data_2_numpy.shape == (159109,)
# Test writing empty numpy.ndarray
empty_numpy = numpy.empty(0, dtype=numpy.uint8)
vision.utils.write_file(filename_2, empty_numpy)
data_2_numpy = numpy.fromfile(filename_2, dtype=numpy.uint8)
os.remove(filename_2)
assert data_2_numpy.shape == (0,)
# Test writing empty Tensor
empty_tensor = Tensor.from_numpy(empty_numpy)
vision.utils.write_file(filename_2, empty_tensor)
data_2_numpy = numpy.fromfile(filename_2, dtype=numpy.uint8)
os.remove(filename_2)
assert data_2_numpy.shape == (0,)
def test_write_file_exception():
"""
Feature: write_file
Description: Test the write_file with invalid parameter
Expectation: Error is caught when the parameter is invalid
"""
def test_invalid_param(filename_param, data_param, error, error_msg):
"""
a function used for checking correct error and message with invalid parameter
"""
with pytest.raises(error) as error_info:
vision.write_file(filename_param, data_param)
assert error_msg in str(error_info.value)
filename_1 = "../data/dataset/apple.jpg"
data_1_numpy = numpy.fromfile(filename_1, dtype=numpy.uint8)
data_1_tensor = Tensor.from_numpy(data_1_numpy)
# Test with a directory name
wrong_filename = "../data/dataset/"
error_message = "Invalid file path, " + wrong_filename + " is not a regular file."
test_invalid_param(wrong_filename, data_1_numpy, RuntimeError, error_message)
# Test with an invalid filename
wrong_filename = "/dev/cdrom/0"
error_message = "No such file or directory"
test_invalid_param(wrong_filename, data_1_tensor, RuntimeError, error_message)
# Test with an invalid type for the filename
error_message = "Input filename is not of type"
test_invalid_param(0, data_1_numpy, TypeError, error_message)
# Test with an invalid type for the data
filename_2 = filename_1 + ".test_write_file"
error_message = "Input data is not of type"
test_invalid_param(filename_2, 0, TypeError, error_message)
# Test with invalid float elements
invalid_data = numpy.ndarray(shape=(10), dtype=float)
error_message = "The type of the elements of data should be"
test_invalid_param(filename_2, invalid_data, RuntimeError, error_message)
# Test with invalid data
error_message = "The data has invalid dimensions"
invalid_data = numpy.ndarray(shape=(10, 10), dtype=numpy.uint8)
test_invalid_param(filename_2, invalid_data, RuntimeError, error_message)
if __name__ == "__main__":
test_write_file_normal()
test_write_file_exception()