forked from mindspore-Ecosystem/mindspore
fix field_name probelem from tfrecord to mindrecord
This commit is contained in:
parent
b5d8dad47d
commit
f521532a06
|
@ -385,8 +385,13 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
std::unordered_map<std::string, int32_t> column_name_id_map =
|
std::unordered_map<std::string, int32_t> column_name_id_map;
|
||||||
iterator_->GetColumnNameMap(); // map of column name, id
|
for (auto el : iterator_->GetColumnNameMap()) {
|
||||||
|
std::string column_name = el.first;
|
||||||
|
std::transform(column_name.begin(), column_name.end(), column_name.begin(),
|
||||||
|
[](unsigned char c) { return ispunct(c) ? '_' : c; });
|
||||||
|
column_name_id_map[column_name] = el.second;
|
||||||
|
}
|
||||||
bool first_loop = true; // build schema in first loop
|
bool first_loop = true; // build schema in first loop
|
||||||
do {
|
do {
|
||||||
json row_raw_data;
|
json row_raw_data;
|
||||||
|
@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
|
||||||
std::vector<std::string> index_fields;
|
std::vector<std::string> index_fields;
|
||||||
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
|
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
|
||||||
RETURN_IF_NOT_OK(s);
|
RETURN_IF_NOT_OK(s);
|
||||||
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id);
|
if (mindrecord::SUCCESS !=
|
||||||
|
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
|
||||||
|
}
|
||||||
mr_writer->SetShardHeader(mr_header);
|
mr_writer->SetShardHeader(mr_header);
|
||||||
first_loop = false;
|
first_loop = false;
|
||||||
}
|
}
|
||||||
|
@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
|
||||||
}
|
}
|
||||||
} while (!row.empty());
|
} while (!row.empty());
|
||||||
mr_writer->Commit();
|
mr_writer->Commit();
|
||||||
mindrecord::ShardIndexGenerator::finalize(file_names);
|
if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -16,6 +16,7 @@
|
||||||
This is the test module for saveOp.
|
This is the test module for saveOp.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from string import punctuation
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.mindrecord import FileWriter
|
from mindspore.mindrecord import FileWriter
|
||||||
|
@ -24,7 +25,7 @@ import pytest
|
||||||
|
|
||||||
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
|
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
|
||||||
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
|
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
|
||||||
|
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
|
||||||
FILES_NUM = 1
|
FILES_NUM = 1
|
||||||
num_readers = 1
|
num_readers = 1
|
||||||
|
|
||||||
|
@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):
|
||||||
|
|
||||||
with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
|
with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
|
||||||
d1.save(CV_FILE_NAME2, 1, "tfrecord")
|
d1.save(CV_FILE_NAME2, 1, "tfrecord")
|
||||||
|
|
||||||
|
|
||||||
|
def cast_name(key):
|
||||||
|
"""
|
||||||
|
Cast schema names which containing special characters to valid names.
|
||||||
|
"""
|
||||||
|
special_symbols = set('{}{}'.format(punctuation, ' '))
|
||||||
|
special_symbols.remove('_')
|
||||||
|
new_key = ['_' if x in special_symbols else x for x in key]
|
||||||
|
casted_key = ''.join(new_key)
|
||||||
|
return casted_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_07():
|
||||||
|
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
||||||
|
os.remove("{}".format(CV_FILE_NAME2))
|
||||||
|
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
||||||
|
os.remove("{}.db".format(CV_FILE_NAME2))
|
||||||
|
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
|
||||||
|
tf_data = []
|
||||||
|
for x in d1.create_dict_iterator():
|
||||||
|
tf_data.append(x)
|
||||||
|
d1.save(CV_FILE_NAME2, FILES_NUM)
|
||||||
|
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
|
||||||
|
num_parallel_workers=num_readers,
|
||||||
|
shuffle=False)
|
||||||
|
mr_data = []
|
||||||
|
for x in d2.create_dict_iterator():
|
||||||
|
mr_data.append(x)
|
||||||
|
count = 0
|
||||||
|
for x in tf_data:
|
||||||
|
for k, v in x.items():
|
||||||
|
if isinstance(v, np.ndarray):
|
||||||
|
assert (v == mr_data[count][cast_name(k)]).all()
|
||||||
|
else:
|
||||||
|
assert v == mr_data[count][cast_name(k)]
|
||||||
|
count += 1
|
||||||
|
assert count == 10
|
||||||
|
|
||||||
|
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
||||||
|
os.remove("{}".format(CV_FILE_NAME2))
|
||||||
|
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
||||||
|
os.remove("{}.db".format(CV_FILE_NAME2))
|
||||||
|
|
Loading…
Reference in New Issue