From 447420eb9a650219e993d634661b57aa368b67f3 Mon Sep 17 00:00:00 2001 From: liyong Date: Fri, 22 May 2020 11:27:56 +0800 Subject: [PATCH] fix bug when append data by absolute path --- mindspore/ccsrc/mindrecord/io/shard_writer.cc | 2 +- .../python/mindrecord/test_mindrecord_base.py | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 0b0acf52d7c..9756b475e5e 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -201,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (ret == FAILED) { return FAILED; } - ret = Open(json_header["shard_addresses"], true); + ret = Open(real_addresses, true); if (ret == FAILED) { MS_LOG(ERROR) << "Open file failed"; return FAILED; diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 65f4b9a305d..4c3e062e8dd 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -25,6 +25,7 @@ FILES_NUM = 4 CV_FILE_NAME = "./imagenet.mindrecord" CV2_FILE_NAME = "./imagenet_loop.mindrecord" CV3_FILE_NAME = "./imagenet_append.mindrecord" +CV4_FILE_NAME = "/tmp/imagenet_append.mindrecord" NLP_FILE_NAME = "./aclImdb.mindrecord" @@ -170,6 +171,35 @@ def test_cv_file_append_writer(): os.remove("{}.db".format(x)) +def test_cv_file_append_writer_absolute_path(): + """tutorial for cv dataset append writer.""" + writer = FileWriter(CV4_FILE_NAME, 4) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data[0:5]) + writer.commit() + write_append = FileWriter.open_for_append(CV4_FILE_NAME + "0") + write_append.write_raw_data(data[5:10]) + write_append.commit() + reader = FileReader(CV4_FILE_NAME + "0") + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 3 + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 10 + reader.close() + + paths = ["{}{}".format(CV4_FILE_NAME, str(x).rjust(1, '0')) + for x in range(4)] + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + + def test_cv_file_writer_loop_and_read(): """tutorial for cv dataset loop writer.""" writer = FileWriter(CV2_FILE_NAME, FILES_NUM)