diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index f41831ed206..c0e0519cedc 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -200,12 +200,10 @@ int64_t Dataset::GetDatasetSize() { MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; return -1; } - if (!tree_getters_->isInitialized()) { - rc = tree_getters_->Init(this->IRNode()); - if (rc.IsError()) { - MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; - return -1; - } + rc = tree_getters_->Init(this->IRNode()); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; + return -1; } rc = tree_getters_->GetDatasetSize(&dataset_size); return rc.IsError() ? -1 : dataset_size; @@ -218,16 +216,12 @@ std::vector Dataset::GetOutputTypes() { rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; - types.clear(); return types; } - if (!tree_getters_->isInitialized()) { - rc = tree_getters_->Init(this->IRNode()); - if (rc.IsError()) { - MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed."; - types.clear(); - return types; - } + rc = tree_getters_->Init(this->IRNode()); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed."; + return types; } rc = tree_getters_->GetOutputTypes(&types); if (rc.IsError()) { @@ -245,16 +239,12 @@ std::vector Dataset::GetOutputShapes() { rc = runtime_context->Init(); if (rc.IsError()) { MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; - shapes.clear(); return shapes; } - if (!tree_getters_->isInitialized()) { - rc = tree_getters_->Init(this->IRNode()); - if (rc.IsError()) { - MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed."; - shapes.clear(); - return shapes; - } + rc = tree_getters_->Init(this->IRNode()); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed."; + return shapes; } rc = tree_getters_->GetOutputShapes(&shapes); if (rc.IsError()) { @@ -275,17 +265,39 @@ int64_t Dataset::GetNumClasses() { MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; return -1; } - if (!tree_getters_->isInitialized()) { - rc = tree_getters_->Init(ds->IRNode()); - if (rc.IsError()) { - MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed."; - return -1; - } + rc = tree_getters_->Init(ds->IRNode()); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed."; + return -1; } rc = tree_getters_->GetNumClasses(&num_classes); return rc.IsError() ? -1 : num_classes; } +std::vector>> Dataset::GetClassIndexing() { + std::vector>> output_class_indexing; + auto ds = shared_from_this(); + Status rc; + std::unique_ptr runtime_context = std::make_unique(); + rc = runtime_context->Init(); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed."; + return output_class_indexing; + } + rc = tree_getters_->Init(ds->IRNode()); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed."; + return output_class_indexing; + } + rc = tree_getters_->GetClassIndexing(&output_class_indexing); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed."; + output_class_indexing.clear(); + return output_class_indexing; + } + return output_class_indexing; +} + /// \brief Function to create a SchemaObj /// \param[in] schema_file Path of schema file /// \return Shared pointer to the current schema @@ -580,12 +592,10 @@ int64_t Dataset::GetBatchSize() { MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; return -1; } - if (!tree_getters_->isInitialized()) { - rc = tree_getters_->Init(ds->IRNode()); - if (rc.IsError()) { - MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; - return -1; - } + rc = tree_getters_->Init(ds->IRNode()); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; + return -1; } rc = tree_getters_->GetBatchSize(&batch_size); return rc.IsError() ? -1 : batch_size; @@ -601,22 +611,22 @@ int64_t Dataset::GetRepeatCount() { MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; return -1; } - if (!tree_getters_->isInitialized()) { - rc = tree_getters_->Init(ds->IRNode()); - if (rc.IsError()) { - MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; - return -1; - } + rc = tree_getters_->Init(ds->IRNode()); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; + return -1; } rc = tree_getters_->GetRepeatCount(&repeat_count); return rc.IsError() ? 0 : repeat_count; } + std::shared_ptr Dataset::SetNumWorkers(int32_t num_workers) { if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) { return nullptr; } return shared_from_this(); } + #ifndef ENABLE_ANDROID std::shared_ptr Dataset::BuildSentencePieceVocab( const std::vector &col_names, uint32_t vocab_size, float character_coverage, diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 6154692df40..58863a89c48 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -384,6 +384,9 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal } Status TreeGetters::Init(std::shared_ptr d) { + if (init_flag_) { + return Status::OK(); + } Status s = tree_adapter_->Compile(std::move(d), 1); if (!s.IsError()) { init_flag_ = true; @@ -463,6 +466,13 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) { return Status::OK(); } +Status TreeGetters::GetClassIndexing(std::vector>> *output_class_indexing) { + std::shared_ptr root = std::shared_ptr(tree_adapter_->GetRoot()); + CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); + RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing)); + return Status::OK(); +} + Status BuildVocabConsumer::Init(std::shared_ptr d) { return tree_adapter_->Compile(std::move(d), 1); } Status BuildVocabConsumer::Start() { diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index 07549cab82a..67d7dec5130 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -166,6 +166,7 @@ class TreeGetters : public TreeConsumer { Status GetBatchSize(int64_t *batch_size); Status GetRepeatCount(int64_t *repeat_count); Status GetNumClasses(int64_t *num_classes); + Status GetClassIndexing(std::vector>> *output_class_indexing); bool isInitialized(); std::string Name() override { return "TreeGetters"; } Status GetRow(TensorRow *r); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 951d28c4f40..905f9ead34e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -316,6 +316,14 @@ Status DatasetOp::GetNumClasses(int64_t *num_classes) { } } +Status DatasetOp::GetClassIndexing(std::vector>> *output_class_indexing) { + if (!child_.empty()) { + return child_[0]->GetClassIndexing(output_class_indexing); + } else { + RETURN_STATUS_UNEXPECTED("Can't get the class index for the current tree."); + } +} + // Performs handling for when an eoe message is received. // The base class implementation simply flows the eoe message to output. Derived classes // may override if they need to perform special eoe handling. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index b1e4ff8d4b2..d310b580218 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "minddata/dataset/callback/callback_manager.h" #include "minddata/dataset/core/constants.h" @@ -195,6 +196,10 @@ class DatasetOp : public std::enable_shared_from_this { /// \return Status - The status code return virtual Status GetNumClasses(int64_t *num_classes); + /// \brief Gets the class indexing + /// \return Status - The status code return + virtual Status GetClassIndexing(std::vector>> *output_class_indexing); + /// \brief Performs handling for when an eoe message is received. /// The base class implementation simply flows the eoe message to output. Derived classes /// may override if they need to perform special eoe handling. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index 915873cff0a..c9adcc79def 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -710,5 +710,30 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) { dataset_size_ = *dataset_size; return Status::OK(); } + +Status CocoOp::GetClassIndexing(std::vector>> *output_class_indexing) { + if ((*output_class_indexing).empty()) { + if ((task_type_ != TaskType::Detection) && (task_type_ != TaskType::Panoptic)) { + MS_LOG(ERROR) << "Class index only valid in \"Detection\" and \"Panoptic\" task."; + RETURN_STATUS_UNEXPECTED("GetClassIndexing: Get Class Index failed in CocoOp."); + } + std::shared_ptr op; + std::string task_type; + switch (task_type_) { + case TaskType::Detection: + task_type = "Detection"; + break; + case TaskType::Panoptic: + task_type = "Panoptic"; + break; + } + RETURN_IF_NOT_OK(Builder().SetDir(image_folder_path_).SetFile(annotation_path_).SetTask(task_type).Build(&op)); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + for (const auto label : op->label_index_) { + (*output_class_indexing).emplace_back(std::make_pair(label.first, label.second)); + } + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h index 5b7600764ee..bed009412e5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -218,6 +218,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp { /// \return Status of the function Status GetDatasetSize(int64_t *dataset_size) override; + /// \brief Gets the class indexing + /// \return Status - The status code return + Status GetClassIndexing(std::vector>> *output_class_indexing) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index fe1919fc8af..8e0450ea32b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -491,5 +491,25 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) { return Status::OK(); } +Status ManifestOp::GetClassIndexing(std::vector>> *output_class_indexing) { + if ((*output_class_indexing).empty()) { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); + RETURN_IF_NOT_OK(op->ParseManifestFile()); + RETURN_IF_NOT_OK(op->CountDatasetInfo()); + uint32_t count = 0; + for (const auto label : op->label_index_) { + if (!class_index_.empty()) { + (*output_class_indexing) + .emplace_back(std::make_pair(label.first, std::vector(1, class_index_[label.first]))); + } else { + (*output_class_indexing).emplace_back(std::make_pair(label.first, std::vector(1, count))); + } + count++; + } + } + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h index fe35caeff80..b762a586cc3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h @@ -193,6 +193,10 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { /// \return Status of the function Status GetNumClasses(int64_t *num_classes) override; + /// \brief Gets the class indexing + /// \return Status - The status code return + Status GetClassIndexing(std::vector>> *output_class_indexing) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index c9589e7380e..a68b4fc59d9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -542,5 +542,28 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) { dataset_size_ = *dataset_size; return Status::OK(); } + +Status VOCOp::GetClassIndexing(std::vector>> *output_class_indexing) { + if ((*output_class_indexing).empty()) { + if (task_type_ != TaskType::Detection) { + MS_LOG(ERROR) << "Class index only valid in \"Detection\" task."; + RETURN_STATUS_UNEXPECTED("GetClassIndexing: Get Class Index failed in VOCOp."); + } + std::shared_ptr op; + RETURN_IF_NOT_OK( + Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + for (const auto label : op->label_index_) { + if (!class_index_.empty()) { + (*output_class_indexing) + .emplace_back(std::make_pair(label.first, std::vector(1, class_index_[label.first]))); + } else { + (*output_class_indexing).emplace_back(std::make_pair(label.first, std::vector(1, label.second))); + } + } + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h index 35aef73df2c..03256e2c67c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h @@ -221,6 +221,10 @@ class VOCOp : public ParallelOp, public RandomAccessOp { /// \return Status of the function Status GetDatasetSize(int64_t *dataset_size) override; + // /// \brief Gets the class indexing + // /// \return Status - The status code return + Status GetClassIndexing(std::vector>> *output_class_indexing) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index db2a8e8cee1..ccdaf22f543 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -119,6 +119,10 @@ class Dataset : public std::enable_shared_from_this { /// \return number of classes. If failed, return -1 int64_t GetNumClasses(); + /// \brief Gets the class indexing + /// \return a map of ClassIndexing. If failed, return an empty map + std::vector>> GetClassIndexing(); + /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object diff --git a/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc b/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc index 48e3ee2e088..5d92d697bb6 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc @@ -266,6 +266,28 @@ TEST_F(MindDataTestPipeline, TestCocoPanoptic) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCocoPanopticGetClassIndex) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoPanopticGetClassIndex."; + // Create a Coco Dataset + std::string folder_path = datasets_root_path_ + "/testCOCO/train"; + std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/panoptic.json"; + + std::shared_ptr ds = Coco(folder_path, annotation_file, "Panoptic", false, SequentialSampler(0, 2)); + EXPECT_NE(ds, nullptr); + + std::vector>> class_index1 = ds->GetClassIndexing(); + EXPECT_EQ(class_index1.size(), 3); + EXPECT_EQ(class_index1[0].first, "person"); + EXPECT_EQ(class_index1[0].second[0], 1); + EXPECT_EQ(class_index1[0].second[1], 1); + EXPECT_EQ(class_index1[1].first, "bicycle"); + EXPECT_EQ(class_index1[1].second[0], 2); + EXPECT_EQ(class_index1[1].second[1], 1); + EXPECT_EQ(class_index1[2].first, "car"); + EXPECT_EQ(class_index1[2].second[0], 3); + EXPECT_EQ(class_index1[2].second[1], 1); +} + TEST_F(MindDataTestPipeline, TestCocoStuff) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoStuff."; // Create a Coco Dataset diff --git a/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc b/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc index 38f569b2e3c..26820d7e12e 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc @@ -70,6 +70,22 @@ TEST_F(MindDataTestPipeline, TestManifestGetters) { EXPECT_NE(ds2, nullptr); EXPECT_EQ(ds2->GetDatasetSize(), 4); EXPECT_EQ(ds2->GetNumClasses(), 3); + + std::vector>> class_index1 = ds1->GetClassIndexing(); + EXPECT_EQ(class_index1.size(), 2); + EXPECT_EQ(class_index1[0].first, "cat"); + EXPECT_EQ(class_index1[0].second[0], 0); + EXPECT_EQ(class_index1[1].first, "dog"); + EXPECT_EQ(class_index1[1].second[0], 1); + + std::vector>> class_index2 = ds2->GetClassIndexing(); + EXPECT_EQ(class_index2.size(), 3); + EXPECT_EQ(class_index2[0].first, "cat"); + EXPECT_EQ(class_index2[0].second[0], 0); + EXPECT_EQ(class_index2[1].first, "dog"); + EXPECT_EQ(class_index2[1].second[0], 1); + EXPECT_EQ(class_index2[2].first, "flower"); + EXPECT_EQ(class_index2[2].second[0], 2); } TEST_F(MindDataTestPipeline, TestManifestDecode) { @@ -151,6 +167,13 @@ TEST_F(MindDataTestPipeline, TestManifestClassIndex) { std::shared_ptr ds = Manifest(file_path, "train", RandomSampler(), map, true); EXPECT_NE(ds, nullptr); + std::vector>> class_index1 = ds->GetClassIndexing(); + EXPECT_EQ(class_index1.size(), 2); + EXPECT_EQ(class_index1[0].first, "cat"); + EXPECT_EQ(class_index1[0].second[0], 111); + EXPECT_EQ(class_index1[1].first, "dog"); + EXPECT_EQ(class_index1[1].second[0], 222); + // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); diff --git a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc index fe2e2749212..2a44d04bcc5 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc @@ -72,6 +72,28 @@ TEST_F(MindDataTestPipeline, TestVOCClassIndex) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestVOCGetClassIndex) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetClassIndex."; + // Create a VOC Dataset + std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; + std::map class_index; + class_index["car"] = 0; + class_index["cat"] = 1; + class_index["train"] = 9; + + std::shared_ptr ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6)); + EXPECT_NE(ds, nullptr); + + std::vector>> class_index1 = ds->GetClassIndexing(); + EXPECT_EQ(class_index1.size(), 3); + EXPECT_EQ(class_index1[0].first, "car"); + EXPECT_EQ(class_index1[0].second[0], 0); + EXPECT_EQ(class_index1[1].first, "cat"); + EXPECT_EQ(class_index1[1].second[0], 1); + EXPECT_EQ(class_index1[2].first, "train"); + EXPECT_EQ(class_index1[2].second[0], 9); +} + TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetDatasetSize.";