diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index ed36277d93e..06720150293 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -192,15 +192,45 @@ int64_t Dataset::GetDatasetSize() { MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; return -1; } - rc = tree_getters_->Init(ds); - if (rc.IsError()) { - MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; - return -1; + if (!tree_getters_->isInitialized()) { + rc = tree_getters_->Init(ds); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; + return -1; + } } rc = tree_getters_->GetDatasetSize(&dataset_size); return rc.IsError() ? -1 : dataset_size; } +std::vector Dataset::GetOutputTypes() { + std::vector types; + Status s; + if (!tree_getters_->isInitialized()) { + s = tree_getters_->Init(shared_from_this()); + if (s.IsError()) { + MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; + return types; + } + } + tree_getters_->GetOutputTypes(&types); + return types; +} + +std::vector Dataset::GetOutputShapes() { + std::vector shapes; + Status s; + if (!tree_getters_->isInitialized()) { + s = tree_getters_->Init(shared_from_this()); + if (s.IsError()) { + MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; + return shapes; + } + } + tree_getters_->GetOutputShapes(&shapes); + return shapes; +} + // Constructor to initialize the cache Dataset::Dataset(const std::shared_ptr &dataset_cache) : Dataset() { cache_ = dataset_cache; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index f27b4d5e418..bcedbb7f916 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -351,12 +351,27 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape & } #endif -TreeGetters::TreeGetters() { +TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) { tree_adapter_ = std::make_unique(); - dataset_size_ = -1; } -Status TreeGetters::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); } +Status TreeGetters::Init(std::shared_ptr d) { + Status s = tree_adapter_->BuildAndPrepare(std::move(d)); + if (!s.IsError()) { + init_flag_ = true; + } + return s; +} + +bool TreeGetters::isInitialized() { return init_flag_; } + +Status TreeGetters::GetRow(TensorRow *row) { + if (row_flag_ == false) { + RETURN_IF_NOT_OK(tree_adapter_->GetNext(row)); + row_flag_ = true; + } + return Status::OK(); +} Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { if (dataset_size_ == -1) { @@ -364,10 +379,10 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); dataset_size_ = *dataset_size; - TensorRow row; if (*dataset_size == -1) { + RETURN_IF_NOT_OK(GetRow(&row_)); int64_t num_rows = 0; - RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); + TensorRow row = row_; while (row.size() != 0) { num_rows++; RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); @@ -379,4 +394,22 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { *dataset_size = dataset_size_; return Status::OK(); } + +Status TreeGetters::GetOutputTypes(std::vector *types) { + RETURN_IF_NOT_OK(GetRow(&row_)); + for (auto ts : row_) { + DataType dt = ts->type(); + types->push_back(dt); + } + return Status::OK(); +} + +Status TreeGetters::GetOutputShapes(std::vector *shapes) { + RETURN_IF_NOT_OK(GetRow(&row_)); + for (auto ts : row_) { + TensorShape t = ts->shape(); + shapes->push_back(t); + } + return Status::OK(); +} } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index 9dd66f9bcb6..30d289388c6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -156,29 +156,17 @@ class TreeGetters : public TreeConsumer { TreeGetters(); Status Init(std::shared_ptr d) override; Status GetDatasetSize(int64_t *size); - Status GetBatchSize(int32_t *batch_size) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } - Status GetRepeatCount(int32_t *repeat_count) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } - Status GetNumClasses(int32_t *num_classes) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } - Status GetOutputShapes(std::vector *shapes) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } - Status GetOutputTypes(std::vector *types) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } - Status GetOutputNames(std::vector *names) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } - + Status GetOutputTypes(std::vector *types); + Status GetOutputShapes(std::vector *shapes); + bool isInitialized(); std::string Name() override { return "TreeGetters"; } + Status GetRow(TensorRow *r); private: int64_t dataset_size_; + TensorRow row_; + bool init_flag_; // indicate whether the tree has initialized + bool row_flag_; // indicate whether the first row has been stored in row_ }; } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 27ec919c82f..71c900ddfeb 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -27,7 +27,6 @@ #include #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" #include "minddata/dataset/core/constants.h" - #include "minddata/dataset/engine/consumers/tree_consumer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/include/iterator.h" @@ -584,6 +583,14 @@ class Dataset : public std::enable_shared_from_this { /// \return status code int64_t GetDatasetSize(); + /// \brief Gets the output type + /// \return status code + std::vector GetOutputTypes(); + + /// \brief Gets the output shape + /// \return status code + std::vector GetOutputShapes(); + /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object diff --git a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc index b7a33cea5f1..f0817647346 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc @@ -34,6 +34,8 @@ using namespace mindspore::dataset::api; using mindspore::dataset::Tensor; +using mindspore::dataset::DataType; +using mindspore::dataset::TensorShape; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: @@ -84,6 +86,33 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) { EXPECT_EQ(ds->GetDatasetSize(), 10000); } +TEST_F(MindDataTestPipeline, TestCifar10MixGetter) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter."; + + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, "all"); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 10000); + std::vector types = ds->GetOutputTypes(); + std::vector shapes = ds->GetOutputShapes(); + EXPECT_EQ(types.size(), 2); + EXPECT_EQ(types[0].ToString(), "uint8"); + EXPECT_EQ(types[1].ToString(), "uint32"); + EXPECT_EQ(shapes.size(), 2); + EXPECT_EQ(shapes[0].ToString(), "<32,32,3>"); + EXPECT_EQ(shapes[1].ToString(), "<>"); + + EXPECT_EQ(ds->GetDatasetSize(), 10000); + EXPECT_EQ(ds->GetOutputTypes(), types); + EXPECT_EQ(ds->GetOutputShapes(), shapes); + EXPECT_EQ(ds->GetDatasetSize(), 10000); + EXPECT_EQ(ds->GetOutputTypes(), types); + EXPECT_EQ(ds->GetOutputShapes(), shapes); + EXPECT_EQ(ds->GetDatasetSize(), 10000); +} + TEST_F(MindDataTestPipeline, TestCifar100Dataset) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset.";