dataset: Resolve protected parent_access from DeepCopyPass

This commit is contained in:
Cathy Wong 2021-01-22 14:59:25 -05:00
parent fe3473c0cc
commit a3b10213b7
3 changed files with 4 additions and 22 deletions

View File

@ -121,6 +121,9 @@ std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int
// The base class of all IR nodes
class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
// Allow DeepCopyPass to access internal members
friend class DeepCopyPass;
public:
/// \brief Constructor
DatasetNode();
@ -183,10 +186,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Child nodes
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
/// \brief Getter function for the parent node
/// \return The parent node (of a node from a cloned IR tree)
DatasetNode *const Parent() const { return parent_; }
/// \brief Establish a parent-child relationship between this node and the input node.
/// Used during the cloning of the user-input IR tree (temporary use)
Status AppendChild(std::shared_ptr<DatasetNode> child);

View File

@ -62,7 +62,7 @@ Status DeepCopyPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modifi
Status DeepCopyPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
*modified = true;
// After visit the node, move up to its parent
parent_ = parent_->Parent();
parent_ = parent_->parent_;
return Status::OK();
}
} // namespace dataset

View File

@ -67,7 +67,6 @@ TEST_F(MindDataTestTreeModifying, AppendChild) {
rc = ds4_node->AppendChild(node_to_insert);
EXPECT_EQ(rc, Status::OK());
EXPECT_TRUE( ds4_node->Children()[2] == node_to_insert);
EXPECT_TRUE(node_to_insert->Parent() == ds4_node.get());
}
TEST_F(MindDataTestTreeModifying, InsertChildAt01) {
@ -126,7 +125,6 @@ TEST_F(MindDataTestTreeModifying, InsertChildAt01) {
rc = ds4_node->InsertChildAt(1, ds6_to_insert);
EXPECT_EQ(rc, Status::OK());
EXPECT_TRUE( ds4_node->Children()[1] == ds6_to_insert);
EXPECT_TRUE(ds6_to_insert->Parent() == ds4_node.get());
EXPECT_TRUE( ds4_node->Children()[2] == ds2_node);
// Case 2:
@ -140,7 +138,6 @@ TEST_F(MindDataTestTreeModifying, InsertChildAt01) {
rc = ds4_node->InsertChildAt(0, ds6_to_insert);
EXPECT_EQ(rc, Status::OK());
EXPECT_TRUE( ds4_node->Children()[0] == ds6_to_insert);
EXPECT_TRUE(ds6_to_insert->Parent() == ds4_node.get());
EXPECT_TRUE( ds4_node->Children()[1] == ds3_node);
// Case 3:
@ -153,7 +150,6 @@ TEST_F(MindDataTestTreeModifying, InsertChildAt01) {
rc = ds4_node->InsertChildAt(2, ds6_to_insert);
EXPECT_EQ(rc, Status::OK());
EXPECT_TRUE( ds4_node->Children()[2] == ds6_to_insert);
EXPECT_TRUE(ds6_to_insert->Parent() == ds4_node.get());
}
TEST_F(MindDataTestTreeModifying, InsertChildAt04) {
@ -267,9 +263,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove01) {
rc = ds3_node->InsertAbove(ds5_to_insert);
EXPECT_EQ(rc, Status::OK());
EXPECT_TRUE(ds5_to_insert->Children()[0] == ds3_node);
EXPECT_TRUE( ds3_node->Parent() == ds5_to_insert.get());
EXPECT_TRUE( ds4_node->Children()[0] == ds5_to_insert);
EXPECT_TRUE( ds5_to_insert->Parent() == ds4_node.get());
}
TEST_F(MindDataTestTreeModifying, InsertAbove02) {
@ -294,9 +288,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove02) {
rc = ds2_node->InsertAbove(ds6_to_insert);
EXPECT_EQ(rc, Status::OK());
EXPECT_TRUE(ds6_to_insert->Children()[0] == ds2_node);
EXPECT_TRUE( ds2_node->Parent() == ds6_to_insert.get());
EXPECT_TRUE( ds4_node->Children()[1] == ds6_to_insert);
EXPECT_TRUE( ds6_to_insert->Parent() == ds4_node.get());
}
TEST_F(MindDataTestTreeModifying, InsertAbove03) {
@ -321,9 +313,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove03) {
std::shared_ptr<RepeatNode> ds7_to_insert = std::make_shared<RepeatNode>(nullptr, 3);
rc = ds1_node->InsertAbove(ds7_to_insert);
EXPECT_TRUE(ds7_to_insert->Children()[0] == ds1_node);
EXPECT_TRUE( ds1_node->Parent() == ds7_to_insert.get());
EXPECT_TRUE( ds3_node->Children()[0] == ds7_to_insert);
EXPECT_TRUE( ds7_to_insert->Parent() == ds3_node.get());
}
TEST_F(MindDataTestTreeModifying, Drop01) {
@ -392,7 +382,6 @@ TEST_F(MindDataTestTreeModifying, Drop01) {
EXPECT_EQ(rc, Status::OK());
// ds8 becomes a childless node
EXPECT_TRUE(ds8_node->Children().empty());
EXPECT_TRUE(ds7_node->Parent() == nullptr);
EXPECT_TRUE(ds7_node->Children().empty());
// Case 2
@ -407,8 +396,6 @@ TEST_F(MindDataTestTreeModifying, Drop01) {
EXPECT_EQ(rc, Status::OK());
// ds7 becomes a child of ds9
EXPECT_TRUE(ds9_node->Children()[0] == ds7_node);
EXPECT_TRUE(ds7_node->Parent() == ds9_node.get());
EXPECT_TRUE(ds8_node->Parent() == nullptr);
EXPECT_TRUE(ds8_node->Children().empty());
}
@ -509,8 +496,6 @@ TEST_F(MindDataTestTreeModifying, Drop04) {
EXPECT_EQ(rc, Status::OK());
EXPECT_TRUE(ds6_node->Children().size() == 2);
EXPECT_TRUE(ds6_node->Children()[0] == ds4_node);
EXPECT_TRUE(ds4_node->Parent() == ds6_node.get());
EXPECT_TRUE(ds5_node->Parent() == nullptr);
EXPECT_TRUE(ds5_node->Children().empty());
}
@ -565,8 +550,6 @@ TEST_F(MindDataTestTreeModifying, Drop05) {
EXPECT_EQ(rc, Status::OK());
EXPECT_TRUE(ds6_node->Children().size() == 3);
EXPECT_TRUE(ds6_node->Children()[1] == ds3_node);
EXPECT_TRUE(ds3_node->Parent() == ds6_node.get());
EXPECT_TRUE(ds4_node->Parent() == nullptr);
EXPECT_TRUE(ds4_node->Children().empty());
}