forked from mindspore-Ecosystem/mindspore
!7702 Redesigned GetOutputType and GetOutputShape
Merge pull request !7702 from Alex Yuyue/IR_dataset_input
This commit is contained in:
commit
3d194137da
|
@ -192,15 +192,45 @@ int64_t Dataset::GetDatasetSize() {
|
|||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
rc = tree_getters_->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
rc = tree_getters_->GetDatasetSize(&dataset_size);
|
||||
return rc.IsError() ? -1 : dataset_size;
|
||||
}
|
||||
|
||||
std::vector<DataType> Dataset::GetOutputTypes() {
|
||||
std::vector<DataType> types;
|
||||
Status s;
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
s = tree_getters_->Init(shared_from_this());
|
||||
if (s.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||
return types;
|
||||
}
|
||||
}
|
||||
tree_getters_->GetOutputTypes(&types);
|
||||
return types;
|
||||
}
|
||||
|
||||
std::vector<TensorShape> Dataset::GetOutputShapes() {
|
||||
std::vector<TensorShape> shapes;
|
||||
Status s;
|
||||
if (!tree_getters_->isInitialized()) {
|
||||
s = tree_getters_->Init(shared_from_this());
|
||||
if (s.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||
return shapes;
|
||||
}
|
||||
}
|
||||
tree_getters_->GetOutputShapes(&shapes);
|
||||
return shapes;
|
||||
}
|
||||
|
||||
// Constructor to initialize the cache
|
||||
Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; }
|
||||
|
||||
|
|
|
@ -351,12 +351,27 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
|
|||
}
|
||||
#endif
|
||||
|
||||
TreeGetters::TreeGetters() {
|
||||
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) {
|
||||
tree_adapter_ = std::make_unique<TreeAdapter>();
|
||||
dataset_size_ = -1;
|
||||
}
|
||||
|
||||
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); }
|
||||
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) {
|
||||
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
|
||||
if (!s.IsError()) {
|
||||
init_flag_ = true;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
bool TreeGetters::isInitialized() { return init_flag_; }
|
||||
|
||||
Status TreeGetters::GetRow(TensorRow *row) {
|
||||
if (row_flag_ == false) {
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(row));
|
||||
row_flag_ = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ == -1) {
|
||||
|
@ -364,10 +379,10 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
|
||||
dataset_size_ = *dataset_size;
|
||||
TensorRow row;
|
||||
if (*dataset_size == -1) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
int64_t num_rows = 0;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
||||
TensorRow row = row_;
|
||||
while (row.size() != 0) {
|
||||
num_rows++;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
||||
|
@ -379,4 +394,22 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
|||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
for (auto ts : row_) {
|
||||
DataType dt = ts->type();
|
||||
types->push_back(dt);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
for (auto ts : row_) {
|
||||
TensorShape t = ts->shape();
|
||||
shapes->push_back(t);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -156,29 +156,17 @@ class TreeGetters : public TreeConsumer {
|
|||
TreeGetters();
|
||||
Status Init(std::shared_ptr<api::Dataset> d) override;
|
||||
Status GetDatasetSize(int64_t *size);
|
||||
Status GetBatchSize(int32_t *batch_size) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetRepeatCount(int32_t *repeat_count) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetNumClasses(int32_t *num_classes) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetOutputShapes(std::vector<TensorShape> *shapes) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetOutputTypes(std::vector<DataType> *types) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetOutputNames(std::vector<std::string> *names) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
|
||||
Status GetOutputTypes(std::vector<DataType> *types);
|
||||
Status GetOutputShapes(std::vector<TensorShape> *shapes);
|
||||
bool isInitialized();
|
||||
std::string Name() override { return "TreeGetters"; }
|
||||
Status GetRow(TensorRow *r);
|
||||
|
||||
private:
|
||||
int64_t dataset_size_;
|
||||
TensorRow row_;
|
||||
bool init_flag_; // indicate whether the tree has initialized
|
||||
bool row_flag_; // indicate whether the first row has been stored in row_
|
||||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include <vector>
|
||||
#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/include/iterator.h"
|
||||
|
@ -584,6 +583,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return status code
|
||||
int64_t GetDatasetSize();
|
||||
|
||||
/// \brief Gets the output type
|
||||
/// \return status code
|
||||
std::vector<DataType> GetOutputTypes();
|
||||
|
||||
/// \brief Gets the output shape
|
||||
/// \return status code
|
||||
std::vector<TensorShape> GetOutputShapes();
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
/// \return Shared pointer to the original object
|
||||
|
|
|
@ -34,6 +34,8 @@
|
|||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
|
@ -84,6 +86,33 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) {
|
|||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar10MixGetter) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter.";
|
||||
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
std::vector<DataType> types = ds->GetOutputTypes();
|
||||
std::vector<TensorShape> shapes = ds->GetOutputShapes();
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "uint32");
|
||||
EXPECT_EQ(shapes.size(), 2);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<32,32,3>");
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
EXPECT_EQ(ds->GetOutputTypes(), types);
|
||||
EXPECT_EQ(ds->GetOutputShapes(), shapes);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
EXPECT_EQ(ds->GetOutputTypes(), types);
|
||||
EXPECT_EQ(ds->GetOutputShapes(), shapes);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset.";
|
||||
|
||||
|
|
Loading…
Reference in New Issue