!9564 Tidy up code in dataset compilation phase

From: @nsyca
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-08 05:11:29 +08:00 committed by Gitee
commit 49fd5308a4
15 changed files with 347 additions and 266 deletions

View File

@ -556,8 +556,8 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
auto vocab = std::make_shared<SentencePieceVocab>();
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode()->DeepCopy(), vocab, col_names, vocab_size,
character_coverage, model_type, params);
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params);
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();
@ -588,8 +588,8 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first) {
auto vocab = std::make_shared<Vocab>();
auto ds = std::make_shared<BuildVocabNode>(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens,
special_first);
auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();

View File

@ -206,9 +206,6 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
}
}
// Getter function to get all of our children.
std::vector<std::shared_ptr<DatasetOp>> DatasetOp::children() const { return child_; }
// Getter function to get all of our parents.
std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; }

View File

@ -111,9 +111,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
void Parent(DatasetOp **parent, int32_t parent_index) const;
// Getter function to get all of our children.
std::vector<std::shared_ptr<DatasetOp>> children() const;
// Getter function to get all of our parents.
std::vector<DatasetOp *> parents() const;

View File

@ -233,40 +233,14 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
return shared_from_this();
}
DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) {
DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
}
const bool DatasetNode::IsTree() const {
bool is_tree = true;
if (this->parent_.size() > 1) {
MS_LOG(WARNING) << Name() << " has more than one parent.";
return false;
}
for (const auto &child : children_) {
is_tree = child->IsTree();
if (!is_tree) {
MS_LOG(WARNING) << Name() << " has more than one parent.";
break;
}
}
return is_tree;
}
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
std::shared_ptr<DatasetNode> new_node = this->Copy();
// temporary fix to set the num_workers to the new node.
new_node->SetNumWorkers(this->num_workers_);
for (const auto &child : children_) {
new_node->AddChild(child->DeepCopy());
}
return new_node;
mappable_ = kNotADataSource;
}
std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
@ -310,54 +284,105 @@ void DatasetNode::PrintNode(std::ostream &out, int *level) const {
}
// Add a node as a child, node's parent needs to be empty
// this function will allow child to be a nullptr, in which case it will simply skip
// This function will allow child to be a nullptr, in which case it will simply skip.
// This function is used only when building IR node one by one from parsing the user code.
// During the parsing, we allow a node to have more than one parent, possibly forming a graph.
// It does not maintain the parent_ attribute of the node, which enforces a single parent and a tree structure.
void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
if (child != nullptr && child->parent_.empty()) {
if (child != nullptr) {
children_.push_back(child);
child->parent_.push_back(this);
} else if (child != nullptr) {
MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent";
children_.push_back(child);
child->parent_.push_back(this);
}
}
// Insert a node as a child of this node. This node's children becomes the children of the inserted node.
// Add the input node to be the next child of this node
// This function is used in doing a deep copy of the IR tree built by parsing the user code.
// This function assumes we walk the tree in DFS left-to-right.
// This is a temporary function to be replaced later by a set of better tree operations.
void DatasetNode::AppendChild(std::shared_ptr<DatasetNode> child) {
if (child != nullptr) {
if (child->parent_ != nullptr) {
MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent";
}
children_.push_back(child);
child->parent_ = this;
}
}
// Add a node as a parent, node's parent needs to be empty (future use)
Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer.");
if (node->parent_ != nullptr) {
DatasetNode *parent = node->parent_;
for (auto i = parent->children_.size() - 1; i >= 0; --i) {
if (parent->children_[i] == node) {
parent->children_[i] = static_cast<std::shared_ptr<DatasetNode>>(this);
}
}
}
children_.push_back(node);
node->parent_ = this;
return Status::OK();
}
// Insert a node as a child of this node
// This node's children become the children of the inserted node.
Status DatasetNode::InsertBelow(std::shared_ptr<DatasetNode> node) {
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer.");
CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children.");
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_.empty(), "Inserted node must not have a parent.");
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Inserted node must not have a parent.");
for (auto child : children_) {
node->children_.push_back(child);
child->parent_.clear();
child->parent_.push_back(node.get());
child->parent_ = node.get();
}
// Then establish the new parent-child relationship with the new parent.
children_.clear();
children_.push_back(node);
node->parent_.clear();
node->parent_.push_back(this);
node->parent_ = this;
return Status::OK();
}
// Insert a node as a child next to this node (future use)
Status DatasetNode::InsertAfter(std::shared_ptr<DatasetNode> node) {
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must have a parent.");
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Inserted node must not have a parent.");
auto size = parent_->children_.size();
// Duplicate the last child to increase the size by 1
parent_->children_.push_back(parent_->children_[size - 1]);
// Shift each child to its right until we found the insertion point, then insert the input node
bool found = false;
for (auto i = parent_->children_.size() - 2; i >= 0; --i) {
if (parent_->children_[i].get() != this) {
parent_->children_[i + 1] = parent_->children_[i];
} else {
parent_->children_[i + 1] = node;
node->parent_ = parent_;
found = true;
break;
}
}
CHECK_FAIL_RETURN_UNEXPECTED(!found, "Insertion point not found.");
return Status::OK();
}
// Remove this node from its parent. Add the child of this node to its parent.
// for now, this remove is limited to node with a single child or no child
Status DatasetNode::Remove() {
CHECK_FAIL_RETURN_UNEXPECTED(parent_.size() != 0, "Cannot remove root or a node without parent.");
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent.");
CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child.");
if (children_.empty()) { // I am a leaf node, remove me from my parent's children list
parent_[0]->children_.erase(
std::remove(parent_[0]->children_.begin(), parent_[0]->children_.end(), shared_from_this()),
parent_[0]->children_.end()); // removal using "erase remove idiom"
} else { // replace my position in my parent's children list with my single child
auto itr = std::find(parent_[0]->children_.begin(), parent_[0]->children_.end(), shared_from_this());
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_[0]->children_.end(), "I am not in my parent's children list.");
parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()),
parent_->children_.end()); // removal using "erase remove idiom"
} else { // replace my position in my parent's children list with my single child
auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent
*itr = std::move(children_[0]); // replace me in my parent's children list with my single child
children_.clear(); // release my single child from my children list
children_.clear(); // release my single child from my children list
}
parent_[0] = nullptr;
parent_ = nullptr;
return Status::OK();
}

