!7703 C++ API: Add a parameter validation and fix incorrect format of error message for TFrecord

Merge pull request !7703 from TinaMengtingZhang/fix_tfrecord_issue_I1YKU6_I1YKXL
This commit is contained in:
mindspore-ci-bot 2020-10-24 07:33:26 +08:00 committed by Gitee
commit fe2852df82
3 changed files with 22 additions and 6 deletions

View File

@ -642,7 +642,7 @@ Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, std::unique_p
const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature(); const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
auto iter_column = feature_map.find(current_col.name()); auto iter_column = feature_map.find(current_col.name());
if (iter_column == feature_map.end()) { if (iter_column == feature_map.end()) {
RETURN_STATUS_UNEXPECTED("Invalid parameter, column name: " + current_col.name() + "does not exist."); RETURN_STATUS_UNEXPECTED("Invalid parameter, column name: " + current_col.name() + " does not exist.");
} }
const dataengine::Feature &column_values_list = iter_column->second; const dataengine::Feature &column_values_list = iter_column->second;
RETURN_IF_NOT_OK(LoadFeature(tensor_table, column_values_list, current_col, row, col)); RETURN_IF_NOT_OK(LoadFeature(tensor_table, column_values_list, current_col, row, col));

View File

@ -484,6 +484,14 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f
MS_LOG(ERROR) << "TFRecordNode: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards; MS_LOG(ERROR) << "TFRecordNode: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
return nullptr; return nullptr;
} }
if (cache == nullptr && !shard_equal_rows && dataset_files.size() < num_shards) {
// This check only makes sense in a non-cache path. We should make sure there is at least one file per
// shard in file-based sharding
MS_LOG(ERROR) << "TFRecordNode: Invalid number of dataset files, should at least be " << std::to_string(num_shards);
return nullptr;
}
std::shared_ptr<TFRecordNode> ds = nullptr; std::shared_ptr<TFRecordNode> ds = nullptr;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) { if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema; std::shared_ptr<SchemaObj> schema_obj = schema;

View File

@ -398,12 +398,14 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) {
// Create a TFRecord Dataset // Create a TFRecord Dataset
// Each file has two columns("image", "label") and 3 rows // Each file has two columns("image", "label") and 3 rows
std::vector<std::string> files = {datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data", std::vector<std::string> files = {
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data", datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"}; datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"
std::shared_ptr<Dataset> ds1 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, true); };
std::shared_ptr<Dataset> ds1 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, true);
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, false); std::shared_ptr<Dataset> ds2 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, false);
EXPECT_NE(ds2, nullptr); EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
@ -465,6 +467,12 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) {
// This case expected to fail because shard_id is out_of_bound. // This case expected to fail because shard_id is out_of_bound.
std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3); std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3);
EXPECT_EQ(ds7, nullptr); EXPECT_EQ(ds7, nullptr);
// This case expected to fail because the provided number of files < num_shards in file-based sharding.
std::string file_path1 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
std::string file_path2 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data";
std::shared_ptr<Dataset> ds8 = TFRecord({file_path1, file_path2}, "", {}, 0, ShuffleMode::kFalse, 3);
EXPECT_EQ(ds8, nullptr);
} }
TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) { TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {