diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index 385aa42ebee..e6009d73881 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -642,7 +642,7 @@ Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, std::unique_p const google::protobuf::Map &feature_map = example_features.feature(); auto iter_column = feature_map.find(current_col.name()); if (iter_column == feature_map.end()) { - RETURN_STATUS_UNEXPECTED("Invalid parameter, column name: " + current_col.name() + "does not exist."); + RETURN_STATUS_UNEXPECTED("Invalid parameter, column name: " + current_col.name() + " does not exist."); } const dataengine::Feature &column_values_list = iter_column->second; RETURN_IF_NOT_OK(LoadFeature(tensor_table, column_values_list, current_col, row, col)); diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 5a5eb9305dc..27ec919c82f 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -484,6 +484,14 @@ std::shared_ptr TFRecord(const std::vector &dataset_f MS_LOG(ERROR) << "TFRecordNode: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards; return nullptr; } + + if (cache == nullptr && !shard_equal_rows && dataset_files.size() < num_shards) { + // This check only makes sense in a non-cache path. We should make sure there is at least one file per + // shard in file-based sharding + MS_LOG(ERROR) << "TFRecordNode: Invalid number of dataset files, should at least be " << std::to_string(num_shards); + return nullptr; + } + std::shared_ptr ds = nullptr; if constexpr (std::is_same::value || std::is_same>::value) { std::shared_ptr schema_obj = schema; diff --git a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc index 339dc5d60a1..04b7c897167 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc @@ -398,12 +398,14 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) { // Create a TFRecord Dataset // Each file has two columns("image", "label") and 3 rows - std::vector files = {datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data", - datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data", - datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"}; - std::shared_ptr ds1 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, true); + std::vector files = { + datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data", + datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data", + datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data" + }; + std::shared_ptr ds1 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, true); EXPECT_NE(ds1, nullptr); - std::shared_ptr ds2 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, false); + std::shared_ptr ds2 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, false); EXPECT_NE(ds2, nullptr); // Create an iterator over the result of the above dataset @@ -465,6 +467,12 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) { // This case expected to fail because shard_id is out_of_bound. std::shared_ptr ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3); EXPECT_EQ(ds7, nullptr); + + // This case expected to fail because the provided number of files < num_shards in file-based sharding. + std::string file_path1 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; + std::string file_path2 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data"; + std::shared_ptr ds8 = TFRecord({file_path1, file_path2}, "", {}, 0, ShuffleMode::kFalse, 3); + EXPECT_EQ(ds8, nullptr); } TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {