From f521532a062fb1207c40733ce627a37d1ad8fffe Mon Sep 17 00:00:00 2001 From: liyong Date: Mon, 20 Jul 2020 15:09:47 +0800 Subject: [PATCH] fix field_name probelem from tfrecord to mindrecord --- .../ccsrc/minddata/dataset/api/de_pipeline.cc | 20 ++++++-- .../testTFRecordData/dummy.tfrecord | Bin 0 -> 820 bytes tests/ut/python/dataset/test_save_op.py | 46 +++++++++++++++++- 3 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc index b31fdcf63b8..4302e12954d 100644 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc @@ -385,9 +385,14 @@ Status DEPipeline::SaveDataset(const std::vector &file_names, const } TensorRow row; - std::unordered_map column_name_id_map = - iterator_->GetColumnNameMap(); // map of column name, id - bool first_loop = true; // build schema in first loop + std::unordered_map column_name_id_map; + for (auto el : iterator_->GetColumnNameMap()) { + std::string column_name = el.first; + std::transform(column_name.begin(), column_name.end(), column_name.begin(), + [](unsigned char c) { return ispunct(c) ? '_' : c; }); + column_name_id_map[column_name] = el.second; + } + bool first_loop = true; // build schema in first loop do { json row_raw_data; std::map>> row_bin_data; @@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector &file_names, const std::vector index_fields; s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields); RETURN_IF_NOT_OK(s); - mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id); + if (mindrecord::SUCCESS != + mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) { + RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader."); + } mr_writer->SetShardHeader(mr_header); first_loop = false; } @@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector &file_names, const } } while (!row.empty()); mr_writer->Commit(); - mindrecord::ShardIndexGenerator::finalize(file_names); + if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) { + RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator."); + } return Status::OK(); } diff --git a/tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord b/tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord new file mode 100644 index 0000000000000000000000000000000000000000..da4f853e2d7c948921750cdf31539bed30543566 GIT binary patch literal 820 zcmZ=_fPl)+d8@b_xRkhfGjkKuQ}t8xlJiqiQ-lP$__=r!fglM8l7Ya0ONvVnq9!>f zvA9@2Cow5CM~GF5nTwI(f}!>`CsIu`#A703fknC-$tD`%F_B5odh=G2O*F=1BD0~} ze^ruAG{IvcOHi1!D9I+8;xUo+lUH~v$tIfNF_CR@M)MbvO*F@2BKzjV^C={oXo1H> Xjw>AAA!sJzPQ9E&q<+izO&fv%6|3Gq literal 0 HcmV?d00001 diff --git a/tests/ut/python/dataset/test_save_op.py b/tests/ut/python/dataset/test_save_op.py index 2ed326276b3..2af14aec1ce 100644 --- a/tests/ut/python/dataset/test_save_op.py +++ b/tests/ut/python/dataset/test_save_op.py @@ -16,6 +16,7 @@ This is the test module for saveOp. """ import os +from string import punctuation import mindspore.dataset as ds from mindspore import log as logger from mindspore.mindrecord import FileWriter @@ -24,7 +25,7 @@ import pytest CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" - +TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord" FILES_NUM = 1 num_readers = 1 @@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file): with pytest.raises(Exception, match="tfrecord dataset format is not supported."): d1.save(CV_FILE_NAME2, 1, "tfrecord") + + +def cast_name(key): + """ + Cast schema names which containing special characters to valid names. + """ + special_symbols = set('{}{}'.format(punctuation, ' ')) + special_symbols.remove('_') + new_key = ['_' if x in special_symbols else x for x in key] + casted_key = ''.join(new_key) + return casted_key + + +def test_case_07(): + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False) + tf_data = [] + for x in d1.create_dict_iterator(): + tf_data.append(x) + d1.save(CV_FILE_NAME2, FILES_NUM) + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + mr_data = [] + for x in d2.create_dict_iterator(): + mr_data.append(x) + count = 0 + for x in tf_data: + for k, v in x.items(): + if isinstance(v, np.ndarray): + assert (v == mr_data[count][cast_name(k)]).all() + else: + assert v == mr_data[count][cast_name(k)] + count += 1 + assert count == 10 + + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2))