!69442 Fixed the bug of ListToTupleEliminator and TupleToListEliminator.

Merge pull request !69442 from Margaret_wangrui/master_list_tuple
This commit is contained in:
i-robot 2024-05-17 02:04:31 +00:00 committed by Gitee
commit cbe46fb03b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 10 additions and 30 deletions

View File

@ -37,13 +37,14 @@ namespace irpass {
class ListToTupleEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
AnfVisitor::Match(prim::kPrimListToTuple, {IsCNode})(node);
if (!IsPrimitiveCNode(node, prim::kPrimListToTuple)) {
return nullptr;
}
auto fg = node->func_graph();
if (fg != nullptr) {
auto real_node = node->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(real_node);
std::vector<AnfNodePtr> args_{NewValueNode(prim::kPrimMakeTuple)};
if (real_node == nullptr) {
return nullptr;
}
MS_EXCEPTION_IF_NULL(real_node->abstract());
auto input_abs = real_node->abstract()->cast<abstract::AbstractListPtr>();
MS_EXCEPTION_IF_NULL(input_abs);
@ -60,30 +61,20 @@ class ListToTupleEliminator : public AnfVisitor {
}
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
real_node = cnode;
while (IsPrimitiveCNode(real_node, prim::kPrimDepend)) {
auto depend = real_node->cast<CNodePtr>();
real_node = depend->input(1)->cast<CNodePtr>();
}
}
private:
CNodePtr real_node{nullptr};
};
// {prim::kPrimTupleToList, data} => {prim::kPrimMakeList, {prim::kPrimTupleGetItem, data, 0}, ...}
class TupleToListEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
AnfVisitor::Match(prim::kPrimTupleToList, {IsCNode})(node);
if (!IsPrimitiveCNode(node, prim::kPrimTupleToList)) {
return nullptr;
}
auto fg = node->func_graph();
if (fg != nullptr) {
auto real_node = node->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(real_node);
std::vector<AnfNodePtr> args_{NewValueNode(prim::kPrimMakeList)};
if (real_node == nullptr) {
return nullptr;
}
MS_EXCEPTION_IF_NULL(real_node->abstract());
auto input_abs = real_node->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(input_abs);
@ -100,17 +91,6 @@ class TupleToListEliminator : public AnfVisitor {
}
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
real_node = cnode;
while (IsPrimitiveCNode(real_node, prim::kPrimDepend)) {
auto depend = real_node->cast<CNodePtr>();
real_node = depend->input(1)->cast<CNodePtr>();
}
}
private:
CNodePtr real_node{nullptr};
};
} // namespace irpass
} // namespace opt