forked from mindspore-Ecosystem/mindspore
added checking of first row crc to find invalid tfrecord files
addressed code review comments. added check in python layer to exclude directories and to raise an error if a pattern does not match any file fixed clang format fixed cppcheck fixed cppcheck (used std::accumulate and std::copy_if). regenerated tfrecord file to contain correct header, it was a dummy header before fixed cppcheck: added const reference for string parameter for lambdas, fixed clang format: whitespace adjustments more clang whitespace fixes... changed print to logger.info
This commit is contained in:
parent
d8176a77f4
commit
9bc2134cb7
|
@ -42,6 +42,7 @@
|
|||
#include "dataset/util/status.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "dataset/util/wait_post.h"
|
||||
#include "utils/system/crc32c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder()
|
|||
builder_data_schema_ = std::make_unique<DataSchema>();
|
||||
}
|
||||
|
||||
bool ValidateFirstRowCrc(const std::string &filename) {
|
||||
std::ifstream reader;
|
||||
reader.open(filename);
|
||||
if (!reader) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// read data
|
||||
int64_t record_length = 0;
|
||||
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
|
||||
|
||||
// read crc from file
|
||||
uint32_t masked_crc = 0;
|
||||
(void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t)));
|
||||
|
||||
// generate crc from data
|
||||
uint32_t generated_crc =
|
||||
system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t));
|
||||
|
||||
return masked_crc == generated_crc;
|
||||
}
|
||||
|
||||
Status TFReaderOp::Builder::ValidateInputs() const {
|
||||
std::string err_msg;
|
||||
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : "";
|
||||
if (!builder_equal_rows_per_shard_) {
|
||||
err_msg += builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)
|
||||
? "No enough tf_file files provided\n"
|
||||
: "";
|
||||
|
||||
if (builder_num_workers_ <= 0) {
|
||||
err_msg += "Number of parallel workers is smaller or equal to 0\n";
|
||||
}
|
||||
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
|
||||
|
||||
if (!builder_equal_rows_per_shard_ &&
|
||||
builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) {
|
||||
err_msg += "Not enough tfrecord files provided\n";
|
||||
}
|
||||
|
||||
if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) {
|
||||
err_msg += "Wrong sharding configs\n";
|
||||
}
|
||||
|
||||
std::vector<std::string> invalid_files(builder_dataset_files_list_.size());
|
||||
auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(),
|
||||
[](const std::string &filename) { return !ValidateFirstRowCrc(filename); });
|
||||
invalid_files.resize(std::distance(invalid_files.begin(), it));
|
||||
|
||||
if (!invalid_files.empty()) {
|
||||
err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n";
|
||||
|
||||
std::string accumulated_filenames = std::accumulate(
|
||||
invalid_files.begin(), invalid_files.end(), std::string(""),
|
||||
[](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; });
|
||||
err_msg += accumulated_filenames;
|
||||
}
|
||||
|
||||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
|
@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
|
|||
RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read));
|
||||
rows_read++;
|
||||
}
|
||||
|
||||
// ignore crc footer
|
||||
(void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
|
||||
rows_total++;
|
||||
|
|
|
@ -900,13 +900,22 @@ class SourceDataset(Dataset):
|
|||
List, files.
|
||||
"""
|
||||
|
||||
def flat(lists):
|
||||
return list(np.array(lists).flatten())
|
||||
|
||||
if not isinstance(patterns, list):
|
||||
patterns = [patterns]
|
||||
|
||||
file_list = flat([glob.glob(file, recursive=True) for file in patterns])
|
||||
file_list = []
|
||||
unmatched_patterns = []
|
||||
for pattern in patterns:
|
||||
matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]
|
||||
|
||||
if matches:
|
||||
file_list.extend(matches)
|
||||
else:
|
||||
unmatched_patterns.append(pattern)
|
||||
|
||||
if unmatched_patterns:
|
||||
raise ValueError("The following patterns did not match any files: ", unmatched_patterns)
|
||||
|
||||
if file_list: # not empty
|
||||
return file_list
|
||||
raise ValueError("The list of path names matching the patterns is empty.")
|
||||
|
|
|
@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) {
|
|||
TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true);
|
||||
ASSERT_EQ(total_rows, 60);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) {
|
||||
// Start with an empty execution tree
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data";
|
||||
std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
|
||||
std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt";
|
||||
std::string nonexistent_file = "this/file/doesnt/exist";
|
||||
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder builder;
|
||||
builder.SetDatasetFilesList({invalid_file, valid_file, schema_file})
|
||||
.SetRowsPerBuffer(16)
|
||||
.SetNumWorkers(16);
|
||||
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
schema->LoadSchemaFile(schema_file, {});
|
||||
builder.SetDataSchema(std::move(schema));
|
||||
|
||||
Status rc = builder.Build(&my_tfreader_op);
|
||||
ASSERT_TRUE(!rc.IsOk());
|
||||
|
||||
builder.SetDatasetFilesList({invalid_file, valid_file, schema_file, nonexistent_file})
|
||||
.SetRowsPerBuffer(16)
|
||||
.SetNumWorkers(16);
|
||||
|
||||
schema = std::make_unique<DataSchema>();
|
||||
schema->LoadSchemaFile(schema_file, {});
|
||||
builder.SetDataSchema(std::move(schema));
|
||||
|
||||
rc = builder.Build(&my_tfreader_op);
|
||||
ASSERT_TRUE(!rc.IsOk());
|
||||
}
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1 @@
|
|||
this is just a text file, not a valid tfrecord file.
|
|
@ -32,7 +32,7 @@ def test_case_tf_shape():
|
|||
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||
ds1 = ds1.batch(2)
|
||||
for data in ds1.create_dict_iterator():
|
||||
print(data)
|
||||
logger.info(data)
|
||||
output_shape = ds1.output_shapes()
|
||||
assert (len(output_shape[-1]) == 1)
|
||||
|
||||
|
@ -203,6 +203,32 @@ def test_tf_record_schema_columns_list():
|
|||
a = row["col_sint32"]
|
||||
assert "col_sint32" in str(info.value)
|
||||
|
||||
def test_case_invalid_files():
|
||||
valid_file = "../data/dataset/testTFTestAllTypes/test.data"
|
||||
invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
|
||||
files = [invalid_file, valid_file, SCHEMA_FILE]
|
||||
|
||||
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
row = data.create_dict_iterator().get_next()
|
||||
assert "cannot be opened" in str(info.value)
|
||||
assert "not valid tfrecord files" in str(info.value)
|
||||
assert valid_file not in str(info.value)
|
||||
assert invalid_file in str(info.value)
|
||||
assert SCHEMA_FILE in str(info.value)
|
||||
|
||||
nonexistent_file = "this/file/does/not/exist"
|
||||
files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file]
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
assert "did not match any files" in str(info.value)
|
||||
assert valid_file not in str(info.value)
|
||||
assert invalid_file not in str(info.value)
|
||||
assert SCHEMA_FILE not in str(info.value)
|
||||
assert nonexistent_file in str(info.value)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_case_tf_shape()
|
||||
test_case_tf_file()
|
||||
|
@ -212,3 +238,4 @@ if __name__ == '__main__':
|
|||
test_tf_record_schema()
|
||||
test_tf_record_shuffle()
|
||||
test_tf_shard_equal_rows()
|
||||
test_case_invalid_files()
|
||||
|
|
Loading…
Reference in New Issue