From 5005b434e665f35ff0c483b66e2fd3d9e6817326 Mon Sep 17 00:00:00 2001 From: Mahdi Date: Fri, 30 Oct 2020 17:42:45 -0400 Subject: [PATCH] Fixed repeat issue with GetDatasetSize --- .../ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc | 2 +- .../minddata/dataset/engine/datasetops/source/manifest_op.cc | 5 ++++- tests/ut/cpp/dataset/c_api_dataset_album_test.cc | 4 +--- tests/ut/cpp/dataset/c_api_dataset_voc_test.cc | 3 +++ 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index 25830a6dfb1..6afa4eda4da 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -194,7 +194,7 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) { // Get Dataset size Status RepeatOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0 || num_repeats_ == -1) { + if (dataset_size_ > 0) { *dataset_size = dataset_size_; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index e0d08f93bb9..48344f410ed 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -481,10 +481,13 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) { *num_classes = num_classes_; return Status::OK(); } + int64_t classes_count; std::shared_ptr op; RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); RETURN_IF_NOT_OK(op->ParseManifestFile()); - *num_classes = num_classes_; + classes_count = static_cast(op->label_index_.size()); + *num_classes = classes_count; + num_classes_ = classes_count; return Status::OK(); } diff --git a/tests/ut/cpp/dataset/c_api_dataset_album_test.cc b/tests/ut/cpp/dataset/c_api_dataset_album_test.cc index bceec6e40f2..068ba02e1a3 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_album_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_album_test.cc @@ -79,8 +79,6 @@ TEST_F(MindDataTestPipeline, TestAlbumgetters) { std::shared_ptr ds = Album(folder_path, schema_file, column_names); EXPECT_NE(ds, nullptr); - int64_t dataset_size = ds->GetDatasetSize(); - EXPECT_EQ(dataset_size, 7); int64_t num_classes = ds->GetNumClasses(); EXPECT_EQ(num_classes, -1); int64_t batch_size = ds->GetBatchSize(); @@ -114,7 +112,7 @@ TEST_F(MindDataTestPipeline, TestAlbumDecode) { auto shape = image->shape(); MS_LOG(INFO) << "Tensor image shape size: " << shape.Size(); MS_LOG(INFO) << "Tensor image shape: " << image->shape(); - EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect + EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect iter->GetNextRow(&row); } diff --git a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc index b2b89436e75..191447527d1 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc @@ -99,6 +99,9 @@ TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) { std::shared_ptr ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6)); EXPECT_NE(ds, nullptr); + ds = ds->Batch(2); + ds = ds->Repeat(2); + EXPECT_EQ(ds->GetDatasetSize(), 6); }