forked from mindspore-Ecosystem/mindspore
add a validation code to fix issue I1YKU6 and corresponded ut, fix format issue I1YKXL
This commit is contained in:
parent
79b5fee04b
commit
e7346dd3a9
|
@ -642,7 +642,7 @@ Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, std::unique_p
|
|||
const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
|
||||
auto iter_column = feature_map.find(current_col.name());
|
||||
if (iter_column == feature_map.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid parameter, column name: " + current_col.name() + "does not exist.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid parameter, column name: " + current_col.name() + " does not exist.");
|
||||
}
|
||||
const dataengine::Feature &column_values_list = iter_column->second;
|
||||
RETURN_IF_NOT_OK(LoadFeature(tensor_table, column_values_list, current_col, row, col));
|
||||
|
|
|
@ -482,6 +482,14 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f
|
|||
MS_LOG(ERROR) << "TFRecordNode: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (cache == nullptr && !shard_equal_rows && dataset_files.size() < num_shards) {
|
||||
// This check only makes sense in a non-cache path. We should make sure there is at least one file per
|
||||
// shard in file-based sharding
|
||||
MS_LOG(ERROR) << "TFRecordNode: Invalid number of dataset files, should at least be " << std::to_string(num_shards);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<TFRecordNode> ds = nullptr;
|
||||
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
|
||||
std::shared_ptr<SchemaObj> schema_obj = schema;
|
||||
|
|
|
@ -368,12 +368,14 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) {
|
|||
|
||||
// Create a TFRecord Dataset
|
||||
// Each file has two columns("image", "label") and 3 rows
|
||||
std::vector<std::string> files = {datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
|
||||
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"};
|
||||
std::shared_ptr<Dataset> ds1 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, true);
|
||||
std::vector<std::string> files = {
|
||||
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
|
||||
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"
|
||||
};
|
||||
std::shared_ptr<Dataset> ds1 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, true);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
std::shared_ptr<Dataset> ds2 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, false);
|
||||
std::shared_ptr<Dataset> ds2 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, false);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
|
@ -435,6 +437,12 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) {
|
|||
// This case expected to fail because shard_id is out_of_bound.
|
||||
std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3);
|
||||
EXPECT_EQ(ds7, nullptr);
|
||||
|
||||
// This case expected to fail because the provided number of files < num_shards in file-based sharding.
|
||||
std::string file_path1 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
|
||||
std::string file_path2 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data";
|
||||
std::shared_ptr<Dataset> ds8 = TFRecord({file_path1, file_path2}, "", {}, 0, ShuffleMode::kFalse, 3);
|
||||
EXPECT_EQ(ds8, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {
|
||||
|
|
Loading…
Reference in New Issue