From 4d1db74ff3c398c3e159930779950be69e429e2b Mon Sep 17 00:00:00 2001 From: jonyguo Date: Fri, 15 Jul 2022 18:55:16 +0800 Subject: [PATCH] fix: read multi mindrecord files which contains one empty --- .../minddata/mindrecord/io/shard_reader.cc | 5 + tests/ut/python/dataset/test_minddataset.py | 92 +++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index 0d5da1aa672..598ea09e091 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -356,6 +356,11 @@ std::vector> ShardReader::ReadRowGroupSummar // return -1 when page's size equals to 0. auto last_page_id = shard_header_->GetLastPageId(shard_id); if (static_cast(last_page_id) == -1) { + // Empty mindrecord file which does not contain any samples + MS_LOG(WARNING) << "The mindrecord file: " << file_paths_[shard_id] + << " does not contain any samples, pls remove it."; + row_group_summary.emplace_back(shard_id, 0, 0, 0); + shard_sample_count_.push_back(total_count); continue; } for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index f2f80ddbce4..04f68d3cf3b 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -110,6 +110,53 @@ def add_and_remove_nlp_file(): os.remove("{}.db".format(x)) +@pytest.fixture +def add_and_remove_three_nlp_file_with_sample_10_0_5(): + """add/remove nlp file""" + file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + paths = [file_name + "_with_sample_10", + file_name + "_with_sample_0", + file_name + "_with_sample_5"] + try: + def create_mindrecord_file(file_name, num_samples): + if os.path.exists(file_name): + os.remove("{}".format(file_name)) + writer = FileWriter(file_name) + if num_samples > 0: + data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, num_samples)] + nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"}, + "rating": {"type": "float32"}, + "input_ids": {"type": "int64", + "shape": [-1]}, + "input_mask": {"type": "int64", + "shape": [1, -1]}, + "segment_ids": {"type": "int64", + "shape": [2, -1]} + } + writer.set_header_size(1 << 14) + writer.set_page_size(1 << 15) + writer.add_schema(nlp_schema_json, "nlp_schema") + writer.add_index(["id", "rating"]) + if num_samples > 0: + writer.write_raw_data(data) + writer.commit() + + create_mindrecord_file(paths[0], 10) + create_mindrecord_file(paths[1], 0) + create_mindrecord_file(paths[2], 5) + + yield "yield_nlp_data" + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + + @pytest.fixture def add_and_remove_nlp_compress_file(): """add/remove nlp file""" @@ -2884,6 +2931,51 @@ def test_for_loop_dataset_iterator(add_and_remove_nlp_compress_file): assert (next(dataset_iter3)["array_a"] == data[4]["array_a"]).all() assert (next(dataset_iter3)["array_a"] == data[5]["array_a"]).all() + +def test_minddataset_multi_files_with_empty_one(add_and_remove_three_nlp_file_with_sample_10_0_5): + """ + Feature: MindDataset + Description: Test for with multi mindrecord files but there is one empty + Expectation: Output is equal to the expected output + """ + file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + paths = [file_name + "_with_sample_10", + file_name + "_with_sample_0", + file_name + "_with_sample_5"] + epoch_size = 3 + + # with distribute + def read_multi_mindrecord_files_with_distribute(file_names): + data_set = ds.MindDataset(dataset_files=file_names, + num_shards=4, + shard_id=1) + assert data_set.get_dataset_size() == 4 + dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True) + for epoch in range(epoch_size): # 3 epoch + num_iter = 0 + for item in dataset_iter: + num_iter += 1 + assert num_iter == 4 + + read_multi_mindrecord_files_with_distribute([paths[0], paths[1], paths[2]]) + read_multi_mindrecord_files_with_distribute([paths[1], paths[0], paths[2]]) + read_multi_mindrecord_files_with_distribute([paths[0], paths[2], paths[1]]) + + # with non-distribute + def read_multi_mindrecord_files(file_names): + data_set = ds.MindDataset(dataset_files=file_names) + assert data_set.get_dataset_size() == 15 + dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True) + for epoch in range(epoch_size): # 3 epoch + num_iter = 0 + for item in dataset_iter: + num_iter += 1 + assert num_iter == 15 + + read_multi_mindrecord_files([paths[0], paths[1], paths[2]]) + read_multi_mindrecord_files([paths[1], paths[0], paths[2]]) + read_multi_mindrecord_files([paths[0], paths[2], paths[1]]) + if __name__ == '__main__': test_nlp_compress_data(add_and_remove_nlp_compress_file) test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file)