forked from mindspore-Ecosystem/mindspore
!8387 Add support for GetClassIndexing in C++ API
From: @alex-yuyue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
70f5775711
|
@ -200,12 +200,10 @@ int64_t Dataset::GetDatasetSize() {
|
|||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetDatasetSize(&dataset_size);
|
||||
return rc.IsError() ? -1 : dataset_size;
|
||||
|
@ -218,16 +216,12 @@ std::vector<DataType> Dataset::GetOutputTypes() {
|
|||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
|
||||
types.clear();
|
||||
return types;
|
||||
}
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
|
||||
types.clear();
|
||||
return types;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
|
||||
return types;
|
||||
}
|
||||
rc = tree_getters_->GetOutputTypes(&types);
|
||||
if (rc.IsError()) {
|
||||
|
@ -245,16 +239,12 @@ std::vector<TensorShape> Dataset::GetOutputShapes() {
|
|||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
|
||||
shapes.clear();
|
||||
return shapes;
|
||||
}
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
|
||||
shapes.clear();
|
||||
return shapes;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
|
||||
return shapes;
|
||||
}
|
||||
rc = tree_getters_->GetOutputShapes(&shapes);
|
||||
if (rc.IsError()) {
|
||||
|
@ -275,17 +265,39 @@ int64_t Dataset::GetNumClasses() {
|
|||
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetNumClasses(&num_classes);
|
||||
return rc.IsError() ? -1 : num_classes;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed.";
|
||||
return output_class_indexing;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed.";
|
||||
return output_class_indexing;
|
||||
}
|
||||
rc = tree_getters_->GetClassIndexing(&output_class_indexing);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed.";
|
||||
output_class_indexing.clear();
|
||||
return output_class_indexing;
|
||||
}
|
||||
return output_class_indexing;
|
||||
}
|
||||
|
||||
/// \brief Function to create a SchemaObj
|
||||
/// \param[in] schema_file Path of schema file
|
||||
/// \return Shared pointer to the current schema
|
||||
|
@ -580,12 +592,10 @@ int64_t Dataset::GetBatchSize() {
|
|||
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetBatchSize(&batch_size);
|
||||
return rc.IsError() ? -1 : batch_size;
|
||||
|
@ -601,22 +611,22 @@ int64_t Dataset::GetRepeatCount() {
|
|||
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetRepeatCount(&repeat_count);
|
||||
return rc.IsError() ? 0 : repeat_count;
|
||||
}
|
||||
|
||||
std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
|
||||
if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return shared_from_this();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
|
||||
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
|
||||
|
|
|
@ -384,6 +384,9 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal
|
|||
}
|
||||
|
||||
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
|
||||
if (init_flag_) {
|
||||
return Status::OK();
|
||||
}
|
||||
Status s = tree_adapter_->Compile(std::move(d), 1);
|
||||
if (!s.IsError()) {
|
||||
init_flag_ = true;
|
||||
|
@ -463,6 +466,13 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
|
||||
|
||||
Status BuildVocabConsumer::Start() {
|
||||
|
|
|
@ -166,6 +166,7 @@ class TreeGetters : public TreeConsumer {
|
|||
Status GetBatchSize(int64_t *batch_size);
|
||||
Status GetRepeatCount(int64_t *repeat_count);
|
||||
Status GetNumClasses(int64_t *num_classes);
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
|
||||
bool isInitialized();
|
||||
std::string Name() override { return "TreeGetters"; }
|
||||
Status GetRow(TensorRow *r);
|
||||
|
|
|
@ -316,6 +316,14 @@ Status DatasetOp::GetNumClasses(int64_t *num_classes) {
|
|||
}
|
||||
}
|
||||
|
||||
Status DatasetOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
if (!child_.empty()) {
|
||||
return child_[0]->GetClassIndexing(output_class_indexing);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Can't get the class index for the current tree.");
|
||||
}
|
||||
}
|
||||
|
||||
// Performs handling for when an eoe message is received.
|
||||
// The base class implementation simply flows the eoe message to output. Derived classes
|
||||
// may override if they need to perform special eoe handling.
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/callback/callback_manager.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
|
@ -195,6 +196,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return Status - The status code return
|
||||
virtual Status GetNumClasses(int64_t *num_classes);
|
||||
|
||||
/// \brief Gets the class indexing
|
||||
/// \return Status - The status code return
|
||||
virtual Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
|
||||
|
||||
/// \brief Performs handling for when an eoe message is received.
|
||||
/// The base class implementation simply flows the eoe message to output. Derived classes
|
||||
/// may override if they need to perform special eoe handling.
|
||||
|
|
|
@ -710,5 +710,30 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CocoOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
if ((*output_class_indexing).empty()) {
|
||||
if ((task_type_ != TaskType::Detection) && (task_type_ != TaskType::Panoptic)) {
|
||||
MS_LOG(ERROR) << "Class index only valid in \"Detection\" and \"Panoptic\" task.";
|
||||
RETURN_STATUS_UNEXPECTED("GetClassIndexing: Get Class Index failed in CocoOp.");
|
||||
}
|
||||
std::shared_ptr<CocoOp> op;
|
||||
std::string task_type;
|
||||
switch (task_type_) {
|
||||
case TaskType::Detection:
|
||||
task_type = "Detection";
|
||||
break;
|
||||
case TaskType::Panoptic:
|
||||
task_type = "Panoptic";
|
||||
break;
|
||||
}
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(image_folder_path_).SetFile(annotation_path_).SetTask(task_type).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
|
||||
for (const auto label : op->label_index_) {
|
||||
(*output_class_indexing).emplace_back(std::make_pair(label.first, label.second));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -218,6 +218,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Gets the class indexing
|
||||
/// \return Status - The status code return
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -491,5 +491,25 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ManifestOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
if ((*output_class_indexing).empty()) {
|
||||
std::shared_ptr<ManifestOp> op;
|
||||
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseManifestFile());
|
||||
RETURN_IF_NOT_OK(op->CountDatasetInfo());
|
||||
uint32_t count = 0;
|
||||
for (const auto label : op->label_index_) {
|
||||
if (!class_index_.empty()) {
|
||||
(*output_class_indexing)
|
||||
.emplace_back(std::make_pair(label.first, std::vector<int32_t>(1, class_index_[label.first])));
|
||||
} else {
|
||||
(*output_class_indexing).emplace_back(std::make_pair(label.first, std::vector<int32_t>(1, count)));
|
||||
}
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -193,6 +193,10 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
/// \return Status of the function
|
||||
Status GetNumClasses(int64_t *num_classes) override;
|
||||
|
||||
/// \brief Gets the class indexing
|
||||
/// \return Status - The status code return
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -542,5 +542,28 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
if ((*output_class_indexing).empty()) {
|
||||
if (task_type_ != TaskType::Detection) {
|
||||
MS_LOG(ERROR) << "Class index only valid in \"Detection\" task.";
|
||||
RETURN_STATUS_UNEXPECTED("GetClassIndexing: Get Class Index failed in VOCOp.");
|
||||
}
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(
|
||||
Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
|
||||
for (const auto label : op->label_index_) {
|
||||
if (!class_index_.empty()) {
|
||||
(*output_class_indexing)
|
||||
.emplace_back(std::make_pair(label.first, std::vector<int32_t>(1, class_index_[label.first])));
|
||||
} else {
|
||||
(*output_class_indexing).emplace_back(std::make_pair(label.first, std::vector<int32_t>(1, label.second)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -221,6 +221,10 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
// /// \brief Gets the class indexing
|
||||
// /// \return Status - The status code return
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -119,6 +119,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return number of classes. If failed, return -1
|
||||
int64_t GetNumClasses();
|
||||
|
||||
/// \brief Gets the class indexing
|
||||
/// \return a map of ClassIndexing. If failed, return an empty map
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing();
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
/// \return Shared pointer to the original object
|
||||
|
|
|
@ -266,6 +266,28 @@ TEST_F(MindDataTestPipeline, TestCocoPanoptic) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCocoPanopticGetClassIndex) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoPanopticGetClassIndex.";
|
||||
// Create a Coco Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCOCO/train";
|
||||
std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/panoptic.json";
|
||||
|
||||
std::shared_ptr<Dataset> ds = Coco(folder_path, annotation_file, "Panoptic", false, SequentialSampler(0, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index1 = ds->GetClassIndexing();
|
||||
EXPECT_EQ(class_index1.size(), 3);
|
||||
EXPECT_EQ(class_index1[0].first, "person");
|
||||
EXPECT_EQ(class_index1[0].second[0], 1);
|
||||
EXPECT_EQ(class_index1[0].second[1], 1);
|
||||
EXPECT_EQ(class_index1[1].first, "bicycle");
|
||||
EXPECT_EQ(class_index1[1].second[0], 2);
|
||||
EXPECT_EQ(class_index1[1].second[1], 1);
|
||||
EXPECT_EQ(class_index1[2].first, "car");
|
||||
EXPECT_EQ(class_index1[2].second[0], 3);
|
||||
EXPECT_EQ(class_index1[2].second[1], 1);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCocoStuff) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoStuff.";
|
||||
// Create a Coco Dataset
|
||||
|
|
|
@ -70,6 +70,22 @@ TEST_F(MindDataTestPipeline, TestManifestGetters) {
|
|||
EXPECT_NE(ds2, nullptr);
|
||||
EXPECT_EQ(ds2->GetDatasetSize(), 4);
|
||||
EXPECT_EQ(ds2->GetNumClasses(), 3);
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index1 = ds1->GetClassIndexing();
|
||||
EXPECT_EQ(class_index1.size(), 2);
|
||||
EXPECT_EQ(class_index1[0].first, "cat");
|
||||
EXPECT_EQ(class_index1[0].second[0], 0);
|
||||
EXPECT_EQ(class_index1[1].first, "dog");
|
||||
EXPECT_EQ(class_index1[1].second[0], 1);
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index2 = ds2->GetClassIndexing();
|
||||
EXPECT_EQ(class_index2.size(), 3);
|
||||
EXPECT_EQ(class_index2[0].first, "cat");
|
||||
EXPECT_EQ(class_index2[0].second[0], 0);
|
||||
EXPECT_EQ(class_index2[1].first, "dog");
|
||||
EXPECT_EQ(class_index2[1].second[0], 1);
|
||||
EXPECT_EQ(class_index2[2].first, "flower");
|
||||
EXPECT_EQ(class_index2[2].second[0], 2);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestManifestDecode) {
|
||||
|
@ -151,6 +167,13 @@ TEST_F(MindDataTestPipeline, TestManifestClassIndex) {
|
|||
std::shared_ptr<Dataset> ds = Manifest(file_path, "train", RandomSampler(), map, true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index1 = ds->GetClassIndexing();
|
||||
EXPECT_EQ(class_index1.size(), 2);
|
||||
EXPECT_EQ(class_index1[0].first, "cat");
|
||||
EXPECT_EQ(class_index1[0].second[0], 111);
|
||||
EXPECT_EQ(class_index1[1].first, "dog");
|
||||
EXPECT_EQ(class_index1[1].second[0], 222);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
|
|
|
@ -72,6 +72,28 @@ TEST_F(MindDataTestPipeline, TestVOCClassIndex) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestVOCGetClassIndex) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetClassIndex.";
|
||||
// Create a VOC Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testVOC2012_2";
|
||||
std::map<std::string, int32_t> class_index;
|
||||
class_index["car"] = 0;
|
||||
class_index["cat"] = 1;
|
||||
class_index["train"] = 9;
|
||||
|
||||
std::shared_ptr<Dataset> ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> class_index1 = ds->GetClassIndexing();
|
||||
EXPECT_EQ(class_index1.size(), 3);
|
||||
EXPECT_EQ(class_index1[0].first, "car");
|
||||
EXPECT_EQ(class_index1[0].second[0], 0);
|
||||
EXPECT_EQ(class_index1[1].first, "cat");
|
||||
EXPECT_EQ(class_index1[1].second[0], 1);
|
||||
EXPECT_EQ(class_index1[2].first, "train");
|
||||
EXPECT_EQ(class_index1[2].second[0], 9);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetDatasetSize.";
|
||||
|
||||
|
|
Loading…
Reference in New Issue