forked from mindspore-Ecosystem/mindspore
add skeleton code for tree_adapter
stage II tree_adapter add test case for tree_adapter use dfs to build tree add more test case add test case fix ci fix ci round 2 fix ci round 3 fix ci round 4 fix complie error fix ci round 6 fix ci round 7 fix ci round 8 fix ci round 9
This commit is contained in:
parent
cc4aa65743
commit
aa8442e0a3
|
@ -14,6 +14,7 @@ add_library(engine OBJECT
|
|||
data_buffer.cc
|
||||
data_schema.cc
|
||||
dataset_iterator.cc
|
||||
tree_adapter.cc
|
||||
)
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::Dataset> root_ir, int32_t num_epoch) {
|
||||
// Check whether this function has been called before. If so, return fail
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root_ir);
|
||||
// GlobalInit, might need to be moved to the proper place once RuntimeConext is complete
|
||||
RETURN_IF_NOT_OK(GlobalInit());
|
||||
|
||||
// this will evolve in the long run
|
||||
tree_ = std::make_unique<ExecutionTree>();
|
||||
|
||||
std::shared_ptr<DatasetOp> root_op;
|
||||
RETURN_IF_NOT_OK(DFSBuildTree(root_ir, &root_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||
|
||||
// Prepare the tree
|
||||
RETURN_IF_NOT_OK(tree_->Prepare(num_epoch));
|
||||
|
||||
// after the tree is prepared, the col_name_id_map can safely be obtained
|
||||
column_name_map_ = tree_->root()->column_name_id_map();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::GetNext(TensorRow *row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
row->clear(); // make sure row is empty
|
||||
// cur_db_ being a nullptr means this is the first call to get_next, launch ExecutionTree
|
||||
if (cur_db_ == nullptr) {
|
||||
RETURN_IF_NOT_OK(tree_->Launch());
|
||||
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag
|
||||
RETURN_OK_IF_TRUE(cur_db_->eoe()); // return empty tensor if 1st buf is a ctrl buf (no rows)
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached.");
|
||||
|
||||
if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf
|
||||
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_));
|
||||
RETURN_OK_IF_TRUE(cur_db_->eoe() || cur_db_->eof()); // return empty if this new buffer is a ctrl flag
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(cur_db_->PopRow(row));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_ptr<DatasetOp> *op) {
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node.");
|
||||
|
||||
(*op) = ops.front(); // return the first op to be added as child by the caller of this function
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(*op));
|
||||
|
||||
for (size_t i = 1; i < ops.size(); i++) {
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(ops[i]));
|
||||
RETURN_IF_NOT_OK(ops[i - 1]->AddChild(ops[i]));
|
||||
}
|
||||
|
||||
// build the children of ir, once they return, add the return value to *op
|
||||
for (std::shared_ptr<api::Dataset> child_ir : ir->children) {
|
||||
std::shared_ptr<DatasetOp> child_op;
|
||||
RETURN_IF_NOT_OK(DFSBuildTree(child_ir, &child_op));
|
||||
RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_ADAPTER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_ADAPTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class TreeAdapter {
|
||||
public:
|
||||
TreeAdapter() = default;
|
||||
|
||||
~TreeAdapter() = default;
|
||||
|
||||
// This will construct a ExeTree from a Dataset root and Prepare() the ExeTree
|
||||
// This function is only meant to be called once and needs to be called before GetNext
|
||||
// ExeTree will be launched when the first GetNext is called
|
||||
Status BuildAndPrepare(std::shared_ptr<api::Dataset> root, int32_t num_epoch = -1);
|
||||
|
||||
// This is the main method TreeConsumer uses to interact with TreeAdapter
|
||||
// 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared)
|
||||
// 2. GetNext will return empty row when eoe/eof is obtained
|
||||
Status GetNext(TensorRow *);
|
||||
|
||||
// this function will return the column_name_map once BuildAndPrepare() is called
|
||||
std::unordered_map<std::string, int32_t> GetColumnNameMap() const { return column_name_map_; }
|
||||
|
||||
// this function returns the TaskGroup associated with ExeTree, this is needed by DeviceQueueConsumer
|
||||
// to be able to launch a thread. BuildAndPrepare needs to be called before this function
|
||||
TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; }
|
||||
|
||||
private:
|
||||
// this RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. ir could build a vector of ops. In
|
||||
// such case, the first node is returned. Op is added as child when the current function returns.
|
||||
Status DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_ptr<DatasetOp> *op);
|
||||
|
||||
std::unique_ptr<DataBuffer> cur_db_;
|
||||
std::unordered_map<std::string, int32_t> column_name_map_;
|
||||
std::unique_ptr<ExecutionTree> tree_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_ADAPTER_H_
|
|
@ -45,6 +45,7 @@ class DatasetOp;
|
|||
class DataSchema;
|
||||
class Tensor;
|
||||
class TensorShape;
|
||||
class TreeAdapter;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class Vocab;
|
||||
#endif
|
||||
|
@ -458,7 +459,9 @@ std::shared_ptr<ZipNode> Zip(const std::vector<std::shared_ptr<Dataset>> &datase
|
|||
/// \brief A base class to represent a dataset in the data pipeline.
|
||||
class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||
public:
|
||||
// need friend class so they can access the children_ field
|
||||
friend class Iterator;
|
||||
friend class mindspore::dataset::TreeAdapter;
|
||||
|
||||
/// \brief Constructor
|
||||
Dataset();
|
||||
|
|
|
@ -70,12 +70,13 @@ SET(DE_UT_SRCS
|
|||
stand_alone_samplers_test.cc
|
||||
status_test.cc
|
||||
task_manager_test.cc
|
||||
tensor_test.cc
|
||||
tensor_row_test.cc
|
||||
tensor_string_test.cc
|
||||
tensor_test.cc
|
||||
tensorshape_test.cc
|
||||
tfReader_op_test.cc
|
||||
to_float16_op_test.cc
|
||||
tree_adapter_test.cc
|
||||
type_cast_op_test.cc
|
||||
zip_op_test.cc
|
||||
random_resize_op_test.cc
|
||||
|
@ -112,8 +113,8 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_csv_test.cc
|
||||
c_api_dataset_manifest_test.cc
|
||||
c_api_dataset_randomdata_test.cc
|
||||
c_api_dataset_textfile_test.cc
|
||||
c_api_dataset_tfrecord_test.cc
|
||||
c_api_dataset_textfile_test.cc
|
||||
c_api_dataset_tfrecord_test.cc
|
||||
c_api_dataset_voc_test.cc
|
||||
c_api_datasets_test.cc
|
||||
c_api_dataset_iterator_test.cc
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
class MindDataTestTreeAdapter : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestSimpleTreeAdapter.";
|
||||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<api::Dataset> ds = Mnist(folder_path, "all", api::SequentialSampler(0, 4));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->Batch(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
mindspore::dataset::TreeAdapter tree_adapter;
|
||||
|
||||
Status rc = tree_adapter.BuildAndPrepare(ds, 1);
|
||||
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}};
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {2, 2, 0, 0};
|
||||
|
||||
TensorRow row;
|
||||
for (size_t sz : row_sizes) {
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
EXPECT_EQ(row.size(), sz);
|
||||
}
|
||||
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
EXPECT_TRUE(rc.IsError());
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestTreeAdapterWithRepeat.";
|
||||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<api::Dataset> ds = Mnist(folder_path, "all", api::SequentialSampler(0, 3));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->Batch(2, false);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
mindspore::dataset::TreeAdapter tree_adapter;
|
||||
|
||||
Status rc = tree_adapter.BuildAndPrepare(ds, 2);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap();
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0, 0};
|
||||
|
||||
TensorRow row;
|
||||
for (size_t sz : row_sizes) {
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
EXPECT_EQ(row.size(), sz);
|
||||
}
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProjectMap.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<api::Dataset> ds = ImageFolder(folder_path, true, api::SequentialSampler(0, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<api::TensorOperation> one_hot = api::transforms::OneHot(10);
|
||||
EXPECT_NE(one_hot, nullptr);
|
||||
|
||||
// Create a Map operation, this will automatically add a project after map
|
||||
ds = ds->Map({one_hot}, {"label"}, {"label"}, {"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
mindspore::dataset::TreeAdapter tree_adapter;
|
||||
|
||||
Status rc = tree_adapter.BuildAndPrepare(ds, 2);
|
||||
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
const std::unordered_map<std::string, int32_t> map = {{"label", 0}};
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0, 0};
|
||||
TensorRow row;
|
||||
|
||||
for (size_t sz : row_sizes) {
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
EXPECT_EQ(row.size(), sz);
|
||||
}
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
}
|
Loading…
Reference in New Issue