forked from mindspore-Ecosystem/mindspore
Fix IR node Drop()
This commit is contained in:
parent
cf2734da8e
commit
3c5917a03d
|
@ -308,8 +308,7 @@ void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
|
|||
*
|
||||
*/
|
||||
Status DatasetNode::AppendChild(std::shared_ptr<DatasetNode> child) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(child != nullptr, "Node to append must not be a null pointer.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(child->parent_ == nullptr, "Node to append must have no parent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
|
||||
"This node must be a unary operator with no child or an n-ary operator");
|
||||
children_.push_back(child);
|
||||
|
@ -324,8 +323,7 @@ Status DatasetNode::AppendChild(std::shared_ptr<DatasetNode> child) {
|
|||
*/
|
||||
Status DatasetNode::InsertChildAt(int32_t pos, std::shared_ptr<DatasetNode> child) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(pos > -1 && pos <= children_.size(), "Position must in the range of [0, size]");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(child != nullptr, "Node to insert must not be a null pointer.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(child->parent_ == nullptr, "Node to insert must have no parent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
|
||||
"This node must be a unary operator with no child or an n-ary operator");
|
||||
children_.insert(children_.begin() + pos, child);
|
||||
|
@ -374,8 +372,7 @@ Status DatasetNode::InsertChildAt(int32_t pos, std::shared_ptr<DatasetNode> chil
|
|||
* InsertAbove() cannot use on the root node of a tree.
|
||||
*/
|
||||
Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Node to insert must not be a null pointer.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_ == nullptr, "Node to insert must have no parent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(node), "Node to insert must be an orphan node.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must not be the root or a node without parent.");
|
||||
auto parent = parent_;
|
||||
|
||||
|
@ -384,10 +381,10 @@ Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
|
|||
// 2. node->parent_ and node->children_
|
||||
// 3. this->parent_
|
||||
auto current_node_itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
|
||||
*current_node_itr = node;
|
||||
node->parent_ = parent;
|
||||
node->children_.push_back(shared_from_this());
|
||||
parent_ = node.get();
|
||||
*current_node_itr = node; // replace me in my parent's children list with the newly inserted node
|
||||
node->parent_ = parent; // set the newly inserted node's parent ptr to my parent
|
||||
node->children_.push_back(shared_from_this()); // add myself to the newly inserted node's children list
|
||||
parent_ = node.get(); // set my parent ptr to the newly inserted node
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -477,7 +474,29 @@ Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
|
|||
* | / \
|
||||
* ds7 ds3 ds2
|
||||
*
|
||||
* Case 5: When the node has more than one child and more than one sibling, Drop() will raise an error.
|
||||
* Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
|
||||
* children become its parent's children.
|
||||
*
|
||||
* Input tree:
|
||||
* ds10
|
||||
* / \
|
||||
* ds9 ds6
|
||||
* | / | \
|
||||
* ds8 ds5 ds4 ds1
|
||||
* | |
|
||||
* ds7 ds3
|
||||
*
|
||||
* ds4->Drop() yields the tree below:
|
||||
*
|
||||
* ds10
|
||||
* / \
|
||||
* ds9 ds6
|
||||
* | / | \
|
||||
* ds8 ds5 ds3 ds1
|
||||
* |
|
||||
* ds7
|
||||
*
|
||||
* Case 6: When the node has more than one child and more than one sibling, Drop() will raise an error.
|
||||
* If we want to drop ds4 from the input tree, ds4->Drop() will not work. We will have to do it
|
||||
* with a combination of Drop(), InsertChildAt()
|
||||
*
|
||||
|
@ -507,8 +526,6 @@ Status DatasetNode::Drop() {
|
|||
"Trying to drop an n-ary operator that is a child of a unary operator");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!(children_.size() > 1 && parent_->children_.size() > 1),
|
||||
"This node to drop must not have more than one child and more than one sibling.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(children_.size() == 0 || parent_->children_.size() == 1,
|
||||
"If this node to drop has children, it must be its parent's only child.");
|
||||
if (parent_->children_.size() == 1) {
|
||||
auto parent = parent_;
|
||||
// Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
|
||||
|
@ -547,6 +564,16 @@ Status DatasetNode::Drop() {
|
|||
// And mark itself as an orphan
|
||||
parent_ = nullptr;
|
||||
children_.clear();
|
||||
} else if (children_.size() == 1 && parent_->children_.size() > 1) {
|
||||
// Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
|
||||
// children become its parent's children.
|
||||
auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
|
||||
*itr = children_[0]; // replace this node in its parent's children list with its single child
|
||||
children_[0]->parent_ = parent_; // set its single child's parent ptr to its parent
|
||||
// And mark itself as an orphan
|
||||
parent_ = nullptr;
|
||||
children_.clear();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Internal error: we should not reach here.");
|
||||
}
|
||||
|
|
|
@ -237,6 +237,12 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \return True if a cache-enabled operator is an ancestor of this node
|
||||
const bool IsDescendantOfCache() const { return descendant_of_cache_; }
|
||||
|
||||
/// \brief Check if this node is an orphan node
|
||||
/// \return True if this node isn't nullptr nor does it have any children and a parent
|
||||
static bool IsOrphanNode(std::shared_ptr<DatasetNode> node) {
|
||||
return node != nullptr && node->parent_ == nullptr && node->Children().empty();
|
||||
}
|
||||
|
||||
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
|
||||
void HasCacheAbove() { descendant_of_cache_ = true; }
|
||||
|
||||
|
|
|
@ -517,7 +517,63 @@ TEST_F(MindDataTestTreeModifying, Drop04) {
|
|||
TEST_F(MindDataTestTreeModifying, Drop05) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestTreeModifying-Drop05";
|
||||
/*
|
||||
* Case 5: When the node has more than one child and more than one sibling, Drop() will raise an error.
|
||||
* Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
|
||||
* children become its parent's children.
|
||||
*
|
||||
* Input tree:
|
||||
* ds10
|
||||
* / \
|
||||
* ds9 ds6
|
||||
* | / | \
|
||||
* ds8 ds5 ds4 ds1
|
||||
* | |
|
||||
* ds7 ds3
|
||||
*
|
||||
* ds4->Drop() yields the tree below:
|
||||
*
|
||||
* ds10
|
||||
* / \
|
||||
* ds9 ds6
|
||||
* | / | \
|
||||
* ds8 ds5 ds3 ds1
|
||||
* |
|
||||
* ds7
|
||||
*
|
||||
*/
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds7 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds8 = ds7->Take(20);
|
||||
std::shared_ptr<Dataset> ds9 = ds8->Skip(1);
|
||||
std::shared_ptr<Dataset> ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Skip(1);
|
||||
std::shared_ptr<Dataset> ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!!
|
||||
std::shared_ptr<Dataset> ds10 = ds6 + ds9;
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds10->IRNode();
|
||||
auto ir_tree = std::make_shared<TreeAdapter>();
|
||||
rc = ir_tree->Compile(root); // Compile adds a new RootNode to the top of the tree
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
// Descend two levels as Compile adds the root node and the epochctrl node on top of ds4
|
||||
std::shared_ptr<DatasetNode> ds10_node = ir_tree->RootIRNode()->Children()[0]->Children()[0];
|
||||
std::shared_ptr<DatasetNode> ds6_node = ds10_node->Children()[1];
|
||||
std::shared_ptr<DatasetNode> ds4_node = ds6_node->Children()[1];
|
||||
std::shared_ptr<DatasetNode> ds3_node = ds4_node->Children()[0];
|
||||
rc = ds4_node->Drop();
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
EXPECT_TRUE(ds6_node->Children().size() == 3);
|
||||
EXPECT_TRUE(ds6_node->Children()[1] == ds3_node);
|
||||
EXPECT_TRUE(ds3_node->Parent() == ds6_node.get());
|
||||
EXPECT_TRUE(ds4_node->Parent() == nullptr);
|
||||
EXPECT_TRUE(ds4_node->Children().empty());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTreeModifying, Drop06) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestTreeModifying-Drop06";
|
||||
/*
|
||||
* Case 6: When the node has more than one child and more than one sibling, Drop() will raise an error.
|
||||
* If we want to drop ds4 from the input tree, ds4->Drop() will not work. We will have to do it
|
||||
* with a combination of Drop(), InsertChildAt()
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue