!8387 Add support for GetClassIndexing in C++ API

From: @alex-yuyue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-11 03:02:22 +08:00 committed by Gitee
commit 70f5775711
15 changed files with 225 additions and 40 deletions

View File

@ -200,12 +200,10 @@ int64_t Dataset::GetDatasetSize() {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetDatasetSize(&dataset_size);
return rc.IsError() ? -1 : dataset_size;
@ -218,16 +216,12 @@ std::vector<DataType> Dataset::GetOutputTypes() {
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
types.clear();
return types;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
types.clear();
return types;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
return types;
}
rc = tree_getters_->GetOutputTypes(&types);
if (rc.IsError()) {
@ -245,16 +239,12 @@ std::vector<TensorShape> Dataset::GetOutputShapes() {
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
shapes.clear();
return shapes;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
shapes.clear();
return shapes;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
return shapes;
}
rc = tree_getters_->GetOutputShapes(&shapes);
if (rc.IsError()) {
@ -275,17 +265,39 @@ int64_t Dataset::GetNumClasses() {
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetNumClasses(&num_classes);
return rc.IsError() ? -1 : num_classes;
}
std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed.";
return output_class_indexing;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed.";
return output_class_indexing;
}
rc = tree_getters_->GetClassIndexing(&output_class_indexing);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed.";
output_class_indexing.clear();
return output_class_indexing;
}
return output_class_indexing;
}
/// \brief Function to create a SchemaObj
/// \param[in] schema_file Path of schema file
/// \return Shared pointer to the current schema
@ -580,12 +592,10 @@ int64_t Dataset::GetBatchSize() {
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetBatchSize(&batch_size);
return rc.IsError() ? -1 : batch_size;
@ -601,22 +611,22 @@ int64_t Dataset::GetRepeatCount() {
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetRepeatCount(&repeat_count);
return rc.IsError() ? 0 : repeat_count;
}
std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) {
return nullptr;
}
return shared_from_this();
}
#ifndef ENABLE_ANDROID
std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,

View File

@ -384,6 +384,9 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal
}
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
if (init_flag_) {
return Status::OK();
}
Status s = tree_adapter_->Compile(std::move(d), 1);
if (!s.IsError()) {
init_flag_ = true;
@ -463,6 +466,13 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) {
return Status::OK();
}
Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
return Status::OK();
}
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
Status BuildVocabConsumer::Start() {

View File

@ -166,6 +166,7 @@ class TreeGetters : public TreeConsumer {
Status GetBatchSize(int64_t *batch_size);
Status GetRepeatCount(int64_t *repeat_count);
Status GetNumClasses(int64_t *num_classes);
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
bool isInitialized();
std::string Name() override { return "TreeGetters"; }
Status GetRow(TensorRow *r);

View File

@ -316,6 +316,14 @@ Status DatasetOp::GetNumClasses(int64_t *num_classes) {
}
}
Status DatasetOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
if (!child_.empty()) {
return child_[0]->GetClassIndexing(output_class_indexing);
} else {
RETURN_STATUS_UNEXPECTED("Can't get the class index for the current tree.");
}
}
// Performs handling for when an eoe message is received.
// The base class implementation simply flows the eoe message to output. Derived classes
// may override if they need to perform special eoe handling.

View File

@ -21,6 +21,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include "minddata/dataset/callback/callback_manager.h"
#include "minddata/dataset/core/constants.h"
@ -195,6 +196,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The status code return
virtual Status GetNumClasses(int64_t *num_classes);
/// \brief Gets the class indexing
/// \return Status - The status code return
virtual Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
/// \brief Performs handling for when an eoe message is received.
/// The base class implementation simply flows the eoe message to output. Derived classes
/// may override if they need to perform special eoe handling.

View File

