fix: read multi mindrecord files which contains one empty
This commit is contained in:
parent
b971362607
commit
4d1db74ff3
|
@ -356,6 +356,11 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
|
|||
// return -1 when page's size equals to 0.
|
||||
auto last_page_id = shard_header_->GetLastPageId(shard_id);
|
||||
if (static_cast<int>(last_page_id) == -1) {
|
||||
// Empty mindrecord file which does not contain any samples
|
||||
MS_LOG(WARNING) << "The mindrecord file: " << file_paths_[shard_id]
|
||||
<< " does not contain any samples, pls remove it.";
|
||||
row_group_summary.emplace_back(shard_id, 0, 0, 0);
|
||||
shard_sample_count_.push_back(total_count);
|
||||
continue;
|
||||
}
|
||||
for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {
|
||||
|
|
|
@ -110,6 +110,53 @@ def add_and_remove_nlp_file():
|
|||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_and_remove_three_nlp_file_with_sample_10_0_5():
|
||||
"""add/remove nlp file"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
paths = [file_name + "_with_sample_10",
|
||||
file_name + "_with_sample_0",
|
||||
file_name + "_with_sample_5"]
|
||||
try:
|
||||
def create_mindrecord_file(file_name, num_samples):
|
||||
if os.path.exists(file_name):
|
||||
os.remove("{}".format(file_name))
|
||||
writer = FileWriter(file_name)
|
||||
if num_samples > 0:
|
||||
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, num_samples)]
|
||||
nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
|
||||
"rating": {"type": "float32"},
|
||||
"input_ids": {"type": "int64",
|
||||
"shape": [-1]},
|
||||
"input_mask": {"type": "int64",
|
||||
"shape": [1, -1]},
|
||||
"segment_ids": {"type": "int64",
|
||||
"shape": [2, -1]}
|
||||
}
|
||||
writer.set_header_size(1 << 14)
|
||||
writer.set_page_size(1 << 15)
|
||||
writer.add_schema(nlp_schema_json, "nlp_schema")
|
||||
writer.add_index(["id", "rating"])
|
||||
if num_samples > 0:
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
create_mindrecord_file(paths[0], 10)
|
||||
create_mindrecord_file(paths[1], 0)
|
||||
create_mindrecord_file(paths[2], 5)
|
||||
|
||||
yield "yield_nlp_data"
|
||||
except Exception as error:
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
raise error
|
||||
else:
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_and_remove_nlp_compress_file():
|
||||
"""add/remove nlp file"""
|
||||
|
@ -2884,6 +2931,51 @@ def test_for_loop_dataset_iterator(add_and_remove_nlp_compress_file):
|
|||
assert (next(dataset_iter3)["array_a"] == data[4]["array_a"]).all()
|
||||
assert (next(dataset_iter3)["array_a"] == data[5]["array_a"]).all()
|
||||
|
||||
|
||||
def test_minddataset_multi_files_with_empty_one(add_and_remove_three_nlp_file_with_sample_10_0_5):
|
||||
"""
|
||||
Feature: MindDataset
|
||||
Description: Test for with multi mindrecord files but there is one empty
|
||||
Expectation: Output is equal to the expected output
|
||||
"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
paths = [file_name + "_with_sample_10",
|
||||
file_name + "_with_sample_0",
|
||||
file_name + "_with_sample_5"]
|
||||
epoch_size = 3
|
||||
|
||||
# with distribute
|
||||
def read_multi_mindrecord_files_with_distribute(file_names):
|
||||
data_set = ds.MindDataset(dataset_files=file_names,
|
||||
num_shards=4,
|
||||
shard_id=1)
|
||||
assert data_set.get_dataset_size() == 4
|
||||
dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
|
||||
for epoch in range(epoch_size): # 3 epoch
|
||||
num_iter = 0
|
||||
for item in dataset_iter:
|
||||
num_iter += 1
|
||||
assert num_iter == 4
|
||||
|
||||
read_multi_mindrecord_files_with_distribute([paths[0], paths[1], paths[2]])
|
||||
read_multi_mindrecord_files_with_distribute([paths[1], paths[0], paths[2]])
|
||||
read_multi_mindrecord_files_with_distribute([paths[0], paths[2], paths[1]])
|
||||
|
||||
# with non-distribute
|
||||
def read_multi_mindrecord_files(file_names):
|
||||
data_set = ds.MindDataset(dataset_files=file_names)
|
||||
assert data_set.get_dataset_size() == 15
|
||||
dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
|
||||
for epoch in range(epoch_size): # 3 epoch
|
||||
num_iter = 0
|
||||
for item in dataset_iter:
|
||||
num_iter += 1
|
||||
assert num_iter == 15
|
||||
|
||||
read_multi_mindrecord_files([paths[0], paths[1], paths[2]])
|
||||
read_multi_mindrecord_files([paths[1], paths[0], paths[2]])
|
||||
read_multi_mindrecord_files([paths[0], paths[2], paths[1]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_nlp_compress_data(add_and_remove_nlp_compress_file)
|
||||
test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file)
|
||||
|
|
Loading…
Reference in New Issue