forked from mindspore-Ecosystem/mindspore
!7703 C++ API: Add a parameter validation and fix incorrect format of error message for TFrecord
Merge pull request !7703 from TinaMengtingZhang/fix_tfrecord_issue_I1YKU6_I1YKXL
This commit is contained in:
commit
fe2852df82
|
@ -484,6 +484,14 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f
|
||||||
MS_LOG(ERROR) << "TFRecordNode: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
|
MS_LOG(ERROR) << "TFRecordNode: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cache == nullptr && !shard_equal_rows && dataset_files.size() < num_shards) {
|
||||||
|
// This check only makes sense in a non-cache path. We should make sure there is at least one file per
|
||||||
|
// shard in file-based sharding
|
||||||
|
MS_LOG(ERROR) << "TFRecordNode: Invalid number of dataset files, should at least be " << std::to_string(num_shards);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<TFRecordNode> ds = nullptr;
|
std::shared_ptr<TFRecordNode> ds = nullptr;
|
||||||
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
|
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
|
||||||
std::shared_ptr<SchemaObj> schema_obj = schema;
|
std::shared_ptr<SchemaObj> schema_obj = schema;
|
||||||
|
|
|
@ -398,12 +398,14 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) {
|
||||||
|
|
||||||
// Create a TFRecord Dataset
|
// Create a TFRecord Dataset
|
||||||
// Each file has two columns("image", "label") and 3 rows
|
// Each file has two columns("image", "label") and 3 rows
|
||||||
std::vector<std::string> files = {datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
|
std::vector<std::string> files = {
|
||||||
|
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||||
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
|
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
|
||||||
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"};
|
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"
|
||||||
std::shared_ptr<Dataset> ds1 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, true);
|
};
|
||||||
|
std::shared_ptr<Dataset> ds1 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, true);
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
std::shared_ptr<Dataset> ds2 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, false);
|
std::shared_ptr<Dataset> ds2 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, false);
|
||||||
EXPECT_NE(ds2, nullptr);
|
EXPECT_NE(ds2, nullptr);
|
||||||
|
|
||||||
// Create an iterator over the result of the above dataset
|
// Create an iterator over the result of the above dataset
|
||||||
|
@ -465,6 +467,12 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) {
|
||||||
// This case expected to fail because shard_id is out_of_bound.
|
// This case expected to fail because shard_id is out_of_bound.
|
||||||
std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3);
|
std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3);
|
||||||
EXPECT_EQ(ds7, nullptr);
|
EXPECT_EQ(ds7, nullptr);
|
||||||
|
|
||||||
|
// This case expected to fail because the provided number of files < num_shards in file-based sharding.
|
||||||
|
std::string file_path1 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
|
||||||
|
std::string file_path2 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data";
|
||||||
|
std::shared_ptr<Dataset> ds8 = TFRecord({file_path1, file_path2}, "", {}, 0, ShuffleMode::kFalse, 3);
|
||||||
|
EXPECT_EQ(ds8, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {
|
TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {
|
||||||
|
|
Loading…
Reference in New Issue