forked from mindspore-Ecosystem/mindspore
add mindrecord overwrite
This commit is contained in:
parent
027adf9fff
commit
41e8280f55
|
@ -105,8 +105,8 @@ void BindShardWriter(py::module *m) {
|
|||
(void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local())
|
||||
.def(py::init<>())
|
||||
.def("open",
|
||||
[](ShardWriter &s, const std::vector<std::string> &paths, bool append) {
|
||||
THROW_IF_ERROR(s.Open(paths, append));
|
||||
[](ShardWriter &s, const std::vector<std::string> &paths, bool append, bool overwrite) {
|
||||
THROW_IF_ERROR(s.Open(paths, append, overwrite));
|
||||
return SUCCESS;
|
||||
})
|
||||
.def("open_for_append",
|
||||
|
|
|
@ -54,9 +54,10 @@ class __attribute__((visibility("default"))) ShardWriter {
|
|||
|
||||
/// \brief Open file at the beginning
|
||||
/// \param[in] paths the file names list
|
||||
/// \param[in] append new data at the end of file if true, otherwise overwrite file
|
||||
/// \param[in] append new data at the end of file if true, otherwise try to overwrite file
|
||||
/// \param[in] overwrite a file with the same name if true
|
||||
/// \return Status
|
||||
Status Open(const std::vector<std::string> &paths, bool append = false);
|
||||
Status Open(const std::vector<std::string> &paths, bool append = false, bool overwrite = false);
|
||||
|
||||
/// \brief Open file at the ending
|
||||
/// \param[in] paths the file names list
|
||||
|
@ -215,7 +216,7 @@ class __attribute__((visibility("default"))) ShardWriter {
|
|||
Status GetFullPathFromFileName(const std::vector<std::string> &paths);
|
||||
|
||||
/// \brief Open files
|
||||
Status OpenDataFiles(bool append);
|
||||
Status OpenDataFiles(bool append, bool overwrite);
|
||||
|
||||
/// \brief Remove lock file
|
||||
Status RemoveLockFile();
|
||||
|
|
|
@ -64,7 +64,7 @@ Status ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &path
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardWriter::OpenDataFiles(bool append) {
|
||||
Status ShardWriter::OpenDataFiles(bool append, bool overwrite) {
|
||||
// Open files
|
||||
for (const auto &file : file_paths_) {
|
||||
std::optional<std::string> dir = "";
|
||||
|
@ -82,13 +82,33 @@ Status ShardWriter::OpenDataFiles(bool append) {
|
|||
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
if (!append) {
|
||||
// if not append and mindrecord file exist, return FAILED
|
||||
// if not append && mindrecord or db file exist
|
||||
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
|
||||
if (fs->good()) {
|
||||
std::ifstream fs_db(whole_path.value() + ".db");
|
||||
if (fs->good() || fs_db.good()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, Mindrecord files already existed in path: " + file);
|
||||
fs_db.close();
|
||||
if (overwrite) {
|
||||
auto res1 = std::remove(whole_path.value().c_str());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!std::ifstream(whole_path.value()) == true,
|
||||
"Failed to delete file, path: " + file);
|
||||
if (res1 == 0) {
|
||||
MS_LOG(WARNING) << "Succeed to delete file, path: " << file;
|
||||
}
|
||||
auto db_file = whole_path.value() + ".db";
|
||||
auto res2 = std::remove(db_file.c_str());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!std::ifstream(whole_path.value() + ".db") == true,
|
||||
"Failed to delete db file, path: " + file + ".db");
|
||||
if (res2 == 0) {
|
||||
MS_LOG(WARNING) << "Succeed to delete metadata file, path: " << file + ".db";
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, Mindrecord files already existed in path: " + file);
|
||||
}
|
||||
} else {
|
||||
fs->close();
|
||||
fs_db.close();
|
||||
}
|
||||
fs->close();
|
||||
// open the mindrecord file to write
|
||||
fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc);
|
||||
if (!fs->good()) {
|
||||
|
@ -131,7 +151,7 @@ Status ShardWriter::InitLockFile() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
|
||||
Status ShardWriter::Open(const std::vector<std::string> &paths, bool append, bool overwrite) {
|
||||
shard_count_ = paths.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ <= kMaxSchemaCount,
|
||||
"Invalid data, schema_count_ must be less than or equal to " +
|
||||
|
@ -140,7 +160,7 @@ Status ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
|
|||
// Get full path from file name
|
||||
RETURN_IF_NOT_OK(GetFullPathFromFileName(paths));
|
||||
// Open files
|
||||
RETURN_IF_NOT_OK(OpenDataFiles(append));
|
||||
RETURN_IF_NOT_OK(OpenDataFiles(append, overwrite));
|
||||
// Init lock file
|
||||
RETURN_IF_NOT_OK(InitLockFile());
|
||||
return Status::OK();
|
||||
|
|
|
@ -41,11 +41,12 @@ class FileWriter:
|
|||
|
||||
Args:
|
||||
file_name (str): File name of MindRecord file.
|
||||
shard_num (int, optional): The Number of MindRecord file. Default: 1.
|
||||
shard_num (int, optional): The Number of MindRecord files. Default: 1.
|
||||
It should be between [1, 1000].
|
||||
overwrite (bool, optional): Overwrite MindRecord files if true. Default: False.
|
||||
|
||||
Raises:
|
||||
ParamValueError: If `file_name` or `shard_num` is invalid.
|
||||
ParamValueError: If `file_name` or `shard_num` or `overwrite` is invalid.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.mindrecord import FileWriter
|
||||
|
@ -68,19 +69,22 @@ class FileWriter:
|
|||
MSRStatus.SUCCESS
|
||||
"""
|
||||
|
||||
def __init__(self, file_name, shard_num=1):
|
||||
def __init__(self, file_name, shard_num=1, overwrite=False):
|
||||
check_filename(file_name)
|
||||
self._file_name = file_name
|
||||
|
||||
if shard_num is not None:
|
||||
if isinstance(shard_num, int):
|
||||
if shard_num < MIN_SHARD_COUNT or shard_num > MAX_SHARD_COUNT:
|
||||
raise ParamValueError("Shard number should between {} and {}."
|
||||
.format(MIN_SHARD_COUNT, MAX_SHARD_COUNT))
|
||||
raise ParamValueError("Parameter shard_num's value: {} should between {} and {}."
|
||||
.format(shard_num, MIN_SHARD_COUNT, MAX_SHARD_COUNT))
|
||||
else:
|
||||
raise ParamValueError("Shard num is illegal.")
|
||||
raise ParamValueError("Parameter shard_num's type is not int.")
|
||||
else:
|
||||
raise ParamValueError("Shard num is illegal.")
|
||||
raise ParamValueError("Parameter shard_num is None.")
|
||||
|
||||
if not isinstance(overwrite, bool):
|
||||
raise ParamValueError("Parameter overwrite's type is not bool.")
|
||||
|
||||
self._shard_num = shard_num
|
||||
self._index_generator = True
|
||||
|
@ -93,6 +97,7 @@ class FileWriter:
|
|||
str(x).rjust(suffix_shard_size, '0'))
|
||||
for x in range(self._shard_num)]
|
||||
|
||||
self._overwrite = overwrite
|
||||
self._append = False
|
||||
self._header = ShardHeader()
|
||||
self._writer = ShardWriter()
|
||||
|
@ -137,7 +142,7 @@ class FileWriter:
|
|||
check_filename(file_name)
|
||||
# construct ShardHeader
|
||||
reader = ShardReader()
|
||||
reader.open(file_name)
|
||||
reader.open(file_name, False)
|
||||
header = ShardHeader(reader.get_header())
|
||||
reader.close()
|
||||
|
||||
|
@ -268,7 +273,7 @@ class FileWriter:
|
|||
MRMSetHeaderError: If failed to set header.
|
||||
"""
|
||||
if not self._writer.is_open:
|
||||
ret = self._writer.open(self._paths)
|
||||
ret = self._writer.open(self._paths, self._overwrite)
|
||||
if not self._writer.get_shard_header():
|
||||
return self._writer.set_shard_header(self._header)
|
||||
return ret
|
||||
|
@ -296,7 +301,7 @@ class FileWriter:
|
|||
MRMWriteDatasetError: If failed to write dataset.
|
||||
"""
|
||||
if not self._writer.is_open:
|
||||
self._writer.open(self._paths)
|
||||
self._writer.open(self._paths, self._overwrite)
|
||||
if not self._writer.get_shard_header():
|
||||
self._writer.set_shard_header(self._header)
|
||||
if not isinstance(raw_data, list):
|
||||
|
@ -378,7 +383,7 @@ class FileWriter:
|
|||
MRMCommitError: If failed to flush data to disk.
|
||||
"""
|
||||
if not self._writer.is_open:
|
||||
self._writer.open(self._paths)
|
||||
self._writer.open(self._paths, self._overwrite)
|
||||
# permit commit without data
|
||||
if not self._writer.get_shard_header():
|
||||
self._writer.set_shard_header(self._header)
|
||||
|
|
|
@ -126,13 +126,13 @@ def check_parameter(func):
|
|||
check_filename(value)
|
||||
if name == 'num_consumer':
|
||||
if value is None:
|
||||
raise ParamValueError("Consumer number is illegal.")
|
||||
raise ParamValueError("Parameter num_consumer is None.")
|
||||
if isinstance(value, int):
|
||||
if value < MIN_CONSUMER_COUNT or value > MAX_CONSUMER_COUNT():
|
||||
raise ParamValueError("Consumer number should between {} and {}."
|
||||
.format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
|
||||
raise ParamValueError("Parameter num_consumer: {} should between {} and {}."
|
||||
.format(value, MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
|
||||
else:
|
||||
raise ParamValueError("Consumer number is illegal.")
|
||||
raise ParamValueError("Parameter num_consumer is not int.")
|
||||
return func(*args, **kw)
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -36,7 +36,7 @@ class ShardWriter:
|
|||
self._header = None
|
||||
self._is_open = False
|
||||
|
||||
def open(self, paths):
|
||||
def open(self, paths, override):
|
||||
"""
|
||||
Open a new MindRecord File and prepare to write raw data.
|
||||
|
||||
|
@ -49,7 +49,7 @@ class ShardWriter:
|
|||
Raises:
|
||||
MRMOpenError: If failed to open MindRecord File.
|
||||
"""
|
||||
ret = self._writer.open(paths, False)
|
||||
ret = self._writer.open(paths, False, override)
|
||||
if ret != ms.MSRStatus.SUCCESS:
|
||||
logger.critical("Failed to open paths")
|
||||
raise MRMOpenError
|
||||
|
|
|
@ -1156,5 +1156,90 @@ TEST_F(TestShardWriter, TestOpenForAppend) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Feature: OverWriting in FileWriter
|
||||
/// Description: old mindrecord files exist in output path
|
||||
/// Expectation: generated mindrecord files
|
||||
TEST_F(TestShardWriter, TestOverWrite) {
|
||||
|
||||
MS_LOG(INFO) << common::SafeCStr(FormatInfo("OverWrite imageNet"));
|
||||
|
||||
// load binary data
|
||||
std::vector<std::vector<uint8_t>> bin_data;
|
||||
std::vector<std::string> filenames;
|
||||
if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) {
|
||||
MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------";
|
||||
return;
|
||||
}
|
||||
mindrecord::Img2DataUint8(filenames, bin_data);
|
||||
|
||||
// init shardHeader
|
||||
ShardHeader header_data;
|
||||
MS_LOG(INFO) << "Init ShardHeader Already.";
|
||||
|
||||
// create schema
|
||||
json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json;
|
||||
std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json);
|
||||
if (anno_schema == nullptr) {
|
||||
MS_LOG(ERROR) << "Build annotation schema failed";
|
||||
return;
|
||||
}
|
||||
|
||||
// add schema to shardHeader
|
||||
int anno_schema_id = header_data.AddSchema(anno_schema);
|
||||
MS_LOG(INFO) << "Init Schema Already.";
|
||||
|
||||
// create index
|
||||
std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name");
|
||||
std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label");
|
||||
std::vector<std::pair<uint64_t, std::string>> fields;
|
||||
fields.push_back(index_field1);
|
||||
fields.push_back(index_field2);
|
||||
|
||||
// add index to shardHeader
|
||||
header_data.AddIndexFields(fields);
|
||||
MS_LOG(INFO) << "Init Index Fields Already.";
|
||||
// load meta data
|
||||
std::vector<json> annotations;
|
||||
LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 10);
|
||||
|
||||
// add data
|
||||
std::map<std::uint64_t, std::vector<json>> rawdatas;
|
||||
rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations));
|
||||
MS_LOG(INFO) << "Init Images Already.";
|
||||
|
||||
// init file_writer
|
||||
std::vector<std::string> file_names;
|
||||
int file_count = 4;
|
||||
for (int i = 1; i <= file_count; i++) {
|
||||
file_names.emplace_back(std::string("./imagenet.shard0") + std::to_string(i));
|
||||
MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]);
|
||||
}
|
||||
|
||||
std::ofstream outfile(file_names[0]);
|
||||
outfile << "dummy data!" << std::endl;
|
||||
outfile.close();
|
||||
MS_LOG(INFO) << "Init Output Files Already.";
|
||||
{
|
||||
ShardWriter fw_init;
|
||||
fw_init.Open(file_names, false, true);
|
||||
// set shardHeader
|
||||
fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||
// close file_writer
|
||||
fw_init.Commit();
|
||||
}
|
||||
|
||||
{
|
||||
mindrecord::ShardWriter fw;
|
||||
fw.OpenForAppend(file_names[0]);
|
||||
fw.WriteRawData(rawdatas, bin_data);
|
||||
fw.Commit();
|
||||
}
|
||||
|
||||
for (const auto &oneFile : file_names) {
|
||||
remove(common::SafeCStr(oneFile));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""test mindrecord base"""
|
||||
import os
|
||||
import uuid
|
||||
import pytest
|
||||
import numpy as np
|
||||
from utils import get_data, get_nlp_data
|
||||
|
||||
|
@ -1079,3 +1080,168 @@ def test_write_read_process_without_ndarray_type():
|
|||
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
def test_cv_file_overwrite_01():
|
||||
"""
|
||||
Feature: Overwriting in FileWriter
|
||||
Description: full mindrecord files exist
|
||||
Expectation: generated new mindrecord files
|
||||
"""
|
||||
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
test_cv_file_writer_tutorial(mindrecord_file_name, remove_file=False)
|
||||
|
||||
writer = FileWriter(mindrecord_file_name, FILES_NUM, True)
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name + "0")
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 10
|
||||
reader.close()
|
||||
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
|
||||
def test_cv_file_overwrite_02():
|
||||
"""
|
||||
Feature: Overwriting in FileWriter
|
||||
Description: lack 1 mindrecord file
|
||||
Expectation: generated new mindrecord files
|
||||
"""
|
||||
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
test_cv_file_writer_tutorial(mindrecord_file_name, remove_file=False)
|
||||
# remove 1 mindrecord file
|
||||
os.remove(mindrecord_file_name + "0")
|
||||
|
||||
writer = FileWriter(mindrecord_file_name, FILES_NUM, True)
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name + "0")
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 10
|
||||
reader.close()
|
||||
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
|
||||
def test_cv_file_overwrite_03():
|
||||
"""
|
||||
Feature: Overwriting in FileWriter
|
||||
Description: lack 1 db file
|
||||
Expectation: generated new mindrecord files
|
||||
"""
|
||||
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
test_cv_file_writer_tutorial(mindrecord_file_name, remove_file=False)
|
||||
# remove 1 db file
|
||||
os.remove(mindrecord_file_name + "0" + ".db")
|
||||
|
||||
writer = FileWriter(mindrecord_file_name, FILES_NUM, True)
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name + "0")
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 10
|
||||
reader.close()
|
||||
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
|
||||
def test_cv_file_overwrite_04():
|
||||
"""
|
||||
Feature: Overwriting in FileWriter
|
||||
Description: lack 1 db file and mindrecord file
|
||||
Expectation: generated new mindrecord files
|
||||
"""
|
||||
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
test_cv_file_writer_tutorial(mindrecord_file_name, remove_file=False)
|
||||
|
||||
os.remove(mindrecord_file_name + "0")
|
||||
os.remove(mindrecord_file_name + "0" + ".db")
|
||||
|
||||
writer = FileWriter(mindrecord_file_name, FILES_NUM, True)
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name + "0")
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 10
|
||||
reader.close()
|
||||
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
|
||||
def test_cv_file_overwrite_exception_01():
|
||||
"""
|
||||
Feature: Overwriting in FileWriter
|
||||
Description: default write mode, detect mindrecord file
|
||||
Expectation: exception occur
|
||||
"""
|
||||
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
with open(mindrecord_file_name + "0", 'w'):
|
||||
pass
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
writer = FileWriter(mindrecord_file_name, FILES_NUM)
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.write_raw_data(data)
|
||||
assert 'Unexpected error. Invalid file, Mindrecord files already existed in path:' in str(err.value)
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
|
||||
def test_cv_file_overwrite_exception_02():
|
||||
"""
|
||||
Feature: Overwriting in FileWriter
|
||||
Description: default write mode, detect db file
|
||||
Expectation: exception occur
|
||||
"""
|
||||
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
with open(mindrecord_file_name + "0" + ".db", 'w'):
|
||||
pass
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
writer = FileWriter(mindrecord_file_name, FILES_NUM)
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.write_raw_data(data)
|
||||
assert 'Unexpected error. Invalid file, Mindrecord files already existed in path:' in str(err.value)
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
|
|
|
@ -42,37 +42,44 @@ def remove_file(file_name):
|
|||
|
||||
def test_cv_file_writer_shard_num_none():
|
||||
"""test cv file writer when shard num is None."""
|
||||
with pytest.raises(Exception, match="Shard num is illegal."):
|
||||
with pytest.raises(Exception, match="Parameter shard_num is None."):
|
||||
FileWriter("/tmp/123454321", None)
|
||||
|
||||
def test_cv_file_writer_overwrite_int():
|
||||
"""
|
||||
Feature: Overwriting in FileWriter
|
||||
Description: invalid parameter
|
||||
Expectation: exception occur
|
||||
"""
|
||||
with pytest.raises(Exception, match="Parameter overwrite's type is not bool."):
|
||||
FileWriter("/tmp/123454321", 4, 1)
|
||||
|
||||
def test_cv_file_writer_shard_num_str():
|
||||
"""test cv file writer when shard num is string."""
|
||||
with pytest.raises(Exception, match="Shard num is illegal."):
|
||||
with pytest.raises(Exception, match="Parameter shard_num's type is not int."):
|
||||
FileWriter("/tmp/123454321", "20")
|
||||
|
||||
|
||||
def test_cv_page_reader_consumer_num_none():
|
||||
"""test cv page reader when consumer number is None."""
|
||||
with pytest.raises(Exception, match="Consumer number is illegal."):
|
||||
with pytest.raises(Exception, match="Parameter num_consumer is None."):
|
||||
MindPage("dummy.mindrecord", None)
|
||||
|
||||
|
||||
def test_cv_page_reader_consumer_num_str():
|
||||
"""test cv page reader when consumer number is string."""
|
||||
with pytest.raises(Exception, match="Consumer number is illegal."):
|
||||
with pytest.raises(Exception, match="Parameter num_consumer is not int."):
|
||||
MindPage("dummy.mindrecord", "2")
|
||||
|
||||
|
||||
def test_nlp_file_reader_consumer_num_none():
|
||||
"""test nlp file reader when consumer number is None."""
|
||||
with pytest.raises(Exception, match="Consumer number is illegal."):
|
||||
with pytest.raises(Exception, match="Parameter num_consumer is None."):
|
||||
FileReader("dummy.mindrecord", None)
|
||||
|
||||
|
||||
def test_nlp_file_reader_consumer_num_str():
|
||||
"""test nlp file reader when consumer number is string."""
|
||||
with pytest.raises(Exception, match="Consumer number is illegal."):
|
||||
with pytest.raises(Exception, match="Parameter num_consumer is not int."):
|
||||
FileReader("dummy.mindrecord", "4")
|
||||
|
||||
|
||||
|
@ -271,7 +278,7 @@ def test_overwrite_invalid_db():
|
|||
f.write('just for test')
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
create_cv_mindrecord(1, file_name)
|
||||
assert 'Unexpected error. Failed to write data to db.' in str(err.value)
|
||||
assert 'Unexpected error. Invalid file, Mindrecord files already existed in path:' in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_read_after_close():
|
||||
|
@ -323,7 +330,7 @@ def test_cv_file_writer_shard_num_greater_than_1000():
|
|||
"""
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
FileWriter('dummy.mindrecord', 1001)
|
||||
assert 'Shard number should between' in str(err.value)
|
||||
assert "Parameter shard_num's value: 1001 should between 1 and 1000." in str(err.value)
|
||||
|
||||
|
||||
def test_add_index_without_add_schema():
|
||||
|
|
Loading…
Reference in New Issue