diff --git a/mindspore/python/mindspore/mindrecord/filewriter.py b/mindspore/python/mindspore/mindrecord/filewriter.py index ffc21878e0a..ec8b0761752 100644 --- a/mindspore/python/mindspore/mindrecord/filewriter.py +++ b/mindspore/python/mindspore/mindrecord/filewriter.py @@ -102,6 +102,7 @@ class FileWriter: self._overwrite = overwrite self._append = False + self._flush = False self._header = ShardHeader() self._writer = ShardWriter() self._generator = None @@ -316,6 +317,11 @@ class FileWriter: self._writer.set_shard_header(self._header) if not isinstance(raw_data, list): raise ParamTypeError('raw_data', 'list') + if self._flush and not self._append: + raise RuntimeError("Unexpected error. Not allow to call `write_raw_data` on flushed MindRecord files." \ + "When creating new Mindrecord files, please remove `commit` before `write_raw_data`." \ + "In other cases, when appending to existing MindRecord files, " \ + "please call `open_for_append` first and then `write_raw_data`.") for each_raw in raw_data: if not isinstance(each_raw, dict): raise ParamTypeError('raw_data item', 'dict') @@ -392,6 +398,7 @@ class FileWriter: MRMGenerateIndexError: If failed to write to database. MRMCommitError: If failed to flush data to disk. """ + self._flush = True if not self._writer.is_open: self._writer.open(self._paths, self._overwrite) # permit commit without data diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 5901a6f891a..e8c81ddbeae 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -1208,6 +1208,31 @@ def test_cv_file_overwrite_04(): remove_multi_files(mindrecord_file_name, FILES_NUM) + +def test_mindrecord_commit_exception_01(): + """ + Feature: commit excepiton + Description: write_raw_data after commit + Expectation: exception occur + """ + + mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + remove_multi_files(mindrecord_file_name, 4) + + with pytest.raises(RuntimeError) as err: + writer = FileWriter(mindrecord_file_name, 4) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.write_raw_data(data[0:5]) + writer.commit() + writer.write_raw_data(data[5:10]) + + assert 'Unexpected error. Not allow to call `write_raw_data` on flushed MindRecord files.' in str(err.value) + remove_multi_files(mindrecord_file_name, 4) + + def test_cv_file_overwrite_exception_01(): """ Feature: Overwriting in FileWriter