@ -710,5 +710,30 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
dataset_size_ = *dataset_size;
return Status::OK();
}
Status CocoOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
if ((*output_class_indexing).empty()) {
if ((task_type_ != TaskType::Detection) && (task_type_ != TaskType::Panoptic)) {
MS_LOG(ERROR) << "Class index only valid in \"Detection\" and \"Panoptic\" task.";
RETURN_STATUS_UNEXPECTED("GetClassIndexing: Get Class Index failed in CocoOp.");
}
std::shared_ptr<CocoOp> op;
std::string task_type;
switch (task_type_) {
case TaskType::Detection:
task_type = "Detection";
break;
case TaskType::Panoptic:
task_type = "Panoptic";
break;
}
RETURN_IF_NOT_OK(Builder().SetDir(image_folder_path_).SetFile(annotation_path_).SetTask(task_type).Build(&op));
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
for (const auto label : op->label_index_) {
(*output_class_indexing).emplace_back(std::make_pair(label.first, label.second));
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -218,6 +218,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
/// \brief Gets the class indexing
/// \return Status - The status code return
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

View File

@ -491,5 +491,25 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) {
return Status::OK();
}
Status ManifestOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
if ((*output_class_indexing).empty()) {
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
RETURN_IF_NOT_OK(op->CountDatasetInfo());
uint32_t count = 0;
for (const auto label : op->label_index_) {
if (!class_index_.empty()) {
(*output_class_indexing)
.emplace_back(std::make_pair(label.first, std::vector<int32_t>(1, class_index_[label.first])));
} else {
(*output_class_indexing).emplace_back(std::make_pair(label.first, std::vector<int32_t>(1, count)));
}
count++;
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -193,6 +193,10 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
/// \return Status of the function
Status GetNumClasses(int64_t *num_classes) override;
/// \brief Gets the class indexing
/// \return Status - The status code return
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

View File

@ -542,5 +542,28 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
dataset_size_ = *dataset_size;
return Status::OK();
}
Status VOCOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
if ((*output_class_indexing).empty()) {
if (task_type_ != TaskType::Detection) {
MS_LOG(ERROR) << "Class index only valid in \"Detection\" task.";
RETURN_STATUS_UNEXPECTED("GetClassIndexing: Get Class Index failed in VOCOp.");
}
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
for (const auto label : op->label_index_) {
if (!class_index_.empty()) {
(*output_class_indexing)
.emplace_back(std::make_pair(label.first, std::vector<int32_t>(1, class_index_[label.first])));
} else {
(*output_class_indexing).emplace_back(std::make_pair(label.first, std::vector<int32_t>(1, label.second)));
}
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -221,6 +221,10 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
// /// \brief Gets the class indexing
// /// \return Status - The status code return
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

View File

@ -119,6 +119,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return number of classes. If failed, return -1
int64_t GetNumClasses();
/// \brief Gets the class indexing
/// \return a map of ClassIndexing. If failed, return an empty map
std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing();
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object

View File

@ -266,6 +266,28 @@ TEST_F(MindDataTestPipeline, TestCocoPanoptic) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestCocoPanopticGetClassIndex) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoPanopticGetClassIndex.";
// Create a Coco Dataset
std::string folder_path = datasets_root_path_ + "/testCOCO/train";
std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/panoptic.json";
std::shared_ptr<Dataset> ds = Coco(folder_path, annotation_file, "Panoptic", false, SequentialSampler(0, 2));
EXPECT_NE(ds, nullptr);
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index1 = ds->GetClassIndexing();
EXPECT_EQ(class_index1.size(), 3);
EXPECT_EQ(class_index1[0].first, "person");
EXPECT_EQ(class_index1[0].second[0], 1);
EXPECT_EQ(class_index1[0].second[1], 1);
EXPECT_EQ(class_index1[1].first, "bicycle");
EXPECT_EQ(class_index1[1].second[0], 2);
EXPECT_EQ(class_index1[1].second[1], 1);
EXPECT_EQ(class_index1[2].first, "car");
EXPECT_EQ(class_index1[2].second[0], 3);
EXPECT_EQ(class_index1[2].second[1], 1);
}
TEST_F(MindDataTestPipeline, TestCocoStuff) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoStuff.";
// Create a Coco Dataset

View File

@ -70,6 +70,22 @@ TEST_F(MindDataTestPipeline, TestManifestGetters) {
EXPECT_NE(ds2, nullptr);
EXPECT_EQ(ds2->GetDatasetSize(), 4);
EXPECT_EQ(ds2->GetNumClasses(), 3);
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index1 = ds1->GetClassIndexing();
EXPECT_EQ(class_index1.size(), 2);
EXPECT_EQ(class_index1[0].first, "cat");
EXPECT_EQ(class_index1[0].second[0], 0);
EXPECT_EQ(class_index1[1].first, "dog");
EXPECT_EQ(class_index1[1].second[0], 1);
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index2 = ds2->GetClassIndexing();
EXPECT_EQ(class_index2.size(), 3);
EXPECT_EQ(class_index2[0].first, "cat");
EXPECT_EQ(class_index2[0].second[0], 0);
EXPECT_EQ(class_index2[1].first, "dog");
EXPECT_EQ(class_index2[1].second[0], 1);
EXPECT_EQ(class_index2[2].first, "flower");
EXPECT_EQ(class_index2[2].second[0], 2);
}
TEST_F(MindDataTestPipeline, TestManifestDecode) {
@ -151,6 +167,13 @@ TEST_F(MindDataTestPipeline, TestManifestClassIndex) {
std::shared_ptr<Dataset> ds = Manifest(file_path, "train", RandomSampler(), map, true);
EXPECT_NE(ds, nullptr);
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index1 = ds->GetClassIndexing();
EXPECT_EQ(class_index1.size(), 2);
EXPECT_EQ(class_index1[0].first, "cat");
EXPECT_EQ(class_index1[0].second[0], 111);
EXPECT_EQ(class_index1[1].first, "dog");
EXPECT_EQ(class_index1[1].second[0], 222);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();

View File

@ -72,6 +72,28 @@ TEST_F(MindDataTestPipeline, TestVOCClassIndex) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestVOCGetClassIndex) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetClassIndex.";
// Create a VOC Dataset
std::string folder_path = datasets_root_path_ + "/testVOC2012_2";
std::map<std::string, int32_t> class_index;
class_index["car"] = 0;
class_index["cat"] = 1;
class_index["train"] = 9;
std::shared_ptr<Dataset> ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6));
EXPECT_NE(ds, nullptr);
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index1 = ds->GetClassIndexing();
EXPECT_EQ(class_index1.size(), 3);
EXPECT_EQ(class_index1[0].first, "car");
EXPECT_EQ(class_index1[0].second[0], 0);
EXPECT_EQ(class_index1[1].first, "cat");
EXPECT_EQ(class_index1[1].second[0], 1);
EXPECT_EQ(class_index1[2].first, "train");
EXPECT_EQ(class_index1[2].second[0], 9);
}
TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetDatasetSize.";