View File

@ -146,10 +146,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
return out;
}
/// \brief Make a new copy of the tree from the current node
/// \return The new copy of the tree
std::shared_ptr<DatasetNode> DeepCopy();
/// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
@ -175,43 +171,61 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Child nodes
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
/// \brief Getter function for parents nodes
/// \return Parent nodes
const std::vector<DatasetNode *> Parent() const { return parent_; }
/// \brief Getter function for the parent node
/// \return The parent node (of a node from a cloned IR tree)
DatasetNode *Parent() const { return parent_; }
/// \brief Establish the parent-child relationship between this node and its child.
/// \brief Establish a parent-child relationship between this node and the input node.
/// Used when building the IR tree.
void AddChild(std::shared_ptr<DatasetNode> child);
/// \brief Establish a parent-child relationship between this node and the input node.
/// Used during the cloning of the user-input IR tree (temporary use)
void AppendChild(std::shared_ptr<DatasetNode> child);
/// \brief Establish the child-parent relationship between this node and the input node (future use)
Status InsertAbove(std::shared_ptr<DatasetNode> node);
/// \brief Insert the input node below this node. This node's children becomes the children of the inserted node.
Status InsertBelow(std::shared_ptr<DatasetNode> node);
/// \brief Add the input node as the next sibling (future use)
Status InsertAfter(std::shared_ptr<DatasetNode> node);
/// \brief detach this node from its parent, add its child (if any) to its parent
/// \return error code, return error if node has more than 1 children
Status Remove();
/// \brief Check if this node has cache
/// \brief Check if this node has cache
/// \return True if the data of this node will be cached
const bool IsCached() const { return (cache_ != nullptr); }
/// \brief Check if this node is a tree
/// \return True if the structure is indeed a tree, i.e., no node has more than one parent
const bool IsTree() const;
/// \brief Check if this node is a leaf node.
/// \brief Check if this node is a leaf node.
/// \return True if this is a leaf node.
const bool IsLeaf() const { return children_.empty(); }
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
/// \return True if the dataset represented by this node is a mappable dataset
const bool IsMappable() const { return mappable_; }
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
/// \return True if this node is a mappable dataset
const bool IsMappable() const { return (mappable_ == kMappableSource); }
/// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes
/// \brief Check if this node is a non-mappable dataset. Only applicable to leaf nodes
/// \return True if this node is a non-mappable dataset
const bool IsNonMappable() const { return (mappable_ == kNonMappableSource); }
/// \brief Check if this node is not a data source node.
/// \return True if this node is not a data source node
const bool IsNotADataSource() const { return (mappable_ == kNotADataSource); }
/// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes
/// \return True if a cache-enabled operator is an ancestor of this node
const bool IsDescendantOfCache() const { return descendant_of_cache_; }
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
void HasCacheAbove() { descendant_of_cache_ = true; }
/// \brief Getter of the number of workers
int32_t num_workers() { return num_workers_; }
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object
@ -247,7 +261,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
protected:
std::vector<std::shared_ptr<DatasetNode>> children_;
std::vector<DatasetNode *> parent_;
DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase
std::shared_ptr<DatasetCache> cache_;
int64_t dataset_size_ = -1;
int32_t num_workers_;
@ -257,7 +271,8 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
std::string PrintColumns(const std::vector<std::string> &columns) const;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
void PrintNode(std::ostream &out, int *level) const;
bool mappable_;
enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 };
enum DataSource mappable_;
bool descendant_of_cache_;
};
@ -265,12 +280,12 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
class MappableSourceNode : public DatasetNode {
public:
/// \brief Constructor
MappableSourceNode() : DatasetNode() { mappable_ = true; }
MappableSourceNode() : DatasetNode() { mappable_ = kMappableSource; }
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
mappable_ = true;
mappable_ = kMappableSource;
// Initially set to false, and set to true by the optimizer when conditions are met.
descendant_of_cache_ = false;
}
@ -287,12 +302,12 @@ class MappableSourceNode : public DatasetNode {
class NonMappableSourceNode : public DatasetNode {
public:
/// \brief Constructor
NonMappableSourceNode() : DatasetNode() { mappable_ = false; }
NonMappableSourceNode() : DatasetNode() { mappable_ = kNonMappableSource; }
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
mappable_ = false;
mappable_ = kNonMappableSource;
// Initially set to false, and set to true by the optimizer when conditions are met.
descendant_of_cache_ = false;
}

View File

@ -27,13 +27,14 @@ namespace mindspore {
namespace dataset {
// Constructor for RootNode
RootNode::RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) {
// The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode)
RootNode::RootNode(std::shared_ptr<DatasetNode> child) : DatasetNode() {
// The root node's parent must remain nullptr, which is set in the constructor of DatasetNode.
AddChild(child);
}
std::shared_ptr<DatasetNode> RootNode::Copy() {
auto node = std::make_shared<RootNode>(nullptr, num_epochs_);
auto node = std::make_shared<RootNode>(nullptr);
node->SetNumEpochs(num_epochs_);
return node;
}
@ -54,7 +55,7 @@ Status RootNode::ValidateParams() {
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (parent_.size() != 0) {
if (parent_ != nullptr) {
std::string err_msg = "Internal error: root node should not have a parent";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);

View File

@ -29,7 +29,10 @@ namespace dataset {
class RootNode : public DatasetNode {
public:
/// \brief Constructor
RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
RootNode() : DatasetNode() {}
/// \brief Constructor
explicit RootNode(std::shared_ptr<DatasetNode> child);
/// \brief Destructor
~RootNode() = default;
@ -54,6 +57,9 @@ class RootNode : public DatasetNode {
/// \brief Getter of number of epochs
int32_t num_epochs() { return num_epochs_; }
/// \brief Setter of number of epochs
void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

View File

@ -7,6 +7,7 @@ add_library(engine-opt OBJECT
pre/cache_error_pass.cc
pre/cache_transform_pass.cc
pre/cache_validation_pass.cc
pre/deep_copy_pass.cc
pre/epoch_ctrl_pass.cc
pre/epoch_injection_pass.cc
pre/getter_pass.cc

View File

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
namespace mindspore {
namespace dataset {
DeepCopyPass::DeepCopyPass() {
root_ = std::make_shared<RootNode>();
parent_ = root_.get();
}
Status DeepCopyPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
*modified = true;
// Do a nested-loop walk to check whether a node has the same child more than once.
// This is an artificial restriction. We can support it since we will do a clone of the input tree in this pass.
// Example: ds2 = ds1 + ds1;
auto children = node->Children();
if (children.size() > 0) {
for (auto it1 = children.begin(); it1 != children.end() - 1; ++it1) {
for (auto it2 = it1 + 1; it2 != children.end(); ++it2) {
if (*it1 == *it2) {
std::string err_msg = "The same node " + (*it1)->Name() + " is a child of its parent more than once.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
}
}
// Clone a new copy of this node
std::shared_ptr<DatasetNode> new_node = node->Copy();
// Temporary fix to set the num_workers to each cloned node.
// This can be improved by adding a new method in the base class DatasetNode to transfer the properties to
// the cloned node. Each derived class's Copy() will need to include this method.
new_node->SetNumWorkers(node->num_workers());
// This method below assumes a DFS walk and from the first child to the last child.
// Future: A more robust implementation that does not depend on the above assumption.
parent_->AppendChild(new_node);
// Then set this node to be a new parent to accept a copy of its next child
parent_ = new_node.get();
return Status::OK();
}
Status DeepCopyPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) {
*modified = true;
// After visit the node, move up to its parent
parent_ = parent_->Parent();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_
#define DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
/// \class DeepCopyPass
/// \brief This pass clones a new copy of IR tree. A new copy is used in the compilation to avoid any modification to
/// the IR tree associated with the user code.
class DeepCopyPass : public IRNodePass {
public:
/// \brief Constructor
DeepCopyPass();
/// \brief Destructor
~DeepCopyPass() = default;
/// \brief Clone a new copy of the node
/// \param[in] node The node being visited
/// \param[inout] *modified indicates whether the node has been visited
/// \return Status code
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
/// \brief Reset parent after walking its sub tree.
/// \param[in] node The node being visited
/// \param[inout] *modified indicates whether the node has been visited
/// \return Status code
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
/// \brief Getter method to retrieve the root node
/// \return the root node of the new cloned tree
std::shared_ptr<RootNode> Root() { return root_; }
private:
std::shared_ptr<RootNode> root_;
DatasetNode *parent_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
@ -24,6 +25,19 @@ namespace dataset {
Status InputValidationPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
*modified = false;
RETURN_IF_NOT_OK(node->ValidateParams());
// A data source node must be a leaf node
if ((node->IsMappable() || node->IsNonMappable()) && !node->IsLeaf()) {
std::string err_msg = node->Name() + " is a data source and must be a leaf node.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
// A non-leaf node must not be a data source node
if (node->IsNotADataSource() && node->IsLeaf()) {
std::string err_msg = node->Name() + " is a dataset operator and must not be a leaf node.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
} // namespace dataset

View File

@ -20,6 +20,7 @@
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
@ -27,6 +28,11 @@
namespace mindspore {
namespace dataset {
TreeAdapter::TreeAdapter() {
tree_state_ = kCompileStateInit;
optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false;
}
Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
// Vector of actions in pre-pass phase
std::vector<std::unique_ptr<IRPass>> actions;
@ -86,7 +92,7 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
return Status::OK();
}
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
// Build the DatasetOp ExecutionTree from the optimized IR tree
std::vector<std::shared_ptr<DatasetOp>> ops;
RETURN_IF_NOT_OK(ir->Build(&ops));
@ -104,47 +110,20 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::sha
// Build the children of IR, once they return, add the return value to *op
for (std::shared_ptr<DatasetNode> child_ir : ir->Children()) {
std::shared_ptr<DatasetOp> child_op;
RETURN_IF_NOT_OK(BuildExecutionTree(child_ir, &child_op));
RETURN_IF_NOT_OK(BuildExecutionTreeRecur(child_ir, &child_op));
RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops
}
return Status::OK();
}
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
optimize_ = true; // Always ON (temporary)
RETURN_UNEXPECTED_IF_NULL(input_ir);
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
// We will first walk the input tree to sanity check this is not a graph
// Flag an error when it is not a tree
CHECK_FAIL_RETURN_UNEXPECTED(input_ir->IsTree(), "The data pipeline is not a tree (i.e. one node has two consumers)");
// Copy the input IR tree and insert under the root node
// Create a root node to host the new copy of the input IR tree to pass to the optimizer
auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs);
MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n';
// Pre-pass of the IR tree
RETURN_IF_NOT_OK(PrePass(root_ir));
// Optional phase of optimization
if (optimize_) {
RETURN_IF_NOT_OK(Optimize(root_ir));
}
// Post-pass of the IR tree
RETURN_IF_NOT_OK(PostPass(root_ir));
MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n';
Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) {
// This will evolve in the long run
tree_ = std::make_unique<ExecutionTree>();
// Build the Execution tree from the child of the IR root node, which represent the root of the input IR tree
std::shared_ptr<DatasetOp> root_op;
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir->Children()[0], &root_op));
RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op));
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
@ -165,6 +144,48 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
return Status::OK();
}
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
RETURN_UNEXPECTED_IF_NULL(input_ir);
tree_state_ = kCompileStateIRGraphBuilt;
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
// Clone the input IR tree and insert under the root node
// Create a root node to host the new copy of the input IR tree
// This is done so that the compilation will process and modify the tree
// without changing the tree associated with the user code.
// The tree from the user code is permitted to form a graph where any node
// is consumed by more than one parent. However, this cloning process here
// will break the graph into a tree by copying each consumption of a node into a new copy.
bool m = false;
DeepCopyPass cloning_tree;
RETURN_IF_NOT_OK(cloning_tree.Run(input_ir, &m));
std::shared_ptr<RootNode> root_ir = cloning_tree.Root();
root_ir->SetNumEpochs(num_epochs);
tree_state_ = kCompileStateIRTreeCloned;
MS_LOG(INFO) << "Plan before optimization:" << '\n' << *root_ir << '\n';
// Pre-pass of the IR tree
RETURN_IF_NOT_OK(PrePass(root_ir));
// Optional phase of optimization
if (optimize_) {
RETURN_IF_NOT_OK(Optimize(root_ir));
}
// Post-pass of the IR tree
RETURN_IF_NOT_OK(PostPass(root_ir));
tree_state_ = kCompileStateOptimized;
MS_LOG(INFO) << "Plan after optimization:" << '\n' << *root_ir << '\n';
RETURN_IF_NOT_OK(Build(root_ir, num_epochs));
tree_state_ = kCompileStateReady;
return Status::OK();
}
Status TreeAdapter::GetNext(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_UNEXPECTED_IF_NULL(row);

View File

@ -33,7 +33,7 @@ class DatasetNode;
class TreeAdapter {
public:
TreeAdapter() = default;
TreeAdapter();
~TreeAdapter() = default;
@ -68,28 +68,40 @@ class TreeAdapter {
bool OptimizationEnabled() const { return optimize_; }
private:
// This function runs a mandatory pass checking the syntax and semantics of the IR tree.
// Run the mandatory pass checking the syntax and semantics of the IR tree
Status PrePass(std::shared_ptr<DatasetNode> ir);
// This function runs an optional optimization pass on the IR tree.
// Run the optional optimization pass on the IR tree
Status Optimize(std::shared_ptr<DatasetNode> ir);
// This function runs a mandatory pass augmenting the IR tree before the execution.
// Run the mandatory pass augmenting the IR tree
Status PostPass(std::shared_ptr<DatasetNode> ir);
// Build an Execution tree
Status Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs);
// This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree.
Status BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);
Status BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);
std::unique_ptr<DataBuffer> cur_db_;
std::unordered_map<std::string, int32_t> column_name_map_;
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
int32_t num_epochs_;
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
bool optimize_; // Flag to enable optional optimization pass
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
int32_t cur_batch_num_; // current batch number, used for profiling
int32_t cur_connector_size_; // current connector size of root op, used for profiling
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
// State flags for the lifecycle of the tree
enum CompileState {
kCompileStateInit = 0, // The freshly initialized state
kCompileStateIRGraphBuilt, // User code has been parsed and its IR graph built
kCompileStateIRTreeCloned, // IR tree has been cloned from the IR graph
kCompileStateOptimized, // IR tree has been optimized
kCompileStateReady // Execution tree is generated from the optimized IR
};
CompileState tree_state_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -65,7 +65,6 @@ SET(DE_UT_SRCS
image_folder_op_test.cc
image_process_test.cc
interrupt_test.cc
ir_node_test.cc
jieba_tokenizer_op_test.cc
main_test.cc
map_op_test.cc

View File

@ -1,137 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "common/common.h"
#include "gtest/gtest.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
using namespace mindspore::dataset;
class MindDataTestIRNodes : public UT::DatasetOpTesting {
public:
MindDataTestIRNodes() = default;
// compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code
// if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same)
Status CompareTwoTrees(std::shared_ptr<DatasetNode> root1, std::shared_ptr<DatasetNode> root2, bool flag) {
CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr.");
if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) {
std::string err_msg =
"Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!";
RETURN_STATUS_UNEXPECTED(err_msg);
}
size_t num_child = root1->Children().size();
CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(),
root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " +
std::to_string(root2->Children().size()) + " child.");
for (size_t ind = 0; ind < num_child; ind++) {
RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag));
}
return Status::OK();
}
// print the node's name in post order
Status PostOrderPrintTree(std::shared_ptr<DatasetNode> ir, std::string &names) {
RETURN_UNEXPECTED_IF_NULL(ir);
for (auto child : ir->Children()) {
RETURN_IF_NOT_OK(PostOrderPrintTree(child, names));
}
names += (ir->Name() + "->");
return Status::OK();
}
};
TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy.";
auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode();
auto tree2 = tree1->DeepCopy();
std::string tree_1_names, tree_2_names;
ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
// expected output for the 2 names:
// RandomDataset->Repeat->Project->Shuffle->Batch->
EXPECT_EQ(tree_1_names, tree_2_names);
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
// verify compare function is correct
EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError());
}
TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy.";
auto branch1 = RandomData(44)->Project({"label"});
auto branch2 = RandomData(44)->Shuffle(10);
auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode();
auto tree2 = tree1->DeepCopy();
std::string tree_1_names, tree_2_names;
ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
// expected output for the 2 names:
// RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch->
EXPECT_EQ(tree_1_names, tree_2_names);
// verify the pointer within the same tree are the same
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
// verify two trees
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
}
TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove.";
auto branch1 = RandomData(44)->Project({"label"});
auto branch2 = ImageFolder("path");
auto tree = Zip({branch1, branch2})->IRNode();
/***
tree looks like this, we will remove node and test its functionalities
Zip
/ \
Project ImageFolder
/
RandomData
***/
auto tree_copy_1 = tree->DeepCopy();
ASSERT_EQ(tree_copy_1->Children().size(), 2);
// remove the project in the tree and test
ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree
ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false));
// remove the ImageFolder, a leaf node from the tree
std::string tree_1_names, tree_2_names;
ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names));
EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->");
auto tree_copy_2 = tree->DeepCopy();
ASSERT_EQ(tree_copy_2->Children().size(), 2);
tree_copy_2->Children()[1]->Remove();
ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names));
EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->");
}