forked from mindspore-Ecosystem/mindspore
!49251 add schema & len for FileReader
Merge pull request !49251 from guozhijian/add_interface_for_mr
This commit is contained in:
commit
14e9afce2b
|
@ -29,3 +29,17 @@
|
|||
|
||||
异常:
|
||||
- **MRMUnsupportedSchemaError** - 当schema无效。
|
||||
|
||||
.. py:method:: schema()
|
||||
|
||||
返回当前MindRecord文件的Schema信息。
|
||||
|
||||
返回:
|
||||
dict,Schema信息。
|
||||
|
||||
.. py:method:: len()
|
||||
|
||||
返回当前MindRecord文件的样本个数。
|
||||
|
||||
返回:
|
||||
int,样本个数。
|
||||
|
|
|
@ -181,7 +181,8 @@ void BindShardReader(const py::module *m) {
|
|||
});
|
||||
return res;
|
||||
})
|
||||
.def("close", &ShardReader::Close);
|
||||
.def("close", &ShardReader::Close)
|
||||
.def("len", &ShardReader::GetNumRows);
|
||||
}
|
||||
|
||||
void BindShardIndexGenerator(const py::module *m) {
|
||||
|
|
|
@ -100,3 +100,21 @@ class FileReader:
|
|||
def close(self):
|
||||
"""Stop reader worker and close file."""
|
||||
self._reader.close()
|
||||
|
||||
def schema(self):
|
||||
"""
|
||||
Get the schema of the MindRecord.
|
||||
|
||||
Returns:
|
||||
dict, the schema info.
|
||||
"""
|
||||
return self._header.schema
|
||||
|
||||
def len(self):
|
||||
"""
|
||||
Get the number of the samples in MindRecord.
|
||||
|
||||
Returns:
|
||||
int, the number of the samples in MindRecord.
|
||||
"""
|
||||
return self._reader.len()
|
||||
|
|
|
@ -106,3 +106,12 @@ class ShardReader:
|
|||
def close(self):
|
||||
"""close MindRecord File."""
|
||||
self._reader.close()
|
||||
|
||||
def len(self):
|
||||
"""
|
||||
Get the number of the samples in MindRecord.
|
||||
|
||||
Returns:
|
||||
int, the number of the samples in MindRecord.
|
||||
"""
|
||||
return self._reader.len()
|
||||
|
|
|
@ -1270,3 +1270,83 @@ def test_cv_file_overwrite_exception_02():
|
|||
writer.write_raw_data(data)
|
||||
assert 'Invalid file, mindrecord files already exist. Please check file path:' in str(err.value)
|
||||
remove_multi_files(mindrecord_file_name, FILES_NUM)
|
||||
|
||||
|
||||
def test_file_writer_schema_len(file_name=None, remove_file=True):
|
||||
"""
|
||||
Feature: FileWriter
|
||||
Description: writer for schema and len
|
||||
Expectation: SUCCESS
|
||||
"""
|
||||
if not file_name:
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
|
||||
## single mindrecord file
|
||||
# 1. empty file
|
||||
remove_one_file(file_name)
|
||||
remove_one_file(file_name + ".db")
|
||||
writer = FileWriter(file_name, 1)
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.commit()
|
||||
|
||||
# get the schema & len
|
||||
reader = FileReader(file_name)
|
||||
assert cv_schema_json == reader.schema()
|
||||
assert reader.len() == 0
|
||||
if remove_file:
|
||||
remove_one_file(file_name)
|
||||
remove_one_file(file_name + ".db")
|
||||
|
||||
# 2. with 10 samples
|
||||
remove_one_file(file_name)
|
||||
remove_one_file(file_name + ".db")
|
||||
writer = FileWriter(file_name, 1)
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
# get the schema & len
|
||||
reader = FileReader(file_name)
|
||||
assert cv_schema_json == reader.schema()
|
||||
assert reader.len() == 10
|
||||
if remove_file:
|
||||
remove_one_file(file_name)
|
||||
remove_one_file(file_name + ".db")
|
||||
|
||||
## multi mindrecord file
|
||||
# 1. empty file
|
||||
remove_multi_files(file_name, FILES_NUM)
|
||||
writer = FileWriter(file_name, FILES_NUM)
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.commit()
|
||||
|
||||
# get the schema & len
|
||||
reader = FileReader(file_name + "0")
|
||||
assert cv_schema_json == reader.schema()
|
||||
assert reader.len() == 0
|
||||
if remove_file:
|
||||
remove_multi_files(file_name, FILES_NUM)
|
||||
|
||||
# 2. with samples
|
||||
remove_multi_files(file_name, FILES_NUM)
|
||||
writer = FileWriter(file_name, FILES_NUM)
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
# get the schema & len
|
||||
reader = FileReader(file_name + "0")
|
||||
assert cv_schema_json == reader.schema()
|
||||
assert reader.len() == 10
|
||||
if remove_file:
|
||||
remove_multi_files(file_name, FILES_NUM)
|
||||
|
|
Loading…
Reference in New Issue