forked from mindspore-Ecosystem/mindspore
!7702 Redesigned GetOutputType and GetOutputShape
Merge pull request !7702 from Alex Yuyue/IR_dataset_input
This commit is contained in:
commit
3d194137da
|
@ -192,15 +192,45 @@ int64_t Dataset::GetDatasetSize() {
|
||||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
if (!tree_getters_->isInitialized()) {
|
||||||
rc = tree_getters_->Init(ds);
|
rc = tree_getters_->Init(ds);
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
rc = tree_getters_->GetDatasetSize(&dataset_size);
|
rc = tree_getters_->GetDatasetSize(&dataset_size);
|
||||||
return rc.IsError() ? -1 : dataset_size;
|
return rc.IsError() ? -1 : dataset_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<DataType> Dataset::GetOutputTypes() {
|
||||||
|
std::vector<DataType> types;
|
||||||
|
Status s;
|
||||||
|
if (!tree_getters_->isInitialized()) {
|
||||||
|
s = tree_getters_->Init(shared_from_this());
|
||||||
|
if (s.IsError()) {
|
||||||
|
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||||
|
return types;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_getters_->GetOutputTypes(&types);
|
||||||
|
return types;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TensorShape> Dataset::GetOutputShapes() {
|
||||||
|
std::vector<TensorShape> shapes;
|
||||||
|
Status s;
|
||||||
|
if (!tree_getters_->isInitialized()) {
|
||||||
|
s = tree_getters_->Init(shared_from_this());
|
||||||
|
if (s.IsError()) {
|
||||||
|
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||||
|
return shapes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_getters_->GetOutputShapes(&shapes);
|
||||||
|
return shapes;
|
||||||
|
}
|
||||||
|
|
||||||
// Constructor to initialize the cache
|
// Constructor to initialize the cache
|
||||||
Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; }
|
Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; }
|
||||||
|
|
||||||
|
|
|
@ -351,12 +351,27 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
TreeGetters::TreeGetters() {
|
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) {
|
||||||
tree_adapter_ = std::make_unique<TreeAdapter>();
|
tree_adapter_ = std::make_unique<TreeAdapter>();
|
||||||
dataset_size_ = -1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); }
|
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) {
|
||||||
|
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
|
||||||
|
if (!s.IsError()) {
|
||||||
|
init_flag_ = true;
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TreeGetters::isInitialized() { return init_flag_; }
|
||||||
|
|
||||||
|
Status TreeGetters::GetRow(TensorRow *row) {
|
||||||
|
if (row_flag_ == false) {
|
||||||
|
RETURN_IF_NOT_OK(tree_adapter_->GetNext(row));
|
||||||
|
row_flag_ = true;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
||||||
if (dataset_size_ == -1) {
|
if (dataset_size_ == -1) {
|
||||||
|
@ -364,10 +379,10 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||||
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
|
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
|
||||||
dataset_size_ = *dataset_size;
|
dataset_size_ = *dataset_size;
|
||||||
TensorRow row;
|
|
||||||
if (*dataset_size == -1) {
|
if (*dataset_size == -1) {
|
||||||
|
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||||
int64_t num_rows = 0;
|
int64_t num_rows = 0;
|
||||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
TensorRow row = row_;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
num_rows++;
|
num_rows++;
|
||||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
||||||
|
@ -379,4 +394,22 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
||||||
*dataset_size = dataset_size_;
|
*dataset_size = dataset_size_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
|
||||||
|
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||||
|
for (auto ts : row_) {
|
||||||
|
DataType dt = ts->type();
|
||||||
|
types->push_back(dt);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
|
||||||
|
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||||
|
for (auto ts : row_) {
|
||||||
|
TensorShape t = ts->shape();
|
||||||
|
shapes->push_back(t);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace mindspore::dataset
|
} // namespace mindspore::dataset
|
||||||
|
|
|
@ -156,29 +156,17 @@ class TreeGetters : public TreeConsumer {
|
||||||
TreeGetters();
|
TreeGetters();
|
||||||
Status Init(std::shared_ptr<api::Dataset> d) override;
|
Status Init(std::shared_ptr<api::Dataset> d) override;
|
||||||
Status GetDatasetSize(int64_t *size);
|
Status GetDatasetSize(int64_t *size);
|
||||||
Status GetBatchSize(int32_t *batch_size) {
|
Status GetOutputTypes(std::vector<DataType> *types);
|
||||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
Status GetOutputShapes(std::vector<TensorShape> *shapes);
|
||||||
}
|
bool isInitialized();
|
||||||
Status GetRepeatCount(int32_t *repeat_count) {
|
|
||||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
|
||||||
}
|
|
||||||
Status GetNumClasses(int32_t *num_classes) {
|
|
||||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
|
||||||
}
|
|
||||||
Status GetOutputShapes(std::vector<TensorShape> *shapes) {
|
|
||||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
|
||||||
}
|
|
||||||
Status GetOutputTypes(std::vector<DataType> *types) {
|
|
||||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
|
||||||
}
|
|
||||||
Status GetOutputNames(std::vector<std::string> *names) {
|
|
||||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Name() override { return "TreeGetters"; }
|
std::string Name() override { return "TreeGetters"; }
|
||||||
|
Status GetRow(TensorRow *r);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64_t dataset_size_;
|
int64_t dataset_size_;
|
||||||
|
TensorRow row_;
|
||||||
|
bool init_flag_; // indicate whether the tree has initialized
|
||||||
|
bool row_flag_; // indicate whether the first row has been stored in row_
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore::dataset
|
} // namespace mindspore::dataset
|
||||||
|
|
|
@ -27,7 +27,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
|
#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||||
#include "minddata/dataset/core/constants.h"
|
#include "minddata/dataset/core/constants.h"
|
||||||
|
|
||||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||||
#include "minddata/dataset/engine/data_schema.h"
|
#include "minddata/dataset/engine/data_schema.h"
|
||||||
#include "minddata/dataset/include/iterator.h"
|
#include "minddata/dataset/include/iterator.h"
|
||||||
|
@ -584,6 +583,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// \return status code
|
/// \return status code
|
||||||
int64_t GetDatasetSize();
|
int64_t GetDatasetSize();
|
||||||
|
|
||||||
|
/// \brief Gets the output type
|
||||||
|
/// \return status code
|
||||||
|
std::vector<DataType> GetOutputTypes();
|
||||||
|
|
||||||
|
/// \brief Gets the output shape
|
||||||
|
/// \return status code
|
||||||
|
std::vector<TensorShape> GetOutputShapes();
|
||||||
|
|
||||||
/// \brief Setter function for runtime number of workers
|
/// \brief Setter function for runtime number of workers
|
||||||
/// \param[in] num_workers The number of threads in this operator
|
/// \param[in] num_workers The number of threads in this operator
|
||||||
/// \return Shared pointer to the original object
|
/// \return Shared pointer to the original object
|
||||||
|
|
|
@ -34,6 +34,8 @@
|
||||||
|
|
||||||
using namespace mindspore::dataset::api;
|
using namespace mindspore::dataset::api;
|
||||||
using mindspore::dataset::Tensor;
|
using mindspore::dataset::Tensor;
|
||||||
|
using mindspore::dataset::DataType;
|
||||||
|
using mindspore::dataset::TensorShape;
|
||||||
|
|
||||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||||
protected:
|
protected:
|
||||||
|
@ -84,6 +86,33 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) {
|
||||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestCifar10MixGetter) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter.";
|
||||||
|
|
||||||
|
// Create a Cifar10 Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||||
|
std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all");
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||||
|
std::vector<DataType> types = ds->GetOutputTypes();
|
||||||
|
std::vector<TensorShape> shapes = ds->GetOutputShapes();
|
||||||
|
EXPECT_EQ(types.size(), 2);
|
||||||
|
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||||
|
EXPECT_EQ(types[1].ToString(), "uint32");
|
||||||
|
EXPECT_EQ(shapes.size(), 2);
|
||||||
|
EXPECT_EQ(shapes[0].ToString(), "<32,32,3>");
|
||||||
|
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||||
|
|
||||||
|
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||||
|
EXPECT_EQ(ds->GetOutputTypes(), types);
|
||||||
|
EXPECT_EQ(ds->GetOutputShapes(), shapes);
|
||||||
|
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||||
|
EXPECT_EQ(ds->GetOutputTypes(), types);
|
||||||
|
EXPECT_EQ(ds->GetOutputShapes(), shapes);
|
||||||
|
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
|
TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset.";
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset.";
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue