From e28fb6ce4de1f702dd1d09e69750025157fa80ef Mon Sep 17 00:00:00 2001 From: Nat Sutyanyong Date: Fri, 8 Jan 2021 22:28:43 -0500 Subject: [PATCH] Tested with new test cases and all dataset UTs passed --- .../engine/ir/datasetops/concat_node.cc | 1 + .../engine/ir/datasetops/dataset_node.cc | 350 ++++++++--- .../engine/ir/datasetops/dataset_node.h | 37 +- .../dataset/engine/ir/datasetops/zip_node.cc | 1 + .../engine/opt/pre/cache_validation_pass.cc | 2 +- .../dataset/engine/opt/pre/deep_copy_pass.cc | 2 +- .../dataset/engine/opt/pre/epoch_ctrl_pass.cc | 11 +- .../engine/opt/pre/input_validation_pass.cc | 2 +- .../engine/opt/pre/node_removal_pass.cc | 36 +- .../engine/opt/pre/node_removal_pass.h | 21 +- .../minddata/dataset/engine/tree_adapter.cc | 4 +- .../minddata/dataset/engine/tree_adapter.h | 4 + tests/ut/cpp/dataset/CMakeLists.txt | 12 +- .../dataset/tree_modifying_function_test.cc | 567 ++++++++++++++++++ 14 files changed, 883 insertions(+), 167 deletions(-) create mode 100644 tests/ut/cpp/dataset/tree_modifying_function_test.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index b8d46dcf543..997c13fd6e9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -35,6 +35,7 @@ ConcatNode::ConcatNode(const std::vector> &datasets : sampler_(sampler), children_flag_and_nums_(children_flag_and_nums), children_start_end_index_(children_start_end_index) { + nary_op_ = true; for (auto const &child : datasets) AddChild(child); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 11f99465cc5..5d7c6a36a0d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -221,15 +221,20 @@ std::shared_ptr DatasetNode::SetNumWorkers(int32_t num_workers) { return shared_from_this(); } -DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}), dataset_size_(-1) { +DatasetNode::DatasetNode() + : cache_(nullptr), + parent_(nullptr), + children_({}), + dataset_size_(-1), + mappable_(kNotADataSource), + nary_op_(false), + descendant_of_cache_(false) { // Fetch some default value from config manager std::shared_ptr cfg = GlobalContext::config_manager(); num_workers_ = cfg->num_parallel_workers(); rows_per_buffer_ = cfg->rows_per_buffer(); connector_que_size_ = cfg->op_connector_size(); worker_connector_size_ = cfg->worker_connector_size(); - mappable_ = kNotADataSource; - descendant_of_cache_ = false; } std::string DatasetNode::PrintColumns(const std::vector &columns) const { @@ -283,95 +288,268 @@ void DatasetNode::AddChild(std::shared_ptr child) { } } -// Add the input node to be the next child of this node -// This function is used in doing a deep copy of the IR tree built by parsing the user code. -// This function assumes we walk the tree in DFS left-to-right. -// This is a temporary function to be replaced later by a set of better tree operations. -void DatasetNode::AppendChild(std::shared_ptr child) { - if (child != nullptr) { - if (child->parent_ != nullptr) { - MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent"; - } - children_.push_back(child); - child->parent_ = this; - } +/* + * AppendChild() appending as the last child of this node. The new node must have no parent. + * + * Input tree: + * ds4 + * / \ + * ds3 ds2 + * | + * ds1 + * + * ds4->AppendChild(ds6) yields this tree + * + * _ ds4 _ + * / | \ + * ds3 ds2 ds6 + * | + * ds1 + * + */ +Status DatasetNode::AppendChild(std::shared_ptr child) { + CHECK_FAIL_RETURN_UNEXPECTED(child != nullptr, "Node to append must not be a null pointer."); + CHECK_FAIL_RETURN_UNEXPECTED(child->parent_ == nullptr, "Node to append must have no parent."); + CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(), + "This node must be a unary operator with no child or an n-ary operator"); + children_.push_back(child); + child->parent_ = this; + return Status::OK(); } -// Add a node as a parent, node's parent needs to be empty (future use) -Status DatasetNode::InsertAbove(std::shared_ptr node) { - CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer."); +/* + * InsertChildAt(, ) inserts the to be at the index of the vector of its child nodes. + * As in the convention of C++, starts at position 0. + * If the is a negative number or larger than the size of the vector minus one, an error is raised. + */ +Status DatasetNode::InsertChildAt(int32_t pos, std::shared_ptr child) { + CHECK_FAIL_RETURN_UNEXPECTED(pos > -1 && pos <= children_.size(), "Position must in the range of [0, size]"); + CHECK_FAIL_RETURN_UNEXPECTED(child != nullptr, "Node to insert must not be a null pointer."); + CHECK_FAIL_RETURN_UNEXPECTED(child->parent_ == nullptr, "Node to insert must have no parent."); + CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(), + "This node must be a unary operator with no child or an n-ary operator"); + children_.insert(children_.begin() + pos, child); + child->parent_ = this; + return Status::OK(); +} - if (node->parent_ != nullptr) { - DatasetNode *parent = node->parent_; - for (auto i = parent->children_.size() - 1; i >= 0; --i) { - if (parent->children_[i] == node) { - parent->children_[i] = static_cast>(this); +/* + * Insert the input above this node + * Input tree: + * ds4 + * / \ + * ds3 ds2 + * | + * ds1 + * + * Case 1: If we want to insert a new node ds5 between ds4 and ds3, use + * ds3->InsertAbove(ds5) + * + * ds4 + * / \ + * ds5 ds2 + * | + * ds3 + * | + * ds1 + * + * Case 2: Likewise, ds2->InsertAbove(ds6) yields + * + * ds4 + * / \ + * ds3 ds6 + * | | + * ds1 ds2 + * + * Case 3: We can insert a new node between ds3 and ds1 by ds1->InsertAbove(ds7) + * + * ds4 + * / \ + * ds3 ds2 + * | + * ds7 + * | + * ds1 + * + * InsertAbove() cannot use on the root node of a tree. + */ +Status DatasetNode::InsertAbove(std::shared_ptr node) { + CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Node to insert must not be a null pointer."); + CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Node to insert must have no parent."); + CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must not be the root or a node without parent."); + auto parent = parent_; + + // The following fields of these three nodes are changed in this function: + // 1. parent->children_ + // 2. node->parent_ and node->children_ + // 3. this->parent_ + auto current_node_itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this()); + *current_node_itr = node; + node->parent_ = parent; + node->children_.push_back(shared_from_this()); + parent_ = node.get(); + + return Status::OK(); +} + +/* + * Drop() detaches this node from the tree it is in. Calling Drop() from a standalone node is a no-op. + * + * Input tree: + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds8 ds5 ds4 ds1 + * | / \ + * ds7 ds3 ds2 + * + * Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree. + * + * ds7->Drop() yields the tree below: + * + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds8 ds5 ds4 ds1 + * / \ + * ds3 ds2 + * + * Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child + * becomes its parent's child. + * + * ds8->Drop() yields the tree below: + * + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds7 ds5 ds4 ds1 + * / \ + * ds3 ds2 + * + * Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and the node's + * children become its parent's children. + * + * When the input tree is + * + * ds10 + * / \ + * ds9 ds6 + * | | + * ds8 ds4 + * | / \ + * ds7 ds3 ds2 + * + * ds4->Drop() yields the tree below: + * + * ds10 + * / \ + * ds9 ds6 + * | / \ + * ds8 ds3 ds2 + * | + * ds7 + * + * But if ds6 is not an n-ary operator, ds4->Drop() will raise an error because we cannot add the children of an + * n-ary operator (ds4) to a unary operator (ds6). + * + * Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will be + * squeezed left. + * + * Input tree: + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds8 ds5 ds4 ds1 + * | / \ + * ds7 ds3 ds2 + * + * ds5->Drop() yields the tree below: + * + * ds10 + * / \ + * ds9 ds6 + * | / \ + * ds8 ds4 ds1 + * | / \ + * ds7 ds3 ds2 + * + * Case 5: When the node has more than one child and more than one sibling, Drop() will raise an error. + * If we want to drop ds4 from the input tree, ds4->Drop() will not work. We will have to do it + * with a combination of Drop(), InsertChildAt() + * + * Input tree: + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds8 ds5 ds4 ds1 + * | / \ + * ds7 ds3 ds2 + * + * If we want to form this tree below: + * + * ds10 + * / \ + * ds9 ds6_____ + * | / | | \ + * ds8 ds5 ds3 ds2 ds1 + * | + * ds7 + * + */ +Status DatasetNode::Drop() { + CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node to drop must not be the root or a node without parent."); + CHECK_FAIL_RETURN_UNEXPECTED(!(IsNaryOperator() && parent_->IsUnaryOperator()), + "Trying to drop an n-ary operator that is a child of a unary operator"); + CHECK_FAIL_RETURN_UNEXPECTED(!(children_.size() > 1 && parent_->children_.size() > 1), + "This node to drop must not have more than one child and more than one sibling."); + CHECK_FAIL_RETURN_UNEXPECTED(children_.size() == 0 || parent_->children_.size() == 1, + "If this node to drop has children, it must be its parent's only child."); + if (parent_->children_.size() == 1) { + auto parent = parent_; + // Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child + // becomes its parent's child. + // This is the most common use case. + if (children_.size() == 1) { + auto child = children_[0]; + // Move its child to be its parent's child + parent->children_[0] = child; + child->parent_ = parent; + } else if (children_.empty()) { + // Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree. + // Remove this node from its parent's child + parent_->children_.clear(); + } else if (children_.size() > 1) { + // Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and + // the node's children become its parent's children. + // Remove this node from its parent's child + parent->children_.clear(); + // Move its child to be its parent's child + for (auto &child : children_) { + parent->children_.push_back(child); + child->parent_ = parent; } } + // And mark itself as an orphan + parent_ = nullptr; + children_.clear(); + } else if (children_.empty() && parent_->children_.size() > 1) { + // Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will + // be squeezed left. + auto parent = parent_; + // Remove this node from its parent's child + parent->children_.erase(std::remove(parent->children_.begin(), parent->children_.end(), shared_from_this()), + parent->children_.end()); // removal using "erase remove idiom" + // And mark itself as an orphan + parent_ = nullptr; + children_.clear(); + } else { + RETURN_STATUS_UNEXPECTED("Internal error: we should not reach here."); } - children_.push_back(node); - node->parent_ = this; - - return Status::OK(); -} - -// Insert a node as a child of this node -// This node's children become the children of the inserted node. -Status DatasetNode::InsertBelow(std::shared_ptr node) { - CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer."); - CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children."); - CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Inserted node must not have a parent."); - - for (auto child : children_) { - node->children_.push_back(child); - child->parent_ = node.get(); - } - // Then establish the new parent-child relationship with the new parent. - children_.clear(); - children_.push_back(node); - node->parent_ = this; - return Status::OK(); -} - -// Insert a node as a child next to this node (future use) -Status DatasetNode::InsertAfter(std::shared_ptr node) { - CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must have a parent."); - CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Inserted node must not have a parent."); - auto size = parent_->children_.size(); - // Duplicate the last child to increase the size by 1 - parent_->children_.push_back(parent_->children_[size - 1]); - // Shift each child to its right until we found the insertion point, then insert the input node - bool found = false; - for (auto i = parent_->children_.size() - 2; i >= 0; --i) { - if (parent_->children_[i].get() != this) { - parent_->children_[i + 1] = parent_->children_[i]; - } else { - parent_->children_[i + 1] = node; - node->parent_ = parent_; - found = true; - break; - } - } - CHECK_FAIL_RETURN_UNEXPECTED(!found, "Insertion point not found."); - return Status::OK(); -} - -// Remove this node from its parent. Add the child of this node to its parent. -// for now, this remove is limited to node with a single child or no child -Status DatasetNode::Remove() { - CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent."); - CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child."); - if (children_.empty()) { // I am a leaf node, remove me from my parent's children list - parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()), - parent_->children_.end()); // removal using "erase remove idiom" - } else { // replace my position in my parent's children list with my single child - auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this()); - CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list."); - children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent - *itr = std::move(children_[0]); // replace me in my parent's children list with my single child - children_.clear(); // release my single child from my children list - } - parent_ = nullptr; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 97051eed433..1419a173159 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -140,7 +140,7 @@ class DatasetNode : public std::enable_shared_from_this { /// \param out - The output stream to write output to virtual void Print(std::ostream &out) const = 0; - /// \brief Pure virtual function to make a new copy of the node + /// \brief Pure virtual function to clone a new copy of the node /// \return The new copy of the node virtual std::shared_ptr Copy() = 0; @@ -187,26 +187,19 @@ class DatasetNode : public std::enable_shared_from_this { /// \return The parent node (of a node from a cloned IR tree) DatasetNode *const Parent() const { return parent_; } - /// \brief Establish a parent-child relationship between this node and the input node. - /// Used when building the IR tree. - void AddChild(std::shared_ptr child); - /// \brief Establish a parent-child relationship between this node and the input node. /// Used during the cloning of the user-input IR tree (temporary use) - void AppendChild(std::shared_ptr child); + Status AppendChild(std::shared_ptr child); - /// \brief Establish the child-parent relationship between this node and the input node (future use) + /// \brief Insert the input above this node Status InsertAbove(std::shared_ptr node); - /// \brief Insert the input node below this node. This node's children becomes the children of the inserted node. - Status InsertBelow(std::shared_ptr node); - /// \brief Add the input node as the next sibling (future use) - Status InsertAfter(std::shared_ptr node); + Status InsertChildAt(int32_t pos, std::shared_ptr node); /// \brief detach this node from its parent, add its child (if any) to its parent /// \return error code, return error if node has more than 1 children - Status Remove(); + Status Drop(); /// \brief Check if this node has cache /// \return True if the data of this node will be cached @@ -216,13 +209,25 @@ class DatasetNode : public std::enable_shared_from_this { /// \return True if this is a leaf node. const bool IsLeaf() const { return children_.empty(); } + /// \brief Check if this node is a unary operator node. + /// \return True if this node is semantically a unary operator node + const bool IsUnaryOperator() const { return (mappable_ == kNotADataSource && !nary_op_); } + + /// \brief Check if this node is a n-ary operator node. + /// \return True if this node is semantically a n-ary operator node + const bool IsNaryOperator() const { return (mappable_ == kNotADataSource && nary_op_); } + /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes /// \return True if this node is a mappable dataset - const bool IsMappable() const { return (mappable_ == kMappableSource); } + const bool IsMappableDataSource() const { return (mappable_ == kMappableSource); } /// \brief Check if this node is a non-mappable dataset. Only applicable to leaf nodes /// \return True if this node is a non-mappable dataset - const bool IsNonMappable() const { return (mappable_ == kNonMappableSource); } + const bool IsNonMappableDataSource() const { return (mappable_ == kNonMappableSource); } + + /// \brief Check if this node is a data source node. + /// \return True if this node is a data source node + const bool IsDataSource() const { return (mappable_ == kMappableSource || mappable_ == kNonMappableSource); } /// \brief Check if this node is not a data source node. /// \return True if this node is not a data source node @@ -285,11 +290,15 @@ class DatasetNode : public std::enable_shared_from_this { int32_t rows_per_buffer_; int32_t connector_que_size_; int32_t worker_connector_size_; + // Establish a parent-child relationship between this node and the input node. + // Used only in the constructor of the class and its derived classes. + void AddChild(std::shared_ptr child); std::string PrintColumns(const std::vector &columns) const; Status AddCacheOp(std::vector> *node_ops); void PrintNode(std::ostream &out, int *level) const; enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 }; enum DataSource mappable_; + bool nary_op_; // an indicator of whether the current node supports multiple children, true for concat/zip node bool descendant_of_cache_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc index 1d7ee6be5ea..4c69f8b7473 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc @@ -28,6 +28,7 @@ namespace mindspore { namespace dataset { ZipNode::ZipNode(const std::vector> &datasets) { + nary_op_ = true; for (auto const &child : datasets) AddChild(child); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc index c3cff08d235..981334e0d05 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc @@ -135,7 +135,7 @@ Status CacheValidationPass::Visit(std::shared_ptr node, bool *const // If this node is created to be cached, set the flag. is_cached_ = true; } - if (node->IsLeaf() && node->IsMappable()) { + if (node->IsLeaf() && node->IsMappableDataSource()) { is_mappable_ = true; } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.cc index 4691d201aa4..7a1c80a97d0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/deep_copy_pass.cc @@ -52,7 +52,7 @@ Status DeepCopyPass::Visit(std::shared_ptr node, bool *const modifi new_node->SetNumWorkers(node->num_workers()); // This method below assumes a DFS walk and from the first child to the last child. // Future: A more robust implementation that does not depend on the above assumption. - parent_->AppendChild(new_node); + RETURN_IF_NOT_OK(parent_->AppendChild(new_node)); // Then set this node to be a new parent to accept a copy of its next child parent_ = new_node.get(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc index 014c7f0147a..545bef5a96b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc @@ -31,7 +31,7 @@ EpochCtrlPass::InjectionFinder::InjectionFinder(std::shared_ptr nod // Performs finder work for BuildVocabOp that has special rules about epoch control injection Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr node, bool *const modified) { // The injection is at the child of the root node - injection_point_ = node; + injection_point_ = node->Children()[0]; num_epochs_ = node->num_epochs(); return Status::OK(); } @@ -53,7 +53,7 @@ Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr node, bool *const modified) { // Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here. // Move the injection point to the child of this node. - injection_point_ = node; + injection_point_ = node->Children()[0]; return Status::OK(); } @@ -71,12 +71,11 @@ Status EpochCtrlPass::RunOnTree(std::shared_ptr root_ir, bool *cons // The first injection logic is to check if we should inject the epoch control op as the root node. // Do not inject the op if the number of epochs is 1. - std::shared_ptr parent = finder.injection_point(); + std::shared_ptr node = finder.injection_point(); int32_t num_epochs = finder.num_epochs(); - if (num_epochs != 1 && parent != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(parent->Children().size() == 1, "EpochCtrl must be injected on only one child."); + if (num_epochs != 1 && node != nullptr) { auto epoch_ctrl_node = std::make_shared(num_epochs); - RETURN_IF_NOT_OK(parent->InsertBelow(epoch_ctrl_node)); + RETURN_IF_NOT_OK(node->InsertAbove(epoch_ctrl_node)); } MS_LOG(INFO) << "Pre pass: Injection pass complete."; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc index fda92dcd36e..d4019c3c170 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc @@ -27,7 +27,7 @@ Status InputValidationPass::Visit(std::shared_ptr node, bool *const RETURN_IF_NOT_OK(node->ValidateParams()); // A data source node must be a leaf node - if ((node->IsMappable() || node->IsNonMappable()) && !node->IsLeaf()) { + if ((node->IsMappableDataSource() || node->IsNonMappableDataSource()) && !node->IsLeaf()) { std::string err_msg = node->Name() + " is a data source and must be a leaf node."; RETURN_STATUS_UNEXPECTED(err_msg); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc index 5dc2baa893a..73369135917 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc @@ -14,41 +14,15 @@ * limitations under the License. */ -#include -#include #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" -#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/skip_node.h" #include "minddata/dataset/engine/ir/datasetops/take_node.h" namespace mindspore { namespace dataset { -NodeRemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} - -// Identifies the subtree below this node as a cached descendant tree. -Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr node, bool *const modified) { - *modified = false; - MS_LOG(INFO) << "Node removal pass: Operation with cache found, identified descendant tree."; - if (node->IsCached()) { - is_caching_ = true; - } - return Status::OK(); -} - -// Resets the tracking of the cache within the tree -Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr node, bool *const modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: Descendant walk is complete."; - if (is_caching_ && node->IsLeaf()) { - // Mark this leaf node to indicate it is a descendant of an operator with cache. - // This is currently used in non-mappable dataset (leaf) nodes to not add a ShuffleOp in DatasetNode::Build(). - node->HasCacheAbove(); - } - is_caching_ = false; - return Status::OK(); -} +NodeRemovalPass::RemovalNodes::RemovalNodes() {} // Perform RepeatNode removal check. Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr node, bool *const modified) { @@ -59,12 +33,6 @@ Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr node, bo return Status::OK(); } -// Perform ShuffleNode removal check. -Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr node, bool *const modified) { - *modified = false; - return Status::OK(); -} - // Perform SkipNode removal check. Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr node, bool *const modified) { *modified = false; @@ -95,7 +63,7 @@ Status NodeRemovalPass::RunOnTree(std::shared_ptr root_ir, bool *co // Then, execute the removal of any nodes that were set up for removal for (auto node : removal_nodes->nodes_to_remove()) { - RETURN_IF_NOT_OK(node->Remove()); + RETURN_IF_NOT_OK(node->Drop()); } MS_LOG(INFO) << "Pre pass: node removal pass complete."; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h index 865902b9cdf..492cab1f6c6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h @@ -31,7 +31,7 @@ class DatasetOp; /// nodes should be removed, and then removes them. class NodeRemovalPass : public IRTreePass { /// \class RemovalNodes - /// \brief This is a NodePass who's job is to identify which nodes should be removed. + /// \brief This is a NodePass whose job is to identify which nodes should be removed. /// It works in conjunction with the removal_pass. class RemovalNodes : public IRNodePass { public: @@ -42,30 +42,12 @@ class NodeRemovalPass : public IRTreePass { /// \brief Destructor ~RemovalNodes() = default; - /// \brief Identifies the subtree below this node as a cached descendant tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status Visit(std::shared_ptr node, bool *const modified) override; - - /// \brief Resets the tracking of the cache within the tree - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status VisitAfter(std::shared_ptr node, bool *const modified) override; - /// \brief Perform RepeatNode removal check /// \param[in] node The node being visited /// \param[inout] modified Indicator if the node was changed at all /// \return Status The status code returned Status Visit(std::shared_ptr node, bool *const modified) override; - /// \brief Perform ShuffleNode removal check - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status Visit(std::shared_ptr node, bool *const modified) override; - /// \brief Perform SkipNode removal check /// \param[in] node The node being visited /// \param[inout] modified Indicator if the node was changed at all @@ -83,7 +65,6 @@ class NodeRemovalPass : public IRTreePass { std::vector> nodes_to_remove() { return nodes_to_remove_; } private: - bool is_caching_; std::vector> nodes_to_remove_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index ec7e3e823d2..ed4b3ebdc51 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -187,8 +187,10 @@ Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_e tree_state_ = kCompileStateOptimized; MS_LOG(INFO) << "Plan after optimization:" << '\n' << *root_ir << '\n'; + // Remember the root node + root_ir_ = root_ir; - RETURN_IF_NOT_OK(Build(root_ir, num_epochs)); + RETURN_IF_NOT_OK(Build(root_ir_, num_epochs)); tree_state_ = kCompileStateReady; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index 53c1e3dd549..71644f0f995 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -46,6 +46,9 @@ class TreeAdapter { // the Execution tree. Status Compile(std::shared_ptr root_ir, int32_t num_epochs = -1); + // Return the root node of the IR after cloned from the parsed IR tree + std::shared_ptr RootIRNode() const { return root_ir_; } + // This is the main method TreeConsumer uses to interact with TreeAdapter // 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared) // 2. GetNext will return empty row when eoe/eof is obtained @@ -87,6 +90,7 @@ class TreeAdapter { std::unique_ptr cur_db_; std::unordered_map column_name_map_; + std::shared_ptr root_ir_; std::unique_ptr tree_; // current connector capacity of root op, used for profiling bool optimize_; // Flag to enable optional optimization pass std::shared_ptr tracing_; // trace profiling data diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 59f92349c42..1e6caf1bb86 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -131,13 +131,14 @@ SET(DE_UT_SRCS to_float16_op_test.cc tokenizer_op_test.cc treap_test.cc + tree_modifying_function_test.cc trucate_pair_test.cc type_cast_op_test.cc weighted_random_sampler_test.cc zip_op_test.cc ) -if (ENABLE_PYTHON) +if(ENABLE_PYTHON) set(DE_UT_SRCS ${DE_UT_SRCS} filter_op_test.cc @@ -145,13 +146,18 @@ if (ENABLE_PYTHON) voc_op_test.cc sentence_piece_vocab_op_test.cc ) -endif () +endif() add_executable(de_ut_tests ${DE_UT_SRCS}) set_target_properties(de_ut_tests PROPERTIES INSTALL_RPATH "$ORIGIN/../lib:$ORIGIN/../lib64") -target_link_libraries(de_ut_tests PRIVATE _c_dataengine pybind11::embed ${GTEST_LIBRARY} ${SECUREC_LIBRARY} ${SLOG_LIBRARY}) +target_link_libraries(de_ut_tests PRIVATE + _c_dataengine pybind11::embed + ${GTEST_LIBRARY} + ${SECUREC_LIBRARY} + ${SLOG_LIBRARY} + ) gtest_discover_tests(de_ut_tests WORKING_DIRECTORY ${Project_DIR}/tests/dataset) diff --git a/tests/ut/cpp/dataset/tree_modifying_function_test.cc b/tests/ut/cpp/dataset/tree_modifying_function_test.cc new file mode 100644 index 00000000000..713481da1b8 --- /dev/null +++ b/tests/ut/cpp/dataset/tree_modifying_function_test.cc @@ -0,0 +1,567 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common.h" +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" +#include "minddata/dataset/engine/ir/datasetops/skip_node.h" +#include "minddata/dataset/engine/ir/datasetops/take_node.h" +#include "minddata/dataset/engine/ir/datasetops/repeat_node.h" +#include "minddata/dataset/include/datasets.h" + +using namespace mindspore::dataset; + +class MindDataTestTreeModifying : public UT::DatasetOpTesting { + public: + MindDataTestTreeModifying() = default; +}; + +TEST_F(MindDataTestTreeModifying, AppendChild) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-AppendChild"; + /* + * Input tree: + * ds4 + * / \ + * ds3 ds2 + * | + * ds1 + * + * ds4->AppendChild(ds6) yields this tree + * + * _ ds4 _ + * / | \ + * ds3 ds2 ds6 + * | + * ds1 + */ + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds6 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds3 = ds1->Take(10); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + Status rc; + + std::shared_ptr root = ds4->IRNode(); + auto ir_tree = std::make_shared(); + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + // You can inspect the plan by sending *ir_tree->RootIRNode() to std::cout + std::shared_ptr node_to_insert = ds6->IRNode(); + rc = ds4_node->AppendChild(node_to_insert); + EXPECT_EQ(rc, Status::OK()); + EXPECT_TRUE( ds4_node->Children()[2] == node_to_insert); + EXPECT_TRUE(node_to_insert->Parent() == ds4_node.get()); +} + +TEST_F(MindDataTestTreeModifying, InsertChildAt01) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-InsertChildAt01"; + /* + * Input tree: + * ds4 + * / \ + * ds3 ds2 + * | | + * ds1 ds5 + * + * Case 1: ds4->InsertChildAt(1, ds6) yields this tree + * + * _ ds4 _ + * / | \ + * ds3 ds6 ds2 + * | | + * ds1 ds5 + * + * Case 2: ds4->InsertChildAt(0, ds6) yields this tree + * + * _ ds4 _ + * / | \ + * ds6 ds3 ds2 + * | | + * ds1 ds5 + * + * Case 3: ds4->InsertChildAt(2, ds6) yields this tree + * + * _ ds4 _ + * / | \ + * ds3 ds2 ds6 + * | | + * ds1 ds5 + * + */ + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds3 = ds1->Take(10); + std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ds5->Repeat(4); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + Status rc; + std::shared_ptr root = ds4->IRNode(); + auto ir_tree = std::make_shared(); + + // Case 1: + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds6 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds6_to_insert = ds6->IRNode(); + std::shared_ptr ds2_node = ds4_node->Children()[1]; + rc = ds4_node->InsertChildAt(1, ds6_to_insert); + EXPECT_EQ(rc, Status::OK()); + EXPECT_TRUE( ds4_node->Children()[1] == ds6_to_insert); + EXPECT_TRUE(ds6_to_insert->Parent() == ds4_node.get()); + EXPECT_TRUE( ds4_node->Children()[2] == ds2_node); + + // Case 2: + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + ds6 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + ds6_to_insert = ds6->IRNode(); + std::shared_ptr ds3_node = ds4_node->Children()[0]; + rc = ds4_node->InsertChildAt(0, ds6_to_insert); + EXPECT_EQ(rc, Status::OK()); + EXPECT_TRUE( ds4_node->Children()[0] == ds6_to_insert); + EXPECT_TRUE(ds6_to_insert->Parent() == ds4_node.get()); + EXPECT_TRUE( ds4_node->Children()[1] == ds3_node); + + // Case 3: + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + ds6 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + ds6_to_insert = ds6->IRNode(); + rc = ds4_node->InsertChildAt(2, ds6_to_insert); + EXPECT_EQ(rc, Status::OK()); + EXPECT_TRUE( ds4_node->Children()[2] == ds6_to_insert); + EXPECT_TRUE(ds6_to_insert->Parent() == ds4_node.get()); +} + +TEST_F(MindDataTestTreeModifying, InsertChildAt04) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-InsertChildAt04"; + + /* + * Input tree: + * ds4 + * / \ + * ds3 ds2 + * | | + * ds1 ds5 + * + */ + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds3 = ds1->Take(10); + std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ds5->Repeat(4); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + Status rc; + std::shared_ptr root = ds4->IRNode(); + auto ir_tree = std::make_shared(); + + // Case 4: ds4->InsertChildAt(3, ds6) raises an error + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds6 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds6_to_insert = ds6->IRNode(); + std::shared_ptr ds3_node = ds4_node->Children()[0]; + std::shared_ptr ds2_node = ds4_node->Children()[1]; + rc = ds4_node->InsertChildAt(3, ds6_to_insert); + EXPECT_NE(rc, Status::OK()); + EXPECT_TRUE( ds4_node->Children()[0] == ds3_node); + EXPECT_TRUE( ds4_node->Children()[1] == ds2_node); + + // Case 5: ds4->InsertChildAt(-1, ds6) raises an error + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + ds6 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + ds6_to_insert = ds6->IRNode(); + ds3_node = ds4_node->Children()[0]; + ds2_node = ds4_node->Children()[1]; + rc = ds4_node->InsertChildAt(-1, ds6_to_insert); + EXPECT_NE(rc, Status::OK()); + EXPECT_TRUE( ds4_node->Children()[0] == ds3_node); + EXPECT_TRUE( ds4_node->Children()[1] == ds2_node); +} + +TEST_F(MindDataTestTreeModifying, InsertAbove01) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-InsertAbove01"; + /* + * Insert the input above this node + * Input tree: + * ds4 + * / \ + * ds3 ds2 + * | + * ds1 + * + * Case 1: If we want to insert a new node ds5 between ds4 and ds3, use + * ds3->InsertAbove(ds5) + * + * ds4 + * / \ + * ds5 ds2 + * | + * ds3 + * | + * ds1 + * + * Case 2: Likewise, ds2->InsertAbove(ds6) yields + * + * ds4 + * / \ + * ds3 ds6 + * | | + * ds1 ds2 + * + * Case 3: We can insert a new node between ds3 and ds1 by ds1->InsertAbove(ds7) + * + * ds4 + * / \ + * ds3 ds2 + * | + * ds7 + * | + * ds1 + * + */ + // Case 1 + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds3 = ds1->Take(10); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + Status rc; + + std::shared_ptr root = ds4->IRNode(); + auto ir_tree = std::make_shared(); + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds3_node = ds4_node->Children()[0]; + std::shared_ptr ds5_to_insert = std::make_shared(nullptr, 1); + rc = ds3_node->InsertAbove(ds5_to_insert); + EXPECT_EQ(rc, Status::OK()); + EXPECT_TRUE(ds5_to_insert->Children()[0] == ds3_node); + EXPECT_TRUE( ds3_node->Parent() == ds5_to_insert.get()); + EXPECT_TRUE( ds4_node->Children()[0] == ds5_to_insert); + EXPECT_TRUE( ds5_to_insert->Parent() == ds4_node.get()); +} + +TEST_F(MindDataTestTreeModifying, InsertAbove02) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-InsertAbove02"; + + // Case 2 + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds3 = ds1->Take(10); + std::shared_ptr ds4 = ds2 + ds3; // ds2 is the second child and ds3 is the first child!!! + Status rc; + + std::shared_ptr root = ds4->IRNode(); + auto ir_tree = std::make_shared(); + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds2_node = ds4_node->Children()[1]; + std::shared_ptr ds6_to_insert = std::make_shared(nullptr, 12); + rc = ds2_node->InsertAbove(ds6_to_insert); + EXPECT_EQ(rc, Status::OK()); + EXPECT_TRUE(ds6_to_insert->Children()[0] == ds2_node); + EXPECT_TRUE( ds2_node->Parent() == ds6_to_insert.get()); + EXPECT_TRUE( ds4_node->Children()[1] == ds6_to_insert); + EXPECT_TRUE( ds6_to_insert->Parent() == ds4_node.get()); +} + +TEST_F(MindDataTestTreeModifying, InsertAbove03) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-InsertAbove03"; + + // Case 3 + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds3 = ds1->Take(10); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + Status rc; + + std::shared_ptr root = ds4->IRNode(); + auto ir_tree = std::make_shared(); + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds4_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds3_node = ds4_node->Children()[0]; + std::shared_ptr ds1_node = ds3_node->Children()[0]; + std::shared_ptr ds7_to_insert = std::make_shared(nullptr, 3); + rc = ds1_node->InsertAbove(ds7_to_insert); + EXPECT_TRUE(ds7_to_insert->Children()[0] == ds1_node); + EXPECT_TRUE( ds1_node->Parent() == ds7_to_insert.get()); + EXPECT_TRUE( ds3_node->Children()[0] == ds7_to_insert); + EXPECT_TRUE( ds7_to_insert->Parent() == ds3_node.get()); +} + +TEST_F(MindDataTestTreeModifying, Drop01) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-Drop01"; + /* + * Drop() detaches this node from the tree it is in. Calling Drop() from a standalone node is a no-op. + * + * Input tree: + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds8 ds5 ds4 ds1 + * | / \ + * ds7 ds3 ds2 + * + * Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree. + * + * ds7->Drop() yields the tree below: + * + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds8 ds5 ds4 ds1 + * / \ + * ds3 ds2 + * + * Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child + * becomes its parent's child. + * + * ds8->Drop() yields the tree below: + * + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds7 ds5 ds4 ds1 + * / \ + * ds3 ds2 + * + */ + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds7 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds8 = ds7->Take(20); + std::shared_ptr ds9 = ds8->Skip(1); + std::shared_ptr ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds6 = ds4->Take(13); + std::shared_ptr ds10 = ds6 + ds9; + Status rc; + + std::shared_ptr root = ds10->IRNode(); + auto ir_tree = std::make_shared(); + + // Case 1 + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds10_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds9_node = ds10_node->Children()[0]; + std::shared_ptr ds8_node = ds9_node->Children()[0]; + std::shared_ptr ds7_node = ds8_node->Children()[0]; + rc = ds7_node->Drop(); + EXPECT_EQ(rc, Status::OK()); + // ds8 becomes a childless node + EXPECT_TRUE(ds8_node->Children().empty()); + EXPECT_TRUE(ds7_node->Parent() == nullptr); + EXPECT_TRUE(ds7_node->Children().empty()); + + // Case 2 + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + ds10_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + ds9_node = ds10_node->Children()[0]; + ds8_node = ds9_node->Children()[0]; + ds7_node = ds8_node->Children()[0]; + rc = ds8_node->Drop(); + EXPECT_EQ(rc, Status::OK()); + // ds7 becomes a child of ds9 + EXPECT_TRUE(ds9_node->Children()[0] == ds7_node); + EXPECT_TRUE(ds7_node->Parent() == ds9_node.get()); + EXPECT_TRUE(ds8_node->Parent() == nullptr); + EXPECT_TRUE(ds8_node->Children().empty()); +} + +TEST_F(MindDataTestTreeModifying, Drop03) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-Drop03"; + /* Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and the node's + * children become its parent's children. + * + * When the input tree is + * ds10 + * / \ + * ds9 ds6 + * | | + * ds8 ds4 + * | / \ + * ds7 ds3 ds2 + * + * + * ds4->Drop() will raise an error because we cannot add the children of an n-ary operator (ds4) to a unary operator + * (ds6). + * + */ + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds7 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds8 = ds7->Take(20); + std::shared_ptr ds9 = ds8->Skip(1); + std::shared_ptr ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds6 = ds4->Take(13); + std::shared_ptr ds10 = ds6 + ds9; + Status rc; + + std::shared_ptr root = ds10->IRNode(); + auto ir_tree = std::make_shared(); + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds10_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds6_node = ds10_node->Children()[1]; + std::shared_ptr ds4_node = ds6_node->Children()[0]; + std::shared_ptr ds3_node = ds4_node->Children()[0]; + std::shared_ptr ds2_node = ds4_node->Children()[1]; + rc = ds4_node->Drop(); + EXPECT_NE(rc, Status::OK()); +} + +TEST_F(MindDataTestTreeModifying, Drop04) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-Drop04"; + /* Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will be + * squeezed left. + * + * Input tree: + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds8 ds5 ds4 ds1 + * | / \ + * ds7 ds3 ds2 + * + * ds5->Drop() yields the tree below: + * + * ds10 + * / \ + * ds9 ds6 + * | / \ + * ds8 ds4 ds1 + * | / \ + * ds7 ds3 ds2 + * + */ + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds7 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds8 = ds7->Take(20); + std::shared_ptr ds9 = ds8->Skip(1); + std::shared_ptr ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!! + std::shared_ptr ds10 = ds6 + ds9; + Status rc; + + std::shared_ptr root = ds10->IRNode(); + auto ir_tree = std::make_shared(); + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds10_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds6_node = ds10_node->Children()[1]; + std::shared_ptr ds5_node = ds6_node->Children()[0]; + std::shared_ptr ds4_node = ds6_node->Children()[1]; + EXPECT_TRUE(ds5_node->IsDataSource()); + EXPECT_TRUE(ds6_node->IsNaryOperator()); + rc = ds5_node->Drop(); + EXPECT_EQ(rc, Status::OK()); + EXPECT_TRUE(ds6_node->Children().size() == 2); + EXPECT_TRUE(ds6_node->Children()[0] == ds4_node); + EXPECT_TRUE(ds4_node->Parent() == ds6_node.get()); + EXPECT_TRUE(ds5_node->Parent() == nullptr); + EXPECT_TRUE(ds5_node->Children().empty()); +} + +TEST_F(MindDataTestTreeModifying, Drop05) { + MS_LOG(INFO) << "Doing MindDataTestTreeModifying-Drop05"; + /* + * Case 5: When the node has more than one child and more than one sibling, Drop() will raise an error. + * If we want to drop ds4 from the input tree, ds4->Drop() will not work. We will have to do it + * with a combination of Drop(), InsertChildAt() + * + * Input tree: + * ds10 + * / \ + * ds9 ds6 + * | / | \ + * ds8 ds5 ds4 ds1 + * | / \ + * ds7 ds3 ds2 + * + * If we want to form this tree below: + * + * ds10 + * / \ + * ds9 ds6_____ + * | / | | \ + * ds8 ds5 ds3 ds2 ds1 + * | + * ds7 + * + */ + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds7 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds8 = ds7->Take(20); + std::shared_ptr ds9 = ds8->Skip(1); + std::shared_ptr ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!! + std::shared_ptr ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11)); + std::shared_ptr ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!! + std::shared_ptr ds10 = ds6 + ds9; + Status rc; + + std::shared_ptr root = ds10->IRNode(); + auto ir_tree = std::make_shared(); + rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree + EXPECT_EQ(rc, Status::OK()); + // Descend two levels as Compile adds the root node and the epochctrl node on top of ds4 + std::shared_ptr ds10_node = ir_tree->RootIRNode()->Children()[0]->Children()[0]; + std::shared_ptr ds6_node = ds10_node->Children()[1]; + std::shared_ptr ds4_node = ds6_node->Children()[1]; + rc = ds4_node->Drop(); + EXPECT_NE(rc, Status::OK()); +}