add mindrecord overwrite

This commit is contained in:
liyong 2021-10-26 20:15:10 +08:00
parent 027adf9fff
commit 41e8280f55
9 changed files with 322 additions and 38 deletions

View File

@ -105,8 +105,8 @@ void BindShardWriter(py::module *m) {
(void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local())
.def(py::init<>())
.def("open",
[](ShardWriter &s, const std::vector<std::string> &paths, bool append) {
THROW_IF_ERROR(s.Open(paths, append));
[](ShardWriter &s, const std::vector<std::string> &paths, bool append, bool overwrite) {
THROW_IF_ERROR(s.Open(paths, append, overwrite));
return SUCCESS;
})
.def("open_for_append",

View File

@ -54,9 +54,10 @@ class __attribute__((visibility("default"))) ShardWriter {
/// \brief Open file at the beginning
/// \param[in] paths the file names list
/// \param[in] append new data at the end of file if true, otherwise overwrite file
/// \param[in] append new data at the end of file if true, otherwise try to overwrite file
/// \param[in] overwrite a file with the same name if true
/// \return Status
Status Open(const std::vector<std::string> &paths, bool append = false);
Status Open(const std::vector<std::string> &paths, bool append = false, bool overwrite = false);
/// \brief Open file at the ending
/// \param[in] paths the file names list
@ -215,7 +216,7 @@ class __attribute__((visibility("default"))) ShardWriter {
Status GetFullPathFromFileName(const std::vector<std::string> &paths);
/// \brief Open files
Status OpenDataFiles(bool append);
Status OpenDataFiles(bool append, bool overwrite);
/// \brief Remove lock file
Status RemoveLockFile();

View File

@ -64,7 +64,7 @@ Status ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &path
return Status::OK();
}
Status ShardWriter::OpenDataFiles(bool append) {
Status ShardWriter::OpenDataFiles(bool append, bool overwrite) {
// Open files
for (const auto &file : file_paths_) {
std::optional<std::string> dir = "";
@ -82,13 +82,33 @@ Status ShardWriter::OpenDataFiles(bool append) {
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
if (!append) {
// if not append and mindrecord file exist, return FAILED
// if not append && mindrecord or db file exist
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
if (fs->good()) {
std::ifstream fs_db(whole_path.value() + ".db");
if (fs->good() || fs_db.good()) {
fs->close();
RETURN_STATUS_UNEXPECTED("Invalid file, Mindrecord files already existed in path: " + file);
fs_db.close();
if (overwrite) {
auto res1 = std::remove(whole_path.value().c_str());
CHECK_FAIL_RETURN_UNEXPECTED(!std::ifstream(whole_path.value()) == true,
"Failed to delete file, path: " + file);
if (res1 == 0) {
MS_LOG(WARNING) << "Succeed to delete file, path: " << file;
}
auto db_file = whole_path.value() + ".db";
auto res2 = std::remove(db_file.c_str());
CHECK_FAIL_RETURN_UNEXPECTED(!std::ifstream(whole_path.value() + ".db") == true,
"Failed to delete db file, path: " + file + ".db");
if (res2 == 0) {
MS_LOG(WARNING) << "Succeed to delete metadata file, path: " << file + ".db";
}
} else {
RETURN_STATUS_UNEXPECTED("Invalid file, Mindrecord files already existed in path: " + file);
}
} else {
fs->close();
fs_db.close();
}
fs->close();
// open the mindrecord file to write
fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc);
if (!fs->good()) {
@ -131,7 +151,7 @@ Status ShardWriter::InitLockFile() {
return Status::OK();
}
Status ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
Status ShardWriter::Open(const std::vector<std::string> &paths, bool append, bool overwrite) {
shard_count_ = paths.size();
CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ <= kMaxSchemaCount,
"Invalid data, schema_count_ must be less than or equal to " +
@ -140,7 +160,7 @@ Status ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
// Get full path from file name
RETURN_IF_NOT_OK(GetFullPathFromFileName(paths));
// Open files
RETURN_IF_NOT_OK(OpenDataFiles(append));
RETURN_IF_NOT_OK(OpenDataFiles(append, overwrite));
// Init lock file
RETURN_IF_NOT_OK(InitLockFile());
return Status::OK();

View File

@ -41,11 +41,12 @@ class FileWriter:
Args:
file_name (str): File name of MindRecord file.
shard_num (int, optional): The Number of MindRecord file. Default: 1.
shard_num (int, optional): The Number of MindRecord files. Default: 1.
It should be between [1, 1000].
overwrite (bool, optional): Overwrite MindRecord files if true. Default: False.
Raises:
ParamValueError: If `file_name` or `shard_num` is invalid.
ParamValueError: If `file_name` or `shard_num` or `overwrite` is invalid.
Examples:
>>> from mindspore.mindrecord import FileWriter
@ -68,19 +69,22 @@ class FileWriter:
MSRStatus.SUCCESS
"""
def __init__(self, file_name, shard_num=1):
def __init__(self, file_name, shard_num=1, overwrite=False):
check_filename(file_name)
self._file_name = file_name
if shard_num is not None:
if isinstance(shard_num, int):
if shard_num < MIN_SHARD_COUNT or shard_num > MAX_SHARD_COUNT:
raise ParamValueError("Shard number should between {} and {}."
.format(MIN_SHARD_COUNT, MAX_SHARD_COUNT))
raise ParamValueError("Parameter shard_num's value: {} should between {} and {}."
.format(shard_num, MIN_SHARD_COUNT, MAX_SHARD_COUNT))
else:
raise ParamValueError("Shard num is illegal.")
raise ParamValueError("Parameter shard_num's type is not int.")
else:
raise ParamValueError("Shard num is illegal.")
raise ParamValueError("Parameter shard_num is None.")
if not isinstance(overwrite, bool):
raise ParamValueError("Parameter overwrite's type is not bool.")
self._shard_num = shard_num
self._index_generator = True
@ -93,6 +97,7 @@ class FileWriter:
str(x).rjust(suffix_shard_size, '0'))
for x in range(self._shard_num)]
self._overwrite = overwrite
self._append = False
self._header = ShardHeader()
self._writer = ShardWriter()
@ -137,7 +142,7 @@ class FileWriter:
check_filename(file_name)
# construct ShardHeader
reader = ShardReader()
reader.open(file_name)
reader.open(file_name, False)
header = ShardHeader(reader.get_header())
reader.close()
@ -268,7 +273,7 @@ class FileWriter:
MRMSetHeaderError: If failed to set header.
"""
if not self._writer.is_open:
ret = self._writer.open(self._paths)
ret = self._writer.open(self._paths, self._overwrite)
if not self._writer.get_shard_header():
return self._writer.set_shard_header(self._header)
return ret
@ -296,7 +301,7 @@ class FileWriter:
MRMWriteDatasetError: If failed to write dataset.
"""
if not self._writer.is_open:
self._writer.open(self._paths)
self._writer.open(self._paths, self._overwrite)
if not self._writer.get_shard_header():
self._writer.set_shard_header(self._header)
if not isinstance(raw_data, list):
@ -378,7 +383,7 @@ class FileWriter:
MRMCommitError: If failed to flush data to disk.
"""
if not self._writer.is_open:
self._writer.open(self._paths)
self._writer.open(self._paths, self._overwrite)
# permit commit without data
if not self._writer.get_shard_header():
self._writer.set_shard_header(self._header)

View File

@ -126,13 +126,13 @@ def check_parameter(func):
check_filename(value)
if name == 'num_consumer':
if value is None:
raise ParamValueError("Consumer number is illegal.")
raise ParamValueError("Parameter num_consumer is None.")
if isinstance(value, int):
if value < MIN_CONSUMER_COUNT or value > MAX_CONSUMER_COUNT():
raise ParamValueError("Consumer number should between {} and {}."
.format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
raise ParamValueError("Parameter num_consumer: {} should between {} and {}."
.format(value, MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
else:
raise ParamValueError("Consumer number is illegal.")
raise ParamValueError("Parameter num_consumer is not int.")
return func(*args, **kw)
return wrapper

View File

@ -36,7 +36,7 @@ class ShardWriter:
self._header = None
self._is_open = False
def open(self, paths):
def open(self, paths, override):
"""
Open a new MindRecord File and prepare to write raw data.
@ -49,7 +49,7 @@ class ShardWriter:
Raises:
MRMOpenError: If failed to open MindRecord File.
"""
ret = self._writer.open(paths, False)
ret = self._writer.open(paths, False, override)
if ret != ms.MSRStatus.SUCCESS:
logger.critical("Failed to open paths")
raise MRMOpenError

View File

@ -1156,5 +1156,90 @@ TEST_F(TestShardWriter, TestOpenForAppend) {
}
}
/// Feature: OverWriting in FileWriter
/// Description: old mindrecord files exist in output path
/// Expectation: generated mindrecord files
TEST_F(TestShardWriter, TestOverWrite) {
MS_LOG(INFO) << common::SafeCStr(FormatInfo("OverWrite imageNet"));
// load binary data
std::vector<std::vector<uint8_t>> bin_data;
std::vector<std::string> filenames;
if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) {
MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------";
return;
}
mindrecord::Img2DataUint8(filenames, bin_data);
// init shardHeader
ShardHeader header_data;
MS_LOG(INFO) << "Init ShardHeader Already.";
// create schema
json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json;
std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json);
if (anno_schema == nullptr) {
MS_LOG(ERROR) << "Build annotation schema failed";
return;
}
// add schema to shardHeader
int anno_schema_id = header_data.AddSchema(anno_schema);
MS_LOG(INFO) << "Init Schema Already.";
// create index
std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name");
std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label");
std::vector<std::pair<uint64_t, std::string>> fields;
fields.push_back(index_field1);
fields.push_back(index_field2);
// add index to shardHeader
header_data.AddIndexFields(fields);
MS_LOG(INFO) << "Init Index Fields Already.";
// load meta data
std::vector<json> annotations;
LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 10);
// add data
std::map<std::uint64_t, std::vector<json>> rawdatas;
rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations));
MS_LOG(INFO) << "Init Images Already.";
// init file_writer
std::vector<std::string> file_names;
int file_count = 4;
for (int i = 1; i <= file_count; i++) {
file_names.emplace_back(std::string("./imagenet.shard0") + std::to_string(i));
MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]);
}
std::ofstream outfile(file_names[0]);
outfile << "dummy data!" << std::endl;
outfile.close();
MS_LOG(INFO) << "Init Output Files Already.";
{
ShardWriter fw_init;
fw_init.Open(file_names, false, true);
// set shardHeader
fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
// close file_writer
fw_init.Commit();
}
{
mindrecord::ShardWriter fw;
fw.OpenForAppend(file_names[0]);
fw.WriteRawData(rawdatas, bin_data);
fw.Commit();
}
for (const auto &oneFile : file_names) {
remove(common::SafeCStr(oneFile));
}
}
} // namespace mindrecord
} // namespace mindspore

View File

@ -15,6 +15,7 @@
"""test mindrecord base"""
import os
import uuid
import pytest
import numpy as np
from utils import get_data, get_nlp_data
@ -1079,3 +1080,168 @@ def test_write_read_process_without_ndarray_type():
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
def test_cv_file_overwrite_01():
"""
Feature: Overwriting in FileWriter
Description: full mindrecord files exist
Expectation: generated new mindrecord files
"""
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_multi_files(mindrecord_file_name, FILES_NUM)
test_cv_file_writer_tutorial(mindrecord_file_name, remove_file=False)
writer = FileWriter(mindrecord_file_name, FILES_NUM, True)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name + "0")
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 10
reader.close()
remove_multi_files(mindrecord_file_name, FILES_NUM)
def test_cv_file_overwrite_02():
"""
Feature: Overwriting in FileWriter
Description: lack 1 mindrecord file
Expectation: generated new mindrecord files
"""
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_multi_files(mindrecord_file_name, FILES_NUM)
test_cv_file_writer_tutorial(mindrecord_file_name, remove_file=False)
# remove 1 mindrecord file
os.remove(mindrecord_file_name + "0")
writer = FileWriter(mindrecord_file_name, FILES_NUM, True)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name + "0")
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 10
reader.close()
remove_multi_files(mindrecord_file_name, FILES_NUM)
def test_cv_file_overwrite_03():
"""
Feature: Overwriting in FileWriter
Description: lack 1 db file
Expectation: generated new mindrecord files
"""
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_multi_files(mindrecord_file_name, FILES_NUM)
test_cv_file_writer_tutorial(mindrecord_file_name, remove_file=False)
# remove 1 db file
os.remove(mindrecord_file_name + "0" + ".db")
writer = FileWriter(mindrecord_file_name, FILES_NUM, True)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name + "0")
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 10
reader.close()
remove_multi_files(mindrecord_file_name, FILES_NUM)
def test_cv_file_overwrite_04():
"""
Feature: Overwriting in FileWriter
Description: lack 1 db file and mindrecord file
Expectation: generated new mindrecord files
"""
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_multi_files(mindrecord_file_name, FILES_NUM)
test_cv_file_writer_tutorial(mindrecord_file_name, remove_file=False)
os.remove(mindrecord_file_name + "0")
os.remove(mindrecord_file_name + "0" + ".db")
writer = FileWriter(mindrecord_file_name, FILES_NUM, True)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name + "0")
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 10
reader.close()
remove_multi_files(mindrecord_file_name, FILES_NUM)
def test_cv_file_overwrite_exception_01():
"""
Feature: Overwriting in FileWriter
Description: default write mode, detect mindrecord file
Expectation: exception occur
"""
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
with open(mindrecord_file_name + "0", 'w'):
pass
with pytest.raises(RuntimeError) as err:
writer = FileWriter(mindrecord_file_name, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.write_raw_data(data)
assert 'Unexpected error. Invalid file, Mindrecord files already existed in path:' in str(err.value)
remove_multi_files(mindrecord_file_name, FILES_NUM)
def test_cv_file_overwrite_exception_02():
"""
Feature: Overwriting in FileWriter
Description: default write mode, detect db file
Expectation: exception occur
"""
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
with open(mindrecord_file_name + "0" + ".db", 'w'):
pass
with pytest.raises(RuntimeError) as err:
writer = FileWriter(mindrecord_file_name, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.write_raw_data(data)
assert 'Unexpected error. Invalid file, Mindrecord files already existed in path:' in str(err.value)
remove_multi_files(mindrecord_file_name, FILES_NUM)

View File

@ -42,37 +42,44 @@ def remove_file(file_name):
def test_cv_file_writer_shard_num_none():
"""test cv file writer when shard num is None."""
with pytest.raises(Exception, match="Shard num is illegal."):
with pytest.raises(Exception, match="Parameter shard_num is None."):
FileWriter("/tmp/123454321", None)
def test_cv_file_writer_overwrite_int():
"""
Feature: Overwriting in FileWriter
Description: invalid parameter
Expectation: exception occur
"""
with pytest.raises(Exception, match="Parameter overwrite's type is not bool."):
FileWriter("/tmp/123454321", 4, 1)
def test_cv_file_writer_shard_num_str():
"""test cv file writer when shard num is string."""
with pytest.raises(Exception, match="Shard num is illegal."):
with pytest.raises(Exception, match="Parameter shard_num's type is not int."):
FileWriter("/tmp/123454321", "20")
def test_cv_page_reader_consumer_num_none():
"""test cv page reader when consumer number is None."""
with pytest.raises(Exception, match="Consumer number is illegal."):
with pytest.raises(Exception, match="Parameter num_consumer is None."):
MindPage("dummy.mindrecord", None)
def test_cv_page_reader_consumer_num_str():
"""test cv page reader when consumer number is string."""
with pytest.raises(Exception, match="Consumer number is illegal."):
with pytest.raises(Exception, match="Parameter num_consumer is not int."):
MindPage("dummy.mindrecord", "2")
def test_nlp_file_reader_consumer_num_none():
"""test nlp file reader when consumer number is None."""
with pytest.raises(Exception, match="Consumer number is illegal."):
with pytest.raises(Exception, match="Parameter num_consumer is None."):
FileReader("dummy.mindrecord", None)
def test_nlp_file_reader_consumer_num_str():
"""test nlp file reader when consumer number is string."""
with pytest.raises(Exception, match="Consumer number is illegal."):
with pytest.raises(Exception, match="Parameter num_consumer is not int."):
FileReader("dummy.mindrecord", "4")
@ -271,7 +278,7 @@ def test_overwrite_invalid_db():
f.write('just for test')
with pytest.raises(RuntimeError) as err:
create_cv_mindrecord(1, file_name)
assert 'Unexpected error. Failed to write data to db.' in str(err.value)
assert 'Unexpected error. Invalid file, Mindrecord files already existed in path:' in str(err.value)
remove_file(file_name)
def test_read_after_close():
@ -323,7 +330,7 @@ def test_cv_file_writer_shard_num_greater_than_1000():
"""
with pytest.raises(ParamValueError) as err:
FileWriter('dummy.mindrecord', 1001)
assert 'Shard number should between' in str(err.value)
assert "Parameter shard_num's value: 1001 should between 1 and 1000." in str(err.value)
def test_add_index_without_add_schema():