!13268 add SetitemTupleEliminator to item_tuple_or_list_eliminate pass

From: @huangbingjian
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-14 10:05:55 +08:00 committed by Gitee
commit a018390e40
1 changed files with 27 additions and 3 deletions

View File

@ -191,14 +191,18 @@ class GetitemConstEliminator : public AnfVisitor {
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z}
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, a, b, c, ...}, 0, z} => {prim::kPrimMakeTuple, z, b, c, ...}
// {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...}
// {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...}
// {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...}
class SetitemEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsVNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsVNode, IsVNode, IsNode})(node);
auto fg = node->func_graph();
if (fg != nullptr && z_ != nullptr) {
@ -225,7 +229,27 @@ class SetitemEliminator : public AnfVisitor {
}
void Visit(const ValueNodePtr &vnode) override {
if (!args_.empty() && IsValueNode<Int64Imm>(vnode)) {
if (args_.empty() && IsValueNode<ValueTuple>(vnode)) {
auto tuple = GetValueNode<ValueTuplePtr>(vnode);
if (tuple != nullptr) {
args_.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &val : tuple->value()) {
auto val_node = std::make_shared<ValueNode>(val);
val_node->set_abstract(val->ToAbstract());
args_.emplace_back(val_node);
}
}
} else if (args_.empty() && IsValueNode<ValueList>(vnode)) {
auto list = GetValueNode<ValueListPtr>(vnode);
if (list != nullptr) {
args_.emplace_back(NewValueNode(prim::kPrimMakeList));
for (auto &val : list->value()) {
auto val_node = std::make_shared<ValueNode>(val);
val_node->set_abstract(val->ToAbstract());
args_.emplace_back(val_node);
}
}
} else if (!args_.empty() && IsValueNode<Int64Imm>(vnode)) {
auto idx = GetValue<int64_t>(vnode->value());
if (idx < 0) {
idx = idx + args_.size() - 1;