fix: read multi mindrecord files which contains one empty

This commit is contained in:
jonyguo 2022-07-15 18:55:16 +08:00
parent b971362607
commit 4d1db74ff3
2 changed files with 97 additions and 0 deletions

View File

@ -356,6 +356,11 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
// return -1 when page's size equals to 0.
auto last_page_id = shard_header_->GetLastPageId(shard_id);
if (static_cast<int>(last_page_id) == -1) {
// Empty mindrecord file which does not contain any samples
MS_LOG(WARNING) << "The mindrecord file: " << file_paths_[shard_id]
<< " does not contain any samples, pls remove it.";
row_group_summary.emplace_back(shard_id, 0, 0, 0);
shard_sample_count_.push_back(total_count);
continue;
}
for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {

View File

@ -110,6 +110,53 @@ def add_and_remove_nlp_file():
os.remove("{}.db".format(x))
@pytest.fixture
def add_and_remove_three_nlp_file_with_sample_10_0_5():
"""add/remove nlp file"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
paths = [file_name + "_with_sample_10",
file_name + "_with_sample_0",
file_name + "_with_sample_5"]
try:
def create_mindrecord_file(file_name, num_samples):
if os.path.exists(file_name):
os.remove("{}".format(file_name))
writer = FileWriter(file_name)
if num_samples > 0:
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, num_samples)]
nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
"rating": {"type": "float32"},
"input_ids": {"type": "int64",
"shape": [-1]},
"input_mask": {"type": "int64",
"shape": [1, -1]},
"segment_ids": {"type": "int64",
"shape": [2, -1]}
}
writer.set_header_size(1 << 14)
writer.set_page_size(1 << 15)
writer.add_schema(nlp_schema_json, "nlp_schema")
writer.add_index(["id", "rating"])
if num_samples > 0:
writer.write_raw_data(data)
writer.commit()
create_mindrecord_file(paths[0], 10)
create_mindrecord_file(paths[1], 0)
create_mindrecord_file(paths[2], 5)
yield "yield_nlp_data"
except Exception as error:
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
raise error
else:
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
@pytest.fixture
def add_and_remove_nlp_compress_file():
"""add/remove nlp file"""
@ -2884,6 +2931,51 @@ def test_for_loop_dataset_iterator(add_and_remove_nlp_compress_file):
assert (next(dataset_iter3)["array_a"] == data[4]["array_a"]).all()
assert (next(dataset_iter3)["array_a"] == data[5]["array_a"]).all()
def test_minddataset_multi_files_with_empty_one(add_and_remove_three_nlp_file_with_sample_10_0_5):
"""
Feature: MindDataset
Description: Test for with multi mindrecord files but there is one empty
Expectation: Output is equal to the expected output
"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
paths = [file_name + "_with_sample_10",
file_name + "_with_sample_0",
file_name + "_with_sample_5"]
epoch_size = 3
# with distribute
def read_multi_mindrecord_files_with_distribute(file_names):
data_set = ds.MindDataset(dataset_files=file_names,
num_shards=4,
shard_id=1)
assert data_set.get_dataset_size() == 4
dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
for epoch in range(epoch_size): # 3 epoch
num_iter = 0
for item in dataset_iter:
num_iter += 1
assert num_iter == 4
read_multi_mindrecord_files_with_distribute([paths[0], paths[1], paths[2]])
read_multi_mindrecord_files_with_distribute([paths[1], paths[0], paths[2]])
read_multi_mindrecord_files_with_distribute([paths[0], paths[2], paths[1]])
# with non-distribute
def read_multi_mindrecord_files(file_names):
data_set = ds.MindDataset(dataset_files=file_names)
assert data_set.get_dataset_size() == 15
dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
for epoch in range(epoch_size): # 3 epoch
num_iter = 0
for item in dataset_iter:
num_iter += 1
assert num_iter == 15
read_multi_mindrecord_files([paths[0], paths[1], paths[2]])
read_multi_mindrecord_files([paths[1], paths[0], paths[2]])
read_multi_mindrecord_files([paths[0], paths[2], paths[1]])
if __name__ == '__main__':
test_nlp_compress_data(add_and_remove_nlp_compress_file)
test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file)