!49251 add schema & len for FileReader

Merge pull request !49251 from guozhijian/add_interface_for_mr
This commit is contained in:
i-robot 2023-02-24 06:40:09 +00:00 committed by Gitee
commit 14e9afce2b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 123 additions and 1 deletions

View File

@ -29,3 +29,17 @@
异常:
- **MRMUnsupportedSchemaError** - 当schema无效。
.. py:method:: schema()
返回当前MindRecord文件的Schema信息。
返回:
dictSchema信息。
.. py:method:: len()
返回当前MindRecord文件的样本个数。
返回:
int样本个数。

View File

@ -181,7 +181,8 @@ void BindShardReader(const py::module *m) {
});
return res;
})
.def("close", &ShardReader::Close);
.def("close", &ShardReader::Close)
.def("len", &ShardReader::GetNumRows);
}
void BindShardIndexGenerator(const py::module *m) {

View File

@ -100,3 +100,21 @@ class FileReader:
def close(self):
"""Stop reader worker and close file."""
self._reader.close()
def schema(self):
"""
Get the schema of the MindRecord.
Returns:
dict, the schema info.
"""
return self._header.schema
def len(self):
"""
Get the number of the samples in MindRecord.
Returns:
int, the number of the samples in MindRecord.
"""
return self._reader.len()

View File

@ -106,3 +106,12 @@ class ShardReader:
def close(self):
"""close MindRecord File."""
self._reader.close()
def len(self):
"""
Get the number of the samples in MindRecord.
Returns:
int, the number of the samples in MindRecord.
"""
return self._reader.len()

View File

@ -1270,3 +1270,83 @@ def test_cv_file_overwrite_exception_02():
writer.write_raw_data(data)
assert 'Invalid file, mindrecord files already exist. Please check file path:' in str(err.value)
remove_multi_files(mindrecord_file_name, FILES_NUM)
def test_file_writer_schema_len(file_name=None, remove_file=True):
"""
Feature: FileWriter
Description: writer for schema and len
Expectation: SUCCESS
"""
if not file_name:
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
## single mindrecord file
# 1. empty file
remove_one_file(file_name)
remove_one_file(file_name + ".db")
writer = FileWriter(file_name, 1)
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.commit()
# get the schema & len
reader = FileReader(file_name)
assert cv_schema_json == reader.schema()
assert reader.len() == 0
if remove_file:
remove_one_file(file_name)
remove_one_file(file_name + ".db")
# 2. with 10 samples
remove_one_file(file_name)
remove_one_file(file_name + ".db")
writer = FileWriter(file_name, 1)
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
data = get_data("../data/mindrecord/testImageNetData/")
writer.add_schema(cv_schema_json, "img_schema")
writer.write_raw_data(data)
writer.commit()
# get the schema & len
reader = FileReader(file_name)
assert cv_schema_json == reader.schema()
assert reader.len() == 10
if remove_file:
remove_one_file(file_name)
remove_one_file(file_name + ".db")
## multi mindrecord file
# 1. empty file
remove_multi_files(file_name, FILES_NUM)
writer = FileWriter(file_name, FILES_NUM)
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.commit()
# get the schema & len
reader = FileReader(file_name + "0")
assert cv_schema_json == reader.schema()
assert reader.len() == 0
if remove_file:
remove_multi_files(file_name, FILES_NUM)
# 2. with samples
remove_multi_files(file_name, FILES_NUM)
writer = FileWriter(file_name, FILES_NUM)
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
data = get_data("../data/mindrecord/testImageNetData/")
writer.add_schema(cv_schema_json, "img_schema")
writer.write_raw_data(data)
writer.commit()
# get the schema & len
reader = FileReader(file_name + "0")
assert cv_schema_json == reader.schema()
assert reader.len() == 10
if remove_file:
remove_multi_files(file_name, FILES_NUM)