fix field_name probelem from tfrecord to mindrecord

This commit is contained in:
liyong 2020-07-20 15:09:47 +08:00
parent b5d8dad47d
commit f521532a06
3 changed files with 60 additions and 6 deletions

View File

@ -385,8 +385,13 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
} }
TensorRow row; TensorRow row;
std::unordered_map<std::string, int32_t> column_name_id_map = std::unordered_map<std::string, int32_t> column_name_id_map;
iterator_->GetColumnNameMap(); // map of column name, id for (auto el : iterator_->GetColumnNameMap()) {
std::string column_name = el.first;
std::transform(column_name.begin(), column_name.end(), column_name.begin(),
[](unsigned char c) { return ispunct(c) ? '_' : c; });
column_name_id_map[column_name] = el.second;
}
bool first_loop = true; // build schema in first loop bool first_loop = true; // build schema in first loop
do { do {
json row_raw_data; json row_raw_data;
@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
std::vector<std::string> index_fields; std::vector<std::string> index_fields;
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields); s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id); if (mindrecord::SUCCESS !=
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
}
mr_writer->SetShardHeader(mr_header); mr_writer->SetShardHeader(mr_header);
first_loop = false; first_loop = false;
} }
@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
} }
} while (!row.empty()); } while (!row.empty());
mr_writer->Commit(); mr_writer->Commit();
mindrecord::ShardIndexGenerator::finalize(file_names); if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
}
return Status::OK(); return Status::OK();
} }

View File

@ -16,6 +16,7 @@
This is the test module for saveOp. This is the test module for saveOp.
""" """
import os import os
from string import punctuation
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
@ -24,7 +25,7 @@ import pytest
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
FILES_NUM = 1 FILES_NUM = 1
num_readers = 1 num_readers = 1
@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):
with pytest.raises(Exception, match="tfrecord dataset format is not supported."): with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
d1.save(CV_FILE_NAME2, 1, "tfrecord") d1.save(CV_FILE_NAME2, 1, "tfrecord")
def cast_name(key):
"""
Cast schema names which containing special characters to valid names.
"""
special_symbols = set('{}{}'.format(punctuation, ' '))
special_symbols.remove('_')
new_key = ['_' if x in special_symbols else x for x in key]
casted_key = ''.join(new_key)
return casted_key
def test_case_07():
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
tf_data = []
for x in d1.create_dict_iterator():
tf_data.append(x)
d1.save(CV_FILE_NAME2, FILES_NUM)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
mr_data = []
for x in d2.create_dict_iterator():
mr_data.append(x)
count = 0
for x in tf_data:
for k, v in x.items():
if isinstance(v, np.ndarray):
assert (v == mr_data[count][cast_name(k)]).all()
else:
assert v == mr_data[count][cast_name(k)]
count += 1
assert count == 10
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))