forked from mindspore-Ecosystem/mindspore
fix field_name probelem from tfrecord to mindrecord
This commit is contained in:
parent
b5d8dad47d
commit
f521532a06
|
@ -385,9 +385,14 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
|
|||
}
|
||||
|
||||
TensorRow row;
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map =
|
||||
iterator_->GetColumnNameMap(); // map of column name, id
|
||||
bool first_loop = true; // build schema in first loop
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map;
|
||||
for (auto el : iterator_->GetColumnNameMap()) {
|
||||
std::string column_name = el.first;
|
||||
std::transform(column_name.begin(), column_name.end(), column_name.begin(),
|
||||
[](unsigned char c) { return ispunct(c) ? '_' : c; });
|
||||
column_name_id_map[column_name] = el.second;
|
||||
}
|
||||
bool first_loop = true; // build schema in first loop
|
||||
do {
|
||||
json row_raw_data;
|
||||
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
|
||||
|
@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
|
|||
std::vector<std::string> index_fields;
|
||||
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id);
|
||||
if (mindrecord::SUCCESS !=
|
||||
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
|
||||
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
|
||||
}
|
||||
mr_writer->SetShardHeader(mr_header);
|
||||
first_loop = false;
|
||||
}
|
||||
|
@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
|
|||
}
|
||||
} while (!row.empty());
|
||||
mr_writer->Commit();
|
||||
mindrecord::ShardIndexGenerator::finalize(file_names);
|
||||
if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) {
|
||||
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
Binary file not shown.
|
@ -16,6 +16,7 @@
|
|||
This is the test module for saveOp.
|
||||
"""
|
||||
import os
|
||||
from string import punctuation
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
@ -24,7 +25,7 @@ import pytest
|
|||
|
||||
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
|
||||
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
|
||||
|
||||
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
|
||||
FILES_NUM = 1
|
||||
num_readers = 1
|
||||
|
||||
|
@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):
|
|||
|
||||
with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
|
||||
d1.save(CV_FILE_NAME2, 1, "tfrecord")
|
||||
|
||||
|
||||
def cast_name(key):
|
||||
"""
|
||||
Cast schema names which containing special characters to valid names.
|
||||
"""
|
||||
special_symbols = set('{}{}'.format(punctuation, ' '))
|
||||
special_symbols.remove('_')
|
||||
new_key = ['_' if x in special_symbols else x for x in key]
|
||||
casted_key = ''.join(new_key)
|
||||
return casted_key
|
||||
|
||||
|
||||
def test_case_07():
|
||||
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
||||
os.remove("{}".format(CV_FILE_NAME2))
|
||||
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME2))
|
||||
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
|
||||
tf_data = []
|
||||
for x in d1.create_dict_iterator():
|
||||
tf_data.append(x)
|
||||
d1.save(CV_FILE_NAME2, FILES_NUM)
|
||||
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
mr_data = []
|
||||
for x in d2.create_dict_iterator():
|
||||
mr_data.append(x)
|
||||
count = 0
|
||||
for x in tf_data:
|
||||
for k, v in x.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
assert (v == mr_data[count][cast_name(k)]).all()
|
||||
else:
|
||||
assert v == mr_data[count][cast_name(k)]
|
||||
count += 1
|
||||
assert count == 10
|
||||
|
||||
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
||||
os.remove("{}".format(CV_FILE_NAME2))
|
||||
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME2))
|
||||
|
|
Loading…
Reference in New Issue