From 0c267681c3000eb7fec4eefb29a268d90ffb38be Mon Sep 17 00:00:00 2001 From: jonyguo Date: Wed, 22 Feb 2023 16:59:50 +0800 Subject: [PATCH] Add new interface schema & len for filereader --- .../mindspore.mindrecord.FileReader.rst | 14 ++++ .../mindrecord/common/shard_pybind.cc | 3 +- .../python/mindspore/mindrecord/filereader.py | 18 +++++ .../mindspore/mindrecord/shardreader.py | 9 +++ .../python/mindrecord/test_mindrecord_base.py | 80 +++++++++++++++++++ 5 files changed, 123 insertions(+), 1 deletion(-) diff --git a/docs/api/api_python/mindrecord/mindspore.mindrecord.FileReader.rst b/docs/api/api_python/mindrecord/mindspore.mindrecord.FileReader.rst index 3c81ab82b28..0080776188b 100644 --- a/docs/api/api_python/mindrecord/mindspore.mindrecord.FileReader.rst +++ b/docs/api/api_python/mindrecord/mindspore.mindrecord.FileReader.rst @@ -29,3 +29,17 @@ 异常: - **MRMUnsupportedSchemaError** - 当schema无效。 + + .. py:method:: schema() + + 返回当前MindRecord文件的Schema信息。 + + 返回: + dict,Schema信息。 + + .. py:method:: len() + + 返回当前MindRecord文件的样本个数。 + + 返回: + int,样本个数。 diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc index 8bd2ddf28fd..353a79a5c0b 100644 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc @@ -181,7 +181,8 @@ void BindShardReader(const py::module *m) { }); return res; }) - .def("close", &ShardReader::Close); + .def("close", &ShardReader::Close) + .def("len", &ShardReader::GetNumRows); } void BindShardIndexGenerator(const py::module *m) { diff --git a/mindspore/python/mindspore/mindrecord/filereader.py b/mindspore/python/mindspore/mindrecord/filereader.py index 1f3d315bfd9..f7bfdece2e3 100644 --- a/mindspore/python/mindspore/mindrecord/filereader.py +++ b/mindspore/python/mindspore/mindrecord/filereader.py @@ -100,3 +100,21 @@ class FileReader: def close(self): """Stop reader worker and close file.""" self._reader.close() + + def schema(self): + """ + Get the schema of the MindRecord. + + Returns: + dict, the schema info. + """ + return self._header.schema + + def len(self): + """ + Get the number of the samples in MindRecord. + + Returns: + int, the number of the samples in MindRecord. + """ + return self._reader.len() diff --git a/mindspore/python/mindspore/mindrecord/shardreader.py b/mindspore/python/mindspore/mindrecord/shardreader.py index ddf997bfecb..ba9100975d9 100644 --- a/mindspore/python/mindspore/mindrecord/shardreader.py +++ b/mindspore/python/mindspore/mindrecord/shardreader.py @@ -106,3 +106,12 @@ class ShardReader: def close(self): """close MindRecord File.""" self._reader.close() + + def len(self): + """ + Get the number of the samples in MindRecord. + + Returns: + int, the number of the samples in MindRecord. + """ + return self._reader.len() diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 93278d1eb6d..74a4e84c661 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -1270,3 +1270,83 @@ def test_cv_file_overwrite_exception_02(): writer.write_raw_data(data) assert 'Invalid file, mindrecord files already exist. Please check file path:' in str(err.value) remove_multi_files(mindrecord_file_name, FILES_NUM) + + +def test_file_writer_schema_len(file_name=None, remove_file=True): + """ + Feature: FileWriter + Description: writer for schema and len + Expectation: SUCCESS + """ + if not file_name: + file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + + ## single mindrecord file + # 1. empty file + remove_one_file(file_name) + remove_one_file(file_name + ".db") + writer = FileWriter(file_name, 1) + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.commit() + + # get the schema & len + reader = FileReader(file_name) + assert cv_schema_json == reader.schema() + assert reader.len() == 0 + if remove_file: + remove_one_file(file_name) + remove_one_file(file_name + ".db") + + # 2. with 10 samples + remove_one_file(file_name) + remove_one_file(file_name + ".db") + writer = FileWriter(file_name, 1) + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + data = get_data("../data/mindrecord/testImageNetData/") + writer.add_schema(cv_schema_json, "img_schema") + writer.write_raw_data(data) + writer.commit() + + # get the schema & len + reader = FileReader(file_name) + assert cv_schema_json == reader.schema() + assert reader.len() == 10 + if remove_file: + remove_one_file(file_name) + remove_one_file(file_name + ".db") + + ## multi mindrecord file + # 1. empty file + remove_multi_files(file_name, FILES_NUM) + writer = FileWriter(file_name, FILES_NUM) + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.commit() + + # get the schema & len + reader = FileReader(file_name + "0") + assert cv_schema_json == reader.schema() + assert reader.len() == 0 + if remove_file: + remove_multi_files(file_name, FILES_NUM) + + # 2. with samples + remove_multi_files(file_name, FILES_NUM) + writer = FileWriter(file_name, FILES_NUM) + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + data = get_data("../data/mindrecord/testImageNetData/") + writer.add_schema(cv_schema_json, "img_schema") + writer.write_raw_data(data) + writer.commit() + + # get the schema & len + reader = FileReader(file_name + "0") + assert cv_schema_json == reader.schema() + assert reader.len() == 10 + if remove_file: + remove_multi_files(file_name, FILES_NUM)