diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index eed410f22a9..97c1b4d7c42 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -556,8 +556,8 @@ std::shared_ptr Dataset::BuildSentencePieceVocab( const std::vector &col_names, uint32_t vocab_size, float character_coverage, SentencePieceModel model_type, const std::unordered_map ¶ms) { auto vocab = std::make_shared(); - auto ds = std::make_shared(IRNode()->DeepCopy(), vocab, col_names, vocab_size, - character_coverage, model_type, params); + auto ds = std::make_shared(IRNode(), vocab, col_names, vocab_size, character_coverage, + model_type, params); std::unique_ptr runtime_context = std::make_unique(); Status rc = runtime_context->Init(); @@ -588,8 +588,8 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum const std::pair &freq_range, int64_t top_k, const std::vector &special_tokens, bool special_first) { auto vocab = std::make_shared(); - auto ds = std::make_shared(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens, - special_first); + auto ds = + std::make_shared(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); std::unique_ptr runtime_context = std::make_unique(); Status rc = runtime_context->Init(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 0e9d9aa977d..c1b42bf935c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -206,9 +206,6 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const { } } -// Getter function to get all of our children. -std::vector> DatasetOp::children() const { return child_; } - // Getter function to get all of our parents. std::vector DatasetOp::parents() const { return parent_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index 39ab602fe9a..811cb7ff087 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -111,9 +111,6 @@ class DatasetOp : public std::enable_shared_from_this { /// \param[in] parent_index An operator can have n parents. Indicates which parent to return. void Parent(DatasetOp **parent, int32_t parent_index) const; - // Getter function to get all of our children. - std::vector> children() const; - // Getter function to get all of our parents. std::vector parents() const; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 1a314858d42..f9164723928 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -233,40 +233,14 @@ std::shared_ptr DatasetNode::SetNumWorkers(int32_t num_workers) { return shared_from_this(); } -DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) { +DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) { // Fetch some default value from config manager std::shared_ptr cfg = GlobalContext::config_manager(); num_workers_ = cfg->num_parallel_workers(); rows_per_buffer_ = cfg->rows_per_buffer(); connector_que_size_ = cfg->op_connector_size(); worker_connector_size_ = cfg->worker_connector_size(); -} - -const bool DatasetNode::IsTree() const { - bool is_tree = true; - if (this->parent_.size() > 1) { - MS_LOG(WARNING) << Name() << " has more than one parent."; - return false; - } - for (const auto &child : children_) { - is_tree = child->IsTree(); - if (!is_tree) { - MS_LOG(WARNING) << Name() << " has more than one parent."; - break; - } - } - return is_tree; -} - -// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied -std::shared_ptr DatasetNode::DeepCopy() { - std::shared_ptr new_node = this->Copy(); - // temporary fix to set the num_workers to the new node. - new_node->SetNumWorkers(this->num_workers_); - for (const auto &child : children_) { - new_node->AddChild(child->DeepCopy()); - } - return new_node; + mappable_ = kNotADataSource; } std::string DatasetNode::PrintColumns(const std::vector &columns) const { @@ -310,54 +284,105 @@ void DatasetNode::PrintNode(std::ostream &out, int *level) const { } // Add a node as a child, node's parent needs to be empty -// this function will allow child to be a nullptr, in which case it will simply skip +// This function will allow child to be a nullptr, in which case it will simply skip. +// This function is used only when building IR node one by one from parsing the user code. +// During the parsing, we allow a node to have more than one parent, possibly forming a graph. +// It does not maintain the parent_ attribute of the node, which enforces a single parent and a tree structure. void DatasetNode::AddChild(std::shared_ptr child) { - if (child != nullptr && child->parent_.empty()) { + if (child != nullptr) { children_.push_back(child); - child->parent_.push_back(this); - } else if (child != nullptr) { - MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent"; - children_.push_back(child); - child->parent_.push_back(this); } } -// Insert a node as a child of this node. This node's children becomes the children of the inserted node. +// Add the input node to be the next child of this node +// This function is used in doing a deep copy of the IR tree built by parsing the user code. +// This function assumes we walk the tree in DFS left-to-right. +// This is a temporary function to be replaced later by a set of better tree operations. +void DatasetNode::AppendChild(std::shared_ptr child) { + if (child != nullptr) { + if (child->parent_ != nullptr) { + MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent"; + } + children_.push_back(child); + child->parent_ = this; + } +} + +// Add a node as a parent, node's parent needs to be empty (future use) +Status DatasetNode::InsertAbove(std::shared_ptr node) { + CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer."); + + if (node->parent_ != nullptr) { + DatasetNode *parent = node->parent_; + for (auto i = parent->children_.size() - 1; i >= 0; --i) { + if (parent->children_[i] == node) { + parent->children_[i] = static_cast>(this); + } + } + } + children_.push_back(node); + node->parent_ = this; + + return Status::OK(); +} + +// Insert a node as a child of this node +// This node's children become the children of the inserted node. Status DatasetNode::InsertBelow(std::shared_ptr node) { CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer."); CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children."); - CHECK_FAIL_RETURN_UNEXPECTED(node->parent_.empty(), "Inserted node must not have a parent."); + CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Inserted node must not have a parent."); for (auto child : children_) { node->children_.push_back(child); - child->parent_.clear(); - child->parent_.push_back(node.get()); + child->parent_ = node.get(); } // Then establish the new parent-child relationship with the new parent. children_.clear(); children_.push_back(node); - node->parent_.clear(); - node->parent_.push_back(this); + node->parent_ = this; + return Status::OK(); +} + +// Insert a node as a child next to this node (future use) +Status DatasetNode::InsertAfter(std::shared_ptr node) { + CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must have a parent."); + CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Inserted node must not have a parent."); + auto size = parent_->children_.size(); + // Duplicate the last child to increase the size by 1 + parent_->children_.push_back(parent_->children_[size - 1]); + // Shift each child to its right until we found the insertion point, then insert the input node + bool found = false; + for (auto i = parent_->children_.size() - 2; i >= 0; --i) { + if (parent_->children_[i].get() != this) { + parent_->children_[i + 1] = parent_->children_[i]; + } else { + parent_->children_[i + 1] = node; + node->parent_ = parent_; + found = true; + break; + } + } + CHECK_FAIL_RETURN_UNEXPECTED(!found, "Insertion point not found."); return Status::OK(); } // Remove this node from its parent. Add the child of this node to its parent. // for now, this remove is limited to node with a single child or no child Status DatasetNode::Remove() { - CHECK_FAIL_RETURN_UNEXPECTED(parent_.size() != 0, "Cannot remove root or a node without parent."); + CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent."); CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child."); if (children_.empty()) { // I am a leaf node, remove me from my parent's children list - parent_[0]->children_.erase( - std::remove(parent_[0]->children_.begin(), parent_[0]->children_.end(), shared_from_this()), - parent_[0]->children_.end()); // removal using "erase remove idiom" - } else { // replace my position in my parent's children list with my single child - auto itr = std::find(parent_[0]->children_.begin(), parent_[0]->children_.end(), shared_from_this()); - CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_[0]->children_.end(), "I am not in my parent's children list."); + parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()), + parent_->children_.end()); // removal using "erase remove idiom" + } else { // replace my position in my parent's children list with my single child + auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this()); + CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list."); children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent *itr = std::move(children_[0]); // replace me in my parent's children list with my single child - children_.clear(); // release my single child from my children list + children_.clear(); // release my single child from my children list } - parent_[0] = nullptr; + parent_ = nullptr; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index f037a460ba8..a2480b65d6f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -146,10 +146,6 @@ class DatasetNode : public std::enable_shared_from_this { return out; } - /// \brief Make a new copy of the tree from the current node - /// \return The new copy of the tree - std::shared_ptr DeepCopy(); - /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create /// \return Status Status::OK() if build successfully @@ -175,43 +171,61 @@ class DatasetNode : public std::enable_shared_from_this { /// \return Child nodes const std::vector> Children() const { return children_; } - /// \brief Getter function for parents nodes - /// \return Parent nodes - const std::vector Parent() const { return parent_; } + /// \brief Getter function for the parent node + /// \return The parent node (of a node from a cloned IR tree) + DatasetNode *Parent() const { return parent_; } - /// \brief Establish the parent-child relationship between this node and its child. + /// \brief Establish a parent-child relationship between this node and the input node. + /// Used when building the IR tree. void AddChild(std::shared_ptr child); + /// \brief Establish a parent-child relationship between this node and the input node. + /// Used during the cloning of the user-input IR tree (temporary use) + void AppendChild(std::shared_ptr child); + + /// \brief Establish the child-parent relationship between this node and the input node (future use) + Status InsertAbove(std::shared_ptr node); + /// \brief Insert the input node below this node. This node's children becomes the children of the inserted node. Status InsertBelow(std::shared_ptr node); + /// \brief Add the input node as the next sibling (future use) + Status InsertAfter(std::shared_ptr node); + /// \brief detach this node from its parent, add its child (if any) to its parent /// \return error code, return error if node has more than 1 children Status Remove(); - /// \brief Check if this node has cache + /// \brief Check if this node has cache /// \return True if the data of this node will be cached const bool IsCached() const { return (cache_ != nullptr); } - /// \brief Check if this node is a tree - /// \return True if the structure is indeed a tree, i.e., no node has more than one parent - const bool IsTree() const; - - /// \brief Check if this node is a leaf node. + /// \brief Check if this node is a leaf node. /// \return True if this is a leaf node. const bool IsLeaf() const { return children_.empty(); } - /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes - /// \return True if the dataset represented by this node is a mappable dataset - const bool IsMappable() const { return mappable_; } + /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes + /// \return True if this node is a mappable dataset + const bool IsMappable() const { return (mappable_ == kMappableSource); } - /// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes + /// \brief Check if this node is a non-mappable dataset. Only applicable to leaf nodes + /// \return True if this node is a non-mappable dataset + const bool IsNonMappable() const { return (mappable_ == kNonMappableSource); } + + /// \brief Check if this node is not a data source node. + /// \return True if this node is not a data source node + const bool IsNotADataSource() const { return (mappable_ == kNotADataSource); } + + /// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes /// \return True if a cache-enabled operator is an ancestor of this node const bool IsDescendantOfCache() const { return descendant_of_cache_; } - /// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes + /// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes void HasCacheAbove() { descendant_of_cache_ = true; } + /// \brief Getter of the number of workers + int32_t num_workers() { return num_workers_; } + /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object @@ -247,7 +261,7 @@ class DatasetNode : public std::enable_shared_from_this { protected: std::vector> children_; - std::vector parent_; + DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase std::shared_ptr cache_; int64_t dataset_size_ = -1; int32_t num_workers_; @@ -257,7 +271,8 @@ class DatasetNode : public std::enable_shared_from_this { std::string PrintColumns(const std::vector &columns) const; Status AddCacheOp(std::vector> *node_ops); void PrintNode(std::ostream &out, int *level) const; - bool mappable_; + enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 }; + enum DataSource mappable_; bool descendant_of_cache_; }; @@ -265,12 +280,12 @@ class DatasetNode : public std::enable_shared_from_this { class MappableSourceNode : public DatasetNode { public: /// \brief Constructor - MappableSourceNode() : DatasetNode() { mappable_ = true; } + MappableSourceNode() : DatasetNode() { mappable_ = kMappableSource; } /// \brief Constructor that initializes the cache /// \param dataset_cache DatasetCache explicit MappableSourceNode(const std::shared_ptr &dataset_cache) : DatasetNode(dataset_cache) { - mappable_ = true; + mappable_ = kMappableSource; // Initially set to false, and set to true by the optimizer when conditions are met. descendant_of_cache_ = false; } @@ -287,12 +302,12 @@ class MappableSourceNode : public DatasetNode { class NonMappableSourceNode : public DatasetNode { public: /// \brief Constructor - NonMappableSourceNode() : DatasetNode() { mappable_ = false; } + NonMappableSourceNode() : DatasetNode() { mappable_ = kNonMappableSource; } /// \brief Constructor that initializes the cache /// \param dataset_cache DatasetCache explicit NonMappableSourceNode(const std::shared_ptr &dataset_cache) : DatasetNode(dataset_cache) { - mappable_ = false; + mappable_ = kNonMappableSource; // Initially set to false, and set to true by the optimizer when conditions are met. descendant_of_cache_ = false; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc index 55ed794e10a..b0375e27914 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc @@ -27,13 +27,14 @@ namespace mindspore { namespace dataset { // Constructor for RootNode -RootNode::RootNode(std::shared_ptr child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) { - // The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode) +RootNode::RootNode(std::shared_ptr child) : DatasetNode() { + // The root node's parent must remain nullptr, which is set in the constructor of DatasetNode. AddChild(child); } std::shared_ptr RootNode::Copy() { - auto node = std::make_shared(nullptr, num_epochs_); + auto node = std::make_shared(nullptr); + node->SetNumEpochs(num_epochs_); return node; } @@ -54,7 +55,7 @@ Status RootNode::ValidateParams() { MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (parent_.size() != 0) { + if (parent_ != nullptr) { std::string err_msg = "Internal error: root node should not have a parent"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h index 9dbdfecf820..a6a1ea2b598 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h @@ -29,7 +29,10 @@ namespace dataset { class RootNode : public DatasetNode { public: /// \brief Constructor - RootNode(std::shared_ptr child, int32_t num_epochs); + RootNode() : DatasetNode() {} + + /// \brief Constructor + explicit RootNode(std::shared_ptr child); /// \brief Destructor ~RootNode() = default; @@ -54,6 +57,9 @@ class RootNode : public DatasetNode { /// \brief Getter of number of epochs int32_t num_epochs() { return num_epochs_; } + /// \brief Setter of number of epochs + void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; } + /// \brief Parameters validation /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index 44847ade49b..303e985ce41 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -7,6 +7,7 @@ add_library(engine-opt OBJECT pre/cache_error_pass.cc pre/cache_transform_pass.cc pre/cache_validation_pass.cc + pre/deep_copy_pass.cc pre/epoch_ctrl_pass.cc pre/epoch_injection_pass.cc pre/getter_pass.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.cc new file mode 100644 index 00000000000..cd20df18dbe --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" +#include "minddata/dataset/engine/ir/datasetops/root_node.h" + +namespace mindspore { +namespace dataset { + +DeepCopyPass::DeepCopyPass() { + root_ = std::make_shared(); + parent_ = root_.get(); +} + +Status DeepCopyPass::Visit(std::shared_ptr node, bool *modified) { + *modified = true; + // Do a nested-loop walk to check whether a node has the same child more than once. + // This is an artificial restriction. We can support it since we will do a clone of the input tree in this pass. + // Example: ds2 = ds1 + ds1; + auto children = node->Children(); + if (children.size() > 0) { + for (auto it1 = children.begin(); it1 != children.end() - 1; ++it1) { + for (auto it2 = it1 + 1; it2 != children.end(); ++it2) { + if (*it1 == *it2) { + std::string err_msg = "The same node " + (*it1)->Name() + " is a child of its parent more than once."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + } + } + + // Clone a new copy of this node + std::shared_ptr new_node = node->Copy(); + // Temporary fix to set the num_workers to each cloned node. + // This can be improved by adding a new method in the base class DatasetNode to transfer the properties to + // the cloned node. Each derived class's Copy() will need to include this method. + new_node->SetNumWorkers(node->num_workers()); + // This method below assumes a DFS walk and from the first child to the last child. + // Future: A more robust implementation that does not depend on the above assumption. + parent_->AppendChild(new_node); + + // Then set this node to be a new parent to accept a copy of its next child + parent_ = new_node.get(); + return Status::OK(); +} + +Status DeepCopyPass::VisitAfter(std::shared_ptr node, bool *modified) { + *modified = true; + // After visit the node, move up to its parent + parent_ = parent_->Parent(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.h new file mode 100644 index 00000000000..76daa72a71a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_ +#define DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_ + +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +/// \class DeepCopyPass +/// \brief This pass clones a new copy of IR tree. A new copy is used in the compilation to avoid any modification to +/// the IR tree associated with the user code. +class DeepCopyPass : public IRNodePass { + public: + /// \brief Constructor + DeepCopyPass(); + + /// \brief Destructor + ~DeepCopyPass() = default; + + /// \brief Clone a new copy of the node + /// \param[in] node The node being visited + /// \param[inout] *modified indicates whether the node has been visited + /// \return Status code + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Reset parent after walking its sub tree. + /// \param[in] node The node being visited + /// \param[inout] *modified indicates whether the node has been visited + /// \return Status code + Status VisitAfter(std::shared_ptr node, bool *modified) override; + + /// \brief Getter method to retrieve the root node + /// \return the root node of the new cloned tree + std::shared_ptr Root() { return root_; } + + private: + std::shared_ptr root_; + DatasetNode *parent_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PRE_DEEP_COPY_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc index 8fff7473b03..5ede8c12d60 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "minddata/dataset/include/datasets.h" #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" @@ -24,6 +25,19 @@ namespace dataset { Status InputValidationPass::Visit(std::shared_ptr node, bool *modified) { *modified = false; RETURN_IF_NOT_OK(node->ValidateParams()); + + // A data source node must be a leaf node + if ((node->IsMappable() || node->IsNonMappable()) && !node->IsLeaf()) { + std::string err_msg = node->Name() + " is a data source and must be a leaf node."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // A non-leaf node must not be a data source node + if (node->IsNotADataSource() && node->IsLeaf()) { + std::string err_msg = node->Name() + " is a dataset operator and must not be a leaf node."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index c74e750959a..b7cb9ce0b4b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -20,6 +20,7 @@ #include "minddata/dataset/engine/ir/datasetops/root_node.h" #include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" +#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" @@ -27,6 +28,11 @@ namespace mindspore { namespace dataset { +TreeAdapter::TreeAdapter() { + tree_state_ = kCompileStateInit; + optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false; +} + Status TreeAdapter::PrePass(std::shared_ptr ir) { // Vector of actions in pre-pass phase std::vector> actions; @@ -86,7 +92,7 @@ Status TreeAdapter::PostPass(std::shared_ptr ir) { return Status::OK(); } -Status TreeAdapter::BuildExecutionTree(std::shared_ptr ir, std::shared_ptr *op) { +Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr ir, std::shared_ptr *op) { // Build the DatasetOp ExecutionTree from the optimized IR tree std::vector> ops; RETURN_IF_NOT_OK(ir->Build(&ops)); @@ -104,47 +110,20 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr ir, std::sha // Build the children of IR, once they return, add the return value to *op for (std::shared_ptr child_ir : ir->Children()) { std::shared_ptr child_op; - RETURN_IF_NOT_OK(BuildExecutionTree(child_ir, &child_op)); + RETURN_IF_NOT_OK(BuildExecutionTreeRecur(child_ir, &child_op)); RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops } return Status::OK(); } -Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_epochs) { - optimize_ = true; // Always ON (temporary) - - RETURN_UNEXPECTED_IF_NULL(input_ir); - MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n'; - - // We will first walk the input tree to sanity check this is not a graph - // Flag an error when it is not a tree - CHECK_FAIL_RETURN_UNEXPECTED(input_ir->IsTree(), "The data pipeline is not a tree (i.e. one node has two consumers)"); - - // Copy the input IR tree and insert under the root node - // Create a root node to host the new copy of the input IR tree to pass to the optimizer - auto root_ir = std::make_shared(input_ir->DeepCopy(), num_epochs); - MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n'; - - // Pre-pass of the IR tree - RETURN_IF_NOT_OK(PrePass(root_ir)); - - // Optional phase of optimization - if (optimize_) { - RETURN_IF_NOT_OK(Optimize(root_ir)); - } - - // Post-pass of the IR tree - RETURN_IF_NOT_OK(PostPass(root_ir)); - - MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n'; - +Status TreeAdapter::Build(std::shared_ptr root_ir, int32_t num_epochs) { // This will evolve in the long run tree_ = std::make_unique(); // Build the Execution tree from the child of the IR root node, which represent the root of the input IR tree std::shared_ptr root_op; - RETURN_IF_NOT_OK(BuildExecutionTree(root_ir->Children()[0], &root_op)); + RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); @@ -165,6 +144,48 @@ Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_e return Status::OK(); } +Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_epochs) { + RETURN_UNEXPECTED_IF_NULL(input_ir); + + tree_state_ = kCompileStateIRGraphBuilt; + MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n'; + + // Clone the input IR tree and insert under the root node + // Create a root node to host the new copy of the input IR tree + // This is done so that the compilation will process and modify the tree + // without changing the tree associated with the user code. + // The tree from the user code is permitted to form a graph where any node + // is consumed by more than one parent. However, this cloning process here + // will break the graph into a tree by copying each consumption of a node into a new copy. + bool m = false; + DeepCopyPass cloning_tree; + RETURN_IF_NOT_OK(cloning_tree.Run(input_ir, &m)); + std::shared_ptr root_ir = cloning_tree.Root(); + root_ir->SetNumEpochs(num_epochs); + + tree_state_ = kCompileStateIRTreeCloned; + MS_LOG(INFO) << "Plan before optimization:" << '\n' << *root_ir << '\n'; + + // Pre-pass of the IR tree + RETURN_IF_NOT_OK(PrePass(root_ir)); + + // Optional phase of optimization + if (optimize_) { + RETURN_IF_NOT_OK(Optimize(root_ir)); + } + + // Post-pass of the IR tree + RETURN_IF_NOT_OK(PostPass(root_ir)); + + tree_state_ = kCompileStateOptimized; + MS_LOG(INFO) << "Plan after optimization:" << '\n' << *root_ir << '\n'; + + RETURN_IF_NOT_OK(Build(root_ir, num_epochs)); + tree_state_ = kCompileStateReady; + + return Status::OK(); +} + Status TreeAdapter::GetNext(TensorRow *row) { RETURN_UNEXPECTED_IF_NULL(tree_); RETURN_UNEXPECTED_IF_NULL(row); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index 6e95ae0a60c..84a08cbc400 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -33,7 +33,7 @@ class DatasetNode; class TreeAdapter { public: - TreeAdapter() = default; + TreeAdapter(); ~TreeAdapter() = default; @@ -68,28 +68,40 @@ class TreeAdapter { bool OptimizationEnabled() const { return optimize_; } private: - // This function runs a mandatory pass checking the syntax and semantics of the IR tree. + // Run the mandatory pass checking the syntax and semantics of the IR tree Status PrePass(std::shared_ptr ir); - // This function runs an optional optimization pass on the IR tree. + // Run the optional optimization pass on the IR tree Status Optimize(std::shared_ptr ir); - // This function runs a mandatory pass augmenting the IR tree before the execution. + // Run the mandatory pass augmenting the IR tree Status PostPass(std::shared_ptr ir); + // Build an Execution tree + Status Build(std::shared_ptr root_ir, int32_t num_epochs); + // This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree. - Status BuildExecutionTree(std::shared_ptr ir, std::shared_ptr *op); + Status BuildExecutionTreeRecur(std::shared_ptr ir, std::shared_ptr *op); std::unique_ptr cur_db_; std::unordered_map column_name_map_; - std::unique_ptr tree_; // current connector capacity of root op, used for profiling - int32_t num_epochs_; + std::unique_ptr tree_; // current connector capacity of root op, used for profiling bool optimize_; // Flag to enable optional optimization pass std::shared_ptr tracing_; // trace profiling data int32_t cur_batch_num_; // current batch number, used for profiling int32_t cur_connector_size_; // current connector size of root op, used for profiling int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling std::function pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() + + // State flags for the lifecycle of the tree + enum CompileState { + kCompileStateInit = 0, // The freshly initialized state + kCompileStateIRGraphBuilt, // User code has been parsed and its IR graph built + kCompileStateIRTreeCloned, // IR tree has been cloned from the IR graph + kCompileStateOptimized, // IR tree has been optimized + kCompileStateReady // Execution tree is generated from the optimized IR + }; + CompileState tree_state_; }; } // namespace dataset } // namespace mindspore diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 498959fc3f4..a14b68b5ae4 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -65,7 +65,6 @@ SET(DE_UT_SRCS image_folder_op_test.cc image_process_test.cc interrupt_test.cc - ir_node_test.cc jieba_tokenizer_op_test.cc main_test.cc map_op_test.cc diff --git a/tests/ut/cpp/dataset/ir_node_test.cc b/tests/ut/cpp/dataset/ir_node_test.cc deleted file mode 100644 index 513de18a330..00000000000 --- a/tests/ut/cpp/dataset/ir_node_test.cc +++ /dev/null @@ -1,137 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "common/common.h" -#include "gtest/gtest.h" - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" -#include "minddata/dataset/engine/opt/pre/getter_pass.h" - -using namespace mindspore::dataset; - -class MindDataTestIRNodes : public UT::DatasetOpTesting { - public: - MindDataTestIRNodes() = default; - // compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code - // if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same) - Status CompareTwoTrees(std::shared_ptr root1, std::shared_ptr root2, bool flag) { - CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr."); - if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) { - std::string err_msg = - "Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - size_t num_child = root1->Children().size(); - - CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(), - root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " + - std::to_string(root2->Children().size()) + " child."); - - for (size_t ind = 0; ind < num_child; ind++) { - RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag)); - } - return Status::OK(); - } - - // print the node's name in post order - Status PostOrderPrintTree(std::shared_ptr ir, std::string &names) { - RETURN_UNEXPECTED_IF_NULL(ir); - for (auto child : ir->Children()) { - RETURN_IF_NOT_OK(PostOrderPrintTree(child, names)); - } - names += (ir->Name() + "->"); - return Status::OK(); - } -}; - -TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) { - MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy."; - - auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode(); - - auto tree2 = tree1->DeepCopy(); - std::string tree_1_names, tree_2_names; - - ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names)); - ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names)); - - // expected output for the 2 names: - // RandomDataset->Repeat->Project->Shuffle->Batch-> - EXPECT_EQ(tree_1_names, tree_2_names); - - ASSERT_OK(CompareTwoTrees(tree1, tree1, true)); - ASSERT_OK(CompareTwoTrees(tree1, tree2, false)); - - // verify compare function is correct - EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError()); -} - -TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) { - MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy."; - - auto branch1 = RandomData(44)->Project({"label"}); - auto branch2 = RandomData(44)->Shuffle(10); - - auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode(); - - auto tree2 = tree1->DeepCopy(); - std::string tree_1_names, tree_2_names; - - ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names)); - ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names)); - - // expected output for the 2 names: - // RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch-> - EXPECT_EQ(tree_1_names, tree_2_names); - - // verify the pointer within the same tree are the same - ASSERT_OK(CompareTwoTrees(tree1, tree1, true)); - // verify two trees - ASSERT_OK(CompareTwoTrees(tree1, tree2, false)); -} - -TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) { - MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove."; - - auto branch1 = RandomData(44)->Project({"label"}); - auto branch2 = ImageFolder("path"); - auto tree = Zip({branch1, branch2})->IRNode(); - /*** - tree looks like this, we will remove node and test its functionalities - Zip - / \ - Project ImageFolder - / - RandomData - ***/ - auto tree_copy_1 = tree->DeepCopy(); - ASSERT_EQ(tree_copy_1->Children().size(), 2); - // remove the project in the tree and test - ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree - ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false)); - // remove the ImageFolder, a leaf node from the tree - std::string tree_1_names, tree_2_names; - ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names)); - EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->"); - auto tree_copy_2 = tree->DeepCopy(); - ASSERT_EQ(tree_copy_2->Children().size(), 2); - tree_copy_2->Children()[1]->Remove(); - ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names)); - EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->"); -}