!7702 Redesigned GetOutputType and GetOutputShape

Merge pull request !7702 from Alex Yuyue/IR_dataset_input
This commit is contained in:
mindspore-ci-bot 2020-10-24 14:42:29 +08:00 committed by Gitee
commit 3d194137da
5 changed files with 116 additions and 29 deletions

View File

@ -192,15 +192,45 @@ int64_t Dataset::GetDatasetSize() {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
}
}
rc = tree_getters_->GetDatasetSize(&dataset_size);
return rc.IsError() ? -1 : dataset_size;
}
std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types;
Status s;
if (!tree_getters_->isInitialized()) {
s = tree_getters_->Init(shared_from_this());
if (s.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return types;
}
}
tree_getters_->GetOutputTypes(&types);
return types;
}
std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes;
Status s;
if (!tree_getters_->isInitialized()) {
s = tree_getters_->Init(shared_from_this());
if (s.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return shapes;
}
}
tree_getters_->GetOutputShapes(&shapes);
return shapes;
}
// Constructor to initialize the cache
Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; }

View File

@ -351,12 +351,27 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
}
#endif
TreeGetters::TreeGetters() {
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) {
tree_adapter_ = std::make_unique<TreeAdapter>();
dataset_size_ = -1;
}
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); }
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
if (!s.IsError()) {
init_flag_ = true;
}
return s;
}
bool TreeGetters::isInitialized() { return init_flag_; }
Status TreeGetters::GetRow(TensorRow *row) {
if (row_flag_ == false) {
RETURN_IF_NOT_OK(tree_adapter_->GetNext(row));
row_flag_ = true;
}
return Status::OK();
}
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ == -1) {
@ -364,10 +379,10 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
dataset_size_ = *dataset_size;
TensorRow row;
if (*dataset_size == -1) {
RETURN_IF_NOT_OK(GetRow(&row_));
int64_t num_rows = 0;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
TensorRow row = row_;
while (row.size() != 0) {
num_rows++;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
@ -379,4 +394,22 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
*dataset_size = dataset_size_;
return Status::OK();
}
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
RETURN_IF_NOT_OK(GetRow(&row_));
for (auto ts : row_) {
DataType dt = ts->type();
types->push_back(dt);
}
return Status::OK();
}
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
RETURN_IF_NOT_OK(GetRow(&row_));
for (auto ts : row_) {
TensorShape t = ts->shape();
shapes->push_back(t);
}
return Status::OK();
}
} // namespace mindspore::dataset

View File

@ -156,29 +156,17 @@ class TreeGetters : public TreeConsumer {
TreeGetters();
Status Init(std::shared_ptr<api::Dataset> d) override;
Status GetDatasetSize(int64_t *size);
Status GetBatchSize(int32_t *batch_size) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetRepeatCount(int32_t *repeat_count) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetNumClasses(int32_t *num_classes) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetOutputShapes(std::vector<TensorShape> *shapes) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetOutputTypes(std::vector<DataType> *types) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetOutputNames(std::vector<std::string> *names) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
bool isInitialized();
std::string Name() override { return "TreeGetters"; }
Status GetRow(TensorRow *r);
private:
int64_t dataset_size_;
TensorRow row_;
bool init_flag_; // indicate whether the tree has initialized
bool row_flag_; // indicate whether the first row has been stored in row_
};
} // namespace mindspore::dataset

View File

@ -27,7 +27,6 @@
#include <vector>
#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/include/iterator.h"
@ -584,6 +583,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return status code
int64_t GetDatasetSize();
/// \brief Gets the output type
/// \return status code
std::vector<DataType> GetOutputTypes();
/// \brief Gets the output shape
/// \return status code
std::vector<TensorShape> GetOutputShapes();
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object

View File

@ -34,6 +34,8 @@
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
using mindspore::dataset::DataType;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
@ -84,6 +86,33 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) {
EXPECT_EQ(ds->GetDatasetSize(), 10000);
}
TEST_F(MindDataTestPipeline, TestCifar10MixGetter) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter.";
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all");
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 10000);
std::vector<DataType> types = ds->GetOutputTypes();
std::vector<TensorShape> shapes = ds->GetOutputShapes();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[0].ToString(), "<32,32,3>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(ds->GetDatasetSize(), 10000);
EXPECT_EQ(ds->GetOutputTypes(), types);
EXPECT_EQ(ds->GetOutputShapes(), shapes);
EXPECT_EQ(ds->GetDatasetSize(), 10000);
EXPECT_EQ(ds->GetOutputTypes(), types);
EXPECT_EQ(ds->GetOutputShapes(), shapes);
EXPECT_EQ(ds->GetDatasetSize(), 10000);
}
TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset.";